我正在做一些机器学习自学,目前我正在实施反向模式自动微分作为实践。
该程序的工作方式本质上是重载常用表达式,如乘法、加法等,并构建一棵树,其节点随后将在从上到下的反向传递中被调用。这可能是我第一次在 F# 中使用闭包,我非常喜欢它们,但不幸的是,我怀疑它们会阻塞 GC,尽管我不知道如何验证这一点。我宁愿先在这里问,也不愿重新设计算法,所以它不使用它们,因为它们非常方便。
所述树在程序运行期间建立了很多次,我指望垃圾收集器来处理它。
程序的底部是在 XOR 问题上训练 2 层网络的主循环。在循环之前或循环之后调用 GC.Collect() 然后再次运行它会导致它无限期中断。
这是我正在使用的数据结构。我想要一些建议,我的预感是否正确,我应该重新设计算法以不使用闭包或其他东西是否有问题。
type dMatrix(num_rows:int,num_cols,dArray: DeviceMemory<float32>) =
inherit DisposableObject()
new(num_rows,num_cols) =
new dMatrix(num_rows,num_cols,worker.Malloc<float32>(num_rows*num_cols))
member t.num_rows = num_rows
member t.num_cols = num_cols
member t.dArray = dArray
override net.Dispose(disposing:bool) =
if disposing then
dArray.Dispose()
type Df_rec = {
P: float32
mutable c : int
mutable A : float32
}
type DM_rec = {
P: dMatrix
mutable c : int
mutable A : dMatrix
}
type Rf =
| DfR_Df_DM of Df_rec * (float32 -> dMatrix) * RDM
| DfR_Df_Df of Df_rec * (float32 -> float32) * Rf
and RDM =
| DM of DM_rec
| DMRb of DM_rec * (dMatrix -> dMatrix) * (dMatrix -> dMatrix) * RDM * RDM // Outside node * left derivative function * right derivative func * prev left node * prev right node.
| DMRu of DM_rec * (dMatrix -> dMatrix) * RDM
它使用了许多功能,如下面的一个。sgemm 是 cuBLAS sgemm 包装器。fl out 和 fr out 是闭包。它们很方便,但 GC 可能会发现清理树具有挑战性。
let matmult (a: RDM) (b:RDM) =
let mm va vb =
let c = sgemm nT nT 1.0f va vb
let fl out = sgemm nT T 1.0f out vb // The derivative with respect to the left. So the above argument gets inserted from the right left. Usually error * input.
let fr out = sgemm T nT 1.0f va out // The derivative with respect to the right. So the above argument gets inserted from the right side. Usually weights * error.
DMRb(DM_rec.create c,fl,fr,a,b)
let va = a.r.P
let vb = b.r.P
mm va vb
理想情况下,我会重用树,但因为 AD 可以通过循环和分支传播错误梯度,这不是一个选项,并且程序必须进行大量动态分配,即使它很慢。关于如何有效处理这些树的任何建议?
谢谢。
这是带有完整程序的Github 页面的链接。
编辑:我已将代码转换为不使用闭包,但它仍然在垃圾收集时崩溃。我不确定引擎盖下发生了什么。
Edit2:我终于想到使用调试器,我看到 Alea 库抛出了“System.AccessViolationException”。这可能与我在 sum 模块中报告的早期错误有关。事实上,我之前看到过一两次未对齐的访问,但我的大脑出于某种原因忽略了它。明天我会尝试隔离错误。
Edit3:解决了这个问题。这是由于错误的 Alea reduce 模块造成的,与我使用的闭包或树结构无关。
此外,我已经验证没有理由不使用闭包,因为与显式传递参数相比,它们不涉及任何类型的性能损失。这很好,也很令人惊讶。