我想使用 GADT 在 Haskell 中对张量演算进行类型安全的实现,所以规则是:
- 张量是具有“楼上”或“楼下”的不定值的 n 维度量,例如:
- 是没有定值的张量(标量),
是具有一个“楼上”索引
的张量,是具有一堆 ' 的张量楼上和楼下的猥亵
您可以添加相同类型的张量,这意味着它们具有相同的 indecies 签名。第一个张量的第 0 个索引与第二个张量的第 0 个索引的类型相同(楼上或楼下),依此类推...
所以我希望 Haskell 的类型检查器不允许我编写不遵循这些规则的代码,否则它不会编译。
这是我使用 GADT 的尝试:
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)
plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)
data Tensor a = (a ~ Zero) => Scalar Double |
forall b. (a ~ Up b) => Cov (Direction -> Tensor b) |
forall b. (a ~ Down b) => Con (Direction -> Tensor b)
add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))
mul :: Tensor a -> Tensor b -> Tensor (plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))
但我得到:
Couldn't match type 'Down with `plus ('Down b1)'
Expected type: Tensor (plus a b)
Actual type: Tensor ('Down b)
Relevant bindings include
f :: Direction -> Tensor b1 (bound at main.hs:28:10)
mul :: Tensor a -> Tensor b -> Tensor (plus a b)
(bound at main.hs:24:1)
In the expression: (Con (\ d -> mul (f d) y))
In an equation for `mul':
mul (Con f) y = (Con (\ d -> mul (f d) y))
问题是什么?