注意:我已经在此处发布了有关此问题的信息。我正在创建一个新问题,因为: 1. 我认为这个问题特别与在我的自定义层中重塑我的蒙版有关,但我不确定是否完全忽略了我在原始帖子中写的另一个错误。2. 有很多关于重塑Keras图层或添加Masking图层的帖子,但是我找不到任何关于在图层内重新塑造蒙版的帖子,所以我希望这篇文章可以更普遍地有用。
问题:
我有一个自定义 Keras 层,它接受 2D 输入并返回 3D 输出(batch_size、max_length、1024),然后将其传递给 BiLSTM,然后是 CRF。
自定义 Keras 层是从此存储库中复制的。不同之处在于我从 Elmo 模型中采用 'elmo' 而不是 'default' 输出,因此输出是 BiLSTM 要求的 3D:
result = self.elmo(K.squeeze(K.cast(x, tf.string), axis=1),
as_dict=True,
signature='default',
)['elmo'] # The original code used 'default'
但是,compute_mask 函数不适合我的架构,因为它的输出是 2D。因此我得到错误:
InvalidArgumentError: Incompatible shapes: [32,47] vs. [32,0] [[{{node loss/crf_1_loss/mul_6}}]]
其中 32 是批量大小,47 比我指定的 max_length 小一。
我确定我需要重塑面具,但我无法找到任何方法。
如果需要的话,很高兴用整个东西和/或完整的堆栈跟踪来制作一个 git repo。
自定义 ELMo 层:
class ElmoEmbeddingLayer(Layer):
def __init__(self, **kwargs):
self.dimensions = 1024
self.trainable = True
super(ElmoEmbeddingLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.elmo = hub.Module('https://tfhub.dev/google/elmo/2', trainable=self.trainable, name="{}_module".format(self.name))
self.trainable_weights += K.tf.trainable_variables(scope="^{}_module/.*".format(self.name))
super(ElmoEmbeddingLayer, self).build(input_shape)
def call(self, x, mask=None):
result = self.elmo(K.squeeze(K.cast(x, tf.string), axis=1),
as_dict=True, signature='default',)['elmo']
return result
# Original compute_mask function. Raises;
# InvalidArgumentError: Incompatible shapes: [32,47] vs. [32,0] [[{{node loss/crf_1_loss/mul_6}}]]
def compute_mask(self, inputs, mask=None):
return K.not_equal(inputs, '__PAD__')
def compute_output_shape(self, input_shape):
return input_shape[0], 48, self.dimensions
模型构建如下:
def build_model(): # uses crf from keras_contrib
input = layers.Input(shape=(1,), dtype=tf.string)
model = ElmoEmbeddingLayer(name='ElmoEmbeddingLayer')(input)
model = Bidirectional(LSTM(units=512, return_sequences=True))(model)
crf = CRF(num_tags)
out = crf(model)
model = Model(input, out)
model.compile(optimizer="rmsprop", loss=crf_loss, metrics=[crf_accuracy, categorical_accuracy, mean_squared_error])
model.summary()
return model