我正在 LibTorch/Pytorch 中实现 DCGAN 网络。我正在关注https://github.com/pytorch/examples/blob/master/cpp/dcgan/dcgan.cpp中的官方示例。我的问题和示例之间的唯一区别是:
- 我的数据集由 RGB 图片(CelebA 数据集)组成,而示例中的图片是黑白图片(MNIST)
- 我的图片尺寸为 64x64,而 MNIST 图片为 28x28
这就是我的代码:
#include <torch/torch.h>
#include <cmath>
#include <cstdio>
#include <iostream>
#include "CustomDataset.h"
#include "parameters.h"
// The size of the noise vector fed to the generator.
const int64_t kNoiseSize = 100;
// The batch size for training.
const int64_t kBatchSize = 64;
// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;
// Where to find the MNIST dataset.
const char* kDataFolder = "./data";
// After how many batches to create a new checkpoint periodically.
const int64_t kCheckpointEvery = 20;
// How many images to sample at every checkpoint.
const int64_t kNumberOfSamplesPerCheckpoint = 10;
// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;
using namespace torch;
struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize)
: conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
.bias(false)),
batch_norm1(256),
conv2(nn::ConvTranspose2dOptions(256, 128, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm2(128),
conv3(nn::ConvTranspose2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm3(64),
conv4(nn::ConvTranspose2dOptions(64, 32, 4)
.stride(2)
.padding(1)
.bias(false)),
batch_norm4(32),
conv5(nn::ConvTranspose2dOptions(32, 3, 4)
.stride(2)
.padding(1)
.bias(false))
{
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv3", conv3);
register_module("conv4", conv4);
register_module("conv5", conv5);
register_module("batch_norm1", batch_norm1);
register_module("batch_norm2", batch_norm2);
register_module("batch_norm3", batch_norm3);
register_module("batch_norm4", batch_norm4);
}
torch::Tensor forward(torch::Tensor x)
{
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::relu(batch_norm4(conv4(x)));
x = torch::tanh(conv5(x));
return x;
}
nn::ConvTranspose2d conv1, conv2, conv3, conv4, conv5;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3, batch_norm4;
};
TORCH_MODULE(DCGANGenerator);
int main(int argc, const char* argv[]) {
torch::manual_seed(1);
// Create the device we pass around based on whether CUDA is available.
torch::Device device(torch::kCPU);
if (torch::cuda::is_available()) {
std::cout << "CUDA is available! Training on GPU." << std::endl;
device = torch::Device(torch::kCUDA);
}
DCGANGenerator generator(kNoiseSize);
generator->to(device);
nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(3, 64, 4).stride(2).padding(1).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
//output is 32x32
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(128),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
//output is 16x16
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 64, 4).stride(2).padding(1).bias(false)),
nn::BatchNorm2d(64),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
//output is 8x8
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(64, 32, 5).stride(1).padding(0).bias(false)),
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
// output is 4x4
// Layer 5
nn::Conv2d(
nn::Conv2dOptions(32, 1, 4).stride(1).padding(0).bias(false)),
nn::Sigmoid());
discriminator->to(device);
// Where all my pictures are;
std::string file_location{"dataset/img_align_celeba/*.jpg"};
auto dataset = CustomDataset(file_location).map(data::transforms::Stack<>());
const int64_t batches_per_epoch =
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
auto data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
int64_t checkpoint_counter = 1;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data.to(device);
torch::Tensor real_labels =
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real =
torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();
// Train discriminator with fake images.
torch::Tensor noise =
torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake =
torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();
torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();
// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss =
torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();
batch_index++;
if (batch_index % kCheckpointEvery == 0) {
// Checkpoint the model and optimizer state.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// Sample the generator and save the images.
torch::Tensor samples = generator->forward(torch::randn(
{kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device));
torch::save(
samples,
torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
}
}
std::cout << "Training complete!" << std::endl;
}
我不时保存小批量并绘制在生成器上输入噪声的结果。问题是,在 MNIST 示例中,结果是正确的,但在我的情况下,对于每张输出图片,我看到的是 9 张带有面孔而不是一张的较小图片(见附图)。
生成器怎么可能输出正确的形状但有 9 个几乎相同的面而不是一个?