1

gru下面的代码是用于在 python 2.7中训练 a 的嵌套循环;但它是一个消耗内存的。feats_tensor并且dec_padded_text对象太大并且同时加载它们使我面临内存不足错误。知道如何优化此代码以使用 ram 吗?

  for epoch in xrange(0, 13):
    print ("Starting New Epoch: %d" % epoch)
    np.random.shuffle(order)
    del feats_tensor, dec_text_tensor
    if cuda:
        torch.cuda.empty_cache()
    feats_tensor = torch.tensor(feats[order], requires_grad=False)
    dec_text_tensor = torch.tensor(dec_padded_text[order], requires_grad=False)
    if cuda:
       feats_tensor = feats_tensor.cuda(device=device)
       dec_text_tensor = dec_text_tensor.cuda(device=device) 

    for i in xrange(num_batches):
        s = i * BATCH_SIZE
        e = (i+1) * BATCH_SIZE

        enc.zero_grad()
        dec.zero_grad()

        hid_enc = enc.forward(feats_tensor[s:e]).unsqueeze(0)
        out_dec, hid_dec = dec.forward(dec_text_tensor[s:e,:-1], hid_enc)

        out_perm = out_dec.permute(0, 2, 1)
        loss = lossfunc(out_perm, dec_text_tensor[s:e,1:])
        if sm_loss is None:
            sm_loss = loss.data
        else:
            sm_loss = sm_loss*0.95 + 0.05*loss.data

        loss.backward()
        enc_optim.step()
        dec_optim.step()

        if i % 100 == 0:
            print ("Epoch: %.3f" % (i/float(num_batches) + epoch,), "Loss:", sm_loss)
            #print ("GEN:", untokenize(torch.argmax(out_dec,dim=2)[0,:], dec_idx_to_word))
            #print ("GT:", untokenize(dec_text_tensor[s,:], dec_idx_to_word))
            print ("--------------")

    save_state(enc, dec, enc_optim, dec_optim, dec_idx_to_word, dec_word_to_idx, epoch)
4

0 回答 0