1

Turing.jl 有一个指南,展示了如何编写允许您传递数据或missing参数的模型。在第一种情况下,它将为您提供后验,在第二种情况下,它将从所有变量中提取。

设置

using Turing
iterations = 1000
ϵ = 0.05
τ = 10

工作示例是这样的:

@model gdemo(x, ::Type{T} = Float64) where {T} = begin
    if x === missing
        # Initialize `x` if missing
        x = Vector{T}(undef, 2)
    end
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s))
    end
end

# Construct a model with x = missing
model = gdemo(missing)
c = sample(model, HMC(0.01, 5), 500);

我试图将此代码改编为简单的硬币翻转。我认为,bernoulli分布应该有数据类型BoolInt64不过我也试过)。我试过这样做:

@model coinflip2(y) = begin
    if y === missing
        y = Vector{Bool}(undef, 4)
    end
    p ~ Beta(1, 1)
    N = length(y)
    for n in 1:N
        y[n] ~ Bernoulli(p)
    end
end;

chain = sample(coinflip2(missing), HMC(ϵ, τ), iterations);

但它给了我一个很长的错误消息,我没有得到,但很可能指向一个类型问题:

MethodError: Bool(::ForwardDiff.Dual{ForwardDiff.Tag{Turing.Core.var"#f#7"{DynamicPPL.VarInfo{NamedTuple{(:p, :y),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{Beta{Float64},1},

试图从gdemo上面的工作函数中模仿语法(我没有完全理解),给你:

@model coinflip2(y, ::Type{T} = Bool) where {T} = begin
    if y === missing
        y = Vector{T}(undef, 4)
    end
    p ~ Beta(1, 1)
    N = length(y)
    for n in 1:N
        y[n] ~ Bernoulli(p)
    end
end;

chain = sample(coinflip2(missing), HMC(ϵ, τ), iterations);

这也会失败并显示以以下开头的错误消息:

MethodError: Bool(::ForwardDiff.Dual{ForwardDiff.Tag{Turing.Core.var"#f#7"{DynamicPPL.VarInfo{NamedTuple{(:p, :y),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:p,Tuple{}},Int64},Array{Beta{Float64},1}

我该如何正确地写这个?解释我做错了什么的奖励积分:D 谢谢!

4

1 回答 1

2

Hamiltonian Monte Carlo 仅适用于连续分布,因为它可以区分 PDF 的函数。当您从联合模型中采样时(p, y),这也适用于随机变量y,因为是伯努利,它肯定不是连续的。因此,后台使用的 AD 系统(ForwardDiff默认)会报错。

你可以通过指定一个非哈密顿采样器来解决这个y问题Gibbs

chain = sample(coinflip2(missing), Gibbs(HMC(ϵ, τ, :p), MH(:y)), iterations)

MH 可能不是最佳选择,您必须自己考虑,但它确实有效。

请注意,这对于参数估计不是必需的:just pgiveny是连续的并且使 HMC 高兴:

chain = sample(coinflip2([0,1,1,0]), HMC(ϵ, τ), iterations)
于 2020-05-13T07:39:44.343 回答