4

我正在做一些机器学习自学,目前我正在实施反向模式自动微分作为实践。

该程序的工作方式本质上是重载常用表达式,如乘法、加法等,并构建一棵树,其节点随后将在从上到下的反向传递中被调用。这可能是我第一次在 F# 中使用闭包,我非常喜欢它们,但不幸的是,我怀疑它们会阻塞 GC,尽管我不知道如何验证这一点。我宁愿先在这里问,也不愿重新设计算法,所以它不使用它们,因为它们非常方便。

所述树在程序运行期间建立了很多次,我指望垃圾收集器来处理它。

程序的底部是在 XOR 问题上训练 2 层网络的主循环。在循环之前或循环之后调用 GC.Collect() 然后再次运行它会导致它无限期中断。

这是我正在使用的数据结构。我想要一些建议,我的预感是否正确,我应该重新设计算法以不使用闭包或其他东西是否有问题。

type dMatrix(num_rows:int,num_cols,dArray: DeviceMemory<float32>) = 
    inherit DisposableObject()

    new(num_rows,num_cols) =
        new dMatrix(num_rows,num_cols,worker.Malloc<float32>(num_rows*num_cols))

    member t.num_rows = num_rows
    member t.num_cols = num_cols
    member t.dArray = dArray

    override net.Dispose(disposing:bool) =
        if disposing then
            dArray.Dispose()

type Df_rec = {
    P: float32 
    mutable c : int 
    mutable A : float32
    }

type DM_rec = {
    P: dMatrix 
    mutable c : int 
    mutable A : dMatrix
    }

type Rf =
    | DfR_Df_DM of Df_rec * (float32 -> dMatrix) * RDM
    | DfR_Df_Df of Df_rec * (float32 -> float32) * Rf

and RDM = 
    | DM of DM_rec
    | DMRb of DM_rec * (dMatrix -> dMatrix) * (dMatrix -> dMatrix) * RDM * RDM // Outside node * left derivative function * right derivative func * prev left node * prev right node.
    | DMRu of DM_rec * (dMatrix -> dMatrix) * RDM

它使用了许多功能,如下面的一个。sgemm 是 cuBLAS sgemm 包装器。fl out 和 fr out 是闭包。它们很方便,但 GC 可能会发现清理树具有挑战性。

let matmult (a: RDM) (b:RDM) =
    let mm va vb =
        let c = sgemm nT nT 1.0f va vb
        let fl out = sgemm nT T 1.0f out vb // The derivative with respect to the left. So the above argument gets inserted from the right left. Usually error * input.
        let fr out = sgemm T nT 1.0f va out // The derivative with respect to the right. So the above argument gets inserted from the right side. Usually weights * error.
        DMRb(DM_rec.create c,fl,fr,a,b)
    let va = a.r.P
    let vb = b.r.P
    mm va vb

理想情况下,我会重用树,但因为 AD 可以通过循环和分支传播错误梯度,这不是一个选项,并且程序必须进行大量动态分配,即使它很慢。关于如何有效处理这些树的任何建议?

谢谢。

这是带有完整程序的Github 页面的链接。

编辑:我已将代码转换为不使用闭包,但它仍然在垃圾收集时崩溃。我不确定引擎盖下发生了什么。

Edit2:我终于想到使用调试器,我看到 Alea 库抛出了“System.AccessViolationException”。这可能与我在 sum 模块中报告的早期错误有关。事实上,我之前看到过一两次未对齐的访问,但我的大脑出于某种原因忽略了它。明天我会尝试隔离错误。

Edit3:解决了这个问题。这是由于错误的 Alea reduce 模块造成的,与我使用的闭包或树结构无关。

此外,我已经验证没有理由不使用闭包,因为与显式传递参数相比,它们不涉及任何类型的性能损失。这很好,也很令人惊讶。

4

0 回答 0