我遇到了一个我以前从未见过的问题。我从事贝叶斯机器学习工作,因此大量使用 PyTorch 中的分布。一件常见的事情是根据参数的对数定义分布的一些参数,以便在优化时它们不能变为负数(例如,正态分布的标准偏差)。
然而,为了独立于分布,我不想手动重新计算此参数的转换。通过示例进行演示:
以下代码将不会运行。在第一次反向传递之后,计算参数指数的部分图形被自动删除,而不是重新添加。
import torch
import torch.nn as nn
import torch.distributions as dd
log_std = nn.Parameter(torch.Tensor([1])) # Define the log of the parameter as an nn.Parameter, this is what we want to optimise
std = torch.exp(log_std) # Define the transformation we want to apply to the parameter to using it in the distribution
mean = nn.Parameter(torch.Tensor([1])) # A normal parameter
dist = dd.Normal(loc=mean, scale=std) # Define the distribution. From here I want to ONLY refer to this, not the other variables
optim = torch.optim.SGD([log_std, mean], lr=0.01) # Standard optimiser
target = dd.Normal(5,5) # Target distribution to match
for i in range(50):
optim.zero_grad()
samples = dist.rsample((1000,)) # Sample our model, note no reference to log_std
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum() # KLdivergence cost metric
cost.backward()
optim.step()
print(i)
print(log_std, mean, cost)
print()
下一组代码将运行,但我必须明确引用log_std
循环中的参数,并重新创建分布。如果我想改变分布类型,不考虑具体情况是不可能的。
import torch
import torch.nn as nn
import torch.distributions as dd
log_std = nn.Parameter(torch.Tensor([1])) # Define the log of the parameter as an nn.Parameter, this is what we want to optimise
mean = nn.Parameter(torch.Tensor([1])) # A normal parameter
optim = torch.optim.SGD([log_std, mean], lr=0.001) # Standard optimiser
target = dd.Normal(5,5) # Target distribution to match
for i in range(50):
optim.zero_grad()
std = torch.exp(log_std) # Define the transformation we want to apply to the parameter to using it in the distribution
dist = dd.Normal(loc=mean, scale=std) # Define the distribution.
samples = dist.rsample((1000,)) # Sample our model, note no reference to log_std
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum() # KL divergence cost metric
cost.backward()
optim.step()
print(i)
print(mean, std, cost)
print()
然而,第一个示例在 Tensorflow 中确实有效,因为那里的图表是静态的。有人对我如何解决这个问题有一些想法吗?如果可以只保留定义关系的图形部分,std = torch.exp(log_std)
那么这可以工作。我也尝试过使用反向梯度挂钩,但不幸的是,要正确计算新梯度,您需要访问参数值和学习率。
提前致谢!迈克尔
编辑
我被问到一个我可能想如何改变分布的例子。获取当前不起作用的代码,并将分布更改为 Gamma 分布:
import torch
import torch.nn as nn
import torch.distributions as dd
log_rate = nn.Parameter(torch.Tensor([1])) # Define the log of the parameter as an nn.Parameter, this is what we want to optimise
rate = torch.exp(log_std) # Define the transformation we want to apply to the parameter to usi it in the distribution
concentration = nn.Parameter(torch.Tensor([1])) # A normal parameter
dist = dd.Gamma(concentration=concentration, rate=std) # Define the distribution. From here I want to ONLY refer to this, not the other variables
optim = torch.optim.SGD([log_rate, concentration], lr=0.01) # Standard optimiser
target = dd.Gamma(5,5) # Target distribution to match
for i in range(50):
optim.zero_grad()
samples = dist.rsample((1000,)) # Sample our model, note no reference to log_std
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum() # KL divergence cost metric
cost.backward()
optim.step()
print(i)
print(log_std, mean, cost)
print()
但是查看当前有效的代码:
import torch
import torch.nn as nn
import torch.distributions as dd
log_rate = nn.Parameter(torch.Tensor([1])) # Define the log of the parameter as an nn.Parameter, this is what we want to optimise
mean = nn.Parameter(torch.Tensor([1])) # A normal parameter
optim = torch.optim.SGD([log_rate, concentration], lr=0.001) # Standard optimiser
target = dd.Gamma(5,5) # Target distribution to match
for i in range(50):
optim.zero_grad()
rate = torch.exp(log_rate) # Define the transformation we want to apply to the parameter to usi it in the distribution
dist = dd.Gamma(concentration=concentration, rate=rate) # Define the distribution.
samples = dist.rsample((1000,)) # Sample our model, note no reference to log_std
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum() # KL divergence cost metric
cost.backward()
optim.step()
print(i)
print(mean, std, cost)
print()
您会看到我们必须更改循环内的代码以允许算法工作。在这个小例子中它不是一个大问题,但这只是对更大的算法的一个演示,在这种情况下不必担心会非常有益