我正在尝试构建具有多个离散随机变量和条件概率的贝叶斯网络的最简单示例(科勒书中的“学生网络”,请参见1)
虽然有点笨拙,但我设法使用 pymc3 构建了这个网络。特别是,在 pymc3 中创建 CPD 并不是那么简单,请参见下面的代码片段:
import pymc3 as pm
...
with pm.Model() as basic_model:
# parameters for categorical are indexed as [0, 1, 2, ...]
difficulty = pm.Categorical(name='difficulty', p=[0.6, 0.4])
intelligence = pm.Categorical(name='intelligence', p=[0.7, 0.3])
grade = pm.Categorical(name='grade',
p=pm.math.switch(
theano.tensor.eq(intelligence, 0),
pm.math.switch(
theano.tensor.eq(difficulty, 0),
[0.3, 0.4, 0.3], # I=0, D=0
[0.05, 0.25, 0.7] # I=0, D=1
),
pm.math.switch(
theano.tensor.eq(difficulty, 0),
[0.9, 0.08, 0.02], # I=1, D=0
[0.5, 0.3, 0.2] # I=1, D=1
)
)
)
letter = pm.Categorical(name='letter', p=pm.math.switch(
...
但我不知道如何使用 tensoflow-probability (versions: tfp-nightly==0.7.0.dev20190517
, tf-nightly-2.0-preview==2.0.0.dev20190517
)构建这个网络
对于无条件的二元变量,可以使用分类分布,例如
from tensorflow_probability import distributions as tfd
from tensorflow_probability import edward2 as ed
difficulty = ed.RandomVariable(
tfd.Categorical(
probs=[0.6, 0.4],
name='difficulty'
)
)
但是如何构建 CPD?
tensorflow-probability 中可能相关的类/方法很少(在tensorflow_probability/python/distributions/deterministic.py
或已弃用ConditionalDistribution
),但文档相当稀疏(需要深入了解 tfp)。
--- 更新问题 ---
克里斯的回答是一个很好的起点。然而,即使对于一个非常简单的二变量模型,事情仍然有点不清楚。
这很好用:
jdn = tfd.JointDistributionNamed(dict(
dist_x=tfd.Categorical([0.2, 0.8], validate_args=True),
dist_y=lambda dist_x: tfd.Bernoulli(probs=tf.gather([0.1, 0.9], indices=dist_x), validate_args=True)
))
print(jdn.sample(10))
但是这个失败了
jdn = tfd.JointDistributionNamed(dict(
dist_x=tfd.Categorical([0.2, 0.8], validate_args=True),
dist_y=lambda dist_x: tfd.Categorical(probs=tf.gather_nd([[0.1, 0.9], [0.5, 0.5]], indices=[dist_x]))
))
print(jdn.sample(10))
(我试图在第二个例子中明确地建模分类只是为了学习目的)
-- 更新:已解决 ---
显然,最后一个示例使用错误,tf.gather_nd
而不是tf.gather
因为我们只想根据dist_x
outome 选择第一行或第二行。此代码现在有效:
jdn = tfd.JointDistributionNamed(dict(
dist_x=tfd.Categorical([0.2, 0.8], validate_args=True),
dist_y=lambda dist_x: tfd.Categorical(probs=tf.gather([[0.1, 0.9], [0.5, 0.5]], indices=[dist_x]))
))
print(jdn.sample(10))