我正在尝试在 Haskell 中编写一个简单的自动微分包。
在 Haskell 中表示类型安全(有向)计算图的有效方法是什么?我知道广告包为此使用了“data-reify”方法,但我不太熟悉。谁能给我一些见解?谢谢!
我正在尝试在 Haskell 中编写一个简单的自动微分包。
在 Haskell 中表示类型安全(有向)计算图的有效方法是什么?我知道广告包为此使用了“data-reify”方法,但我不太熟悉。谁能给我一些见解?谢谢!
正如 Will Ness 的评论所指出的,AD 的正确抽象是一个类别,而不是一个图表。不幸的是,标准Category
类并没有真正做到这一点,因为它需要任何Haskell 类型之间的箭头,但区分只在平滑流形之间有意义。大多数库不了解流形并将其进一步限制为欧几里得向量空间(它们表示为“向量”或“张量”,它们只是数组)。确实没有令人信服的理由来限制- 任何仿射空间都可以用于正向模式 AD;对于反向模式,您还需要对差向量的对偶空间的概念。
data FwAD x y = FwAD (x -> (y, Diff x -> Diff y))
data RvAD x y = RvAD (x -> (y, DualVector (Diff y) -> DualVector (Diff x)))
其中Diff x -> Diff y
函数必须是线性函数。(您可以为此类函数使用专用的箭头类型,或者您可以只使用(->)
恰好是线性的函数。)在反向模式中唯一不同的是表示此线性映射的伴随,而不是映射本身. (在实值矩阵实现中,线性映射是雅可比矩阵,伴随版本是其转置,但不要使用矩阵,它们的效率非常低。)
整洁,对吗?人们一直在谈论的所有图表/遍历/变异/向后传递的废话并不是真正需要的。(详见Conal 的论文。)
因此,要使这在 Haskell 中有用,您需要实现类别组合器。这几乎正是我编写constrained-categories 包的目的。这是您需要的概要实例化:
import qualified Prelude as Hask
import Control.Category.Constrained.Prelude
import Control.Arrow.Constrained
import Data.AffineSpace
import Data.AdditiveGroup
import Data.VectorSpace
instance Category FwAD where
type Object FwAD a
= (AffineSpace a, VectorSpace (Diff a), Scalar (Diff a) ~ Double)
id = FwAD $ \x -> (x, id)
FwAD f . FwAD g = FwAD $ \x -> case g x of
(gx, dg) -> case f gx of
(fgx, df) -> (fgx, df . dg)
instance Cartesian FwAD where
...
instance Morphism FwAD where
...
instance PreArrow FwAD where
...
instance WellPointed FwAD where
...
这些实例都很简单而且几乎没有歧义,让编译器消息引导您(类型化的孔 _
非常有用)。基本上,只要需要范围内类型的变量,就使用它;当需要不在范围内的向量空间类型的变量时,请使用zeroV
.
到那时,您将真正拥有所有基本的可微函数工具,但要实际定义此类函数,您需要使用带有大量.
,&&&
和***
组合子以及硬编码数字基元的无点样式,这看起来非常规,而不是令人困惑。为避免这种情况,您可以使用代理值:这些值基本上伪装成简单的数字变量,但实际上包含来自某个固定域类型的整个类别箭头。(这基本上是练习的“构建图表”部分。)您可以简单地使用提供的GenericAgent
包装器。
instance HasAgent FwAD where
type AgentVal FwAD a v = GenericAgent FwAD a v
alg = genericAlg
($~) = genericAgentMap
instance CartesianAgent FwAD where
alg1to2 = genericAlg1to2
alg2to1 = genericAlg2to1
alg2to2 = genericAlg2to2
instance PointAgent (GenericAgent FwAD) FwAD a x where
point = genericPoint
instance ( Num v, AffineSpace v, Diff v ~ v, VectorSpace v, Scalar v ~ v
, Scalar a ~ v )
=> Num (GenericAgent FwAD a v) where
fromInteger = point . fromInteger
(+) = genericAgentCombine . FwAD $ \(x,y) -> (x+y, \(dx,dy) -> dx+dy)
(*) = genericAgentCombine . FwAD $ \(x,y) -> (x*y, \(dx,dy) -> y*dx+x*dy)
abs = genericAgentMap . FwAD $ \x -> (abs x, \dx -> if x<0 then -dx else dx)
...
instance ( Fractional v, AffineSpace v, Diff v ~ v, VectorSpace v, Scalar v ~ v
, Scalar a ~ v )
=> Fractional (GenericAgent FwAD a v) where
...
instance (...) => Floating (...) where
...
如果您完成了所有这些实例,也许还有一个简单的助手来提取结果
evalWithGrad :: FwAD Double Double -> Double -> (Double, Double)
evalWithGrad (FwAD f) x = case f x of
(fx, df) -> (fx, df 1)
然后您可以编写代码,例如
> evalWithGrad (alg (\x -> x^2 + x) 3)
(12.0, 7.0)
> evalWithGrad (alg sin 0)
(0.0, 1.0)
在底层,这些代数表达式构建了一个FwAD
箭头组合, &&&
“拆分”数据流并***
并行组合,即即使输入和最终结果很简单Double
,中间结果也将通过合适的元组类型提取。[我想,这将是您标题问题的答案:有向图在某种意义上表示为分支组合链,原则上与您在s的那些图表解释中Arrow
找到的相同。]