3

gen.state_dict()我试图使用 MNIST 数据集训练 DCGAN 模型,但在完成训练后无法加载。

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import os
from torch.autograd import Variable

workspace_dir = '/content/drive/My Drive/practice'
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)

img_size=64
channel_img=1
lr=2e-4
batch_size=128
z_dim=100
epochs=10
features_gen=64
features_disc=64
save_dir = os.path.join(workspace_dir, 'logs')
os.makedirs(save_dir, exist_ok=True)
import matplotlib.pyplot as plt
transforms=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),std=(0.5,))])
train_data=datasets.MNIST(root='dataset/',train=True,transform=transforms,download=True)
train_loader=torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True)
count=0
for x,y in train_loader:
  if count==5:
    break
  print(x.shape,y.shape)
  count+=1

class Discriminator(nn.Module):
  def __init__(self,channels_img,features_d):
    super(Discriminator,self).__init__()
    
    self.disc=nn.Sequential(
        #input:N * channels_img * 64 *64
        nn.Conv2d(channels_img,features_d,4,2,1),#paper didn't use batchnorm in the early layers in the discriminator features_d* 32 *32
        nn.LeakyReLU(0.2),
        self._block(features_d,features_d*2,4,2,1),#features_d*2 *16 *16
        self._block(features_d*2,features_d*4,4,2,1),#features_d*4 *8 *8
        self._block(features_d*4,features_d*8,4,2,1), #features_d*8 *4 *4
        nn.Conv2d(features_d*8,1,4,2,0),#1 * 1 *1
        nn.Sigmoid()

    )
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )


  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self,Z_dim,channels_img,features_g):
    super(Generator,self).__init__()
    
    self.gen=nn.Sequential(
        #input :n * z_dim * 1 *1
        self._block(Z_dim,features_g*16,4,1,0),#features_g*16 * 4 * 4
        self._block(features_g*16,features_g*8,4,2,1),#features_g*8 * 8 * 8
        self._block(features_g*8,features_g*4,4,2,1),#features_g*4 * 16 * 16
        self._block(features_g*4,features_g*2,4,2,1),#features_g*2 * 32 * 32
        nn.ConvTranspose2d(features_g*2,channels_img,4,2,1), #
        nn.Tanh()# [-1 to 1] normalize the image
    )
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
          nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),#w'=(w-1)*s-2p+k
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )
    
  def forward(self,x):
      return self.gen(x)


def initialize_weights(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

gen=Generator(z_dim,channel_img,features_gen).to(device)
disc=Discriminator(channel_img,features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)
opt_gen=torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
opt_disc=torch.optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))
criterion=nn.BCELoss()
#fixed_noise=torch.randn(32,z_dim,1,1).to(device)
#writer_real=SummaryWriter(f"logs/real")
#writer_fake=SummaryWriter(f"logs/fake")
step=0
gen.train()
disc.train()


z_sample = Variable(torch.randn(100, z_dim,1,1)).cuda()
for epoch in range(2):
  for batch_idx,(real,_) in enumerate(train_loader):
    real=real.to(device)
    noise=torch.randn((batch_size,z_dim,1,1)).to(device)
    fake=gen(noise)
    
    #Train Discriminator max log(D(x)) + log(1-D(G(z)))
    disc_real=disc(real).reshape(-1)
    loss_disc_real=criterion(disc_real,torch.ones_like(disc_real))
    disc_fake=disc(fake).reshape(-1)
    loss_disc_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
    loss_disc=(loss_disc_fake+loss_disc_real)/2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    #Train Generator  min log(1-D(G(z))) <--> max log(D(G(z)))
    output=disc(fake).reshape(-1)
    loss_gen=criterion(output,torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    
    print(f'\rEpoch [{epoch+1}/{3}] {batch_idx+1}/{len(train_loader)} Loss_D: {loss_disc.item():.4f} Loss_G: {loss_gen.item():.4f}', end='')
  gen.eval()
  f_imgs_sample = (gen(z_sample).data + 1) / 2.0
  filename = os.path.join(save_dir, f'Epoch_{epoch+1:03d}.jpg')
  torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
  print(f' | Save some samples to {filename}.')
  # show generated image
  grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
  plt.figure(figsize=(10,10))
  plt.imshow(grid_img.permute(1, 2, 0))
  plt.show()
  gen.train()
  
  torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth'))
  torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth'))
  

我无法在此步骤中加载 gen state_dict:

# load pretrained model
#gen = Generator(z_dim,1,64)
gen=Generator(z_dim,channel_img,features_gen).to(device)
gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
gen.eval()
gen.cuda()

这是错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-4bda27faa444> in <module>()
      5 
      6 #gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
----> 7 gen.load_state_dict(torch.load(os.path.join(workspace_dir, 'dcgan_g.pth')))
      8 #/content/drive/My Drive/practice/dcgan_g.pth
      9 gen.eval()

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1050         if len(error_msgs) > 0:
   1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1054 

***RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict***: "gen.0.0.weight", "gen.0.1.weight", "gen.0.1.bias", "gen.0.1.running_mean", "gen.0.1.running_var", "gen.1.0.weight", "gen.1.1.weight", "gen.1.1.bias", "gen.1.1.running_mean", "gen.1.1.running_var", "gen.2.0.weight", "gen.2.1.weight", "gen.2.1.bias", "gen.2.1.running_mean", "gen.2.1.running_var", "gen.3.0.weight", "gen.3.1.weight", "gen.3.1.bias", "gen.3.1.running_mean", "gen.3.1.running_var", "gen.4.weight", "gen.4.bias". 
    Unexpected key(s) in state_dict: "disc.0.weight", "disc.0.bias", "disc.2.0.weight", "disc.2.1.weight", "disc.2.1.bias", "disc.2.1.running_mean", "disc.2.1.running_var", "disc.2.1.num_batches_tracked", "disc.3.0.weight", "disc.3.1.weight", "disc.3.1.bias", "disc.3.1.running_mean", "disc.3.1.running_var", "disc.3.1.num_batches_tracked", "disc.4.0.weight", "disc.4.1.weight", "disc.4.1.bias", "disc.4.1.running_mean", "disc.4.1.running_var", "disc.4.1.num_batches_tracked", "disc.5.weight", "disc.5.bias".
4

1 回答 1

3

你用错误的名字保存了权重。也就是说,您将生成器的权重dcgan_d.pth保存为,同样,将鉴别器的权重保存为dcgan_g.pth

  torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth')) # should have been dcgan_g.pth
  torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth')) # should have been dcgan_d.pth

因此在加载时,您尝试加载错误的权重:

gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))

dcgan_g.pth包含鉴别器权重而不是您的生成器。保存时首先修复错误的名称。其次,只需相应地重命名它们就可以了。

于 2020-11-09T05:12:15.010 回答