我的损失函数倾向于将我的所有预测预测为空白字符。今天,我为一个简单的 OCR 任务写了 CRNN + CTCLOSS 代码,但是效果不是很好。具体来说,模型的输出结果是空白符号对应的数字
batch_size = 8
epoch = 10
learning_rate = 0.0005
momentum = 0.9
dataloader = ImageDataLoader()
image, label = dataloader.get_data()
image = torch.Tensor(image)
label = torch.Tensor(label)
X_train, X_test, y_train, y_test = train_test_split(image, label, test_size=0.1)
train_dataloader = DataLoader(dataset=TensorDataset(X_train, y_train),
batch_size=batch_size,
shuffle=False,
num_workers=0,
drop_last=True)
test_dataloader = DataLoader(dataset=TensorDataset(X_test, y_test),
batch_size=batch_size,
shuffle=False,
num_workers=0,
drop_last=True)
vocab_num = dataloader.vocab_scale
model = CRNN(32, 3, vocab_num, 256)
# init a CTCLoss function
criticism = nn.CTCLoss(vocab_num-1, reduction="mean", zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate,
betas=(0.9, 0.999),
eps=1e-04,
weight_decay=0,
amsgrad=False)
# parameters to gpu
model = nn.DataParallel(model)
model.to(device)
X_train = X_train.to(device)
X_test = X_test.to(device)
y_train = y_train.to(device)
y_test = y_test.to(device)
# train code
for epo in range(epoch):
for i, data in enumerate(train_dataloader):
model.train()
optimizer.zero_grad()
inputs, labels = data
outputs = model(inputs)
# here is parameters of the loss forward
outputs = outputs.permute(1, 0, 2).log_softmax(2).requires_grad_()
inputs_length = torch.tensor([len(outputs)]*len(labels), dtype=torch.int32)
target_length = torch.tensor([len(labels[0])]*len(labels), dtype=torch.int32)
tmp_labels = flat_labels(labels)
loss = criticism(outputs, tmp_labels, inputs_length, target_length)
loss.backward()
optimizer.step()
print(f"train: epoch {epo}, batch no.{i}, loss = {loss.data}")
print("*"*50)