在分布式 DGL 中,网络通信发生在采样和反向传播阶段。我想分别了解网络通信和计算开销。如何获得分布式 DGL 中的网络开销?
if __name__ == "__main__":
dgl.distributed.initialize("ip_config.txt")
th.distributed.init_process_group(backend="gloo")
g = dgl.distributed.DistGraph("ogbn-products")
pb = g.get_partition_book()
train_nid = dgl.distributed.node_split(g.ndata["train_mask"], pb)
valid_nid = dgl.distributed.node_split(g.ndata["val_mask"], pb)
device = th.device("cuda:0")
# Define model and optimizer
num_hidden = 16
num_labels = len(th.unique(g.ndata["labels"][0: g.number_of_nodes()]))
num_layers = 2
lr = 0.003
epochs = 20
model = SAGE(g.ndata["feat"].shape[1], num_hidden, num_labels, num_layers)
loss_fcn = nn.CrossEntropyLoss()
model = th.nn.parallel.DistributedDataParallel(model)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
sampler = NeighborSampler(g)
train_dataloader = dgl.distributed.DistDataLoader(
dataset=train_nid.numpy(),
batch_size=1000,
collate_fn=sampler.sample,
shuffle=True,
)
for epoch in range(epochs):
for step, blocks in enumerate(train_dataloader):
batch_inputs = g.ndata["feat"][blocks[0].srcdata[dgl.NID]].to(
device)
batch_labels = g.ndata["labels"][blocks[-1].dstdata[dgl.NID]
].to(device)
blocks = [block.int().to(device) for block in blocks]
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
print("loss: " + str(loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()