1
4

1 回答 1

1

这就是我所拥有的,它应该解决批处理/预循环问题并作为您实际想要做的事情的骨架。我无法遵循缩放和重新缩放。我也无法理解长度为 1 的 3dr 维度,但我一直保留在那里。

请注意,我正在使用构建在 xarray 之上的 ArviZ(PyMC3 的依赖项)来使用基于标签的索引和自动广播。而且我实际上正在使用 ArviZ(开发版)和 PyMC3 3.11.1 中添加的最新功能。您可以使用pip install git+git://github.com/arviz-devs/arviz.git.

我从模型中的几个更改开始,以定义具有命名尺寸的变量的形状:

coords = {
    "batch": np.arange(n_batch),
    "pred_num": ["pred1", "pred2"],
    "one": [1]
}
with pm.Model(coords=coords) as hierarchical_model1:
    ...

    # Intercept for each batch, distributed around group mean mu_a
    a = pm.Normal('a', mu=mu_a, sd=sigma_a, dims=("batch", "pred_num", "one"))
    # Slope for each batcht, distributed around group mean mu_b
    b = pm.Normal('b', mu=mu_b, sd=sigma_b, dims=("batch", "pred_num", "one"))
    
    ...

其余代码完全相同,除了在我添加了 kwarg 的地方进行采样return_inferencedata=True

采样后,我从pred数组创建了一个 xarray 数据集,以便利用我上面提到的 xarray 提供的自动广播。正如我设置的那样return_inferencedata=True,我已经在 xarray 对象中拥有模型中的所有变量。要创建 xarray 变量,我们必须提供值、维度名称和坐标变量:

import xarray as xr
pred_ds = xr.DataArray(
    pred, dims=("x", "pred_num"), coords={"pred_num": coords["pred_num"]}
).to_dataset(name="pred")
post = hier_trace.posterior
y_ds = (post["mu_a"]+post["a"])+(post["mu_b"]+post["b"])*pred_ds
y_ds # in jupyter you'll see a nice html interactive summary
# here is the plain text output
# <xarray.Dataset>
# Dimensions:   (batch: 9, chain: 4, draw: 1000, one: 1, pred_num: 2, x: 48)
# Coordinates:
#   * pred_num  (pred_num) <U5 'pred1' 'pred2'
#   * chain     (chain) int64 0 1 2 3
#   * draw      (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
#   * batch     (batch) int64 0 1 2 3 4 5 6 7 8
#   * one       (one) int64 1
# Dimensions without coordinates: x
# Data variables:
#     pred      (chain, draw, batch, pred_num, one, x) float64 1.938e+03 ... 67.98

然后,我们可以使用 ArviZ 循环所需的变量,并为批处理和 pred 的每种组合生成一个子图。虽然它看起来很奇怪,因为根本没有重新缩放。

import matplotlib.pyplot as plt
fig, axes = plt.subplots(n_pred, n_batch, figsize=(11,5), sharex=True, sharey="row")
x = np.arange(len(batch_id))

第一步是创建我们将用于绘图的 suplots,以及一个表示 x 的虚拟变量。

from arviz.labels import BaseLabeller
labeller = BaseLabeller()
iterator = az.sel_utils.xarray_var_iter(
    y_ds.squeeze().transpose("chain", "draw", "x", "pred_num", "batch"), 
    combined=True, 
    skip_dims={"x",}
)

然后我们创建这个迭代器对象,贴标器纯粹是美学的(并且在此处有详细的文档)。第一步是转置y_ds。这也是美学的,如果批次维度出现在之前,那么在绘图时我们不会在同一行看到所有 pred1。combined=Trueskip_dims表示我们不想迭代链或x维度中的值(默认情况下,ArviZ 不会迭代绘制。现在我们可以使用此迭代器循环获取具有 chain, draw, x维度的数组,以便我们可以绘制HDI,意思是。 ..并得到一个看起来像问题中的例子的情节:

for ax, (_, sel, isel, values) in zip(axes.ravel(), iterator):
    az.plot_hdi(x, values, ax=ax)
    ax.plot(x, values.mean(axis=(0,1)))
    ax.set_title(labeller.make_label_vert("y", sel, isel))
fig.tight_layout()

阴谋

请注意,y 轴仅在行之间共享,pred1 和 pred2 的比例是一个数量级的不同,即使它们与此绘图布局看起来相似。

旁注:我遇到了很多分歧,大多数rhat值都在推荐的1.01以上。我很确定 NUTS 没有收敛,你可能不应该相信这些结果。

于 2021-02-28T01:18:07.047 回答