我正在从本地计算机加载音频文件,因此使用torchaudio
. 我正在创建一个collate_fn
为每个批次填充序列的方法,如下所示:
def pad_sequence(batch):
batch = [item.t() for item in batch]
batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0., )
return batch.permute(0, 2, 1)
def collate_fn(batch):
tensors, emotion_targets = [], []
emotion_intensity_targets, gender_targets = [], []
for waveform, e, ei, g in batch:
# apply the transformations, by downsampling the wavefor from sample_rate 16000 to 8000
tensors += [transform(waveform)]
emotion_targets += [e]
emotion_intensity_targets += [ei]
gender_targets += [g]
tensors = pad_sequence(tensors)
emotion_targets = torch.stack(emotion_targets)
emotion_intensity_targets = torch.stack(emotion_intensity_targets)
gender_targets = torch.stack(gender_targets)
return tensors, emotion_targets, emotion_intensity_targets, gender_targets
然后当我创建一个迭代器如下
BATCH_SIZE = 16
train_loader = torch.utils.data.DataLoader(
train,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn
)
test_loader = torch.utils.data.DataLoader(
test,
batch_size=BATCH_SIZE,
collate_fn=collate_fn,
)
当使用它运行的代码运行单元格时:
next(iter(train_loader))
但是,当我继续运行单元时,突然随机批次出现以下错误:
RuntimeError Traceback (most recent call last)
<ipython-input-53-0fd2476fd2ff> in <module>()
----> 1 next(iter(train_loader))
5 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/utils/rnn.py in pad_sequence(sequences, batch_first, padding_value)
361 # assuming trailing dimensions and type of all the Tensors
362 # in sequences are same and fetching those from sequences[0]
--> 363 return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
364
365
RuntimeError: output with shape [39239, 1] doesn't match the broadcast shape [39239, 2]
我在这里可能有什么问题?