0

我的损失函数倾向于将我的所有预测预测为空白字符。今天,我为一个简单的 OCR 任务写了 CRNN + CTCLOSS 代码,但是效果不是很好。具体来说,模型的输出结果是空白符号对应的数字

在此处输入图像描述

batch_size = 8
epoch = 10
learning_rate = 0.0005
momentum = 0.9
dataloader = ImageDataLoader()
image, label = dataloader.get_data()
image = torch.Tensor(image)
label = torch.Tensor(label)
X_train, X_test, y_train, y_test = train_test_split(image, label, test_size=0.1)
train_dataloader = DataLoader(dataset=TensorDataset(X_train, y_train),
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=0,
                             drop_last=True)
test_dataloader = DataLoader(dataset=TensorDataset(X_test, y_test),
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=0,
                             drop_last=True)

vocab_num = dataloader.vocab_scale
model = CRNN(32, 3, vocab_num, 256)

# init a CTCLoss function
criticism = nn.CTCLoss(vocab_num-1, reduction="mean", zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             betas=(0.9, 0.999),
                             eps=1e-04,
                             weight_decay=0,
                             amsgrad=False)

# parameters to gpu
model = nn.DataParallel(model)
model.to(device)
X_train = X_train.to(device)
X_test = X_test.to(device)
y_train = y_train.to(device)
y_test = y_test.to(device)

# train code
for epo in range(epoch):
    for i, data in enumerate(train_dataloader):
        model.train()
        optimizer.zero_grad()
        inputs, labels = data
        outputs = model(inputs)
        # here is parameters of the loss forward
        outputs = outputs.permute(1, 0, 2).log_softmax(2).requires_grad_()
        inputs_length = torch.tensor([len(outputs)]*len(labels), dtype=torch.int32)
        target_length = torch.tensor([len(labels[0])]*len(labels), dtype=torch.int32)
        tmp_labels = flat_labels(labels)
        loss = criticism(outputs, tmp_labels, inputs_length, target_length)
        loss.backward()
        optimizer.step()
        print(f"train: epoch {epo}, batch no.{i}, loss = {loss.data}")
    print("*"*50)
4

0 回答 0