0

我有一种情况,对于每个小批量,我有多个嵌套数据,需要训练哪个模型。

for idx, batch in enumerate(train_dataloader):
data = batch.get("data").squeeze(0)
op = torch.zeros(size) #zero_initializations
for i in range(data.shape[0]):
    optimizer.zero_grad()
    current_data = data[i, ...]
    start_to_current_data = data[:i+1, ...]
    target =  some_transformation_func(start_to_current_data)
    op = model(current_data, op)
    loss = criterion(op, target)
    loss.backward()
    optimizer.step()

但是当我开始训练时,我得到以下错误RuntimeError: Trying back through the graph a second time, but the saved intermediate results have been freed. 第一次向后调用时指定retain_graph=True。设置 retain_graph=True 会增加内存使用量,我无法训练模型。我怎样才能解决这个问题。

4

0 回答 0