2

我正在研究Euler #14项目,并有一个解决方案来获得答案,但是当我尝试运行代码时出现堆栈空间溢出错误。该算法在交互式 GHCI 中运行良好(在低数字上),但当我向它抛出一个非常大的数字并尝试编译它时,它就不起作用了。

这是它在交互式 GHCI 中的作用的粗略概念。在我的电脑上计算“答案 50000”大约需要 10 秒。

让 GHCI 运行几分钟后,它会吐出正确的答案。

*Euler System.IO> answer 1000000
    (525,837799)

但这并不能解决编译程序以本地运行时的堆栈溢出错误。

*Euler System.IO> answer 10
    (20,9)
*Euler System.IO> answer 100
    (119,97)
*Euler System.IO> answer 1000
    (179,871)
*Euler System.IO> answer 10000
    (262,6171)
*Euler System.IO> answer 50000
    (324,35655)

我应该怎么做才能得到“答案 1000000”的答案?我想我的算法需要微调一下,但我不知道如何去做。

代码:

module Main
    where

import System.IO
import Control.Monad

main = print (answer 1000000)

-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total 
-- length of the chain
count' n = (cSeq n, n)
cSeq n = length $ game n

-- Find the maximum chain value of the game
answer n = maximum $ map count' [1..n]

-- Working game. 
-- game 13 = [13,40,20,10,5,16,8,4,2,1]
game n = n : play n
play x
    | x <= 0 = []                               -- is negative or 0
    | x == 1 = []                               -- is 1
    | even x = doEven x : play ((doEven x))     -- even
    | otherwise = doOdd x : play ((doOdd x))    -- odd
  where doOdd x = (3 * x) + 1
        doEven  x = (x `div` 2)
4

3 回答 3

4

这里的问题是maximum太懒了。它没有跟踪最大的元素,而是建立了一个巨大的maxthunk 树。这是因为maximum是根据 定义的foldl,所以评估如下:

maximum [1, 2, 3, 4, 5]
foldl max 1 [2, 3, 4, 5]
foldl max (max 1 2) [3, 4, 5]
foldl max (max (max 1 2) 3) [4, 5]
foldl max (max (max (max 1 2) 3) 4) [5]
foldl max (max (max (max (max 1 2) 3) 4) 5) []
max (max (max (max 1 2) 3) 4) 5  -- this expression will be huge for large lists

试图评估太多这些嵌套max调用会导致堆栈溢出。

解决方案是通过使用严格的版本来强制它评估这些foldl',(或者,在这种情况下,它的表亲foldl1')。max这可以通过在每个步骤中减少它们来防止's 堆积:

foldl1' max [1, 2, 3, 4, 5]
foldl' max 1 [2, 3, 4, 5]
foldl' max 2 [3, 4, 5]
foldl' max 3 [4, 5]
foldl' max 4 [5]
foldl' max 5 []
5

GHC 通常可以自行解决这些类型的问题,如果您使用-O2它(除其他外)对您的程序运行严格性分析进行编译。但是,我认为编写不需要依赖优化来工作的程序是一种很好的做法。

注意:修复此问题后,生成的程序仍然很慢。您可能想考虑使用 memoization 来解决这个问题。

于 2012-10-05T08:12:49.407 回答
4

@hammar 已经指出了太懒的问题maximum,以及如何解决这个问题(使用foldl1'的严格版本foldl1)。

但是代码中还有进一步的低效率。

cSeq n = length $ game n

cSeq让我们game构造一个列表,只计算它的长度。不幸的是,length不是一个“好消费者”,所以中间列表的构建并没有被融合掉。这是相当多的不必要的分配和花费时间。消除这些列表

cSeq n = coll (1 :: Int) n
  where
    coll acc 1 = acc
    coll acc m
      | even m    = coll (acc + 1) (m `div` 2)
      | otherwise = coll (acc + 1) (3*m+1)

减少了大约 65% 的分配和大约 20% 的运行时间(仍然很慢)。下一点,您正在使用div,它在正常除法之外执行符号检查。由于所涉及的所有数字都是正数,因此使用quotinstead 确实会加快速度(这里不多,但稍后会变得很重要)。

下一个重点是,由于您没有给出类型签名,因此数字的类型(除了在我的重写中由使用length或由表达式类型签名确定的地方)是. 上的操作比 上的相应操作慢得多,因此如果可能,您应该使用(or ) 而不是在速度很重要的时候。如果您有 64 位 GHC,则足以进行这些计算,使用 时,运行时间减少约一半,使用时减少约 70% ,使用本机代码生成器时,使用 LLVM 后端时,运行时间使用时减少约 70%,使用时减少约 95% 。(1 :: Int)IntegerIntegerIntIntWordIntegerIntdivquotdivquot

本机代码生成器和 LLVM 后端之间的差异主要是由于一些基本的低级优化。

even并且odd被定义

even, odd       :: (Integral a) => a -> Bool
even n          =  n `rem` 2 == 0
odd             =  not . even

GHC.Real. 当类型为Int时,LLVM 知道将用于确定模数的除以 2 替换为按位和 ( n .&. 1 == 0)。本机代码生成器(尚未)执行许多这些低级优化。如果您手动执行此操作,则 NCG 和 LLVM 后端生成的代码执行几乎相同。

使用div时,NCG 和 LLVM 都无法用短的移位和加法序列替换除法,因此您会通过符号测试获得相对较慢的机器除法指令。使用quot,两者都可以做到这一点Int,因此您可以获得更快的代码。

所有出现的数字都是正数的知识允许我们用简单的右移替换除以 2,而无需任何代码来纠正负参数,这将 LLVM 后端生成的代码再加速约 33%,奇怪的是它没有对 NCG 没有影响。

因此,从最初花费了 8 秒加/减一点(NCG 少一点,LLVM 后端多一点)的原始版本,我们已经去了

module Main (main)
    where

import Data.List
import Data.Bits

main = print (answer (1000000 :: Int))

-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total 
-- length of the chain
count' n = (cSeq n, n)
cSeq n = go (1 :: Int) n
  where
    go !acc 1 = acc
    go acc m
        | even' m   = go (acc+1) (m `shiftR` 1)
        | otherwise = go (acc+1) (3*m+1)

even' :: Int -> Bool
even' m = m .&. 1 == 0

-- Find the maximum chain value of the game
answer n = foldl1' max $ map count' [1..n]

在我的设置中,使用 NCG 需要 0.37 秒,使用 LLVM 后端需要 0.27 秒。

运行时间有微小的改进,但可以通过foldl1' max手动递归替换 来获得巨大的分配减少,

answer n = go 1 1 2
  where
    go ml mi i
        | n < i     = (ml,mi)
        | l > ml    = go l i (i+1)
        | otherwise = go ml mi (i+1)
          where
            l = cSeq i

这使其分别为0.35。0.25 秒(并产生一个微小的52,936 bytes allocated in the heap)。

现在,如果这仍然太慢,您可以考虑一个好的记忆策略。我知道的最好的(1)是使用未装箱的数组来存储不超过限制的数字的链长度,

{-# LANGUAGE BangPatterns #-}
module Main (main) where

import System.Environment (getArgs)
import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Data.Bits

main :: IO ()
main = do
    args <- getArgs
    let bd = case args of
               a:_ -> read a
               _   -> 100000
    print $ mxColl bd

mxColl :: Int -> (Int,Int)
mxColl bd = runST $ do
    arr <- newArray (0,bd) 0
    unsafeWrite arr 1 1
    goColl arr bd 1 1 2

goColl :: STUArray s Int Int -> Int -> Int -> Int -> Int -> ST s (Int,Int)
goColl arr bd ms ml i
    | bd < i    = return (ms,ml)
    | otherwise = do
        nln <- collatzLength arr bd i
        if ml < nln
          then goColl arr bd i nln (i+1)
          else goColl arr bd ms ml (i+1)

collatzLength :: STUArray s Int Int -> Int -> Int -> ST s Int
collatzLength arr bd n = go 1 n
  where
    go !l 1 = return l
    go l m
        | bd < m    = go (l+1) $ case m .&. 1 of
                                   0 -> m `shiftR` 1
                                   _ -> 3*m+1
        | otherwise = do
            l' <- unsafeRead arr m
            case l' of
              0 -> do
                  l'' <- go 1 $ case m .&. 1 of
                                  0 -> m `shiftR` 1
                                  _ -> 3*m+1
                  unsafeWrite arr m (l''+1)
                  return (l + l'')
              _ -> return (l+l'-1)

当使用 NCG 编译时,它在 0.04 秒内完成了 1000000 的限制,使用 LLVM 后端编译为 0.05 (显然,这在优化STUArray代码方面不如 NCG 好)。

如果您没有 64 位 GHC,则不能简单地使用Int,因为对于某些输入会溢出。但是计算的绝大部分仍然是在Int范围内执行的,所以你应该尽可能使用它,并且只移动到Integer需要的地方。

switch :: Int
switch = (maxBound - 1) `quot` 3

back :: Integer
back = 2 * fromIntegral (maxBound :: Int)

cSeq :: Int -> Int
cSeq n = goInt 1 n
  where
    goInt acc 1      = acc
    goInt acc m
      | m .&. 1 == 0 = goInt (acc+1) (m `shiftR` 1)
      | m > switch   = goInteger (acc+1) (3*toInteger m + 1)
      | otherwise    = goInt (acc+1) (3*m+1)
    goInteger acc m
      | fromInteger m .&. (1 :: Int) == 1 = goInteger (acc+1) (3*m+1)
      | m > back  = goInteger (acc+1) (m `quot` 2)  -- yup, quot is faster than shift for Integer here
      | otherwise = goInt (acc + 1) (fromInteger $ m `quot` 2)

使得优化循环变得更加困难,因此它比使用 的单循环慢Int,但仍然不错。在这里(循环从不运行),使用 NCG 需要 0.42 秒,使用 LLVM 后端需要 0.37 秒(这与在纯版本Integer中使用几乎相同)。quotInt

对记忆版本使用类似的技巧会产生类似的结果,它比纯Int版本慢得多,但与未记忆版本相比仍然快得惊人。


(1)对于这个特殊的(类型)问题,您需要记住一系列连续参数的结果。对于其他问题,一个Map或其他一些数据结构将是更好的选择。

于 2012-10-05T13:41:29.427 回答
0

正如已经指出的那样,该函数似乎是罪魁祸首,但是如果您使用标志maximum编译程序,则不必担心它。-O2

该程序仍然很慢,这是因为该问题应该教您有关记忆的知识。这样做的一种好方法是 haskell 是使用Data.Memocombinators

import Data.MemoCombinators
import Control.Arrow
import Data.List
import Data.Ord
import System.Environment

play m = maximumBy (comparing snd) . map (second threeNPuzzle) $ zip [1..] [1..m]
  where
    threeNPuzzle = arrayRange (1,m) memoized
    memoized n 
      | n == 1 = 1
      | odd n  = 1 + threeNPuzzle (3*n + 1)
      | even n = 1 + threeNPuzzle (n `div` 2)

main = getArgs >>= print . play . read . head

上面的程序-O2在我的机器上编译时运行不到一秒钟。

请注意,在这种情况下,记住 threeNPuzzle 找到的所有值并不是一个好主意,上面的程序会记住这些值直到限制(问题中为 1000000)。

于 2012-10-05T11:52:54.083 回答