0

在分布式 DGL 中,网络通信发生在采样和反向传播阶段。我想分别了解网络通信和计算开销。如何获得分布式 DGL 中的网络开销?

if __name__ == "__main__":
    dgl.distributed.initialize("ip_config.txt")
    th.distributed.init_process_group(backend="gloo")
    g = dgl.distributed.DistGraph("ogbn-products")
    pb = g.get_partition_book()
    train_nid = dgl.distributed.node_split(g.ndata["train_mask"], pb)
    valid_nid = dgl.distributed.node_split(g.ndata["val_mask"], pb)
    device = th.device("cuda:0")
    # Define model and optimizer
    num_hidden = 16
    num_labels = len(th.unique(g.ndata["labels"][0: g.number_of_nodes()]))
    num_layers = 2
    lr = 0.003
    epochs = 20
    model = SAGE(g.ndata["feat"].shape[1], num_hidden, num_labels, num_layers)
    loss_fcn = nn.CrossEntropyLoss()
    model = th.nn.parallel.DistributedDataParallel(model)
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    sampler = NeighborSampler(g)
    train_dataloader = dgl.distributed.DistDataLoader(
        dataset=train_nid.numpy(),
        batch_size=1000,
        collate_fn=sampler.sample,
        shuffle=True,
    )
    for epoch in range(epochs):
        for step, blocks in enumerate(train_dataloader):
             
            batch_inputs = g.ndata["feat"][blocks[0].srcdata[dgl.NID]].to(
                device)
            batch_labels = g.ndata["labels"][blocks[-1].dstdata[dgl.NID]
                                             ].to(device)
            blocks = [block.int().to(device) for block in blocks]
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, batch_labels)
            print("loss: " + str(loss.item()))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()    
     
4

0 回答 0