我们也可以简单地使用nn.Sequential()
,例如,使用以下代码片段:
import torch
encoded_dim = 32
encoder = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(28*28, 256),
torch.nn.Sigmoid(),
torch.nn.Linear(256, 64),
torch.nn.Sigmoid(),
torch.nn.Linear(64, encoded_dim)
)
decoder = torch.nn.Sequential(
torch.nn.Linear(encoded_dim, 64),
torch.nn.Sigmoid(),
torch.nn.Linear(64, 256),
torch.nn.Sigmoid(),
torch.nn.Linear(256, 28*28),
torch.nn.Sigmoid(),
torch.nn.Unflatten(1, (28,28))
)
autoencoder = torch.nn.Sequential(encoder, decoder)
autoencoder
# Sequential(
# (0): Sequential(
# (0): Flatten(start_dim=1, end_dim=-1)
# (1): Linear(in_features=784, out_features=256, bias=True)
# (2): Sigmoid()
# (3): Linear(in_features=256, out_features=64, bias=True)
# (4): Sigmoid()
# (5): Linear(in_features=64, out_features=32, bias=True)
# )
# (1): Sequential(
# (0): Linear(in_features=32, out_features=64, bias=True)
# (1): Sigmoid()
# (2): Linear(in_features=64, out_features=256, bias=True)
# (3): Sigmoid()
# (4): Linear(in_features=256, out_features=784, bias=True)
# (5): Sigmoid()
# (6): Unflatten(dim=1, unflattened_size=(28, 28))
# )
#)
使用 MNIST 数据进行示例训练
加载数据(MNIST)torchvision
:
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
# ...
])),
batch_size=64, shuffle=True)
现在,让我们训练自动编码器模型,使用的优化器是Adam
,虽然SGD
也可以使用:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3, weight_decay=1e-5)
for epoch in range(10):
for idx, (x, _) in enumerate(train_loader):
x = x.squeeze()
x = x / x.max()
x_pred = autoencoder(x) # forward pass
loss = loss_fn(x_pred, x)
if idx % 1024 == 0:
print(epoch, loss.item())
optimizer.zero_grad()
loss.backward() # backward pass
optimizer.step()
# epoch loss
# 0 0.702496349811554
# 1 0.24611620604991913
# 2 0.20603498816490173
# 3 0.1827092468738556
# 4 0.1805819869041443
# 5 0.16927748918533325
# 6 0.17275433242321014
# 7 0.15827134251594543
# 8 0.1635081171989441
# 9 0.15693898499011993
下面的动画展示了自动编码器在不同时期重建一些随机选择的图像,注意 MNIST 数字的重建如何随着越来越多的时期变得更好: