1

我需要使用 PyMC3 拟合多级线性模型,我非常喜欢glm api,因为它提供了简洁性。我想问一下是否以及如何做到这一点。我发现的这篇文提到:

glm() 还不能很好地处理分层模型

所以我有点怀疑这实际上是否可以做到,但它是几年前写的,所以这可能已经改变了。举个例子,下面是我想用 glm api 重写的模型

import numpy as np
import pymc3 as pm


def generate_data():
    n, beta_0, beta_1, sd_eps = 100, 1.2, 0.6, 0.2
    b_group = np.array([0.05, 0.14, -0.23])
    x = np.random.randn(n)
    group_index = np.random.choice([0, 1, 2], n)
    y = 1.2 + 0.6 * x + sd_eps * np.random.randn(n) + b_group[group_index]
    return x, group_index, y


if __name__ == '__main__':
    x, group_index, y = generate_data()

    with pm.Model() as multi_level_model:
        sd_b_group = pm.HalfNormal("sd_b_group", sigma=100)

        b_group = pm.Normal("b_group", mu=0, sigma=sd_b_group, shape=3)

        beta_0 = pm.Normal("beta_0", mu=0, sigma=100)
        beta_1 = pm.Normal("beta_1", mu=0, sigma=100)

        sd_eps = pm.HalfNormal("sd_eps", sigma=100)
        pm.Normal("y", beta_0 + beta_1 * x + b_group[group_index],
                  sigma=sd_eps, observed=y)

我认为它看起来像这样:

with pm.Model():
    mu_b_group = pm.Normal("mu_b_group", mu=0, sigma=100)
    sd_b_group = pm.HalfNormal("sd_b_group", sigma=100)

    b_group = pm.Normal("b_group", mu=mu_b_group, sigma=sd_b_group, shape=3)

    pm.glm.GLM.from_formula(formula="y ~ 1 + x",
                            vars={"Intercept": b_group[group_index]},
                            data={"y": y, "x": x})

但是,在尝试堆叠系数时会在此处产生错误

TypeError: Join() can only join tensors with the same number of dimensions.
4

1 回答 1

0

经过一些试验,我想出了这个(不理想,但至少有点用)的解决方案:

with pm.Model():
    sd_b_group = pm.HalfNormal("sd_b_group", sigma=100)
    b_group = pm.Normal("b_group", mu=0, sigma=sd_b_group, shape=3)

    lm = pm.glm.LinearComponent.from_formula(formula="y ~ 1 + x",
                                             data={"y": y, "x": x})

    sd_eps = pm.HalfNormal("sd_eps", sigma=100)

    likelihood = pm.Normal("y", mu=lm.y_est + b_group[group_index], 
                           sigma=sd_eps, observed=y)        

还在github上创建了一个示例以供将来参考。

于 2020-10-27T07:38:15.087 回答