Paul R 的回答很棒!(@Paul R - 如果你读到了,谢谢!)我只是想解释一下它对像我这样的 SSE 新手的实际工作原理。当然,我可能在某个地方错了,所以欢迎任何更正!
如何_mm_shuffle_ps
工作?
首先,SSE 寄存器的索引与您的预期相反,如下所示:
[6, 9, 8, 5] // values
3 2 1 0 // indexes
这种索引顺序使向量左移将数据从低索引移动到高索引,就像左移整数中的位一样。最重要的元素在左边。
_mm_shuffle_ps
可以混合两个寄存器的内容:
// __m128 a : (a3, a2, a1, a0)
// __m128 b : (b3, b2, b1, b0)
__m128 two_from_a_and_two_from_b = _mm_shuffle_ps(b, a, _MM_SHUFFLE(3, 2, 1, 0));
// ^ ^ ^ ^
// indexes into second operand indexes into first operand
// two_from_a_and_two_from_b : (a3, a2, b1, b0)
在这里,我们只想打乱一个寄存器的值,而不是两个。我们可以通过传递 v 作为两个参数来做到这一点,就像这样(你可以在 Paul R 的函数中看到这一点):
// __m128 v : (v3, v2, v1, v0)
__m128 v_rotated_left_by_1 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 1, 0, 3));
// v_rotated_left_by_1 : (v2, v1, v0, v3) // i.e. move all elements left by 1 with wraparound
我将把它包装在一个宏中以提高可读性:
#define mm_shuffle_one(v, pattern) _mm_shuffle_ps(v, v, pattern)
(它不能是一个函数,因为pattern
参数_mm_shuffle_ps
在编译时必须是常量。)
这是实际函数的略微修改版本——我添加了中间名称以提高可读性,因为编译器无论如何都会优化它们:
inline __m128 _mm_hmin_ps(__m128 v){
__m128 v_rotated_left_by_1 = mm_shuffle_one(v, _MM_SHUFFLE(2, 1, 0, 3));
__m128 v2 = _mm_min_ps(v, v_rotated_left_by_1);
__m128 v2_rotated_left_by_2 = mm_shuffle_one(v2, _MM_SHUFFLE(1, 0, 3, 2));
__m128 v3 = _mm_min_ps(v2, v2_rotated_left_by_2);
return v3;
}
为什么要以我们现在的方式改组元素?我们如何通过两个min
操作找到四个元素中最小的一个?
min
我在了解如何仅使用两个矢量化操作就可以实现 4 个浮点数时遇到了一些麻烦min
,但是当我min
逐步手动跟踪将哪些值组合在一起时,我理解了这一点。(虽然自己做可能比阅读更有趣)
假设我们有v
:
[7,6,9,5] v
首先,我们和min
的值:v
v_rotated_left_by_1
[7,6,9,5] v
3 2 1 0 // (just the indices of the elements)
[6,9,5,7] v_rotated_left_by_1
2 1 0 3 // (the indexes refer to v, and we rotated it left by 1, so the indices are shifted)
--------- min
[6,6,5,5] v2
3 2 1 0 // (explained
2 1 0 3 // below )
v2
轨道元素下的每一列的索引v
被min
组合在一起以获得该元素。因此,按列从左到右:
v2[3] == 6 == min(v[3], v[2])
v2[2] == 6 == min(v[2], v[1])
v2[1] == 5 == min(v[1], v[0])
v2[0] == 5 == min(v[0], v[3])
现在是第二个min
:
[6,6,5,5] v2
3 2 1 0
2 1 0 3
[5,5,6,6] v2_rotated_left_by_2
1 0 3 2
0 3 2 1
--------- min
[5,5,5,5] v3
3 2 1 0
2 1 0 3
1 0 3 2
0 3 2 1
瞧!v3
contains下的每一列(3,2,1,0)
- 的每个元素v3
都min
与 的所有元素相乘v
- 所以每个元素都包含整个向量的最小值v
。
使用该函数后,您可以提取最小值float _mm_cvtss_f32(__m128)
:
__m128 min_vector = _mm_hmin_ps(my_vector);
float minval = _mm_cvtss_f32(min_vector);
***
这只是一个切线的想法,但我发现有趣的是,这种方法可以扩展到任意长度的序列,1, 2, 4, 8, ... 2**ceil(log2(len(v)))
在每一步旋转上一步的结果(我认为)。从理论的角度来看,这很酷 - 如果您可以同时比较两个序列元素,您可以在对数时间内找到序列的最小/最大1 !
1这延伸到所有水平折叠/减少,如 sum。同样的洗牌,不同的垂直操作。
然而,AVX(256 位向量)使 128 位边界变得特别,并且更难跨越。如果您只想要一个标量结果,请提取高半部分,以便每一步都将矢量宽度缩小一半。(就像在x86 上进行水平浮点向量求和的最快方式shufps
一样,对于 128 位向量,它比 2x 具有更有效的洗牌,movaps
在没有 AVX 的情况下编译时避免了一些指令。)
但是,如果您希望像@PaulR 的答案一样将结果广播到每个元素,您需要进行通道内随机播放(即在每个通道的 4 个元素内旋转),然后交换一半,或旋转 128 位通道。