0

我正在从本地计算机加载音频文件,因此使用torchaudio. 我正在创建一个collate_fn为每个批次填充序列的方法,如下所示:

def pad_sequence(batch):
  batch = [item.t() for item in batch]
  batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0., )
  return batch.permute(0, 2, 1)

def collate_fn(batch):
  tensors, emotion_targets = [], []
  emotion_intensity_targets, gender_targets = [], []

  for waveform, e, ei, g  in batch:
    # apply the transformations, by downsampling the wavefor from sample_rate 16000 to 8000
    tensors += [transform(waveform)]
    emotion_targets += [e]
    emotion_intensity_targets += [ei]
    gender_targets += [g]
  tensors = pad_sequence(tensors)
  emotion_targets = torch.stack(emotion_targets)
  emotion_intensity_targets = torch.stack(emotion_intensity_targets)
  gender_targets = torch.stack(gender_targets)
  return tensors, emotion_targets, emotion_intensity_targets, gender_targets

然后当我创建一个迭代器如下

BATCH_SIZE = 16
train_loader = torch.utils.data.DataLoader(
    train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = torch.utils.data.DataLoader(
    test,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
)

当使用它运行的代码运行单元格时:

next(iter(train_loader))

但是,当我继续运行单元时,突然随机批次出现以下错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-53-0fd2476fd2ff> in <module>()
----> 1 next(iter(train_loader))

5 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/utils/rnn.py in pad_sequence(sequences, batch_first, padding_value)
    361     # assuming trailing dimensions and type of all the Tensors
    362     # in sequences are same and fetching those from sequences[0]
--> 363     return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
    364 
    365 

RuntimeError: output with shape [39239, 1] doesn't match the broadcast shape [39239, 2]

我在这里可能有什么问题?

4

0 回答 0