我正在训练一个火炬模型,我想在其中冻结(然后解冻)某些参数。我的印象是,简单的设置
param.requires_grad = False
就可以做到这一点。对于有动力的优化器来说,情况似乎并非如此。我知道我可以实例化一个新的优化器或更改现有优化器的参数,但两者都不允许我(轻松地)解冻参数并且不保留对优化器最初更改的所有参数的额外引用。
我认为可以通过将优化器状态下的momentum_buffer设置为零来实现预期的结果,但我不知道如何做到这一点,因为它不容易访问。
下面的代码可用于重现效果,两个已知的“解决方案”都被注释掉了。
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view((x.size()[0], -1))
def main():
data = torchvision.datasets.MNIST("./data",download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
]))
data_loader = torch.utils.data.DataLoader(data,
batch_size=1000,
shuffle=True)
net=nn.Sequential(*[Flatten(),
nn.Linear(28*28,100),
nn.ReLU(),
nn.Linear(100,10)])
opt=torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
for e in range(2):
old_params = [p.clone() for p in net.parameters()]
if e == 1:
for j,p in enumerate(net.parameters()):
if j<2:
p.requires_grad = False
# opt.param_groups[0]['params'] = opt.param_groups[0]['params'][2:]
# opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
for data, label in tqdm(data_loader):
loss=torch.nn.functional.cross_entropy(net(data),label)
opt.zero_grad()
loss.backward()
opt.step()
print(loss)
new_params=[p.clone() for p in net.parameters()]
change = [(~(p1 == p2).all()).item() for p1, p2 in zip(old_params, new_params)]
print("Epoch: %d \t params changed: %s" % (e, change))
print([p.requires_grad for p in net.parameters()])
if __name__ == '__main__':
main()