gen.state_dict()
我试图使用 MNIST 数据集训练 DCGAN 模型,但在完成训练后无法加载。
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import os
from torch.autograd import Variable
workspace_dir = '/content/drive/My Drive/practice'
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)
img_size=64
channel_img=1
lr=2e-4
batch_size=128
z_dim=100
epochs=10
features_gen=64
features_disc=64
save_dir = os.path.join(workspace_dir, 'logs')
os.makedirs(save_dir, exist_ok=True)
import matplotlib.pyplot as plt
transforms=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),std=(0.5,))])
train_data=datasets.MNIST(root='dataset/',train=True,transform=transforms,download=True)
train_loader=torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True)
count=0
for x,y in train_loader:
if count==5:
break
print(x.shape,y.shape)
count+=1
class Discriminator(nn.Module):
def __init__(self,channels_img,features_d):
super(Discriminator,self).__init__()
self.disc=nn.Sequential(
#input:N * channels_img * 64 *64
nn.Conv2d(channels_img,features_d,4,2,1),#paper didn't use batchnorm in the early layers in the discriminator features_d* 32 *32
nn.LeakyReLU(0.2),
self._block(features_d,features_d*2,4,2,1),#features_d*2 *16 *16
self._block(features_d*2,features_d*4,4,2,1),#features_d*4 *8 *8
self._block(features_d*4,features_d*8,4,2,1), #features_d*8 *4 *4
nn.Conv2d(features_d*8,1,4,2,0),#1 * 1 *1
nn.Sigmoid()
)
def _block(self,in_channels,out_channels,kernel_size,stride,padding):
return nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def forward(self,x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self,Z_dim,channels_img,features_g):
super(Generator,self).__init__()
self.gen=nn.Sequential(
#input :n * z_dim * 1 *1
self._block(Z_dim,features_g*16,4,1,0),#features_g*16 * 4 * 4
self._block(features_g*16,features_g*8,4,2,1),#features_g*8 * 8 * 8
self._block(features_g*8,features_g*4,4,2,1),#features_g*4 * 16 * 16
self._block(features_g*4,features_g*2,4,2,1),#features_g*2 * 32 * 32
nn.ConvTranspose2d(features_g*2,channels_img,4,2,1), #
nn.Tanh()# [-1 to 1] normalize the image
)
def _block(self,in_channels,out_channels,kernel_size,stride,padding):
return nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),#w'=(w-1)*s-2p+k
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self,x):
return self.gen(x)
def initialize_weights(model):
for m in model.modules():
if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
nn.init.normal_(m.weight.data,0.0,0.02)
gen=Generator(z_dim,channel_img,features_gen).to(device)
disc=Discriminator(channel_img,features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)
opt_gen=torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
opt_disc=torch.optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))
criterion=nn.BCELoss()
#fixed_noise=torch.randn(32,z_dim,1,1).to(device)
#writer_real=SummaryWriter(f"logs/real")
#writer_fake=SummaryWriter(f"logs/fake")
step=0
gen.train()
disc.train()
z_sample = Variable(torch.randn(100, z_dim,1,1)).cuda()
for epoch in range(2):
for batch_idx,(real,_) in enumerate(train_loader):
real=real.to(device)
noise=torch.randn((batch_size,z_dim,1,1)).to(device)
fake=gen(noise)
#Train Discriminator max log(D(x)) + log(1-D(G(z)))
disc_real=disc(real).reshape(-1)
loss_disc_real=criterion(disc_real,torch.ones_like(disc_real))
disc_fake=disc(fake).reshape(-1)
loss_disc_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
loss_disc=(loss_disc_fake+loss_disc_real)/2
disc.zero_grad()
loss_disc.backward(retain_graph=True)
opt_disc.step()
#Train Generator min log(1-D(G(z))) <--> max log(D(G(z)))
output=disc(fake).reshape(-1)
loss_gen=criterion(output,torch.ones_like(output))
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
print(f'\rEpoch [{epoch+1}/{3}] {batch_idx+1}/{len(train_loader)} Loss_D: {loss_disc.item():.4f} Loss_G: {loss_gen.item():.4f}', end='')
gen.eval()
f_imgs_sample = (gen(z_sample).data + 1) / 2.0
filename = os.path.join(save_dir, f'Epoch_{epoch+1:03d}.jpg')
torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
print(f' | Save some samples to {filename}.')
# show generated image
grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
gen.train()
torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth'))
torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth'))
我无法在此步骤中加载 gen state_dict:
# load pretrained model
#gen = Generator(z_dim,1,64)
gen=Generator(z_dim,channel_img,features_gen).to(device)
gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
gen.eval()
gen.cuda()
这是错误:
RuntimeError Traceback (most recent call last)
<ipython-input-18-4bda27faa444> in <module>()
5
6 #gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
----> 7 gen.load_state_dict(torch.load(os.path.join(workspace_dir, 'dcgan_g.pth')))
8 #/content/drive/My Drive/practice/dcgan_g.pth
9 gen.eval()
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1050 if len(error_msgs) > 0:
1051 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052 self.__class__.__name__, "\n\t".join(error_msgs)))
1053 return _IncompatibleKeys(missing_keys, unexpected_keys)
1054
***RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict***: "gen.0.0.weight", "gen.0.1.weight", "gen.0.1.bias", "gen.0.1.running_mean", "gen.0.1.running_var", "gen.1.0.weight", "gen.1.1.weight", "gen.1.1.bias", "gen.1.1.running_mean", "gen.1.1.running_var", "gen.2.0.weight", "gen.2.1.weight", "gen.2.1.bias", "gen.2.1.running_mean", "gen.2.1.running_var", "gen.3.0.weight", "gen.3.1.weight", "gen.3.1.bias", "gen.3.1.running_mean", "gen.3.1.running_var", "gen.4.weight", "gen.4.bias".
Unexpected key(s) in state_dict: "disc.0.weight", "disc.0.bias", "disc.2.0.weight", "disc.2.1.weight", "disc.2.1.bias", "disc.2.1.running_mean", "disc.2.1.running_var", "disc.2.1.num_batches_tracked", "disc.3.0.weight", "disc.3.1.weight", "disc.3.1.bias", "disc.3.1.running_mean", "disc.3.1.running_var", "disc.3.1.num_batches_tracked", "disc.4.0.weight", "disc.4.1.weight", "disc.4.1.bias", "disc.4.1.running_mean", "disc.4.1.running_var", "disc.4.1.num_batches_tracked", "disc.5.weight", "disc.5.bias".