0

我是 stable-baselines3 的新手,正在尝试解决一个玩具图神经网络问题。我之前有一个使用数组的翻转示例。问题是这样的:给定一个包含 10 个随机位的列表和一个翻转位的操作,找到一种翻转位以将它们全部设置为 1 的方法。显然,您可以通过翻转当前为 0 但系统具有的位来做到这一点学习这个。

我想做同样的事情,输入是带有节点权重的简单线性图而不是数组。我不知道该怎么做。以下代码片段将制作一个有 10 个节点的线性图,将节点权重添加到每个节点并将其转换为 dgl 图

import networkx as nx
import random
import dgl
# Create edges to add
edges = []
N = 10
for i in range(N-1):
edges.append((i, i+1))
# Create graph and convert it into a dgl graph
G=nx.DiGraph()
G.add_edges_from(edges)
for i in range(len(G.nodes)):
    G.nodes[i]['weight'] = random.choice([0,1])
    dgl_graph = dgl.from_networkx(G, node_attrs=["weight"])

当我在位翻转示例中使用线性数组时,我的环境是这样的:

import numpy as np 
import gym from gym 
import spaces
class GraphFlipEnv(gym.Env):
def init(self, array_length=10): 
    super(BitFlipEnv, self).init()
    # Size of the 1D-grid
    self.array_length = array_length
    # Initialize the array of bits to be random
    self.agent_pos = random.choices([0,1], k=array_length)

    # Define action and observation space
    # They must be gym.spaces objects
    # Example when using discrete actions, we have two: left and right

    self.action_space = spaces.Discrete(array_length)
    # The observation will be the coordinate of the agent
    # this can be described both by Discrete and Box space
    self.observation_space = spaces.Box(low=0, high=1,
                                    shape=(array_length,), dtype=np.uint8)
def reset(self): # Initialize the array to have random values self.time = 0
    print(self.agent_pos)
    self.agent_pos = random.choices([0,1], k=self.array_length)
    return np.array(self.agent_pos)

def step(self, action): 
    self.time +=  1 
    if not 0 <= action < self.array_length: 
         raise ValueError("Received invalid action={} which is not part of the action space".format(action)) 
    self.agent_pos[action] ^= 1  # flip the bit
    if self.agent_pos[action] == 1:
        reward = 1
    else:
        reward = -1

    done = all(self.agent_pos)

    info = {}

    return np.array(self.agent_pos), reward, done, info
def render(self, mode='console'): 
    print(self.agent_pos)
def close(self): 
    pass

完成数组版本中代码的最后几行很简单:

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env(lambda: BitFlipEnv(array_length=50), n_envs=12)
# Train the agent
model = PPO('MlpPolicy', env,  verbose=1).learn(500000)

我不能再spaces对图表使用稳定基线了,那么对于这个玩具问题,让稳定基线与我的 dgl 图表交互的正确方法是什么?

4

0 回答 0