与 Java 一样高效。具体而言,假设我正在谈论一个简单的三重循环、单精度、连续列主要布局(float[],而不是 float[][])和大小为 1000x1000 的矩阵,以及一个单核 CPU。(如果您每个周期进行 0.5-2 次浮点运算,那么您可能就在球场上)
所以像
public class MatrixProd {
static float[] matProd(float[] a, int ra, int ca, float[] b, int rb, int cb) {
if (ca != rb) {
throw new IllegalArgumentException("Matrices not fitting");
}
float[] c = new float[ra*cb];
for(int i = 0; i < ra; ++i) {
for(int j = 0; j < cb; ++j) {
float sum = 0;
for(int k = 0; k < ca; ++k) {
sum += a[i*ca+k]*b[k*cb+j];
}
c[i*cb+j] = sum;
}
}
return c;
}
static float[] mkMat(int rs, int cs, float x, float d) {
float[] arr = new float[rs*cs];
for(int i = 0; i < rs; ++i) {
for(int j = 0; j < cs; ++j) {
arr[i*cs+j] = x;
x += d;
}
}
return arr;
}
public static void main(String[] args) {
int sz = 100;
float strt = -32, del = 0.0625f;
if (args.length > 0) {
sz = Integer.parseInt(args[0]);
}
if (args.length > 1) {
strt = Float.parseFloat(args[1]);
}
if (args.length > 2) {
del = Float.parseFloat(args[2]);
}
float[] a = mkMat(sz,sz,strt,del);
float[] b = mkMat(sz,sz,strt-16,del);
System.out.println(a[sz*sz-1]);
System.out.println(b[sz*sz-1]);
long t0 = System.currentTimeMillis();
float[] c = matProd(a,sz,sz,b,sz,sz);
System.out.println(c[sz*sz-1]);
long t1 = System.currentTimeMillis();
double dur = (t1-t0)*1e-3;
System.out.println(dur);
}
}
我想?(我在编码之前没有正确阅读规范,所以布局是行主要的,但由于访问模式是相同的,这不会像混合布局那样产生影响,所以我假设没关系。)
我没有花任何时间思考一个聪明的算法或低级优化技巧(无论如何我在 Java 中不会有太多收获)。我只是写了简单的循环,因为
我不希望这听起来像是一个挑战,但请注意,Java 可以轻松满足上述所有要求
这就是 Java轻松提供的东西,所以我会接受它。
(如果您每个周期进行 0.5-2 次浮点运算,那么您可能就在球场上)
恐怕在 Java 和 Haskell 中都没有。通过简单的三重循环,太多的缓存未命中无法达到该吞吐量。
在 Haskell 中做同样的事情,再次没有考虑聪明,一个简单直接的三重循环:
{-# LANGUAGE BangPatterns #-}
module MatProd where
import Data.Array.ST
import Data.Array.Unboxed
matProd :: UArray Int Float -> Int -> Int -> UArray Int Float -> Int -> Int -> UArray Int Float
matProd a ra ca b rb cb =
let (al,ah) = bounds a
(bl,bh) = bounds b
{-# INLINE getA #-}
getA i j = a!(i*ca + j)
{-# INLINE getB #-}
getB i j = b!(i*cb + j)
{-# INLINE idx #-}
idx i j = i*cb + j
in if al /= 0 || ah+1 /= ra*ca || bl /= 0 || bh+1 /= rb*cb || ca /= rb
then error $ "Matrices not fitting: " ++ show (ra,ca,al,ah,rb,cb,bl,bh)
else runSTUArray $ do
arr <- newArray (0,ra*cb-1) 0
let outer i j
| ra <= i = return arr
| cb <= j = outer (i+1) 0
| otherwise = do
!x <- inner i j 0 0
writeArray arr (idx i j) x
outer i (j+1)
inner i j k !y
| ca <= k = return y
| otherwise = inner i j (k+1) (y + getA i k * getB k j)
outer 0 0
mkMat :: Int -> Int -> Float -> Float -> UArray Int Float
mkMat rs cs x d = runSTUArray $ do
let !r = rs - 1
!c = cs - 1
{-# INLINE idx #-}
idx i j = cs*i + j
arr <- newArray (0,rs*cs-1) 0
let outer i j y
| r < i = return arr
| c < j = outer (i+1) 0 y
| otherwise = do
writeArray arr (idx i j) y
outer i (j+1) (y + d)
outer 0 0 x
和调用模块
module Main (main) where
import System.Environment (getArgs)
import Data.Array.Unboxed
import System.CPUTime
import Text.Printf
import MatProd
main :: IO ()
main = do
args <- getArgs
let (sz, strt, del) = case args of
(a:b:c:_) -> (read a, read b, read c)
(a:b:_) -> (read a, read b, 0.0625)
(a:_) -> (read a, -32, 0.0625)
_ -> (100, -32, 0.0625)
a = mkMat sz sz strt del
b = mkMat sz sz (strt - 16) del
print (a!(sz*sz-1))
print (b!(sz*sz-1))
t0 <- getCPUTime
let c = matProd a sz sz b sz sz
print $ c!(sz*sz-1)
t1 <- getCPUTime
printf "%.6f\n" (fromInteger (t1-t0)*1e-12 :: Double)
所以我们用两种语言做几乎完全相同的事情。用 编译 Haskell -O2
,用 javac 编译 Java
$ java MatrixProd 1000 "-13.7" 0.013
12915.623
12899.999
8.3592897E10
8.193
$ ./vmmult 1000 "-13.7" 0.013
12915.623
12899.999
8.35929e10
8.558699
结果时间非常接近。
如果我们将 Java 代码编译为本机,使用gcj -O3 -Wall -Wextra --main=MatrixProd -fno-bounds-check -fno-store-check -o jmatProd MatrixProd.java
,
$ ./jmatProd 1000 "-13.7" 0.013
12915.623
12899.999
8.3592896512E10
8.215
仍然没有太大的区别。
作为特殊奖励,C 中的相同算法(gcc -O3):
$ ./cmatProd 1000 "-13.7" 0.013
12915.623047
12899.999023
8.35929e+10
8.079759
因此,当涉及使用浮点数的计算密集型任务时,这表明简单的 Java 和简单的 Haskell 之间没有根本区别(在处理中到大数的整数运算时,GHC 对 GMP 的使用使得 Haskell 的性能大大优于 Java 的 BigInteger对于许多任务,但这当然是库问题,而不是语言问题),并且使用这种算法,两者都接近 C。
不过,平心而论,这是因为访问模式每隔纳秒就会导致一次缓存未命中,所以在所有三种语言中,这种计算都是受内存限制的。
如果我们通过将行优先矩阵与列优先矩阵相乘来改进访问模式,一切都会变得更快,gcc 编译的 C 完成它 1.18 秒,java 需要 1.23 秒,ghc 编译的 Haskell 大约需要 5.8 秒,这可以通过使用 llvm 后端减少到 3 秒。
在这里,数组库的范围检查真的很痛苦。使用未经检查的数组访问(应该在检查错误之后,因为检查已经在控制循环的代码中完成),GHC 的本机后端在 2.4 秒内完成,通过 llvm 后端让计算在 1.55 秒内完成,这是不错的,虽然比 C 和 Java 慢得多。使用原语GHC.Prim
而不是数组库,llvm 后端生成的代码可以在 1.16 秒内运行(同样,每次访问都没有边界检查,但是在这种情况下,之前可以很容易地证明在计算过程中只生成有效的索引,所以在这里,没有牺牲内存安全¹;检查每次访问使时间达到 1.96 秒,仍然明显优于数组的边界检查图书馆)。
底线:GHC 需要(更多)更快的边界检查分支,优化器还有改进的空间,但原则上,“Haskell 的方法(在类型系统中编码的纯度)与效率、内存安全和简单性兼容“,我们只是还没到那里。就目前而言,必须决定自己愿意牺牲多少。
¹ 是的,这是一种特殊情况,通常省略边界检查确实会牺牲内存安全,或者至少很难证明它没有。