In [331]: A=np.random.rand(100,200,300)
In [332]: B=A
The suggested einsum, working directly from the
C[i,j,k] = np.dot(A[i,k,:], B[j,k,:]
expression:
In [333]: np.einsum( 'ikm, jkm-> ijk', A, B).shape
Out[333]: (100, 100, 200)
In [334]: timeit np.einsum( 'ikm, jkm-> ijk', A, B).shape
800 ms ± 25.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
matmul does a dot on the last 2 dimensions, and treats the leading one(s) as batch. In your case 'k' is the batch dimension, and 'm' is the one that should obey the last A and 2nd to the last of B rule. So rewriting the ikm,jkm... to fit, and transposing A and B accordingly:
In [335]: np.einsum('kim,kmj->kij', A.transpose(1,0,2), B.transpose(1,2,0)).shape
Out[335]: (200, 100, 100)
In [336]: timeit np.einsum('kim,kmj->kij',A.transpose(1,0,2), B.transpose(1,2,0)).shape
774 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Not much difference in performance. But now use matmul:
In [337]: (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
Out[337]: (100, 100, 200)
In [338]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
64.4 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
and verify that values match (though more often than not, if shapes match, values do to).
In [339]: np.allclose((A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0),np.einsum( 'ikm, jkm->
...: ijk', A, B))
Out[339]: True
I won't try to measure memory usage, but the time improvement suggests it too is better.
In some cases einsum is optimized to use matmul. Here that doesn't seem to be the case, though we could play with its parameters. I'm a little surprised the matmul is doing so much better.
===
I vaguely recall another SO about matmul taking a short cut when the two arrays are the same thing, A@A. I used B=A in these tests.
In [350]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
60.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [352]: B2=np.random.rand(100,200,300)
In [353]: timeit (A.transpose(1,0,2)@B2.transpose(1,2,0)).transpose(1,2,0).shape
97.4 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
But that only made a modest difference.
In [356]: np.__version__
Out[356]: '1.16.4'
My BLAS etc is standard Linux, nothing special.