我正在尝试使用 TorchDyn 将使用 PyTorch Geometric 编写的 GCN 转换为图形神经 ODE。GCN 本身可以正常工作,但我尝试使用以下代码转换为 GDE:
t_span = torch.linspace(0, 1, 2)
model_sub = nn.Sequential(DataControl(), MyGCN(...))
ode = NeuralODE(model_sub, sensitivity='adjoint', solver='rk4', solver_adjoint='dopri5', atol_adjoint=1e-4, rtol_adjoint=1e-4).to('cuda:0')
model = Learner(t_span, ode)
trainer = pl.Trainer(gpus=1, precision=32, limit_train_batches=0.5,
auto_lr_find=True, logger=logger)
trainer.fit(model, train_loader, val_loader)
这会产生以下错误:
Traceback (most recent call last):
File "/home/ubuntu/Desktop/model/main.py", line 30, in <module>
trainer.fit(model, train_loader, val_loader)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
self._call_and_handle_interrupt(
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
self._dispatch()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
return self._run_train()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
self._evaluation_loop.run()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 123, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 215, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step
return self.model.validation_step(*args, **kwargs)
File "/home/ubuntu/Desktop/model/model_ode.py", line 102, in validation_step
return self.model_step(val_batch, batch_idx, 'val')
File "/home/ubuntu/Desktop/model/model_ode.py", line 54, in model_step
t_eval, y_hat = self(batch)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/Desktop/model/model_ode.py", line 28, in forward
return self.model(x)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torchdyn/core/neuralde.py", line 92, in forward
x, t_span = self._prep_integration(x, t_span)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torchdyn/core/neuralde.py", line 88, in _prep_integration
module.u = x[:, excess_dims:].detach()
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 154, in __getitem__
return self.index_select(idx)
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 142, in index_select
return [self.get_example(i) for i in idx]
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 142, in <listcomp>
return [self.get_example(i) for i in idx]
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 96, in get_example
data = separate(
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/separate.py", line 40, in separate
data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
File "/home/ubuntu/anaconda3/envs/storage-env/lib/python3.9/site-packages/torch_geometric/data/separate.py", line 85, in _separate
start, end = slices[idx], slices[idx + 1]
TypeError: unsupported operand type(s) for +: 'slice' and 'int'
Learner
定义了关于 TorchDyn 快速入门中描述的格式,并且train_loader
和val_loader
来自torch_geometric.dataloader
,包含torch_geometric.batch
由 GCN 的功能解包的对象,forward
如下所示:
node_feats = data.x # torch.Tensor
edge_index = data.edge_index # torch.Tensor
graph_feats = data.graph_feats # torch.Tensor
node_feats = gcn_layers(node_feats, edge_index, graph_feats)
return node_feats
尽管该问题引发的错误torch_geometric
显然源于torchdyn
GCN 单独正常工作。我将不胜感激调试此错误的任何帮助。