1

我正在使用 Jupyter 笔记本尝试生成一长串 TF Learn DNN 对象以进行一些蛮力试错测试(我知道这不是最有效的方法,只是试图展示一个示例)。数据遵循泰坦尼克号快速入门教程。

我有一个函数,给定一堆参数,应该返回一个 tflearn.DNN() 对象:

def make_fully_connected(input_shape, output_shape, activation, layers, nodes, dropout, optimizer, loss):
    tflearn.init_graph()
    net = tflearn.input_data(shape=[None, input_shape])
    for l in range(layers):
        net = tflearn.fully_connected(net, nodes)
        if (dropout != 0) and (l%2==1):
            net = tflearn.dropout(net, dropout)
    net = tflearn.fully_connected(net, output_shape, activation=activation)
    net = tflearn.regression(net, optimizer=optimizer, loss=loss)
    return tflearn.DNN(net)

然后我使用该函数生成一个特定的模型:

model = make_fully_connected(6, 2, 'softmax', 2, 32, 0, 'adam', 'categorical_crossentropy')
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)
score = model.evaluate(data, labels)

但我收到一条可爱的错误消息,将我带入 TF Learn 代码,我很快就迷路了:

IndexError                                Traceback (most recent call last)
<ipython-input-15-79e1d2acc8bf> in <module>()
      1 model = make_fully_connected(6, 2, 'softmax', 2, 32, 0, 'adam', 'categorical_crossentropy')
----> 2 model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)
      3 score = model.evaluate(data, labels)
      4 print('| Score: %.4f' % score, end='')

/usr/local/lib/python3.5/dist-packages/tflearn/models/dnn.py in fit(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, callbacks)
    181         # TODO: check memory impact for large data and multiple optimizers
    182         feed_dict = feed_dict_builder(X_inputs, Y_targets, self.inputs,
--> 183                                       self.targets)
    184         feed_dicts = [feed_dict for i in self.train_ops]
    185         val_feed_dicts = None

/usr/local/lib/python3.5/dist-packages/tflearn/utils.py in feed_dict_builder(X, Y, net_inputs, net_targets)
    287                 X = [X]
    288             for i, x in enumerate(X):
--> 289                 feed_dict[net_inputs[i]] = x
    290         else:
    291             # If a dict is provided

IndexError: list index out of range

从函数返回模型是否超出了 TF Learn 的范围?或者还有其他障碍?

4

1 回答 1

0

尝试重新启动内核并清除输出并再次运行它。我也遇到了这个问题,这个解决方案对我有用。这是因为你已经多次运行模型并且它已经崩溃了。

于 2018-05-04T04:18:50.833 回答