3

我想要一种类型来以类型安全的方式表示多维数组(张量)。所以我可以写例如:zero :: Tensor (5,3,2) Integer 这将表示一个具有 5 个元素的多维数组,每个元素有 3 个元素,每个元素有 2 个元素,其中所有元素都是Integers

您将如何使用类型级编程来定义这种类型?

编辑

在 Alec 的精彩回答之后,它使用GADTs 实现了这个,

我想知道您是否可以更进一步,并支持 aclass Tensor和张量操作的多种实现以及张量的序列化

这样您就可以拥有例如:

  • GPUCPU使用的实现C
  • Haskell实现
  • 只打印计算图而不计算任何东西的实现
  • 将结果缓存在磁盘上的实现
  • 并行或分布式计算
  • ETC...

所有类型都安全且易于使用。

我的目的是使用自动微分ad library)和精确实数算术exact-real library)在 Haskell 中创建一个非常相似tensor-flow但类型安全且更具可扩展性的库

我认为像这样的函数式语言Haskell比以某种方式萌芽的 python 生态系统更适合这些事情(我认为适合所有事情)。

  • Haskell 是纯函数式的,比 python 更适合计算编程
  • Haskell 比 python 效率高很多,可以编译成二进制
  • Haskell 的懒惰(可以说)消除了优化计算图的需要,并使代码更简单
  • Haskell 中更强大的抽象

虽然我看到了潜力,但我对这种类型级编程还不够精通(或不够聪明),所以我不知道如何在 Haskell 中实现这样的东西并让它编译。

这就是我需要你帮助的地方。

4

1 回答 1

3

这是一种方法(这是一个完整的要点)。我们坚持使用 Peano 数字而不是 GHC 的类型级别Nat,只是因为归纳对它们更有效。

{-# LANGUAGE GADTs, PolyKinds, DataKinds, TypeOperators, FlexibleInstances, FlexibleContexts #-}

import Data.Foldable
import Text.PrettyPrint.HughesPJClass

data Nat = Z | S Nat

-- Some type synonyms that simplify uses of 'Nat'
type N0 = Z
type N1 = S N0
type N2 = S N1
type N3 = S N2
type N4 = S N3
type N5 = S N4
type N6 = S N5
type N7 = S N6
type N8 = S N7
type N9 = S N8

-- Similar to lists, but indexed over their length
data Vector (dim :: Nat) a where
  Nil    :: Vector Z a
  (:-)   :: a -> Vector n a -> Vector (S n) a

infixr 5 :-

data Tensor (dim :: [Nat]) a where
  Scalar :: a -> Tensor '[] a
  Tensor :: Vector d (Tensor ds a) -> Tensor (d : ds) a

为了显示这些类型,我们将使用pretty包(GHC 已经自带)。

instance (Foldable (Vector n), Pretty a) => Pretty (Vector n a) where
  pPrint = braces . sep . punctuate (text ",") . map pPrint . toList

instance Pretty a => Pretty (Tensor '[] a) where
  pPrint (Scalar x) = pPrint x

instance (Pretty (Tensor ds a), Pretty a, Foldable (Vector d)) => Pretty (Tensor (d : ds) a) where
  pPrint (Tensor xs) = pPrint xs

然后这里是Foldable我们的数据类型的实例(这里没什么奇怪的 - 我包括这个只是因为你需要它Pretty来编译实例):

instance Foldable (Vector Z) where
  foldMap f Nil = mempty

instance Foldable (Vector n) => Foldable (Vector (S n)) where
  foldMap f (x :- xs) = f x `mappend` foldMap f xs


instance Foldable (Tensor '[]) where
  foldMap f (Scalar x) = f x

instance (Foldable (Vector d), Foldable (Tensor ds)) => Foldable (Tensor (d : ds)) where
  foldMap f (Tensor xs) = foldMap (foldMap f) xs

最后,回答您问题的部分:我们可以定义Applicative (Vector n)Applicative (Tensor ds)类似于如何Applicative ZipList定义(除了pure不返回和空列表 - 它返回正确长度的列表)。

instance Applicative (Vector Z) where
  pure _ = Nil
  Nil <*> Nil = Nil

instance Applicative (Vector n) => Applicative (Vector (S n)) where
  pure x = x :- pure x
  (x :- xs) <*> (y :- ys) = x y :- (xs <*> ys)


instance Applicative (Tensor '[]) where
  pure = Scalar
  Scalar x <*> Scalar y = Scalar (x y)

instance (Applicative (Vector d), Applicative (Tensor ds)) => Applicative (Tensor (d : ds)) where
  pure x = Tensor (pure (pure x))
  Tensor xs <*> Tensor ys = Tensor ((<*>) <$> xs <*> ys)

然后,在 GHCi 中,创建函数非常简单zero

ghci> :set -XDataKinds
ghci> zero = pure 0
ghci> pPrint (zero :: Tensor [N5,N3,N2] Integer)
{{{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}}}
于 2017-05-08T02:59:47.473 回答