我是 pytorch 的初学者,我有一些需要在网络中实现的功能。
我的问题是:有没有像tf.function这样的方法,或者我应该使用带有变量的“class(nn.Module)”?
例如,设 X 为 10x2 矩阵。在伪代码中:
a = Variable(1.0)
b = Variable(1.0)
Y = a*X[:,0]**2 + b*X[:,1]
我是 pytorch 的初学者,我有一些需要在网络中实现的功能。
我的问题是:有没有像tf.function这样的方法,或者我应该使用带有变量的“class(nn.Module)”?
例如,设 X 为 10x2 矩阵。在伪代码中:
a = Variable(1.0)
b = Variable(1.0)
Y = a*X[:,0]**2 + b*X[:,1]