0

我昨天发布了一个关于最简化代码的问题。有人建议我发送完整的场景。以下是完整场景的示例代码和分析。所有数据均为演示数据,代码只是我原始代码的简化示例代码。

老问题:如何比较df的列之间的值更快?

能够解决这两个问题之一会很棒。

在我的代码中,需要优化的三个方法是:BackTestStrategy.get_status()、、 BackTestStrategy.calc_df()BackTestStrategy.calc_trade_sig()

根据snakeviz生成的prof文件的分析:

我的原始代码运行大约需要 27 分钟。

  1. 从 2018 年 1 月 1 日开始,我电脑上的总运行时间约为 21 分钟。
  2. 运行时间BackTestStrategy.get_status()为442秒。
  3. 运行时间BackTestStrategy.calc_df()为 341 秒。
  4. 运行时间BackTestStrategy.calc_trade_sig()为58秒。
  5. 虽然运行时间BackTestStrategy.buy_or_sell()是135秒,但是这里只是日志调用,暂时不需要优化。

我的代码如下:

import io
import cProfile
import pstats
from numba import njit
import sys
import os
import datetime
import pandas as pd
import numpy as np
from time import strftime

import backtrader as bt

cerebro = None


class MyPandasFeed(bt.feeds.PandasData):
    # PandasFeed adds a column
    lines = ('ema5',)
    params = (('ema5', 1),)


class BackTestStrategy(bt.Strategy):
    # Backtrader core class
    my_lines = {}
    status = None
    status_changed = False
    long_short = None
    start_time = None
    calc_df_flag = False
    last_price = None
    max_ema5_duration = None
    log_file = None
    calc_trade_sig_flag = None
    trade_sig = ''
    last_trade_sig = ''

    def __init__(self):
        self.init_log_file()
        pid = os.getpid()
        self.log('')
        self.log(f'PID: {pid}')
        self.log('')
        end_time = datetime.datetime.now()
        self.log(f'Time to prepare data:{end_time - self.start_time}')
        self.create_code_duration_reference()
        pass

    def init_log_file(self):
        # Initialize the log file name.
        start_time = self.start_time
        now = datetime.datetime.strftime(start_time, "%Y%m%d_%H%M%S")
        _year = now[:4]
        _month = now[:6]
        _date = now[:8]
        file_name = sys.argv[0].split('/')[-1]
        _path = f'./'
        file_name1 = f'{_path}{file_name}_print_log_{_date}.log'
        self.log_file = file_name1
        pass

    def log(self, txt, do_print=True, dt=None):
        # Initialize the log function.
        dt = dt or self.datas[0].datetime.datetime(0)
        dt1 = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        text = ('%s - %s - %s' % (dt1, dt, txt))
        if do_print:
            print(text)
        try:
            with open(self.log_file, "a+") as foo:
                foo.write(text + '\n')
        except Exception as e:
            log(e)
        return

    def create_code_duration_reference(self):
        # Create an association between my_lines and lines.
        for i in range(len(self.datas)):
            duration = int(self.datas[i]._name)
            self.my_lines[duration] = self.datas[i]
        return

    def next(self):
        self.prepare_data()
        if self.get_base_status():
            self.get_status()
            self.calc_df()
            self.calc_trade_sig()
            max_ema5_duration = self.max_ema5_duration
            flag = (self.trade_sig != self.last_trade_sig)
            if flag:
                if max_ema5_duration >= 1800:
                    self.buy_or_sell()
                    self.last_trade_sig = self.trade_sig

    def buy_or_sell(self):
        # Execute buying and selling action, omitted
        self.log(f'{self.long_short} something here by {self.max_ema5_duration}')
        pass

    def calc_df(self):
        if not self.calc_df_flag:
            df = self.status
            df['ema5_long'] = self.last_price >= df.ema5.to_numpy()
            df['ema5_short'] = ~df.ema5_long.to_numpy()
            self.status = df
        return

    def prepare_data(self):
        self.status_changed = False
        self.calc_df_flag = False
        self.calc_trade_sig_flag = False
        pass

    def stop(self):
        end_time = datetime.datetime.now()
        self.log('')
        self.log(f'Total running time:{end_time - self.start_time}')
        self.log('')
        return

    def get_status(self):
        if not self.status_changed:
            df = pd.DataFrame(
                {
                    # 'xd4': self.my_lines[code][duration].xd4[0],
                    # 'last_price': self.my_lines[code][duration].close[0],
                    'duration': duration,
                    'ema5': self.my_lines[duration].ema5[0],
                } for duration in self.durations
            )
            # df.duration = df.duration.astype(int)
            df['last_price'] = self.my_lines[self.first_duration].close[0]
            df = df.reset_index(drop=True)
            self.status = df
            self.status_changed = True
        return

    @staticmethod
    @njit
    def find_last_true_of_first_group(a, b):
        if not a[0]:
            return 0
        else:
            for i in range(1, len(a)):
                if not a[i]:
                    return b[i - 1]
            return b[-1]

    def calc_trade_sig(self):
        if not self.calc_trade_sig_flag:
            df = self.status
            long_short = self.long_short
            max_ema5_duration = (
                self.find_last_true_of_first_group(
                    df[f'ema5_{long_short}'].to_numpy(), df.duration.to_numpy()))
            self.max_ema5_duration = max_ema5_duration
            self.trade_sig = f' {long_short} {max_ema5_duration}'
        pass

    def get_base_status(self):
        return_value = False
        df = self.my_lines[self.first_duration]
        close = df.close[0]
        self.last_price = close
        ema5 = df.ema5[0]
        long = close >= ema5
        short = not long
        if long | short:
            # The actual conditions are more complicated than this.
            return_value = True
            if long:
                self.long_short = 'long'
            else:
                self.long_short = 'short'
        return return_value


class MyCerebro(bt.Cerebro):
    durations: list = [60, 300, 900, 1800, 3600, 7200, 86400, 604800, 2592000]
    duration_to_freq: dict = {60: 'T', 120: '2T', 300: '5T', 900: '15T', 1800: '30T', 3600: 'H',
                              7200: '2H', 14400: '4H', 86400: 'D', 172800: '2D', 259200: '3D', 604800: 'W',
                              2592000: 'BM'}
    klines = {}
    cash = 100
    start_dt = None
    end_dt = None

    def __init__(self):
        super(MyCerebro, self).__init__()
        self.first_duration = self.durations[0]
        self.broker.setcash(self.cash * 10000.0)
        self.addstrategy(BackTestStrategy)
        self.generate_data()
        self.factor_data()
        self.read_data()
        BackTestStrategy.first_duration = self.first_duration
        BackTestStrategy.durations = self.durations

    def factor_data(self):
        for duration in self.durations:
            df = self.klines[duration]
            df['ema5'] = self.get_ema(df.close, 5)
            self.klines[duration] = df
        return

    @staticmethod
    def get_ema(s, span):
        ema = s.ewm(span=span, adjust=False).mean()
        return ema

    def generate_data(self):
        datetime1 = pd.date_range(start=self.start_dt, end=self.end_dt, freq='T')
        df = pd.DataFrame(datetime1, columns=['datetime'])
        df['time'] = df['datetime'].dt.strftime('%X')
        df['weekday'] = df['datetime'].dt.weekday
        df = df[((df.time > '08:00:00') & (df.time <= '10:00:00')) | (
                (df.time > '12:00:00') & (df.time <= '14:00:00'))]
        df = df[(df.weekday != 0) & (df.weekday != 6)].reset_index(drop=True)
        df['close'] = np.random.normal(100, 10, len(df))
        df = df.drop(columns=['time', 'weekday'])
        for duration in self.durations:
            if duration == self.first_duration:
                self.klines[duration] = df.copy()
            else:
                period = self.duration_to_freq[duration]
                df1 = df.copy()
                df1 = df1.set_index('datetime')
                df1 = self.resample_data(df1, period, closed='right', label='right')
                df1 = df1.reset_index()
                self.klines[duration] = df1
        return

    @staticmethod
    def resample_data(df, freq, *args, **kwargs):
        df_resampled = df.resample(freq, *args, **kwargs).last()
        df_resampled = df_resampled.dropna(subset=['close'])
        return df_resampled

    def read_data(self):
        for duration in self.durations:
            name = str(duration)
            df = self.klines[duration]
            data = MyPandasFeed(dataname=df,
                                datetime='datetime',
                                ema5='ema5', )
            self.adddata(data, name=name)
        return


def log(txt, dt=None):
    if not dt:
        dt = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f'{dt} - {txt}')


def get_cprofile(func):
    def wrapper(*args, **kw):
        pr = cProfile.Profile()
        pr.enable()
        func(*args, **kw)
        pr.disable()

        s = io.StringIO()

        now = strftime("%Y%m%d_%H%M%S")
        _year = now[:4]
        _month = now[:6]
        _date = now[:8]
        file_name = sys.argv[0].split('/')[-1]
        _path = './'

        order_by = 'cumtime'
        ps = pstats.Stats(pr, stream=s).sort_stats(order_by)
        ps.print_stats()

        ps.sort_stats(order_by).dump_stats(f"{_path}{file_name}_cprofile_log_{now}.prof")
        now = strftime('%Y-%m-%d %H:%M:%S')
        print(f'{now}')

    return wrapper


@get_cprofile
def main():
    global cerebro
    MyCerebro.cash = 100
    cerebro = MyCerebro()
    cerebro.run()


if __name__ == '__main__':
    MyCerebro.start_dt = datetime.datetime(2018, 1, 1)
    MyCerebro.end_dt = datetime.datetime(2021, 12, 17) + datetime.timedelta(days=1)
    BackTestStrategy.start_time = datetime.datetime.now()
    main()

4

0 回答 0