我正在试验 TensorFlow 2.0 (alpha)。我想实现一个简单的前馈网络,它有两个用于二进制分类的输出节点(它是这个模型的 2.0 版本)。
这是脚本的简化版本。在我定义了一个简单的Sequential()
模型之后,我设置了:
# import layers + dropout & activation
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.activations import elu, softmax
# Neural Network Architecture
n_input = X_train.shape[1]
n_hidden1 = 15
n_hidden2 = 10
n_output = y_train.shape[1]
model = tf.keras.models.Sequential([
Dense(n_input, input_shape = (n_input,), activation = elu), # Input layer
Dropout(0.2),
Dense(n_hidden1, activation = elu), # hidden layer 1
Dropout(0.2),
Dense(n_hidden2, activation = elu), # hidden layer 2
Dropout(0.2),
Dense(n_output, activation = softmax) # Output layer
])
# define loss and accuracy
bce_loss = tf.keras.losses.BinaryCrossentropy()
accuracy = tf.keras.metrics.BinaryAccuracy()
# define optimizer
optimizer = tf.optimizers.Adam(learning_rate = 0.001)
# save training progress in lists
loss_history = []
accuracy_history = []
# loop over 1000 epochs
for epoch in range(1000):
with tf.GradientTape() as tape:
# take binary cross-entropy (bce_loss)
current_loss = bce_loss(model(X_train), y_train)
# Update weights based on the gradient of the loss function
gradients = tape.gradient(current_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# save in history vectors
current_loss = current_loss.numpy()
loss_history.append(current_loss)
accuracy.update_state(model(X_train), y_train)
current_accuracy = accuracy.result().numpy()
accuracy_history.append(current_accuracy)
# print loss and accuracy scores each 100 epochs
if (epoch+1) % 100 == 0:
print(str(epoch+1) + '.\tTrain Loss: ' + str(current_loss) + ',\tAccuracy: ' + str(current_accuracy))
accuracy.reset_states()
print('\nTraining complete.')
训练没有错误,但是奇怪的事情发生了:
- 有时,网络不会学到任何东西。所有损失和准确度分数在所有时期都是恒定的。
- 其他时候,网络正在学习,但非常非常糟糕。准确度从未超过 0.4(而在 TensorFlow 1.x 中,我毫不费力地获得了 0.95+)。如此低的表现表明我在训练中出了点问题。
- 其他时候,准确性会非常缓慢地提高,而损失始终保持不变。
什么会导致这些问题?请帮助我理解我的错误。
更新:经过一些更正后,我可以让网络学习。但是,它的性能极差。在 1000 个 epoch 之后,它达到了大约 %40 的准确率,这显然意味着仍然有问题。任何帮助表示赞赏。