4

我正在为一个简单的 CNN 实现 DistributedDataParallel 训练,用于同时在 3 个分布式节点上运行的 torchvision.datasets.MNIST。我想将数据集划分为 3 个不重叠的子集(A、B、C),每个子集应包含 20000 张图像。单个子集应进一步分为训练和测试分区,即 0.7% 训练和 0.3% 测试。我计划将每个子集分别提供给每个分布式节点,以便它们可以以 DistributedDataParallel 方式进行训练和测试。

如下所示的基本代码,从 torchvision.datasets.MNIST 下载 MNIST 数据集,然后使用 torch.utils.data.distributed.DistributedSampler 和 torch.utils.data.DataLoader 在单个节点上创建用于训练和测试的数据批次。


# TRAINING DATA

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=3, rank=dist.get_rank())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=3, pin_memory=True, sampler=True)


# TESTING DATA

test_dataset = datasets.MNIST('data', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=3, pin_memory=True)

我希望答案应该创建 train_dataset_a、train_dataset_b 和 train_dataset_c,以及 test_dataset_a、test_dataset_b 和 test_dataset_c。

4

0 回答 0