2

我正在尝试使用stax.GeneralConv() ( https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv ) 在Google Jax中实现一维卷积神经网络。我有一个包含 18 个条目的一维输入数组和一个包含 6 个条目的输出数组。我想实现一个内核宽度为 3 的 CNN,如下所示:

init_random_params, conv_net = stax.serial(
    GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
    LogSoftmax,
    Dense(6),
)

具有初始网络参数:

rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))

但我收到以下错误:

stax.py", line 75, in <listcomp>
    next(filter_shape_iter) for c in rhs_spec]

IndexError: tuple index out of range

stax 要求维度编号rhs_spec至少为 2 个字符长,但我使用一维过滤器。有人知道如何解决这个问题吗?

4

1 回答 1

1

我自己没有尝试过,但我希望一维卷积仍然需要一个方向来进行卷积,例如

Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))

换句话说,放弃W轴从 2d 到 1d 卷积。

对应的输入形状NHC(batch_size, sequence_length, num_channels)

请注意,即使通道数可能为 1,您仍然需要包含该轴,因为GeneralConv会沿着num_channels = input_shape['NHC'.index('C')].

于 2020-07-09T23:56:43.750 回答