8

我需要从更长的列表中随机抽取一个样本而不进行替换(每个元素在样本中只出现一次)。我正在使用下面的代码,但现在我想知道:

  1. 是否有执行此操作的库函数?
  2. 如何改进此代码?(我是 Haskell 初学者,所以即使有库函数也很有用)。

抽样的目的是能够将分析样本的结果推广到总体。

import System.Random

-- | Take a random sample without replacement of size size from a list.
takeRandomSample :: Int -> Int -> [a] -> [a]
takeRandomSample seed size xs
    | size < hi  = subset xs rs
    | otherwise = error "Sample size must be smaller than population."
    where
        rs = randomSample seed size lo hi
        lo = 0
        hi = length xs - 1

getOneRandomV g lo hi = randomR (lo, hi) g

rsHelper size lo hi g x acc
    | x `notElem` acc && length acc < size = rsHelper size lo hi new_g new_x (x:acc)
    | x `elem` acc && length acc < size = rsHelper size lo hi new_g new_x acc
    | otherwise = acc
    where (new_x, new_g) = getOneRandomV g lo hi

-- | Get a random sample without replacement of size size between lo and hi.
randomSample seed size lo hi = rsHelper size lo hi g x [] where
(x, g)  = getOneRandomV (mkStdGen seed) lo hi

subset l = map (l !!) 
4

2 回答 2

6

这是 Daniel Fischer 在评论中建议的快速“粗略”实现,使用我首选的 PRNG(mwc-random):

{-# LANGUAGE BangPatterns #-}

module Sample (sample) where

import Control.Monad.Primitive
import Data.Foldable (toList)
import qualified Data.Sequence as Seq
import System.Random.MWC

sample :: PrimMonad m => [a] -> Int -> Gen (PrimState m) -> m [a]
sample ys size = go 0 (l - 1) (Seq.fromList ys) where
    l = length ys
    go !n !i xs g | n >= size = return $! (toList . Seq.drop (l - size)) xs
                  | otherwise = do
                      j <- uniformR (0, i) g
                      let toI  = xs `Seq.index` j
                          toJ  = xs `Seq.index` i
                          next = (Seq.update i toI . Seq.update j toJ) xs
                      go (n + 1) (i - 1) next g
{-# INLINE sample #-}

这几乎是对 R 的内部 C 版本的(简洁的)功能重写,sample()因为它被称为无需替换。

sample只是一个递归工作函数的包装器,该函数递增地打乱总体,直到达到所需的样本大小,只返回那么多打乱的元素。像这样编写函数可确保 GHC 可以内联它。

它易于使用:

*Main> create >>= sample [1..100] 10
[51,94,58,3,91,70,19,65,24,53]

生产版本可能希望使用可变向量之类的东西,而不是Data.Sequence为了减少花费在 GC 上的时间。

于 2012-12-08T21:39:59.313 回答
2

我认为执行此操作的标准方法是使用前 N 个元素初始化一个固定大小的缓冲区,并且对于每个第 i 个元素,i >= N,执行以下操作:

  1. 选择一个介于 0 和 i 之间的随机数 j。
  2. 如果 j < N 则将缓冲区中的第 j 个元素替换为当前元素。

您可以通过归纳证明正确性:

如果您只有 N 个元素,这显然会生成一个随机样本(我假设顺序无关紧要)。现在假设直到第 i 个元素都是正确的。这意味着任何元素在缓冲区中的概率是 N/(i+1)(我从 0 开始计数)。

选取随机数后,第 i+1 个元素在缓冲区中的概率为 N/(i+2)(j 在 0 和 i+1 之间,其中 N 个元素最终在缓冲区中)。其他人呢?

P(k'th element is in the buffer after processing the i+1'th) =
P(k'th element was in the buffer before)*P(k'th element is not replaced) =
N/(i+1) * (1-1/(i+2)) =
N/(i+2)

这是一些使用标准(慢速)System.Random 在样本大小空间中执行此操作的代码。

import Control.Monad (when)                                                                                                       
import Data.Array                                                                                                                 
import Data.Array.ST                                                                                                              
import System.Random (RandomGen, randomR)                                                                                         

sample :: RandomGen g => g -> Int -> [Int] -> [Int]                                                                               
sample g size xs =                                                                                                                
  if size < length xs                                                                                                             
  then error "sample size must be >= input length"                                                                                
  else elems $ runSTArray $ do                                                                                                    
    arr <- newListArray (0, size-1) pre                                                                                         
    loop arr g size post                                                                                                          
  where                                                                                                                           
    (pre, post) = splitAt size xs                                                                                                 
    loop arr g i [] = return arr                                                                                                  
    loop arr g i (x:xt) = do                                                                                                      
      let (j, g') = randomR (0, i) g                                                                                              
      when (j < size) $ writeArray arr j x                                                                                        
      loop arr g' (i+1) xt                                                                                                        
于 2012-12-10T05:27:29.173 回答