2

我正在关注预测中的 Pyro 入门教程,并在训练模型后尝试访问学习的参数,我对其中一些使用不同的访问方法得到了不同的结果(而对另一些得到相同的结果)。

这是教程中精简的可重现代码:

import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.examples.bart import load_bart_od
from pyro.contrib.forecast import ForecastingModel, Forecaster

pyro.enable_validation(True)
pyro.clear_param_store()

pyro.__version__
# '1.3.1'
torch.__version__
# '1.5.0+cu101'

# import & prepare the data
dataset = load_bart_od()
T, O, D = dataset["counts"].shape
data = dataset["counts"][:T // (24 * 7) * 24 * 7].reshape(T // (24 * 7), -1).sum(-1).log()
data = data.unsqueeze(-1)
T0 = 0              # begining
T2 = data.size(-2)  # end
T1 = T2 - 52        # train/test split

# define the model class
class Model1(ForecastingModel):

    def model(self, zero_data, covariates):
        data_dim = zero_data.size(-1)  
        feature_dim = covariates.size(-1)

        bias = pyro.sample("bias", dist.Normal(0, 10).expand([data_dim]).to_event(1))
        weight = pyro.sample("weight", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))
        prediction = bias + (weight * covariates).sum(-1, keepdim=True)
        assert prediction.shape[-2:] == zero_data.shape

        noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5).expand([1]).to_event(1))
        noise_dist = dist.Normal(0, noise_scale)

        self.predict(noise_dist, prediction)

# fit the model
pyro.set_rng_seed(1)
pyro.clear_param_store()
time = torch.arange(float(T2)) / 365
covariates = torch.stack([time], dim=-1)
forecaster = Forecaster(Model1(), data[:T1], covariates[:T1], learning_rate=0.1)

到目前为止,一切都很好; 现在,我想检查存储在Paramstore. 似乎有不止一种方法可以做到这一点;使用get_all_param_names()方法:

for name in pyro.get_param_store().get_all_param_names():
    print(name, pyro.param(name).data.numpy())

我明白了

AutoNormal.locs.bias [14.585433]
AutoNormal.scales.bias [0.00631594]
AutoNormal.locs.weight [0.11947815]
AutoNormal.scales.weight [0.00922901]
AutoNormal.locs.noise_scale [-2.0719821]
AutoNormal.scales.noise_scale [0.03469057]

但是使用named_parameters()方法:

pyro.get_param_store().named_parameters()

为 location ( locs) 参数提供相同的值,但为所有参数提供不同scales的值:

dict_items([
('AutoNormal.locs.bias', Parameter containing: tensor([14.5854], requires_grad=True)), 
('AutoNormal.scales.bias', Parameter containing: tensor([-5.0647], requires_grad=True)), 
('AutoNormal.locs.weight', Parameter containing: tensor([0.1195], requires_grad=True)), 
('AutoNormal.scales.weight', Parameter containing: tensor([-4.6854], requires_grad=True)),
('AutoNormal.locs.noise_scale', Parameter containing: tensor([-2.0720], requires_grad=True)), 
('AutoNormal.scales.noise_scale', Parameter containing: tensor([-3.3613], requires_grad=True))
])

这怎么可能?根据文档Paramstore是一个简单的键值存储;里面只有这六个键:

pyro.get_param_store().get_all_param_names() # .keys() method gives identical result
# result
dict_keys([
'AutoNormal.locs.bias',
'AutoNormal.scales.bias', 
'AutoNormal.locs.weight', 
'AutoNormal.scales.weight', 
'AutoNormal.locs.noise_scale', 
'AutoNormal.scales.noise_scale'])

因此,不可能一种方法访问一组项目而另一种访问不同的项目。

我在这里错过了什么吗?

4

2 回答 2

1

pyro.param() 在这种情况下,将转换后的参数返回为 的正实数scales

于 2020-05-10T02:14:07.010 回答
1

这是情况,正如我与这个问题同时打开的Github 线程中所揭示的那样......

Paramstore不再只是一个简单的键值存储——它还执行约束转换;从上面的链接中引用 Pyro 开发人员:

这里有一些历史背景。最初只是一个ParamStore键值存储。然后我们添加了对约束参数的支持;这在面向用户的约束值和内部无约束值之间引入了一个新的分离层。我们创建了一个新的类似 dict 的面向用户的界面,它只公开了受约束的值,但为了保持与旧代码的向后兼容性,我们保留了旧界面。这两个接口在源文件中是有区别的 [...] 但正如您所观察到的,我们似乎忘记将旧接口标记为已弃用。

我想在澄清文档时我们应该:

  1. 阐明 ParamStore 不再是一个简单的键值存储,而是执行约束转换;

  2. 将所有“旧”样式接口方法标记为已弃用;

  3. 从示例和教程中删除“旧”样式的界面使用。

结果,事实证明,虽然pyro.param()在受限(面向用户)空间中返回结果,但旧方法named_parameters()返回不受约束(即仅供内部使用)的值,因此存在明显的差异。

不难验证scales上述两种方法返回的值是否通过对数变换相关:

import numpy as np
items = list(pyro.get_param_store().named_parameters())  # unconstrained space

i = 0
for name in pyro.get_param_store().keys():  
  if 'scales' in name:
    temp = np.log(
                  pyro.param(name).item()  # constrained space
                 )
    print(temp, items[i][1][0].item() , np.allclose(temp, items[i][1][0].item()))
  i+=1

# result:
-5.027793402915326 -5.0277934074401855 True
-4.600319371162187 -4.6003193855285645 True
-3.3920585732532835 -3.3920586109161377 True

为什么这种差异只影响scales参数?这是因为scales(即本质上的方差)根据定义被限制为正数;这不适用于locs(即意味着),它们不受约束,因此这两种表示对它们来说是一致的。

由于上述问题,现在在Paramstore 文档中添加了一个新项目符号,给出了相关提示:

通常,参数与受约束不受约束的值相关联。例如,在幕后,一个被约束为正的参数被表示为对数空间中的一个不受约束的张量。

以及在旧接口方法的文档中:named_parameters()

请注意,如果参数受到约束,则unconstrained_value位于约束隐式使用的无约束空间中。

于 2020-05-11T13:57:30.123 回答