I must admit I'm a bit lost with macros. I want to build a macro that does the following task and I'm not sure how to do it. I want to perform a scalar product of two arrays, say x and y, which have the same length N. The result I want to compute is of the form:
z = sum_{i=0}^{N-1} x[i] * y[i].
x
is const
which elements are 0, 1, or -1
which are known at compile time,
while y
's elements are determined at runtime. Because of the
structure of x
, many computations are useless (terms multiplied by 0
can be removed from the sum, and multiplications of the form 1 * y[i], -1 * y[i]
can be transformed into y[i], -y[i]
respectively).
As an example if x = [-1, 1, 0]
, the scalar product above would be
z=-1 * y[0] + 1 * y[1] + 0 * y[2]
To speed up my computation I can unroll the loop by hand and rewrite
the whole thing without x[i]
, and I could hard code the above formula as
z = -y[0] + y[1]
But this procedure is not elegant, error prone and very tedious when N becomes large.
I'm pretty sure I can do that with a macro, but I don't know where to start (the different books I read are not going too deep into macros and I'm stuck)...
Would anyone of you have any idea how to (if it is possible) this problem using macros?
Thank you in advance for your help!
Edit: As pointed out in many of the answers, the compiler is smart enough to remove optimize the loop in the case of integers. I am not only using integers but also floats (the x
array is i32s, but in general y
is f64
s), so the compiler is not smart enough (and rightfully so) to optimize the loop. The following piece of code gives the following asm.
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
xorpd %xmm0, %xmm0
movsd (%rdi), %xmm1
mulsd %xmm0, %xmm1
addsd %xmm0, %xmm1
addsd 8(%rdi), %xmm1
subsd 16(%rdi), %xmm1
movupd 24(%rdi), %xmm2
xorpd %xmm3, %xmm3
mulpd %xmm2, %xmm3
addsd %xmm3, %xmm1
unpckhpd %xmm3, %xmm3
addsd %xmm1, %xmm3
addsd 40(%rdi), %xmm3
mulsd 48(%rdi), %xmm0
addsd %xmm3, %xmm0
subsd 56(%rdi), %xmm0
retq