3

我正在尝试用 Haskell解决Project Euler 问题#92 。我最近开始学习 Haskell。这是我尝试用 Haskell 解决的第一个 Project Euler 问题,但我的代码即使在 10 分钟内也不会终止。我知道你没有直接给我答案,但我再次警告我用 c++ 找到答案没有给出欧拉的答案或解决欧拉的新逻辑。我只是好奇为什么那个家伙工作不快,我应该怎么做才能让它更快?

{--EULER 92--}
import Data.List


myFirstFunction 1 = 0
myFirstFunction 89 = 1
myFirstFunction x= myFirstFunction (giveResult x)

giveResult 0 = 0
giveResult x = (square (mod x 10)) + (giveResult (div x 10))

square x = x*x


a=[1..10000000]


main = putStrLn(show (sum (map myFirstFunction a))) 
4

2 回答 2

23

最大的加速当然可以通过使用更好的算法来获得。不过,我不会在这里深入探讨。

原始算法调整

因此,让我们专注于改进使用的算法,而不是真正改变它。

  1. 您永远不会给出任何类型签名,因此类型默认为任意精度Integer。这里的所有内容都可以轻松放入 中Int,没有溢出的危险,所以让我们使用它。添加类型签名会有所myFirstFunction :: Int -> Int帮助:时间从Total time 13.77s ( 13.79s elapsed)to下降Total time 6.24s ( 6.24s elapsed),总分配下降约 15 倍。对于这样一个简单的更改来说还不错。

  2. 您使用divmod。这些总是计算一个非负余数和相应的商,所以他们需要一些额外的检查,以防涉及一些负数。功能quotrem映射到机器除法指令,它们不涉及此类检查,因此速度更快。如果您通过 LLVM 后端-fllvm(现在时间:Total time 1.56s ( 1.56s elapsed).

  3. 代替单独使用quotand rem,让我们使用quotRem同时​​计算两者的函数,这样我们就不会重复除法(即使乘法+移位需要一点时间):

    giveResult x = case x `quotRem` 10 of
                     (q,r) -> r*r + giveResult q
    

    这并没有太大的收获,而是一点点:Total time 1.49s ( 1.49s elapsed)

  4. 您正在使用 list a = [1 .. 10000000],以及map该列表上的函数,然后sum是结果列表。这是惯用的,简洁的和简短的,但不是超级快,因为分配所有这些列表单元并垃圾收集它们也需要时间 - 不是很多,因为 GHC非常擅长,但将其转换为循环

    main = print $ go 0 1
        where
            go acc n
                | n > 10000000 = acc
                | otherwise    = go (acc + myFirstFunction n) (n+1)
    

    仍然让我们有一点收获:Total time 1.34s ( 1.34s elapsed)分配从880,051,856 bytes allocated in the heap最后一个列表版本下降到51,840 bytes allocated in the heap.

  5. giveResult是递归的,因此不能内联。同样适用于myFirstFunction,因此每次计算都需要两个函数调用(至少)。我们可以通过重写giveResult非递归包装器和递归本地循环来避免这种情况,

    giveResult x = go 0 x
        where
            go acc 0 = acc
            go acc n = case n `quotRem` 10 of
                         (q,r) -> go (acc + r*r) q
    

    这样就可以内联:Total time 1.04s ( 1.04s elapsed).

这些是最明显的观点,进一步的改进——除了哈马尔在评论中提到的记忆——需要一些思考。

我们现在在

module Main (main) where

myFirstFunction :: Int -> Int
myFirstFunction 1 = 0
myFirstFunction 89 = 1
myFirstFunction x= myFirstFunction (giveResult x)

giveResult :: Int -> Int
giveResult x = go 0 x
    where
        go acc 0 = acc
        go acc n = case n `quotRem` 10 of
                     (q,r) -> go (acc + r*r) q

main :: IO ()
main = print $ go 0 1
    where
        go acc n
            | n > 10000000 = acc
            | otherwise    = go (acc + myFirstFunction n) (n+1)

使用-O2 -fllvm, 在这里运行需要 1.04 秒,但使用本机代码生成器(仅-O2)需要 3.5 秒。这种差异是由于 GHC 本身不会将除法转换为乘法和位移。如果我们手动完成,我们可以从本机代码生成器中获得几乎相同的性能。

因为我们知道编译器不知道的事情,即我们在这里从不处理负数,并且数字不会变大,所以我们甚至可以生成更好的乘法和移位(这会产生负数或负数的错误结果)大红利)比编译器花费时间减少到 0.9 秒的本地代码生成器和 0.73 秒的 LLVM 后端:

import Data.Bits

qr10 :: Int -> (Int, Int)
qr10 n = (q, r)
  where
    q = (n * 0x66666667) `unsafeShiftR` 34
    r = n - 10 * q

注意:这要求Int是 64 位类型,它不适用于 32 位Ints,它会产生错误的结果为负数n,并且乘法将溢出为大n。我们正在进入肮脏的黑客领域。Word我们可以通过使用而不是来减轻肮脏Int,只留下溢出(这不会发生在n <= 10737418236Wordresp n <= 5368709118for 中Int,所以在这里我们很舒服地处于安全区域)。时间不受影响。

对应的C程序

#include <stdio.h>

unsigned int myFirstFunction(unsigned int i);
unsigned int giveResult(unsigned int i);

int main(void) {
    unsigned int sum = 0;
    for(unsigned int i = 1; i <= 10000000; ++i) {
        sum += myFirstFunction(i);
    }
    printf("%u\n",sum);
    return 0;
}

unsigned int myFirstFunction(unsigned int i) {
    if (i == 1) return 0;
    if (i == 89) return 1;
    return myFirstFunction(giveResult(i));
}

unsigned int giveResult(unsigned int i) {
    unsigned int acc = 0, r, q;
    while(i) {
        q = (i*0x66666667UL) >> 34;
        r = i - q*10;
        i = q;
        acc += r*r;
    }
    return acc;
}

执行类似,使用 编译gcc -O3,运行时间为 0.78 秒,使用clang -O30.71 运行。

在不改变算法的情况下,这几乎是结束了。


记忆

现在,算法的一个小变化是记忆。如果我们为数字建立一个查找表<= 7*9²,我们只需要对每个数字的数字平方和进行一次计算,而不是迭代直到我们达到 1 或 89,所以让我们记住,

module Main (main) where

import Data.Array.Unboxed
import Data.Array.IArray
import Data.Array.Base (unsafeAt)
import Data.Bits

qr10 :: Int -> (Int, Int)
qr10 n = (q, r)
  where
    q = (n * 0x66666667) `unsafeShiftR` 34
    r = n - 10 * q

digitSquareSum :: Int -> Int
digitSquareSum = go 0
  where
    go acc 0 = acc
    go acc n = case qr10 n of
                 (q,r) -> go (acc + r*r) q

table :: UArray Int Int
table = array (0,567) $ assocs helper
  where
    helper :: Array Int Int
    helper = array (0,567) [(i, f i) | i <- [0 .. 567]]
    f 0 = 0
    f 1 = 0
    f 89 = 1
    f n = helper ! digitSquareSum n

endPoint :: Int -> Int
endPoint n = table `unsafeAt` digitSquareSum n

main :: IO ()
main = print $ go 0 1
  where
    go acc n
        | n > 10000000 = acc
        | otherwise    = go (acc + endPoint n) (n+1)

手动进行记忆而不是使用库会使代码更长,但我们可以根据需要对其进行调整。我们可以使用未装箱的数组,并且可以省略对数组访问的边界检查。两者都显着加快了计算速度。本机代码生成器的时间现在是 0.18 秒,而 LLVM 后端的时间是 0.13 秒。对应的 C 程序用 0.16 秒编译,用gcc -O30.145 秒编译clang -O3(Haskell 比 C 好,w00t!)。


缩放和更好算法的提示

然而,所使用的算法并不能很好地扩展,比线性算法差一点,并且对于 10 8的上限(具有适当调整的记忆限制),它在 1.5 秒(ghc -O2 -fllvm)内运行,分别。1.64 秒 ( clang -O3) 和 1.87 秒 ( gcc -O3) [本机代码生成器为 2.02 秒]。

使用另一种算法,通过将这些数字划分为数字平方和来计算序列以 1 结尾的数字(唯一直接产生 1 的数字是 10 的幂。我们可以写

10 = 1×3² + 1×1²
10 = 2×2² + 2×1²
10 = 1×2² + 6×1²
10 = 10×1²

从第一个,我们得到 13, 31, 103, 130, 301, 310, 1003, 1030, 1300, 3001, 3010, 3100, ... 从第二个,我们得到 1122, 1212, 1221, 2112, 2121, 2211 , 11022, 11202, ... 从第三个 1111112, 1111121, ...

只有 13, 31, 103, 130, 301, 310 是可能的数字的平方和<= 10^10,所以只有那些需要进一步研究。我们可以写

100 = 1×9² + 1×4² + 3×1²
...
100 = 1×8² + 1×6²
...

这些分区中的第一个不生成孩子,因为它需要五个非零数字,另一个明确给出的生成两个孩子 68 和 86(如果限制是 10 8,也是 608 ,更大的限制更多)),我们可以获得更好的缩放和更快的算法。

我在解决这个问题时写的相当未优化的程序运行(输入是极限的 10 的指数)

$ time ./problem92 7
8581146

real    0m0.010s
user    0m0.008s
sys     0m0.002s
$ time ./problem92 8
85744333

real    0m0.022s
user    0m0.018s
sys     0m0.003s
$ time ./problem92 9
854325192

real    0m0.040s
user    0m0.033s
sys     0m0.006s
$ time ./problem92 10
8507390852

real    0m0.074s
user    0m0.069s
sys     0m0.004s

在不同的联赛。

于 2013-05-01T18:53:54.960 回答
9

首先,我冒昧地清理了您的代码:

endsAt89 1  = 0
endsAt89 89 = 1
endsAt89 n  = endsAt89 (sumOfSquareDigits n)

sumOfSquareDigits 0 = 0
sumOfSquareDigits n = (n `mod` 10)^2 + sumOfSquareDigits (n `div` 10)    

main = print . sum $ map endsAt89 [1..10^7]

在我蹩脚的上网本上是 1 分 13 秒。让我们看看我们是否可以改进它。

由于数字很小,我们可以从使用 machine-sizedInt而不是 absolute-size开始Integer。这只是添加类型签名的问题,例如

sumOfSquareDigits :: Int -> Int

这将运行时间大幅缩短至 20 秒。

由于数字都是正数,我们可以用稍快的and替换divand ,或者甚至同时用替换两者:modquotremquotRem

sumOfSquareDigits :: Int -> Int
sumOfSquareDigits 0 = 0
sumOfSquareDigits n = r^2 + sumOfSquareDigits q
  where (q, r) = quotRem x 10

运行时间现在为 17 秒。使其尾递归再减少一秒钟:

sumOfSquareDigits :: Int -> Int
sumOfSquareDigits n = loop n 0
  where
    loop 0 !s = s
    loop n !s = loop q (s + r^2)
      where (q, r) = quotRem n 10

为了进一步改进,我们可以注意到对于给定的输入数字sumOfSquareDigits最多返回567 = 7 * 9^2,因此我们可以记忆小数字以减少所需的迭代次数。这是我的最终版本(使用data-memocombinators包进行记忆):

{-# LANGUAGE BangPatterns #-}
import qualified Data.MemoCombinators as Memo

endsAt89 :: Int -> Int
endsAt89 = Memo.arrayRange (1, 7*9^2) endsAt89'
  where
    endsAt89' 1  = 0
    endsAt89' 89 = 1
    endsAt89' n  = endsAt89 (sumOfSquareDigits n)

sumOfSquareDigits :: Int -> Int
sumOfSquareDigits n = loop n 0
  where
    loop 0 !s = s
    loop n !s = loop q (s + r^2)
      where (q, r) = quotRem n 10

main = print . sum $ map endsAt89 [1..10^7]

这在我的机器上运行不到 9 秒。

于 2013-05-01T18:37:24.637 回答