0

我正在使用 const 泛型和宏在编译时构建一个简单的前馈神经网络。这些是一个接一个的一堆矩阵。

我创建了network!宏,它的工作原理是这样的:

network!(2, 4, 1)

第一项是输入的数量,其余的是每层的神经元数量。宏如下所示:

#[macro_export]
macro_rules! network {
    ( $inputs:expr, $($outputs:expr),* ) => {
        {
            Network {
                layers: [
                    $(
                        &Layer::<$inputs, $outputs>::new(),
                    )*
                ]
            }
        }
    };
}

它声明了一个层元素数组,它使用 const 泛型在每一层上拥有一个固定大小的权重数组,第一个类型参数是它期望的输入数量,第二个类型参数是输出数量。

该宏产生以下代码:

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<2, 1>::new(),
    ]
}

这是完全错误的,因为对于每一层,输入的数量应该是前一层的输出数量,就像这样(注意 2 -> 4):

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<4, 1>::new(),
    ]
}

为此,我需要在每次迭代时将值替换为$inputs的值$outputs,但我不知道该怎么做。

4

2 回答 2

3

您可以匹配两个主要值,然后匹配所有其他值。为这两个值做一些特定的事情并递归调用宏,重用第二个值:

struct Layer<const I: usize, const O: usize>;

macro_rules! example {
    // Do something interesting for a given pair of arguments
    ($a:literal, $b:literal) => {
        Layer::<$a, $b>;
    };

    // Recursively traverse the arguments
    ($a:literal, $b:literal, $($rest:literal),+) => {
        example!($a, $b);
        example!($b, $($rest),*);
    };
}

fn main() {
    example!(1, 2, 3);
}

扩展宏会导致:

fn main() {
    Layer::<1, 2>;
    Layer::<2, 3>;
}
于 2021-03-31T14:16:16.790 回答
0

对于那些感兴趣的人,我终于能够根据@Shepmaster 的回答像这样填充我的网络:

struct Network<'a, const L: usize> {
    layers: [&'a dyn Forward; L],
}

macro_rules! network {
    // Recursively accumulate token tree
    (@accum ($a:literal, $b:literal, $($others:literal),+) $($e:tt)*) => {
        network!(@accum ($b, $($others),*) $($e)*, &Layer::<$a, $b>::new())
    };

    // Latest iteration, convert to expression
    (@accum ($a:literal, $b:literal) $($e:tt)*) => {[$($e)*, &Layer::<$a, $b>::new()]};

    // Entrance
    ($a:literal, $b:literal, $($others:literal),+) => {
        Network {
            layers: network!(@accum ($b, $($others),*) &Layer::<$a, $b>::new())
        }
    };
}

因为network!(2, 3, 4, 5, 1)它转化为:

Network {
     layers:
          [&Layer::<2, 3>::new(),
           &Layer::<3, 4>::new(),
           &Layer::<4, 5>::new(),
           &Layer::<5, 1>::new()]
};
于 2021-04-01T11:29:43.700 回答