我已经阅读了一些 CNTK Python 教程,并且正在尝试编写一个可以计算逻辑与的非常基本的一层神经网络。我有正常运行的代码,但网络没有学习——事实上,随着每个 minibatch 的训练,损失变得越来越严重。
import numpy as np
from cntk import Trainer
from cntk.learner import sgd
from cntk import ops
from cntk.utils import get_train_eval_criterion, get_train_loss
input_dimensions = 2
# Define the training set
input_data = np.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]], dtype=np.float32)
# Each index matches with an index in input data
correct_answers = np.array([[0], [0], [0], [1]])
# Create the input layer
net_input = ops.input_variable(2, np.float32)
weights = ops.parameter(shape=(2, 1))
bias = ops.parameter(shape=(1))
network_output = ops.times(net_input, weights) + bias
# Set up training
expected_output = ops.input_variable((1), np.float32)
loss_function = ops.cross_entropy_with_softmax(network_output, expected_output)
eval_error = ops.classification_error(network_output, expected_output)
learner = sgd(network_output.parameters, lr=0.02)
trainer = Trainer(network_output, loss_function, eval_error, [learner])
minibatch_size = 4
num_samples_to_train = 1000
num_minibatches_to_train = int(num_samples_to_train/minibatch_size)
training_progress_output_freq = 20
def print_training_progress(trainer, mb, frequency, verbose=1):
training_loss, eval_error = "NA", "NA"
if mb % frequency == 0:
training_loss = get_train_loss(trainer)
eval_error = get_train_eval_criterion(trainer)
if verbose:
print("Minibatch: {0}, Loss: {1:.4f}, Error: {2:.2f}".format(
mb, training_loss, eval_error))
return mb, training_loss, eval_error
for i in range(0, num_minibatches_to_train):
trainer.train_minibatch({net_input: input_data, expected_output: correct_answers})
batchsize, loss, error = print_training_progress(trainer, i, training_progress_output_freq, verbose=1)
样本训练输出
Minibatch: 0, Loss: -164.9998, Error: 0.75
Minibatch: 20, Loss: -166.0998, Error: 0.75
Minibatch: 40, Loss: -167.1997, Error: 0.75
Minibatch: 60, Loss: -168.2997, Error: 0.75
Minibatch: 80, Loss: -169.3997, Error: 0.75
Minibatch: 100, Loss: -170.4996, Error: 0.75
Minibatch: 120, Loss: -171.5996, Error: 0.75
Minibatch: 140, Loss: -172.6996, Error: 0.75
Minibatch: 160, Loss: -173.7995, Error: 0.75
Minibatch: 180, Loss: -174.8995, Error: 0.75
Minibatch: 200, Loss: -175.9995, Error: 0.75
Minibatch: 220, Loss: -177.0994, Error: 0.75
Minibatch: 240, Loss: -178.1993, Error: 0.75
我不太确定这里发生了什么。错误停留在 0.75,我认为这意味着网络的性能与偶然情况相同。我不确定我是否误解了 ANN 架构的要求,或者我是否在滥用该库。
任何帮助,将不胜感激。