0

我设置batch_size了 equals 64,但是当我打印出 train_batch 和 val_batch 时,大小不等于 64。

train 数据和 val 数据格式如下: 在此处输入图像描述

首先,我定义TEXTLABEL字段。

tokenize = lambda x: x.split()

TEXT = data.Field(sequential=True, tokenize=tokenize)
LABEL = data.Field(sequential=False)

然后我继续尝试遵循教程,并在下面写了一些东西:

train_data, valid_data = data.TabularDataset.splits(
        path='.',
        train='train_intent.csv', validation='val.csv',
        format='csv',
        fields= {'sentences': ('text', TEXT),
                'labels': ('label',LABEL)}
)

test_data = data.TabularDataset(
        path='test.csv',
        format='csv',
        fields={'sentences': ('text', TEXT)}

)
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

BATCH_SIZE = 64

train_iter, val_iter = data.BucketIterator.splits(
    (train_data, valid_data),
    batch_sizes=(BATCH_SIZE, BATCH_SIZE),
    sort_key=lambda x: len(x.text),
    sort_within_batch=False,
    repeat=False,
    device=device
)

但是当我想知道 iter 是否正常时,我会发现以下奇怪的事情:

train_batch = next(iter(train_iter))
print(train_batch.text.shape)
print(train_batch.label.shape)
[output]
torch.Size([15, 64])
torch.Size([64])

并且训练过程输出错误ValueError: Expected input batch_size (15) to match target batch_size (64).

def train(model, iterator, optimizer, criterion):

    epoch_loss = 0

    model.train()

    for batch in iterator:

        optimizer.zero_grad()

        predictions = model(batch.text)
        loss = criterion(predictions, batch.label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

任何人都可以给我一个提示将不胜感激。谢谢!

4

2 回答 2

1

返回的批量大小并不总是等于batch_size. 例如:你有 100 个训练数据,batch_size 是 64。返回的 batch_size 应该是[64, 36].

代码: https ://github.com/pytorch/text/blob/1c2ae32d67f7f7854542212b229cd95c85cf4026/torchtext/data/iterator.py#L255-L271

于 2019-04-09T09:00:56.143 回答
0

我也遇到了这个问题。我认为问题在于 batch_size 不在 shape[0] 位置。在你的问题中:

train_batch = next(iter(train_iter))
print(train_batch.text.shape)
print(train_batch.label.shape)
[output]
torch.Size([15, 64])
torch.Size([64])

15是batch中的max_sequence_length,可以使用Field定义中的fix_length来固定,64是batch_size。我认为你可以重塑你的文本来解决这个问题,但我也在寻找更好的答案。

于 2021-02-23T03:20:33.620 回答