5

我想使用 GADT 在 Haskell 中对张量演算进行类型安全的实现,所以规则是:

  1. 张量是具有“楼上”或“楼下”的不定值的 n 维度量,例如:在此处输入图像描述- 是没有定值的张量(标量),在此处输入图像描述是具有一个“楼上”索引在此处输入图像描述的张量,是具有一堆 ' 的张量楼上和楼下的猥亵
  2. 您可以添加相同类型的张量,这意味着它们具有相同的 indecies 签名。第一个张量的第 0 个索引与第二个张量的第 0 个索引的类型相同(楼上或楼下),依此类推...

    在此处输入图像描述 ~~~~ 好的

    在此处输入图像描述 ~~~~ 不好

  3. 您可以将张量相乘并获得更大的张量,并将这些不定值连接起来:在此处输入图像描述

所以我希望 Haskell 的类型检查器不允许我编写不遵循这些规则的代码,否则它不会编译。

这是我使用 GADT 的尝试:

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}

data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)

plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)

data Tensor a = (a ~ Zero) => Scalar Double | 
                forall b. (a ~ Up b) => Cov (Direction -> Tensor b) |
                forall b. (a ~ Down b) => Con (Direction -> Tensor b) 

add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))

mul :: Tensor a -> Tensor b -> Tensor (plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))

但我得到:

Couldn't match type 'Down with `plus ('Down b1)'                                                                                                                                                                                                    
    Expected type: Tensor (plus a b)                                                                                                                                                                                                                    
      Actual type: Tensor ('Down b)                                                                                                                                                                                                                     
    Relevant bindings include                                                                                                                                                                                                                           
      f :: Direction -> Tensor b1 (bound at main.hs:28:10)                                                                                                                                                                                              
      mul :: Tensor a -> Tensor b -> Tensor (plus a b)                                                                                                                                                                                                  
        (bound at main.hs:24:1)                                                                                                                                                                                                                         
    In the expression: (Con (\ d -> mul (f d) y))                                                                                                                                                                                                       
    In an equation for `mul':                                                                                                                                                                                                                           
        mul (Con f) y = (Con (\ d -> mul (f d) y)) 

问题是什么?

4

1 回答 1

4

plus只是类型值的函数Index

>>> plus Zero Zero
Zero
>>> plus Zero (Up Zero)
Up Zero

所以它不能像现在一样出现在类型签名中。您想使用 'promoted' 类型,其中ZeroUp Zero是类型。然后你可以编写一个类型函数,一切都会编译。

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}

data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)

-- type function Plus
type family Plus (i :: Index) (j :: Index) :: Index where
  Plus Zero x = x
  Plus (Up x) y  = Up (Plus x y)
  Plus (Down x) y = Down (Plus x y)

-- value fuction plus
plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)

data Tensor (a :: Index) where
  Scalar :: Double -> Tensor Zero
  Cov :: (Direction -> Tensor b) -> Tensor (Up b)
  Con :: (Direction -> Tensor b) -> Tensor (Down b)

add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))

mul :: Tensor a -> Tensor b -> Tensor (Plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))

没有歧义,Plus但我可以使用消除歧义的勾号'来表示我正在处理类型级别ZeroUp

type family Plus (i :: Index) (j :: Index) :: Index where
  Plus 'Zero x = x
  Plus ('Up x) y  = 'Up (Plus x y)
  Plus ('Down x) y = 'Down (Plus x y)

TypeOperators将允许你写a + b而不是Plus a b上面。

type family (i :: Index) + (j :: Index) :: Index where
  Zero + x = x
  Up x + y  = Up (x + y)
  Down x + y = Down (x + y) 
于 2017-04-01T13:21:59.483 回答