0

我正在尝试使用 TorchDyn 将使用 PyTorch Geometric 编写的 GCN 转换为图形神经 ODE。GCN 本身可以正常工作,但我尝试使用以下代码转换为 GDE:

t_span = torch.linspace(0, 1, 2)
model_sub = nn.Sequential(DataControl(), MyGCN(...))
ode = NeuralODE(model_sub, sensitivity='adjoint', solver='rk4', solver_adjoint='dopri5', atol_adjoint=1e-4, rtol_adjoint=1e-4).to('cuda:0')
model = Learner(t_span, ode)
trainer = pl.Trainer(gpus=1, precision=32, limit_train_batches=0.5, 
                     auto_lr_find=True, logger=logger)
trainer.fit(model, train_loader, val_loader)

这会产生以下错误:

Traceback (most recent call last):
  File "/home/ubuntu/Desktop/model/main.py", line 30, in <module>
    trainer.fit(model, train_loader, val_loader)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
    self._dispatch()
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
    return self._run_train()
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
    self._evaluation_loop.run()
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 123, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 215, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step
    return self.model.validation_step(*args, **kwargs)
  File "/home/ubuntu/Desktop/model/model_ode.py", line 102, in validation_step
    return self.model_step(val_batch, batch_idx, 'val')
  File "/home/ubuntu/Desktop/model/model_ode.py", line 54, in model_step
    t_eval, y_hat = self(batch)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/Desktop/model/model_ode.py", line 28, in forward
    return self.model(x)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torchdyn/core/neuralde.py", line 92, in forward
    x, t_span = self._prep_integration(x, t_span)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torchdyn/core/neuralde.py", line 88, in _prep_integration
    module.u = x[:, excess_dims:].detach()
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 154, in __getitem__
    return self.index_select(idx)
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 142, in index_select
    return [self.get_example(i) for i in idx]
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 142, in <listcomp>
    return [self.get_example(i) for i in idx]
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 96, in get_example
    data = separate(
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/separate.py", line 40, in separate
    data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
  File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/separate.py", line 85, in _separate
    start, end = slices[idx], slices[idx + 1]
TypeError: unsupported operand type(s) for +: 'slice' and 'int'

Learner定义了关于 TorchDyn 快速入门中描述的格式,并且train_loaderval_loader来自torch_geometric.dataloader,包含torch_geometric.batch由 GCN 的功能解包的对象,forward如下所示:

node_feats = data.x    # torch.Tensor
edge_index = data.edge_index    # torch.Tensor
graph_feats = data.graph_feats    # torch.Tensor
node_feats = gcn_layers(node_feats, edge_index, graph_feats)
return node_feats

尽管该问题引发的错误torch_geometric显然源于torchdynGCN 单独正常工作。我将不胜感激调试此错误的任何帮助。

4

0 回答 0