在 tensorflow Keras 中通过一些修改(例如,残差(添加连接))实现自我关注。
我有以下输入形状:
我的输入: KerasTensor(type_spec=TensorSpec(shape=(None, 8, 6, 64), dtype=tf.float32, name=None), name='multiply/mul:0', description="created by layer 'multiply'")
我的目标是通过自我关注一个一个地处理TensorSpec(shape=(None, 8, 6, 64)
(8个时间戳 (6 * 64))并获得每个时间戳的自我关注特征图,然后将其再次连接成输出张量形状(None, 8, 6, 64)。
我实现自我关注的代码:
def conv1d(x, channels, ks=1, strides=1, padding='same'):
conv = tf.keras.layers.Conv1D(channels, ks, strides, padding, activation='relu', use_bias=False,
kernel_initializer='HeNormal')(x)
return conv
# Self attention
def my_self_attention(x, channels):
f = conv1d(x, channels) # KerasTensor(type_spec=TensorSpec(shape=(None, 6, 64), dtype=tf.float32, name=None), name='conv1d/Relu:0', description="created by layer 'conv1d'")
g = conv1d(x, channels) # KerasTensor(type_spec=TensorSpec(shape=(None, 6, 64), dtype=tf.float32, name=None), name='conv1d_1/Relu:0', description="created by layer 'conv1d_1'")
h = conv1d(x, channels) # KerasTensor(type_spec=TensorSpec(shape=(None, 6, 64), dtype=tf.float32, name=None), name='conv1d_2/Relu:0', description="created by layer 'conv1d_2'")
attention_weights = tf.keras.activations.softmax(
tf.matmul(g, Permute((2, 1))(f))) # KerasTensor(type_spec=TensorSpec(shape=(None, 6, 6), dtype=tf.float32, name=None), name='tf.math.truediv/truediv:0', description="created by layer 'tf.math.truediv'")
sensor_att_fm = tf.matmul(attention_weights, h) # KerasTensor(type_spec=TensorSpec(shape=(None, 6, 64), dtype=tf.float32, name=None), name='tf.linalg.matmul_1/MatMul:0', description="created by layer 'tf.linalg.matmul_1'")
gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) # <tf.Variable 'gamma:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>
# to refined feature map adding feature map with input
o = gamma * sensor_att_fm + x # KerasTensor(type_spec=TensorSpec(shape=(None, 6, 64), dtype=tf.float32, name=None), name='tf.__operators__.add/AddV2:0', description="created by layer 'tf.__operators__.add'")
return o
# Calling function
sa = [my_self_attention(my_input[:, t, :, :], channels) for t in range(my_input.shape[1])]
我被困在哪里
我正在尝试将这些自注意力特征图一一连接并获得输出张量(无、8、6、64),但未能成功。请你帮我解决这个问题。其次,请验证,我实现的功能(my_self_attention 和 conv1d)是可以的。
谢谢