12

我在 Keras 有一个具有许多输出的网络,但是,我的训练数据一次只提供一个输出的信息。

目前我的训练方法是对有问题的输入进行预测,更改我正在训练的特定输出的值,然后进行单批更新。如果我是对的,这与将所有输出的损失设置为零相同,除了我正在尝试训练的那个。

有没有更好的办法?我已经尝试过班级权重,除了我正在训练的输出之外,我将所有的权重设置为零,但它并没有给我预期的结果?

我正在使用 Theano 后端。

4

2 回答 2

18

输出多个结果并仅优化其中一个

假设您想从多个层返回输出,可能来自一些中间层,但您只需要优化一个目标输出。以下是您的操作方法:

让我们从这个模型开始:

inputs = Input(shape=(784,))
x = Dense(64, activation='relu')(inputs)

# you want to extract these values
useful_info = Dense(32, activation='relu', name='useful_info')(x)

# final output. used for loss calculation and optimization
result = Dense(1, activation='softmax', name='result')(useful_info)

使用多个输出编译,将损失设置None额外输出:

None为您不想用于损失计算和优化的输出提供

model = Model(inputs=inputs, outputs=[result, useful_info])
model.compile(optimizer='rmsprop',
              loss=['categorical_crossentropy', None],
              metrics=['accuracy'])

训练时只提供目标输出。跳过额外的输出:

model.fit(my_inputs, {'result': train_labels}, epochs=.., batch_size=...)

# this also works:
#model.fit(my_inputs, [train_labels], epochs=.., batch_size=...)

一个预测得到他们所有

拥有一个模型,您只能运行predict一次以获得所需的所有输出:

predicted_labels, useful_info = model.predict(new_x)
于 2019-05-16T10:53:36.920 回答
3

为了实现这一点,我最终使用了“功能 API”。您基本上创建多个模型,使用相同的层输入和隐藏层,但不同的输出层。

例如:

https://keras.io/getting-started/functional-api-guide/

from keras.layers import Input, Dense
from keras.models import Model

# This returns a tensor
inputs = Input(shape=(784,))

# a layer instance is callable on a tensor, and returns a tensor
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions_A = Dense(1, activation='softmax')(x)
predictions_B = Dense(1, activation='softmax')(x)

# This creates a model that includes
# the Input layer and three Dense layers
modelA = Model(inputs=inputs, outputs=predictions_A)
modelA.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
modelB = Model(inputs=inputs, outputs=predictions_B)
modelB.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
于 2017-08-21T02:35:10.293 回答