您可以使用 pickle 序列化您的 DGL 图形对象并将生成的字节字符串转换为整数向量(字符串中的每个 char 对应一个 int)。
import dgl
import numpy as np
import pickle
def serialize_graph(graph: dgl.DGLGraph):
as_byte_string = pickle.dumps(graph)
as_int_list = [_ for _ in as_byte_string] # we get ints for free without explicitly casting
as_float_array = np.array(as_int_list, dtype=np.float32)
return as_float_array
然后,您可以反向应用相同的操作来反序列化自定义特征提取器中图形的向量表示。
import dgl
import pickle
import torch as th
def deserialize_graph(observation: th.Tensor):
as_int_tensor = observation.to(dtype=th.int32)
as_char_list = [chr(_) for _ in observation]
as_byte_string = bytearray(''.join(as_char_list), encoding='latin')
as_dgl_graph = pickle.loads(as_byte_string)
return as_dgl_graph