我昨天发布了一个关于最简化代码的问题。有人建议我发送完整的场景。以下是完整场景的示例代码和分析。所有数据均为演示数据,代码只是我原始代码的简化示例代码。
老问题:如何比较df的列之间的值更快?
能够解决这两个问题之一会很棒。
在我的代码中,需要优化的三个方法是:BackTestStrategy.get_status()
、、
BackTestStrategy.calc_df()
和
BackTestStrategy.calc_trade_sig()
。
根据snakeviz
生成的prof文件的分析:
我的原始代码运行大约需要 27 分钟。
- 从 2018 年 1 月 1 日开始,我电脑上的总运行时间约为 21 分钟。
- 运行时间
BackTestStrategy.get_status()
为442秒。 - 运行时间
BackTestStrategy.calc_df()
为 341 秒。 - 运行时间
BackTestStrategy.calc_trade_sig()
为58秒。 - 虽然运行时间
BackTestStrategy.buy_or_sell()
是135秒,但是这里只是日志调用,暂时不需要优化。
我的代码如下:
import io
import cProfile
import pstats
from numba import njit
import sys
import os
import datetime
import pandas as pd
import numpy as np
from time import strftime
import backtrader as bt
cerebro = None
class MyPandasFeed(bt.feeds.PandasData):
# PandasFeed adds a column
lines = ('ema5',)
params = (('ema5', 1),)
class BackTestStrategy(bt.Strategy):
# Backtrader core class
my_lines = {}
status = None
status_changed = False
long_short = None
start_time = None
calc_df_flag = False
last_price = None
max_ema5_duration = None
log_file = None
calc_trade_sig_flag = None
trade_sig = ''
last_trade_sig = ''
def __init__(self):
self.init_log_file()
pid = os.getpid()
self.log('')
self.log(f'PID: {pid}')
self.log('')
end_time = datetime.datetime.now()
self.log(f'Time to prepare data:{end_time - self.start_time}')
self.create_code_duration_reference()
pass
def init_log_file(self):
# Initialize the log file name.
start_time = self.start_time
now = datetime.datetime.strftime(start_time, "%Y%m%d_%H%M%S")
_year = now[:4]
_month = now[:6]
_date = now[:8]
file_name = sys.argv[0].split('/')[-1]
_path = f'./'
file_name1 = f'{_path}{file_name}_print_log_{_date}.log'
self.log_file = file_name1
pass
def log(self, txt, do_print=True, dt=None):
# Initialize the log function.
dt = dt or self.datas[0].datetime.datetime(0)
dt1 = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
text = ('%s - %s - %s' % (dt1, dt, txt))
if do_print:
print(text)
try:
with open(self.log_file, "a+") as foo:
foo.write(text + '\n')
except Exception as e:
log(e)
return
def create_code_duration_reference(self):
# Create an association between my_lines and lines.
for i in range(len(self.datas)):
duration = int(self.datas[i]._name)
self.my_lines[duration] = self.datas[i]
return
def next(self):
self.prepare_data()
if self.get_base_status():
self.get_status()
self.calc_df()
self.calc_trade_sig()
max_ema5_duration = self.max_ema5_duration
flag = (self.trade_sig != self.last_trade_sig)
if flag:
if max_ema5_duration >= 1800:
self.buy_or_sell()
self.last_trade_sig = self.trade_sig
def buy_or_sell(self):
# Execute buying and selling action, omitted
self.log(f'{self.long_short} something here by {self.max_ema5_duration}')
pass
def calc_df(self):
if not self.calc_df_flag:
df = self.status
df['ema5_long'] = self.last_price >= df.ema5.to_numpy()
df['ema5_short'] = ~df.ema5_long.to_numpy()
self.status = df
return
def prepare_data(self):
self.status_changed = False
self.calc_df_flag = False
self.calc_trade_sig_flag = False
pass
def stop(self):
end_time = datetime.datetime.now()
self.log('')
self.log(f'Total running time:{end_time - self.start_time}')
self.log('')
return
def get_status(self):
if not self.status_changed:
df = pd.DataFrame(
{
# 'xd4': self.my_lines[code][duration].xd4[0],
# 'last_price': self.my_lines[code][duration].close[0],
'duration': duration,
'ema5': self.my_lines[duration].ema5[0],
} for duration in self.durations
)
# df.duration = df.duration.astype(int)
df['last_price'] = self.my_lines[self.first_duration].close[0]
df = df.reset_index(drop=True)
self.status = df
self.status_changed = True
return
@staticmethod
@njit
def find_last_true_of_first_group(a, b):
if not a[0]:
return 0
else:
for i in range(1, len(a)):
if not a[i]:
return b[i - 1]
return b[-1]
def calc_trade_sig(self):
if not self.calc_trade_sig_flag:
df = self.status
long_short = self.long_short
max_ema5_duration = (
self.find_last_true_of_first_group(
df[f'ema5_{long_short}'].to_numpy(), df.duration.to_numpy()))
self.max_ema5_duration = max_ema5_duration
self.trade_sig = f' {long_short} {max_ema5_duration}'
pass
def get_base_status(self):
return_value = False
df = self.my_lines[self.first_duration]
close = df.close[0]
self.last_price = close
ema5 = df.ema5[0]
long = close >= ema5
short = not long
if long | short:
# The actual conditions are more complicated than this.
return_value = True
if long:
self.long_short = 'long'
else:
self.long_short = 'short'
return return_value
class MyCerebro(bt.Cerebro):
durations: list = [60, 300, 900, 1800, 3600, 7200, 86400, 604800, 2592000]
duration_to_freq: dict = {60: 'T', 120: '2T', 300: '5T', 900: '15T', 1800: '30T', 3600: 'H',
7200: '2H', 14400: '4H', 86400: 'D', 172800: '2D', 259200: '3D', 604800: 'W',
2592000: 'BM'}
klines = {}
cash = 100
start_dt = None
end_dt = None
def __init__(self):
super(MyCerebro, self).__init__()
self.first_duration = self.durations[0]
self.broker.setcash(self.cash * 10000.0)
self.addstrategy(BackTestStrategy)
self.generate_data()
self.factor_data()
self.read_data()
BackTestStrategy.first_duration = self.first_duration
BackTestStrategy.durations = self.durations
def factor_data(self):
for duration in self.durations:
df = self.klines[duration]
df['ema5'] = self.get_ema(df.close, 5)
self.klines[duration] = df
return
@staticmethod
def get_ema(s, span):
ema = s.ewm(span=span, adjust=False).mean()
return ema
def generate_data(self):
datetime1 = pd.date_range(start=self.start_dt, end=self.end_dt, freq='T')
df = pd.DataFrame(datetime1, columns=['datetime'])
df['time'] = df['datetime'].dt.strftime('%X')
df['weekday'] = df['datetime'].dt.weekday
df = df[((df.time > '08:00:00') & (df.time <= '10:00:00')) | (
(df.time > '12:00:00') & (df.time <= '14:00:00'))]
df = df[(df.weekday != 0) & (df.weekday != 6)].reset_index(drop=True)
df['close'] = np.random.normal(100, 10, len(df))
df = df.drop(columns=['time', 'weekday'])
for duration in self.durations:
if duration == self.first_duration:
self.klines[duration] = df.copy()
else:
period = self.duration_to_freq[duration]
df1 = df.copy()
df1 = df1.set_index('datetime')
df1 = self.resample_data(df1, period, closed='right', label='right')
df1 = df1.reset_index()
self.klines[duration] = df1
return
@staticmethod
def resample_data(df, freq, *args, **kwargs):
df_resampled = df.resample(freq, *args, **kwargs).last()
df_resampled = df_resampled.dropna(subset=['close'])
return df_resampled
def read_data(self):
for duration in self.durations:
name = str(duration)
df = self.klines[duration]
data = MyPandasFeed(dataname=df,
datetime='datetime',
ema5='ema5', )
self.adddata(data, name=name)
return
def log(txt, dt=None):
if not dt:
dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f'{dt} - {txt}')
def get_cprofile(func):
def wrapper(*args, **kw):
pr = cProfile.Profile()
pr.enable()
func(*args, **kw)
pr.disable()
s = io.StringIO()
now = strftime("%Y%m%d_%H%M%S")
_year = now[:4]
_month = now[:6]
_date = now[:8]
file_name = sys.argv[0].split('/')[-1]
_path = './'
order_by = 'cumtime'
ps = pstats.Stats(pr, stream=s).sort_stats(order_by)
ps.print_stats()
ps.sort_stats(order_by).dump_stats(f"{_path}{file_name}_cprofile_log_{now}.prof")
now = strftime('%Y-%m-%d %H:%M:%S')
print(f'{now}')
return wrapper
@get_cprofile
def main():
global cerebro
MyCerebro.cash = 100
cerebro = MyCerebro()
cerebro.run()
if __name__ == '__main__':
MyCerebro.start_dt = datetime.datetime(2018, 1, 1)
MyCerebro.end_dt = datetime.datetime(2021, 12, 17) + datetime.timedelta(days=1)
BackTestStrategy.start_time = datetime.datetime.now()
main()