2

我在使用 DQN 作为对角线和正弦波作为价格波动时遇到问题。当价格上涨时,会有奖励,并且在图表中显示为绿色。当价格下跌并被标记为红色时,奖励就会增加。请看这个链接这个链接的DQN在学习上比稳定基线的DQN好很多。

即使对 DQN 使用对角线,我也遇到了困难。

DQN对角线

辛波:如果结果相反,那就太好了。绿色代表上升,红色代表下降。

仙波线

我所做的是将学习率从 0.01 更改为 10。将 Epsilon 更改为 1。

在 PPO2 中,我可以得到一个不错的结果。对于罪波:

model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
model.learn(total_timesteps=500000)

使用 PPO2 的正弦波线

对于对角线,它也确实有效!

对角线

这是我的代码。只需评论和取消评论测试 PPO2 与 DQN 所需的内容

from copy import deepcopy
import numpy as np
import pandas as pd

import gym
import gym_anytrading


from stable_baselines import A2C , DQN ,ACKTR
from stable_baselines.common.vec_env import DummyVecEnv 
from stable_baselines.deepq.policies import MlpPolicy
import matplotlib.pyplot as plt
import math as m
from stable_baselines.deepq.policies import FeedForwardPolicy


from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2

class CustomDQNPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomDQNPolicy, self).__init__(*args, **kwargs,
                                              layers=[64,64,64],
                                              layer_norm=True,
                                              feature_extraction="mlp")

def main():
    n_cpu = 16    
    # df = gym_anytrading.datasets.STOCKS_GOOGL.copy()
    # print(df)
    
    
    # arraysin =[]
    # for x in range(0,200,1):
    #     arraysin = np.append(arraysin,(m.sin(x/10)+1))
        
    
    # print(arraysin)
    
    arraysin = np.arange(200/10.0) #linearly increasing prices
    
    df = pd.DataFrame(arraysin)
    
    
    
# # convert the column (it's a string) to datetime type
#     datetime_series = pd.to_datetime(df['date_of_birth'])

# # create datetime index passing the datetime series
#     datetime_index = pd.DatetimeIndex(datetime_series.values) 
    df = pd.DataFrame(arraysin)
    print(df)
    df.columns=['Close']
    # df=df.set_index(datetime_index)
    window_size = 1
    print(df)
    start_index = window_size
    end_index = len(df)

    env_maker = lambda: gym.make(
        'stocks-v0',
        df = df,
        window_size = window_size,
        frame_bound = (start_index, end_index)
    )
    print(df) 
    env = DummyVecEnv([env_maker for _ in range(n_cpu)])

    # policy_kwargs = dict(net_arch=[64, 'lstm', dict(vf=[128, 128, 128], pi=[64, 64])])
    # model = A2C('MlpLstmPolicy', env, verbose=1, policy_kwargs=policy_kwargs)
    
    
    # model = A2C(MlpPolicy, env, verbose=1,learning_rate=.01)
    # model = ACKTR(MlpPolicy, env, verbose=1,learning_rate=1)
    model = PPO2(MlpPolicy, env, verbose=1,learning_rate=.01)
    
    # model = DQN(policy=CustomDQNPolicy,env=env, verbose=1,
    #         learning_rate= .01,
    #         buffer_size= 10000,
    #         double_q = False,
    #         exploration_final_eps= 1,
    #         prioritized_replay= True)

    model.learn(total_timesteps=100000)
    # model.save('nzdusdDQN') 
    env = env_maker()
    observation = env.reset()

    while True:
        # observation = observation[np.newaxis, ...]

        # action = env.action_space.sample()
        action, _states = model.predict(observation)
        observation, reward, done, info = env.step(action)

        # env.render()
        if done:
            print("info:", info)
            break

    # for e in env.envs:
    #     plt.figure(figsize=(16, 6))
    #     e.render_all()
    #     plt.show()
    plt.figure(figsize=(16, 6))
    env.render_all()
    plt.show()    


if __name__ == '__main__':
    main()

系统信息: 描述您的环境特征:

  1. 视窗 10
  2. 张量流 1.15.0
  3. 稳定基线 2.10.2a0 dev_0
  4. 健身房-anytrading 1.2.0

康达清单:

PS E:\ML\reinforcementlearning\tradeorig> conda list
# packages in environment at C:\anaconda\envs\gymorig:
#
# Name                    Version                   Build  Channel
_tflow_select             2.2.0                     eigen         
absl-py                   0.11.0           py37haa95532_0         
alabaster                 0.7.12                   py37_0         
apipkg                    1.5                      pypi_0    pypi 
argh                      0.26.2                   py37_0         
asn1crypto                1.4.0                      py_0         
astor                     0.8.1                    py37_0         
astroid                   2.4.2                    py37_0         
async_generator           1.10             py37h28b3542_0         
atari-py                  0.2.6                    pypi_0    pypi
atomicwrites              1.4.0                      py_0
attrs                     20.2.0                     py_0
autopep8                  1.5.4                      py_0
babel                     2.8.0                      py_0
backcall                  0.2.0                      py_0
bcrypt                    3.2.0            py37he774522_0
blas                      1.0                         mkl
bleach                    3.2.1                      py_0
brotlipy                  0.7.0           py37he774522_1000
ca-certificates           2020.10.14                    0
certifi                   2020.6.20        py37haa95532_2
cffi                      1.14.3           py37h7a1dbc1_0
chardet                   3.0.4                 py37_1003
cloudpickle               1.6.0                      py_0
colorama                  0.4.4                      py_0
coverage                  5.3                      pypi_0    pypi
cryptography              2.3.1            py37h74b6da3_0
cycler                    0.10.0                   pypi_0    pypi
decorator                 4.4.2                      py_0
defusedxml                0.6.0                      py_0
diff-match-patch          20200713                   py_0
docutils                  0.16                     py37_1
entrypoints               0.3                      py37_0
execnet                   1.7.1                    pypi_0    pypi
flake8                    3.8.4                      py_0
future                    0.18.2                   py37_1
gast                      0.2.2                    py37_0
google-pasta              0.2.0                      py_0
grpcio                    1.14.1           py37h5c4b210_0
gym                       0.17.3                   pypi_0    pypi
gym-anytrading            1.2.0                    pypi_0    pypi
h5py                      2.10.0           py37h5e291fa_0
hdf5                      1.10.4               h7ebc959_0
icc_rt                    2019.0.0             h0cc432a_1
icu                       58.2                 ha925a31_3
idna                      2.10                       py_0
imagesize                 1.2.0                      py_0
importlab                 0.5.1                    pypi_0    pypi
importlib-metadata        2.0.0                      py_1
importlib_metadata        2.0.0                         1
iniconfig                 1.0.1                    pypi_0    pypi
intel-openmp              2020.2                      254
intervaltree              3.1.0                      py_0
ipykernel                 5.3.4            py37h5ca1d4c_0
ipython                   7.18.1           py37h5ca1d4c_0
ipython_genutils          0.2.0                    py37_0
isort                     5.6.4                      py_0
jedi                      0.17.1                   py37_0
jinja2                    2.11.2                     py_0
joblib                    0.17.0                   pypi_0    pypi
jpeg                      9b                   hb83a4c4_2
jsonschema                3.2.0                      py_2
jupyter_client            6.1.7                      py_0
jupyter_core              4.6.3                    py37_0
jupyterlab_pygments       0.1.2                      py_0
keras-applications        1.0.8                      py_1
keras-base                2.3.1                    py37_0
keras-preprocessing       1.1.0                      py_1
keyring                   21.4.0                   py37_1
kiwisolver                1.2.0                    pypi_0    pypi
lazy-object-proxy         1.4.3            py37he774522_0
libpng                    1.6.37               h2a8f88b_0
libprotobuf               3.13.0.1             h200bbdf_0
libsodium                 1.0.18               h62dcd97_0
libspatialindex           1.9.3                h33f27b4_0
livereload                2.6.3                    pypi_0    pypi
lxml                      4.5.2                    pypi_0    pypi
markdown                  3.3.2                    py37_0
markupsafe                1.1.1            py37hfa6e2cd_1
matplotlib                3.3.2                    pypi_0    pypi
mccabe                    0.6.1                    py37_1
mistune                   0.8.4           py37hfa6e2cd_1001
mkl                       2020.2                      256
mkl-service               2.3.0            py37hb782905_0
mkl_fft                   1.2.0            py37h45dec08_0
mkl_random                1.1.1            py37h47e9c7a_0
mpi4py                    3.0.3                    pypi_0    pypi
msgpack                   1.0.0                    pypi_0    pypi
multitasking              0.0.9                    pypi_0    pypi
nbclient                  0.5.1                      py_0
nbconvert                 6.0.7                    py37_0
nbformat                  5.0.8                      py_0
nest-asyncio              1.4.1                      py_0
networkx                  2.5                      pypi_0    pypi
ninja                     1.10.0.post2             pypi_0    pypi
numpy                     1.19.2           py37hadc3359_0
numpy-base                1.19.2           py37ha3acd2a_0
numpydoc                  1.1.0                      py_0
opencv-python             4.4.0.44                 pypi_0    pypi
openssl                   1.0.2u               he774522_0
opt_einsum                3.1.0                      py_0
packaging                 20.4                       py_0
pandas                    1.1.3            py37ha925a31_0
pandoc                    2.11                 h9490d1a_0
pandocfilters             1.4.2                    py37_1
paramiko                  2.4.2                    py37_0
parso                     0.7.0                      py_0
pathtools                 0.1.2                      py_1
pexpect                   4.8.0                    py37_1
pickleshare               0.7.5                 py37_1001
pillow                    7.2.0                    pypi_0    pypi
pip                       20.2.4                   py37_0
pluggy                    0.13.1                   py37_0
prompt-toolkit            3.0.8                      py_0
protobuf                  3.13.0.1         py37ha925a31_1
psutil                    5.7.2            py37he774522_0
py                        1.9.0                    pypi_0    pypi
pyasn1                    0.4.8                      py_0
pycodestyle               2.6.0                      py_0
pycparser                 2.20                       py_2
pydocstyle                5.1.1                      py_0
pyflakes                  2.2.0                      py_0
pyglet                    1.5.0                    pypi_0    pypi
pygments                  2.7.1                      py_0
pylint                    2.6.0                    py37_0
pynacl                    1.4.0            py37h62dcd97_1
pyopenssl                 19.0.0                   py37_0
pyparsing                 2.4.7                      py_0
pyqt                      5.6.0            py37ha878b3d_6
pyreadline                2.1                      py37_1
pyrsistent                0.17.3           py37he774522_0
pysocks                   1.7.1                    py37_1
pytest                    6.1.1                    pypi_0    pypi
pytest-cov                2.10.1                   pypi_0    pypi
pytest-env                0.6.2                    pypi_0    pypi
pytest-forked             1.3.0                    pypi_0    pypi
pytest-xdist              2.1.0                    pypi_0    pypi
python                    3.7.1                h33f27b4_4
python-dateutil           2.8.1                      py_0
python-jsonrpc-server     0.4.0                      py_0
python-language-server    0.35.1                     py_0
pytype                    2020.9.29                pypi_0    pypi
pytz                      2020.1                     py_0
pywin32                   227              py37he774522_1
pywin32-ctypes            0.2.0                 py37_1001
pyyaml                    5.3.1                    pypi_0    pypi
pyzmq                     19.0.2           py37ha925a31_1
qdarkstyle                2.8.1                      py_0
qt                        5.6.2           vc14h6f8c307_12
qtawesome                 1.0.1                      py_0
qtconsole                 4.7.7                      py_0
qtpy                      1.9.0                      py_0
quantstats                0.0.25                   pypi_0    pypi
requests                  2.24.0                     py_0
rope                      0.18.0                     py_0
rtree                     0.9.4            py37h21ff451_1
ruamel-yaml               0.16.12                  pypi_0    pypi
ruamel-yaml-clib          0.2.2                    pypi_0    pypi
scipy                     1.5.2            py37h9439919_0
seaborn                   0.11.0                   pypi_0    pypi
setuptools                50.3.0           py37h9490d1a_1
sip                       4.18.1           py37h6538335_2
six                       1.15.0                     py_0
snowballstemmer           2.0.0                      py_0
sortedcontainers          2.2.2                      py_0
sphinx                    3.2.1                      py_0
sphinx-autobuild          2020.9.1                 pypi_0    pypi
sphinx-rtd-theme          0.5.0                    pypi_0    pypi
sphinxcontrib-applehelp   1.0.2                      py_0
sphinxcontrib-devhelp     1.0.2                      py_0
sphinxcontrib-htmlhelp    1.0.3                      py_0
sphinxcontrib-jsmath      1.0.1                      py_0
sphinxcontrib-qthelp      1.0.3                      py_0
sphinxcontrib-serializinghtml 1.1.4                      py_0
spyder                    4.1.5                    py37_0
spyder-kernels            1.9.4                    py37_0
sqlite                    3.33.0               h2a8f88b_0
stable-baselines          2.10.2a0                  dev_0    <develop>
tabulate                  0.8.7                    pypi_0    pypi
tensorboard               2.0.0              pyhb38c66f_1
tensorflow                1.15.0          eigen_py37h9f89a44_0
tensorflow-base           1.15.0          eigen_py37h07d2309_0
tensorflow-estimator      1.15.1             pyh2649769_0
termcolor                 1.1.0                    py37_1
testpath                  0.4.4                      py_0
toml                      0.10.1                     py_0
tornado                   6.0.4            py37he774522_1
traitlets                 5.0.5                      py_0
typed-ast                 1.4.1            py37he774522_0
ujson                     4.0.1            py37ha925a31_0
urllib3                   1.25.11                    py_0
vc                        14.1                 h0510ff6_4
vs2015_runtime            14.16.27012          hf0eaf9b_3
watchdog                  0.10.3                   py37_0
wcwidth                   0.2.5                      py_0
webencodings              0.5.1                    py37_1
werkzeug                  0.16.1                     py_0
wheel                     0.35.1                     py_0
win_inet_pton             1.1.0                    py37_0
wincertstore              0.2                      py37_0
wrapt                     1.11.2           py37he774522_0
yaml                      0.2.5                he774522_0
yapf                      0.30.0                     py_0
yfinance                  0.1.55                   pypi_0    pypi
zeromq                    4.3.2                ha925a31_3
zipp                      3.3.1                      py_0
zlib                      1.2.11               h62dcd97_4
4

1 回答 1

0

我认为问题在于您在稳定基线中使用了默认网络结构。您可以在示例中看到:

model = Sequential()
model.add(Dense(4, init='lecun_uniform', input_shape=(2,)))
model.add(Activation('relu'))    
model.add(Dense(4, init='lecun_uniform'))
model.add(Activation('relu'))    
model.add(Dense(4, init='lecun_uniform'))
model.add(Activation('linear'))     
rms = RMSprop()
model.compile(loss='mse', optimizer=rms)

所以,这是一个非常简单的网络,有 3 层,每层有 4 个神经元。在 stable-baselines 中,您使用默认MlpPolicy的两层,有 64 个神经元。您可以通过传递给模型参数来轻松指定网络结构,policy_kwargs如下所示:

policy_kwargs = dict(        
        net_arch=[4, 4, 4]
    )

并且您的 DQN 模型可以通过以下方式初始化:

model = DQN('MlpPolicy', env, policy_kwargs=policy_kwargs, verbose=1)

此外。在您的第一个示例中,作者使用一个网络创建了简单的 DQN 模型。然而,在稳定基线等框架中,DQN 算法包括两个相同结构的网络用于训练和评估。这对于更复杂的问题很有用,而对于像您这样的简单问题,它可能效果不佳。

于 2021-05-21T10:37:54.790 回答