-4

我需要构建一个以两个潜在变量的值作为输入的深度神经网络,并生成一个灰度图像。

我知道这类似于 GAN 中的生成器网络,但是是否有任何已发表的研究工作或任何专门用于此类学习任务的 // 代码PythonTensorflowKeras

4

1 回答 1

1

所以这可能是 GAN 的一项任务,但不一定,取决于你手头的数据。然而来了

使用 GAN 生成 MNIST 样本的玩具问题的代码:

# define variables
g_input_shape = 100 
d_input_shape = (28, 28) 
hidden_1_num_units = 500 
hidden_2_num_units = 500 
g_output_num_units = 784 
d_output_num_units = 1 
epochs = 25 
batch_size = 128

# generator
model_1 = Sequential([
    Dense(units=hidden_1_num_units, input_dim=g_input_shape, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
    Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),   
    Dense(units=g_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),
    Reshape(d_input_shape),
])

# discriminator
model_2 = Sequential([
    InputLayer(input_shape=d_input_shape),
    Flatten(),   
    Dense(units=hidden_1_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),
    Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),    
    Dense(units=d_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),
])


from keras_adversarial import AdversarialModel, simple_gan, gan_targets
from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling

# Let us compile our GAN and start the training
gan = simple_gan(model_1, model_2, normal_latent_sampling((100,)))
model = AdversarialModel(base_model=gan,player_params=[model_1.trainable_weights, model_2.trainable_weights])
model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), player_optimizers=['adam', 'adam'], loss='binary_crossentropy')

history = model.fit(x=train_x, y=gan_targets(train_x.shape[0]), epochs=10, batch_size=batch_size)

# We get a graph like after training for 10 epochs.
plt.plot(history.history['player_0_loss'])
plt.plot(history.history['player_1_loss'])
plt.plot(history.history['loss'])

# After training for 100 epochs, we can now generate images
zsamples = np.random.normal(size=(10, 100))
pred = model_1.predict(zsamples)
for i in range(pred.shape[0]):
    plt.imshow(pred[i, :], cmap='gray')
plt.show()

即使在对此有所了解之后,您也应该真正开始阅读围绕 GAN 及其改编的研究。

笔记:

当您拥有这么好的锤子时,很容易将您的所有任务视为钉子。

但这不一定很漂亮。此外,当您提供有关您的问题的更多详细信息时,回答您的问题会容易得多。

  1. 潜变量是怎样的?
  2. 它们是否与灰度图像配对?
  3. 你有多少数据,规格是什么?
于 2018-12-21T21:52:47.330 回答