我有一个计算时间步长平均值并支持屏蔽的层。我的问题是,可能存在掩码为空(没有填充时间步)的情况,但我不知道在使用张量时如何检查零。
我有一些掩码为空的训练示例,因此我得到了 NaN 损失并且程序崩溃了。
这是我的图层:
class MeanOverTime(Layer):
def __init__(self, **kwargs):
self.supports_masking = True
super(MeanOverTime, self).__init__(**kwargs)
def call(self, x, mask=None):
if mask is not None:
return K.cast(x.sum(axis=1) / mask.sum(axis=1, keepdims=True), K.floatx()) # this may result to division by zero
else:
return K.mean(x, axis=1)
def get_output_shape_for(self, input_shape):
return input_shape[0], input_shape[-1]
def compute_mask(self, input, input_mask=None):
return None
这mask.sum(axis=1, keepdims=True)
变成了零。为了绕过这个,我增加了 input_length,所以它涵盖了我所有的训练示例,但这不是一个解决方案。我也尝试添加一个 try/except 但这也没有用。