2

我正在训练一个火炬模型,我想在其中冻结(然后解冻)某些参数。我的印象是,简单的设置 param.requires_grad = False 就可以做到这一点。对于有动力的优化器来说,情况似乎并非如此。我知道我可以实例化一个新的优化器或更改现有优化器的参数,但两者都不允许我(轻松地)解冻参数并且不保留对优化器最初更改的所有参数的额外引用。

我认为可以通过将优化器状态下的momentum_buffer设置为零来实现预期的结果,但我不知道如何做到这一点,因为它不容易访问。

下面的代码可用于重现效果,两个已知的“解决方案”都被注释掉了。

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view((x.size()[0], -1))

def main():
    data = torchvision.datasets.MNIST("./data",download=True,
                                       transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                        ]))
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=1000,
                                              shuffle=True)

    net=nn.Sequential(*[Flatten(),
                    nn.Linear(28*28,100),
                    nn.ReLU(),
                    nn.Linear(100,10)])
    opt=torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

    for e in range(2):
        old_params = [p.clone() for p in net.parameters()]
        if e == 1:
            for j,p in enumerate(net.parameters()):
                if j<2:
                    p.requires_grad = False
            # opt.param_groups[0]['params'] = opt.param_groups[0]['params'][2:]

        # opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

        for data, label in tqdm(data_loader):
            loss=torch.nn.functional.cross_entropy(net(data),label)
            opt.zero_grad()
            loss.backward()
            opt.step()
        print(loss)

        new_params=[p.clone() for p in net.parameters()]
        change = [(~(p1 == p2).all()).item() for p1, p2 in zip(old_params, new_params)]
        print("Epoch: %d \t params changed: %s" % (e, change))
        print([p.requires_grad for p in net.parameters()])


if __name__ == '__main__':
    main()
4

0 回答 0