我正在尝试构建一个注意力模型,但默认情况下 Relu 和 ShiftRight 层嵌套在串行组合器内。这进一步给了我训练中的错误。
layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(), )
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x)) y = layer_block(x)
print(f'layer_block: {layer_block}')
输出
layer_block: Serial[
Serial[
Relu
]
LayerNorm
]
预期产出
layer_block: Serial[
Relu
LayerNorm
]
tl.ShiftRight() 也会出现同样的问题
以上代码取自官方文档Example 5
提前致谢