我正在尝试使用stax.GeneralConv() ( https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv ) 在Google Jax中实现一维卷积神经网络。我有一个包含 18 个条目的一维输入数组和一个包含 6 个条目的输出数组。我想实现一个内核宽度为 3 的 CNN,如下所示:
init_random_params, conv_net = stax.serial(
GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
LogSoftmax,
Dense(6),
)
具有初始网络参数:
rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))
但我收到以下错误:
stax.py", line 75, in <listcomp>
next(filter_shape_iter) for c in rhs_spec]
IndexError: tuple index out of range
stax 要求维度编号rhs_spec至少为 2 个字符长,但我使用一维过滤器。有人知道如何解决这个问题吗?