我正在尝试在体育比赛中运行稳定的基线,但不断收到以下错误
Traceback (most recent call last):
File "/home/dev/Desktop/Projects/AI/NBA2/stable_baselines_run.py", line 35, in <module>
model.learn(total_timesteps=10000)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/a2c/a2c.py", line 189, in learn
return super(A2C, self).learn(
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 234, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
actions, values, log_probs = self.policy.forward(obs_tensor)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 566, in forward
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 607, in _get_action_dist_from_latent
return self.action_dist.proba_distribution(action_logits=mean_actions)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/distributions.py", line 326, in proba_distribution
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/distributions.py", line 326, in <listcomp>
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/torch/distributions/categorical.py", line 64, in __init__
super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/torch/distributions/distribution.py", line 53, in __init__
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter logits has invalid values
我已经删除了所有 NaN(替换为 0)并对数据进行了规范化,以便所有数据都在 0 和 1 之间,但仍然找不到无效值。
这是我的自定义环境:
import gym
from gym import spaces
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
INITIAL_BALANCE = 100
class BettingEnv(gym.Env):
# metadata = {'render.modes': ['human']}
def __init__(self, df, results, INITIAL_BALANCE=100):
self.df = df
self.results = results
self.initial_balance = INITIAL_BALANCE
self.balance = INITIAL_BALANCE
self.profit = 0
self.starting_point = np.random.randint(len(self.df) - len(self.df) * 0.1) # Start anywhere but in the end 10%
self.timestep = 0
self.games_won = 0
self.game_bets = []
self.game_number = self.starting_point + self.timestep
self.action_space = spaces.MultiDiscrete([3,10])
self.observation_space = spaces.Box(
low = self.df.min().min(), # Lowest value found in df
high = self.df.max().max(), # Search the df for the max value (this may change with different data)
shape = (df.shape[1],), # shape of one row of the df
# dtype = np.float16
)
print('First ob: ',self.df.loc[self.game_number])
def _next_obs(self):
print('Get next obs')
# Get next game row
obs = self.df.loc[self.timestep]
print('next obs success')
return obs
def _print_bet_csv(self):
# Create bet_info_df
bet_info_df = pd.DataFrame(self.game_bets)
results_df = self.results.reset_index()
# #Merge dfs
self.merged_df = pd.merge(bet_info_df, results_df, on=['index', 'Home Odds', 'Vis Odds', 'Home Win'])
self.merged_df.set_index('index', inplace=True)
# #Print df
self.merged_df.to_csv('./temp/MLB Bot Betting DF.csv', index=True)
def _print_bet_chart(self):
x_axis = [i for i in range(self.timestep)]
plt.plot(x_axis, self.merged_df['Bankroll'])
plt.title('Bankroll')
plt.ylabel('Dollars')
plt.xlabel('Games')
plt.savefig('./temp/NBA_Bot_Betting.png')
def _take_action(self, action):
print('Start action')
# Init
action_type = action[0]
amount = action[1] + 1
self.game_number = self.starting_point + self.timestep
game_result = self.results['Home Win'][self.game_number]
odds = 0
bet_on = 'NA'
# VISITOR BET
if action_type == 0:
bet_on = 'False'
# Find vis odds
odds = self.results['Vis Odds'][self.game_number]
if odds == 0:
amount = 0
# Place bet
self.balance -= amount
# Check if win
if game_result == False:
self.balance += round(amount * odds, 2)
self.games_won += 1
# NO BET
if action_type == 1:
bet_on = 'No bet'
# HOME BET
if action_type == 2:
bet_on = 'True'
# Find home odds
odds = self.results['Home Odds'][self.game_number]
if odds == 0:
amount = 0
# Place bet
self.balance -= amount
# Check win
if game_result == True:
self.balance += round(amount * odds, 2)
self.games_won += 1
self.balance = round(self.balance, 2)
bet_info = {
'index': self.game_number,
'Home Odds': self.results['Home Odds'][self.game_number],
'Vis Odds': self.results['Vis Odds'][self.game_number],
'Bet on': bet_on,
'Home Win': game_result,
'Amount': amount,
'Odds': odds,
'Bankroll': self.balance
}
self.game_bets.append(bet_info)
print('Finish action')
return bet_info
def step(self, action):
print('Start step')
info = self._take_action(action)
self.timestep += 1
# Reward
gamma = (self.timestep / len(self.df)) # time discount
self.profit = self.balance - self.initial_balance
reward = self.profit * gamma
# Done
done = self.balance <= 0
# Obs
obs = self._next_obs()
# If last game, print results and start from beginning
#test the printing of csv
if self.timestep == 2500:
self._print_bet_csv()
self._print_bet_chart()
self.game_bets = []
print('Starting point: ',self.starting_point)
print('Chart printed')
print('Finished Step')
return obs, reward, done, info
def reset(self):
self.initial_balance = INITIAL_BALANCE
self.balance = INITIAL_BALANCE
self.profit = 0
self.starting_point = np.random.randint(len(self.df) - len(self.df) * 0.1) # Start anywhere but in the end 10%
self.timestep = 0
self.games_won = 0
self.game_bets = []
def render(self, mode='human', close=False):
print('Timestep: ', self.timestep)
print('Profit: ', self.profit)
print('Games Won: ', self.games_won)
print('Balance: ', self.balance)
这是我运行环境的文件:
import time
start_time = time.time()
import os
import random
import json
import gym
from gym import spaces
import pandas as pd
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO, A2C
from Betting_env import BettingEnv
data = pd.read_csv('Scraping/Games and Stats.csv')
df = data.drop(['Date', 'Home', 'Visitor', 'Home PTS', 'Vis PTS', 'Home Points Dif', 'Home Win'], axis=1)
df = df.astype(float)
normed = (df-df.min())/(df.max()-df.min())
normed = normed.round(10)
env = DummyVecEnv([lambda: BettingEnv(normed, data, INITIAL_BALANCE=100)])
model = A2C('MlpPolicy', env, verbose=0)
model.learn(total_timesteps=10000)
save_path = os.path.join('Training', 'Saved Models', 'Betting_Model_A2C')
model.save(save_path)
end_time = time.time()
total_time = end_time - start_time
print(round(total_time / 60 / 60), ' Hours ', round(total_time / 60), ' Minutes')
更新:通过 stable_baselines3 使用 VecCheckNan() 和 check_env() 函数后,我收到以下错误消息。VecCheckNan() 给出:
Traceback (most recent call last):
File "/home/dev/Desktop/Projects/AI/NBA2/stable_baselines_run.py", line 51, in <module>
model.learn(total_timesteps=10000)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/ppo/ppo.py", line 299, in learn
return super(PPO, self).learn(
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 226, in learn
total_timesteps, callback = self._setup_learn(
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/base_class.py", line 420, in _setup_learn
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_check_nan.py", line 46, in reset
self._check_val(async_step=False, observations=observations)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_check_nan.py", line 84, in _check_val
raise ValueError(msg)
ValueError: found nan in observations.
Originated from the environment observation (at reset)
我已经打印出第一个观察结果,那里没有 NaN。
check_env() 给出:
Traceback (most recent call last):
File "/home/dev/Desktop/Projects/AI/NBA2/stable_baselines_run.py", line 42, in <module>
check_env(env)
File "/home/dev/anaconda3/envs/sb/lib/python3.9/site-packages/stable_baselines3/common/env_checker.py", line 245, in check_env
assert isinstance(
AssertionError: Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py
我的 Betting_Env 课程中有gym.Env。