我正在使用vmap对部分代码进行矢量化处理。这是一个最小的例子,在矢量化之前:
dim = 2
def sum(x):
a = np.ones((dim,))
return np.dot(x, a)
num_samples = 100
samples = np.ones((num_samples, dim))
sum(samples[0]) # 2
使用 vmap:
sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2
但这可能会出错,在矢量化之后:
sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1
这里发生的是samples[0]
具有形状的(2,)
。矢量化函数调用沿第一个轴拆分其输入参数,因此输入 2 个 shape 数组(1,)
。由于使用 广播a
,结果输出再次具有形状(2,)
并堆叠到(2,2)
数组中。
这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他一些隐藏其损坏形状的广播规则所消耗。
是否可以强制执行正确的形状?