我正在尝试实现二进制分类。我有 100K(3 通道,224 x 224px 预调整大小)图像数据集,我正在尝试训练模型以判断图片是否可以安全工作。我是具有统计学家背景的数据工程师,所以我正在研究这个模型,就像过去 5-10 天一样。我试图根据建议实施解决方案,但不幸的是损失并没有减少。
这是使用 PyTorch Lightning 实现的类,
from .dataset import CloudDataset
from .split import DatasetSplit
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
from torch import stack
from torch.nn import BCEWithLogitsLoss, Conv2d, Dropout, Linear, MaxPool2d, ReLU
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import ToTensor
from util import logger
from util.config import config
class ClassifyModel(LightningModule):
def __init__(self):
super(ClassifyModel, self).__init__()
# custom dataset split class
ds = DatasetSplit(config.s3.bucket, config.train.ratio)
# split records for train, validation and test
self._train_itr, self._valid_itr, self._test_itr = ds.split()
self.conv1 = Conv2d(3, 32, 3, padding=1)
self.conv2 = Conv2d(32, 64, 3, padding=1)
self.conv3 = Conv2d(64, 64, 3, padding=1)
self.pool = MaxPool2d(2, 2)
self.fc1 = Linear(7 * 28 * 64, 512)
self.fc2 = Linear(512, 16)
self.fc3 = Linear(16, 4)
self.fc4 = Linear(4, 1)
self.dropout = Dropout(0.25)
self.relu = ReLU(inplace=True)
self.accuracy = Accuracy()
def forward(self, x):
# comments are shape before execution
# [32, 3, 224, 224]
x = self.pool(self.relu(self.conv1(x)))
# [32, 32, 112, 112]
x = self.pool(self.relu(self.conv2(x)))
# [32, 64, 56, 56]
x = self.pool(self.relu(self.conv3(x)))
# [32, 64, 28, 28]
x = self.pool(self.relu(self.conv3(x)))
# [32, 64, 14, 14]
x = self.dropout(x)
# [32, 64, 14, 14]
x = x.view(-1, 7 * 28 * 64)
# [32, 12544]
x = self.relu(self.fc1(x))
# [32, 512]
x = self.relu(self.fc2(x))
# [32, 16]
x = self.relu(self.fc3(x))
# [32, 4]
x = self.dropout(self.fc4(x))
# [32, 1]
x = x.squeeze(1)
# [32]
return x
def configure_optimizers(self):
return Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
image, target = batch
target = target.float()
output = self.forward(image)
loss = BCEWithLogitsLoss()
output = loss(output, target)
logits = self(image)
self.accuracy(logits, target)
return {'loss': output}
def validation_step(self, batch, batch_idx):
image, target = batch
target = target.float()
output = self.forward(image)
loss = BCEWithLogitsLoss()
output = loss(output, target)
return {'val_loss': output}
def collate_fn(self, batch):
batch = list(filter(lambda x: x is not None, batch))
return default_collate(batch)
def train_dataloader(self):
transform = ToTensor()
workers = 0 if config.train.test else config.train.workers
# custom data set class that read files from s3
cds = CloudDataset(config.s3.bucket, self._train_itr, transform)
return DataLoader(
dataset=cds,
batch_size=32,
shuffle=True,
num_workers=workers,
collate_fn=self.collate_fn,
)
def val_dataloader(self):
transform = ToTensor()
workers = 0 if config.train.test else config.train.workers
# custom data set class that read files from s3
cds = CloudDataset(config.s3.bucket, self._valid_itr, transform)
return DataLoader(
dataset=cds,
batch_size=32,
num_workers=workers,
collate_fn=self.collate_fn,
)
def test_dataloader(self):
transform = ToTensor()
workers = 0 if config.train.test else config.train.workers
# custom data set class that read files from s3
cds = CloudDataset(config.s3.bucket, self._test_itr, transform)
return DataLoader(
dataset=cds,
batch_size=32,
shuffle=True,
num_workers=workers,
collate_fn=self.collate_fn,
)
def validation_epoch_end(self, outputs):
avg_loss = stack([x['val_loss'] for x in outputs]).mean()
logger.info(f'Validation loss is {avg_loss}')
def training_epoch_end(self, outs):
accuracy = self.accuracy.compute()
logger.info(f'Training accuracy is {accuracy}')
这是自定义日志输出,
epoch 0
Validation loss is 0.5988735556602478
Training accuracy is 0.4441356360912323
epoch 1
Validation loss is 0.6406065225601196
Training accuracy is 0.4441356360912323
epoch 2
Validation loss is 0.621654748916626
Training accuracy is 0.443579763174057
epoch 3
Validation loss is 0.5089989304542542
Training accuracy is 0.4580322504043579
epoch 4
Validation loss is 0.5484663248062134
Training accuracy is 0.4886047840118408
epoch 5
Validation loss is 0.5552918314933777
Training accuracy is 0.6142301559448242
epoch 6
Validation loss is 0.661466121673584
Training accuracy is 0.625903308391571
该问题可能与优化器或损失函数有关,但我无法弄清楚。