这是一种方法(这是一个完整的要点)。我们坚持使用 Peano 数字而不是 GHC 的类型级别Nat
,只是因为归纳对它们更有效。
{-# LANGUAGE GADTs, PolyKinds, DataKinds, TypeOperators, FlexibleInstances, FlexibleContexts #-}
import Data.Foldable
import Text.PrettyPrint.HughesPJClass
data Nat = Z | S Nat
-- Some type synonyms that simplify uses of 'Nat'
type N0 = Z
type N1 = S N0
type N2 = S N1
type N3 = S N2
type N4 = S N3
type N5 = S N4
type N6 = S N5
type N7 = S N6
type N8 = S N7
type N9 = S N8
-- Similar to lists, but indexed over their length
data Vector (dim :: Nat) a where
Nil :: Vector Z a
(:-) :: a -> Vector n a -> Vector (S n) a
infixr 5 :-
data Tensor (dim :: [Nat]) a where
Scalar :: a -> Tensor '[] a
Tensor :: Vector d (Tensor ds a) -> Tensor (d : ds) a
为了显示这些类型,我们将使用pretty
包(GHC 已经自带)。
instance (Foldable (Vector n), Pretty a) => Pretty (Vector n a) where
pPrint = braces . sep . punctuate (text ",") . map pPrint . toList
instance Pretty a => Pretty (Tensor '[] a) where
pPrint (Scalar x) = pPrint x
instance (Pretty (Tensor ds a), Pretty a, Foldable (Vector d)) => Pretty (Tensor (d : ds) a) where
pPrint (Tensor xs) = pPrint xs
然后这里是Foldable
我们的数据类型的实例(这里没什么奇怪的 - 我包括这个只是因为你需要它Pretty
来编译实例):
instance Foldable (Vector Z) where
foldMap f Nil = mempty
instance Foldable (Vector n) => Foldable (Vector (S n)) where
foldMap f (x :- xs) = f x `mappend` foldMap f xs
instance Foldable (Tensor '[]) where
foldMap f (Scalar x) = f x
instance (Foldable (Vector d), Foldable (Tensor ds)) => Foldable (Tensor (d : ds)) where
foldMap f (Tensor xs) = foldMap (foldMap f) xs
最后,回答您问题的部分:我们可以定义Applicative (Vector n)
并Applicative (Tensor ds)
类似于如何Applicative ZipList
定义(除了pure
不返回和空列表 - 它返回正确长度的列表)。
instance Applicative (Vector Z) where
pure _ = Nil
Nil <*> Nil = Nil
instance Applicative (Vector n) => Applicative (Vector (S n)) where
pure x = x :- pure x
(x :- xs) <*> (y :- ys) = x y :- (xs <*> ys)
instance Applicative (Tensor '[]) where
pure = Scalar
Scalar x <*> Scalar y = Scalar (x y)
instance (Applicative (Vector d), Applicative (Tensor ds)) => Applicative (Tensor (d : ds)) where
pure x = Tensor (pure (pure x))
Tensor xs <*> Tensor ys = Tensor ((<*>) <$> xs <*> ys)
然后,在 GHCi 中,创建函数非常简单zero
:
ghci> :set -XDataKinds
ghci> zero = pure 0
ghci> pPrint (zero :: Tensor [N5,N3,N2] Integer)
{{{0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {0, 0}, {0, 0}}}