0
    split = int(len(train_dataset) * 0.8)
    print(split)
    index_list = list(range(len(train_dataset)))
    train_idx, valid_idx = index_list[:split], index_list[split:]
    print(len(train_idx),len(valid_idx))
48000 12000

我得到了 train_idx 和 valid_idx 的 48000 和 12000 索引,
然后我将此数字应用于数据加载器

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler= SubsetRandomSampler(train_idx))
valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx))

打印(len(train_loader.dataset),len(valid_loader.dataset))
60000 60000

但 len 似乎不对

for epoch in range(EPOCHS):
    for i , (train_idx, valid_idx) in enumerate(splits):
        ## create iterator objects for train and valid datasets
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler= SubsetRandomSampler(train_idx))
        valid_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx))
        submit_loader = DataLoader(dataset = test_dataset,batch_size = batch_size, shuffle = True)

        train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
        valid_loss, valid_acc = evaluate(model, device, valid_loader, criterion)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)

        print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |')
4

0 回答 0