我正在尝试构建一个基本的 GAN 来熟悉 Pytorch。我对 Keras 有一些(有限的)经验,但由于我必须在 Pytorch 中做一个更大的项目,所以我想首先使用“基本”网络进行探索。
我正在使用 Pytorch 闪电。我想我已经添加了所有必要的组件。我尝试分别通过生成器和鉴别器传递一些噪声,我认为输出具有预期的形状。尽管如此,当我尝试训练 GAN 时出现运行时错误(下面的完整回溯):
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)
我注意到 7 是批次的大小(通过打印出批次尺寸),即使我将 batch_size 指定为 64。除此之外,老实说,我不知道从哪里开始:错误回溯没有帮我。
很可能,我犯了多个错误。但是,我希望你们中的一些人能够从代码中发现当前的错误,因为乘法错误似乎指向某个地方的维度问题。这是代码。
import os
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid
from torchvision.transforms import Resize, ToTensor, ToPILImage, Normalize
class DoppelDataset(Dataset):
"""
Dataset class for face data
"""
def __init__(self, face_dir: str, transform=None):
self.face_dir = face_dir
self.face_paths = os.listdir(face_dir)
self.transform = transform
def __len__(self):
return len(self.face_paths)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
face_path = os.path.join(self.face_dir, self.face_paths[idx])
face = io.imread(face_path)
sample = {'image': face}
if self.transform:
sample = self.transform(sample['image'])
return sample
class DoppelDataModule(pl.LightningDataModule):
def __init__(self, data_dir='../data/faces', batch_size: int = 64, num_workers: int = 0):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.transforms = transforms.Compose([
ToTensor(),
Resize(100),
Normalize(mean=(123.26290927634774, 95.90498110733365, 86.03763122875182),
std=(63.20679012922922, 54.86211954409834, 52.31266645797249))
])
def setup(self, stage=None):
# Initialize dataset
doppel_data = DoppelDataset(face_dir=self.data_dir, transform=self.transforms)
# Train/val/test split
n = len(doppel_data)
train_size = int(.8 * n)
val_size = int(.1 * n)
test_size = n - (train_size + val_size)
self.train_data, self.val_data, self.test_data = random_split(dataset=doppel_data,
lengths=[train_size, val_size, test_size])
def train_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
class DoppelGenerator(nn.Sequential):
"""
Generator network that produces images based on latent vector
"""
def __init__(self, latent_dim: int):
super().__init__()
def block(in_channels: int, out_channels: int, padding: int = 1, stride: int = 2, bias=False):
return nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=stride,
padding=padding, bias=bias),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(True)
)
self.model = nn.Sequential(
block(latent_dim, 512, padding=0, stride=1),
block(512, 256),
block(256, 128),
block(128, 64),
block(64, 32),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.model(input)
class DoppelDiscriminator(nn.Sequential):
"""
Discriminator network that classifies images in two categories
"""
def __init__(self):
super().__init__()
def block(in_channels: int, out_channels: int):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(num_features=out_channels),
nn.LeakyReLU(0.2, inplace=True),
)
self.model = nn.Sequential(
block(3, 64),
block(64, 128),
block(128, 256),
block(256, 512),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Flatten(),
nn.Linear(25, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.model(input)
class DoppelGAN(pl.LightningModule):
def __init__(self,
channels: int,
width: int,
height: int,
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = 64,
**kwargs):
super().__init__()
# Save all keyword arguments as hyperparameters, accessible through self.hparams.X)
self.save_hyperparameters()
# Initialize networks
# data_shape = (channels, width, height)
self.generator = DoppelGenerator(latent_dim=self.hparams.latent_dim, )
self.discriminator = DoppelDiscriminator()
self.validation_z = torch.randn(8, self.hparams.latent_dim,1,1)
def forward(self, input):
return self.generator(input)
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx, optimizer_idx):
images = batch
# Sample noise (batch_size, latent_dim,1,1)
z = torch.randn(images.size(0), self.hparams.latent_dim,1,1)
# Train generator
if optimizer_idx == 0:
# Generate images (call generator -- see forward -- on latent vector)
self.generated_images = self(z)
# Log sampled images (visualize what the generator comes up with)
sample_images = self.generated_images[:6]
grid = make_grid(sample_images)
self.logger.experiment.add_image('generated_images', grid, 0)
# Ground truth result (ie: all fake)
valid = torch.ones(images.size(0), 1)
# Adversarial loss is binary cross-entropy
generator_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
tqdm_dict = {'gen_loss': generator_loss}
output = {
'loss': generator_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
}
return output
# Train discriminator: classify real from generated samples
if optimizer_idx == 1:
# How well can it label as real?
valid = torch.ones(images.size(0), 1)
real_loss = self.adversarial_loss(self.discriminator(images), valid)
# How well can it label as fake?
fake = torch.zeros(images.size(0), 1)
fake_loss = self.adversarial_loss(
self.discriminator(self(z).detach()), fake)
# Discriminator loss is the average of these
discriminator_loss = (real_loss + fake_loss) / 2
tqdm_dict = {'d_loss': discriminator_loss}
output = {
'loss': discriminator_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
}
return output
def configure_optimizers(self):
lr = self.hparams.lr
b1 = self.hparams.b1
b2 = self.hparams.b2
# Optimizers
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
# Return optimizers/schedulers (currently no scheduler)
return [opt_g, opt_d], []
def on_epoch_end(self):
# Log sampled images
sample_images = self(self.validation_z)
grid = make_grid(sample_images)
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
if __name__ == '__main__':
# Global parameter
image_dim = 128
latent_dim = 100
batch_size = 64
# Initialize dataset
tfs = transforms.Compose([
ToPILImage(),
Resize(image_dim),
ToTensor()
])
doppel_dataset = DoppelDataset(face_dir='../data/faces', transform=tfs)
# Initialize data module
doppel_data_module = DoppelDataModule(batch_size=batch_size)
# Build models
generator = DoppelGenerator(latent_dim=latent_dim)
discriminator = DoppelDiscriminator()
# Test generator
x = torch.rand(batch_size, latent_dim, 1, 1)
y = generator(x)
print(f'Generator: x {x.size()} --> y {y.size()}')
# Test discriminator
x = torch.rand(batch_size, 3, 128, 128)
y = discriminator(x)
print(f'Discriminator: x {x.size()} --> y {y.size()}')
# Build GAN
doppelgan = DoppelGAN(batch_size=batch_size, channels=3, width=image_dim, height=image_dim, latent_dim=latent_dim)
# Fit GAN
trainer = pl.Trainer(gpus=0, max_epochs=5, progress_bar_refresh_rate=1)
trainer.fit(model=doppelgan, datamodule=doppel_data_module)
完整追溯:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-28805d67d74b>", line 1, in <module>
runfile('/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py', wdir='/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger')
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py", line 298, in <module>
trainer.fit(model=doppelgan, datamodule=doppel_data_module)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
results = self.accelerator_backend.train()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 57, in train
return self.train_or_test()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
results = self.trainer.train()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 561, in train
self.train_loop.run_training_epoch()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 550, in run_training_epoch
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 718, in run_training_batch
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 485, in optimizer_step
model_ref.optimizer_step(
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1298, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 286, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 144, in __optimizer_step
optimizer.step(closure=closure, *args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/optim/adam.py", line 66, in step
loss = closure()
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 708, in train_step_and_backward_closure
result = self.training_step_and_backward(
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 806, in training_step_and_backward
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 319, in training_step
training_step_output = self.trainer.accelerator_backend.training_step(args)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py", line 62, in training_step
return self._step(self.trainer.model.training_step, args)
File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py", line 58, in _step
output = model_step(*args)
File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py", line 223, in training_step
real_loss = self.adversarial_loss(self.discriminator(images), valid)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py", line 154, in forward
return self.model(input)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/usr/local/lib/python3.9/site-packages/torch/nn/functional.py", line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)