好的,所以在我链接的 Google Groups 帖子中,很好地解释了为什么它不起作用。AdvancedSubtensor 是最通用的形式,适用于所有疯狂类型的索引变体。然后是 AdvancedSubtensor1,它只适用于某种子集。AdvancedSubtensor1 仅存在 GPU 版本,AdvancedSubtensor 不存在。我不完全理解原因,但正在进行讨论。
当只有一个索引列表时,可以使用 AdvancedSubtensor1。但是,在我的示例中,情况并非如此。您看到的常见解决方法(也在这些 Google Groups 帖子的其他示例中)是首先展平数组并计算展平数组的索引。
大多数示例都适用于某种nonzero()
左右,您还可以将基本参数展平,然后获得展平版本的索引。
所以,问题是,如何将它应用到我的代码中?
实际上,有一个更简单的解决方案,它将使用我最初没有意识到的 AdvancedSubtensor1:
meminkeyP = meminkey[:, P] # (batch,n_copies,n_cells)
然而,在我意识到这一点之前,我想出了一个通用的解决方案,它也适用于其他情况。我将索引元组(batches, P_bc)
转换为扁平版本的索引。这是通过这个函数完成的:
def indices_in_flatten_array(ndim, shape, *args):
"""
We expect that all args can be broadcasted together.
So, if we have some array A with ndim&shape as given,
A[args] would give us a subtensor.
We return the indices so that A[args].flatten()
and A.flatten()[indices] are the same.
"""
assert ndim > 0
assert len(args) == ndim
indices_per_axis = [args[i] for i in range(ndim)]
for i in range(ndim):
for j in range(i + 1, ndim):
indices_per_axis[i] *= shape[j]
indices = indices_per_axis[0]
for i in range(1, ndim):
indices += indices_per_axis[i]
return indices
然后,我像这样使用它:
meminkeyP = meminkey.flatten()[indices_in_flatten_array(meminkey.ndim, meminkey.shape, batches, P_bc)]
这似乎有效。
我得到这个输出:
Using gpu device 0: GeForce GTX TITAN (CNMeM is disabled, CuDNN not available)
GpuReshape{3} [id A] '' 11
|GpuAdvancedSubtensor1 [id B] '' 10
| |GpuReshape{1} [id C] '' 2
| | |<CudaNdarrayType(float32, matrix)> [id D]
| | |TensorConstant{(1,) of -1} [id E]
| |Reshape{1} [id F] '' 9
| |Elemwise{second,no_inplace} [id G] '' 8
| | |TensorConstant{(1, 5, 10) of 0} [id H]
| | |Elemwise{Mul}[(0, 0)] [id I] '' 7
| | |InplaceDimShuffle{0,x,x} [id J] '' 6
| | | |ARange{dtype='int64'} [id K] '' 4
| | | |TensorConstant{0} [id L]
| | | |Shape_i{0} [id M] '' 0
| | | | |<CudaNdarrayType(float32, matrix)> [id D]
| | | |TensorConstant{1} [id N]
| | |InplaceDimShuffle{x,x,x} [id O] '' 5
| | |Shape_i{1} [id P] '' 1
| | |<CudaNdarrayType(float32, matrix)> [id D]
| |TensorConstant{(1,) of -1} [id E]
|MakeVector{dtype='int64'} [id Q] '' 3
|Shape_i{0} [id M] '' 0
|TensorConstant{5} [id R]
|TensorConstant{10} [id S]
小测试用例:
def test_indices_in_flatten_array():
n_copies, n_cells = 5, 4
n_complex_cells = n_cells / 2
n_batch = 3
static_rng = numpy.random.RandomState(1234)
def make_permut():
p = numpy.zeros((n_copies, n_cells), dtype="int32")
for i in range(n_copies):
p[i, :n_complex_cells] = static_rng.permutation(n_complex_cells)
# Same permutation for imaginary part.
p[i, n_complex_cells:] = p[i, :n_complex_cells] + n_complex_cells
return T.constant(p)
P = make_permut() # (n_copies,n_cells) -> list of indices
meminkey = T.as_tensor_variable(static_rng.rand(n_batch, n_cells).astype("float32"))
i_t = T.ones((meminkey.shape[0],)) # (batch,)
n_batch = i_t.shape[0]
batches = T.arange(0, n_batch).dimshuffle(0, 'x', 'x') # (batch,n_copies,n_cells)
P_bc = P.dimshuffle('x', 0, 1) # (batch,n_copies,n_cells)
meminkeyP1 = meminkey[batches, P_bc] # (batch,n_copies,n_cells)
meminkeyP2 = meminkey.flatten()[indices_in_flatten_array(meminkey.ndim, meminkey.shape, batches, P_bc)]
numpy.testing.assert_allclose(meminkeyP1.eval(), meminkeyP2.eval())