我正在使用 Keras 进行训练,然后预测 n 位十进制数的第 i 位的数字。例如,如果输入是123456(即n = 6),并且i=2,那么程序需要找出一个值4。我把输入保存在x_values和对应的训练、测试和验证值x_values 的元素在 y_values 中。
# Generate x_values by reading the first column from the csv
from io import StringIO
df = pd.read_csv(dst_path, skipinitialspace=True)
x_values = df['Input']
y_values = df['OutputA'] # OutputA indicates 0th element (i.e., i = 0)
我有两个模型。第一个是简单的“relu”,它接受一个标量输入并将其馈入 8 个“神经元”,最后一层是单个神经元。
# Using Keras to create a simple model architecture
model_1 = tf.keras.Sequential()
# First layer takes a scalar input and feeds it through 8 "neurons". The
# neurons decide whether to activate based on the 'relu' activation function.
model_1.add(keras.layers.Dense(8, activation='relu', input_shape=(1,)))
# Final layer is a single neuron, since we want to output a single value
model_1.add(keras.layers.Dense(1))
# Compile the model using the standard 'adam' optimizer and the mean squared error or 'mse' loss function for regression.
model_1.compile(optimizer='adam', loss='mse', metrics=['mae'])
第二个模型是 16 个神经元 x 32 深(比第一个大很多)。
model = tf.keras.Sequential()
# First layer takes a scalar input and feeds it through 16 "neurons". The
# neurons decide whether to activate based on the 'relu' activation function.
NUM_NEURONS = 16
model.add(keras.layers.Dense(NUM_NEURONS, activation='relu', input_shape=(1,)))
# The new layers will help the network learn more complex representations
NUM_LAYERS = 30
for layer in range(NUM_LAYERS):
model.add(keras.layers.Dense(NUM_NEURONS, activation='relu'))
# Final layer is a single neuron, since we want to output a single value
model.add(keras.layers.Dense(1))
# Compile the model using the standard 'adam' optimizer and the mean squared error or 'mse' loss function for regression.
model.compile(optimizer='adam', loss="mse", metrics=["mae"])
我的整个训练集有 15000 个样本。我将输入集分为训练集、测试集和验证集。
# Train the model on our training data while validating on our validation set
history_1 = model_1.fit(x_train, y_train, epochs=500, batch_size=64,
validation_data=(x_validate, y_validate))
毕竟,当我计算损失和平均平均误差时,它们都非常高(分别约为 8.3 和 2.5)。
解决这个问题的最佳方法是什么?玩弄网络的宽度、深度或 epoch 的值并没有明显的效果。任何提示将不胜感激。