我是 stable-baselines3 的新手,正在尝试解决一个玩具图神经网络问题。我之前有一个使用数组的翻转示例。问题是这样的:给定一个包含 10 个随机位的列表和一个翻转位的操作,找到一种翻转位以将它们全部设置为 1 的方法。显然,您可以通过翻转当前为 0 但系统具有的位来做到这一点学习这个。
我想做同样的事情,输入是带有节点权重的简单线性图而不是数组。我不知道该怎么做。以下代码片段将制作一个有 10 个节点的线性图,将节点权重添加到每个节点并将其转换为 dgl 图
import networkx as nx
import random
import dgl
# Create edges to add
edges = []
N = 10
for i in range(N-1):
edges.append((i, i+1))
# Create graph and convert it into a dgl graph
G=nx.DiGraph()
G.add_edges_from(edges)
for i in range(len(G.nodes)):
G.nodes[i]['weight'] = random.choice([0,1])
dgl_graph = dgl.from_networkx(G, node_attrs=["weight"])
当我在位翻转示例中使用线性数组时,我的环境是这样的:
import numpy as np
import gym from gym
import spaces
class GraphFlipEnv(gym.Env):
def init(self, array_length=10):
super(BitFlipEnv, self).init()
# Size of the 1D-grid
self.array_length = array_length
# Initialize the array of bits to be random
self.agent_pos = random.choices([0,1], k=array_length)
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions, we have two: left and right
self.action_space = spaces.Discrete(array_length)
# The observation will be the coordinate of the agent
# this can be described both by Discrete and Box space
self.observation_space = spaces.Box(low=0, high=1,
shape=(array_length,), dtype=np.uint8)
def reset(self): # Initialize the array to have random values self.time = 0
print(self.agent_pos)
self.agent_pos = random.choices([0,1], k=self.array_length)
return np.array(self.agent_pos)
def step(self, action):
self.time += 1
if not 0 <= action < self.array_length:
raise ValueError("Received invalid action={} which is not part of the action space".format(action))
self.agent_pos[action] ^= 1 # flip the bit
if self.agent_pos[action] == 1:
reward = 1
else:
reward = -1
done = all(self.agent_pos)
info = {}
return np.array(self.agent_pos), reward, done, info
def render(self, mode='console'):
print(self.agent_pos)
def close(self):
pass
完成数组版本中代码的最后几行很简单:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env(lambda: BitFlipEnv(array_length=50), n_envs=12)
# Train the agent
model = PPO('MlpPolicy', env, verbose=1).learn(500000)
我不能再
spaces
对图表使用稳定基线了,那么对于这个玩具问题,让稳定基线与我的 dgl 图表交互的正确方法是什么?