1

I posted this question a few days ago: Haskell performance using dynamic programming and was recommended to use ByteStrings instead of Strings. After implementing the algorithm with ByteStrings, the program crashes, going over the memory limits.

import Control.Monad
import Data.Array.IArray
import qualified Data.ByteString as B

main = do
  n <- readLn
  pairs <- replicateM n $ do
    s1 <- B.getLine
    s2 <- B.getLine
    return (s1,s2)
  mapM_ (print . editDistance) pairs

editDistance :: (B.ByteString, B.ByteString) -> Int
editDistance (s1, s2) = dynamic editDistance' (B.length s1, B.length s2)
  where
    editDistance' table (i,j)
      | min i j == 0 = max i j
      | otherwise = min' (table!((i-1),j) + 1) (table!(i,(j-1)) + 1) (table!((i-1),(j-1)) + cost)
      where
        cost =  if B.index s1 (i-1) == B.index s2 (j-1) then 0 else 1
        min' a b = min (min a b)

dynamic :: (Array (Int,Int) Int -> (Int,Int) -> Int) -> (Int,Int) -> Int
dynamic compute (xBnd, yBnd) = table!(xBnd,yBnd)
  where
    table = newTable $ map (\coord -> (coord, compute table coord)) [(x,y) | x<-[0..xBnd], y<-[0..yBnd]]
    newTable xs = array ((0,0),fst (last xs)) xs

The memory consumption appears to scale with n. The length of the input strings are 1000 characters. I would expect Haskell to free all memory used in editDistance after each solution is printed. Is this not the case? If not, how can I force this?

The only other real calculation I see is for cost but forcing that with seq did nothing.

4

2 回答 2

2

Certainly your memory will increase with n if you read all n inputs prior to computing any results and printing outputs. You could try interleaving the input and output operations:

main = do
  n <- readLn
  replicateM_ n $ do
    s1 <- B.getLine
    s2 <- B.getLine
    print (editDistance (s1,s2))

Or alternatively using lazy IO (untested, probably needs gratuitous B.):

main = do
  n <- readLn
  cont <- getContents
  let lns = take n (lines cont)
      pairs = unfoldr (\case (x:y:rs) -> Just ((x,y),rs) ; _ -> Nothing) lns
  mapM_ (print . editDistance) pairs

EDIT: Other possible savings include using an unboxed array and not forcing your whole strLen^2 size list via last during array construction. Consider array ((0,0),(xBnd,yBnd)) xs.

于 2016-10-07T21:20:10.750 回答
0

My feeling is that the problem is that your min' is not strict enough. Because it doesn't force its arguments it simply builds up a thunks for each array element. This causes more memory to be used, GC times to increase, etc.

I would try:

{-# LANGUAGE BangPatterns #-}

...
min' !a !b !c = min a (min b c)
于 2016-10-07T20:58:06.200 回答