0

给定一个长度为 15 的 pytorch 张量t,是否有一种“好”的方法来获取该张量的不相交子集的最大值?具体来说,给定一个列表,l=(5,4,6)我想要 的前 5 个元素的t最大值,然后是接下来 4 个元素的最大值,最后 6 个元素的最大值。

如果 的元素l相等,则可以对张量进行整形,并且可以在单个步骤中找到每行的最大值。但是当 l 的元素不同时,我看不出有一种方法可以很好地做到这一点,而不诉诸循环。很高兴找到一种可并行的方式来做到这一点。

对于上下文,我正在研究一个强化学习问题,其中我的状态是图表。动作只是图的一个顶点。我正在使用 DQN。当我从缓冲区中采样一批时,对于缓冲区中的每个图,我为每个顶点生成一个 Q 值,并且对于每个图,我想在这些 Q 值中找到最大值。但是,当图具有不同数量的顶点时,我会遇到上述问题:我看不到无需遍历批次即可获得每个图的最大 Q 值的方法。

以下是如何使用循环执行此操作的示例:

# Case when the subset lengths are unbalanced
q_values_tensor = torch.randint(10,(15,1))
lengths_list = [5,4,6]
max_q_values_unbalanced = torch.zeros((1,len(lengths_list)))
tensor_index=0
for list_index, length in enumerate(lengths_list):
    max_q_values_unbalanced[0, list_index] = torch.max(q_values_tensor[tensor_index:tensor_index+length])
    tensor_index += length

还有一个子集长度都相等的例子(我希望可能有一个与此类似的解决方案):

# Case when the subset lengths are balanced (what I'd like to replicate)
q_values_tensor = torch.randint(10,(15,1))
lengths_list = [5,5,5]
reshaped_q_values_tensor = q_values_tensor.view(-1,lengths_list[0])
max_q_values_balanced = reshaped_q_values_tensor.max(1)[0]
4

0 回答 0