4

任何人都可以帮助如何优化python中的绘图功能吗?我使用Matplotlib绘制财务数据。这里是绘制 OHLC 数据的小函数。如果我添加指标或其他数据,时间会显着增加。

import numpy as np
import datetime
from matplotlib.collections import LineCollection
from pylab import *
import urllib2

def test_plot(OHLCV):

    bar_width = 1.3
    date_offset = 0.5
    fig = figure(figsize=(50, 20), facecolor='w')
    ax = fig.add_subplot(1, 1, 1)
    labels = ax.get_xmajorticklabels()
    setp(labels, rotation=0)

    month = MonthLocator()
    day   = DayLocator()
    timeFmt = DateFormatter('%Y-%m-%d')

    colormap = OHLCV[:,1] < OHLCV[:,4]
    color = np.zeros(colormap.__len__(), dtype = np.dtype('|S5'))
    color[:] = 'red'
    color[np.where(colormap)] = 'green'
    dates = date2num( OHLCV[:,0])

    lines_hl = LineCollection( zip(zip(dates, OHLCV[:,2]), zip(dates, OHLCV[:,3])))
    lines_hl.set_color(color)
    lines_hl.set_linewidth(bar_width)
    lines_op = LineCollection( zip(zip((np.array(dates) - date_offset).tolist(), OHLCV[:,1]), zip((np.array(dates)).tolist(), parsed_table[:,1])))
    lines_op.set_color(color)
    lines_op.set_linewidth(bar_width)
    lines_cl = LineCollection( zip(zip((np.array(dates) + date_offset).tolist(), OHLCV[:,4]), zip((np.array(dates)).tolist(), parsed_table[:,4])))
    lines_cl.set_color(color)
    lines_cl.set_linewidth(bar_width)
    ax.add_collection(lines_hl,  autolim=True)
    ax.add_collection(lines_cl,  autolim=True)
    ax.add_collection(lines_op,  autolim=True)

    ax.xaxis.set_major_locator(month)
    ax.xaxis.set_major_formatter(timeFmt)
    ax.xaxis.set_minor_locator(day)

    ax.autoscale_view()

    ax.xaxis.grid(True, 'major')
    ax.grid(True)

    ax.set_title('EOD test plot')
    ax.set_xlabel('Date')
    ax.set_ylabel('Price , $')
    fig.savefig('test.png', dpi = 50, bbox_inches='tight')
    close()

if __name__=='__main__':

    data_table = urllib2.urlopen(r"http://ichart.finance.yahoo.com/table.csv?s=IBM&a=00&b=1&c=2012&d=00&e=15&f=2013&g=d&ignore=.csv").readlines()[1:][::-1]
    parsed_table = []
    #Format:  Date, Open, High, Low, Close, Volume
    dtype = (lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(),float, float, float, float, int)

    for row in data_table:

        field = row.strip().split(',')[:-1]
        data_tmp = [i(j) for i,j in zip(dtype, field)]
        parsed_table.append(data_tmp)

    parsed_table = np.array(parsed_table)

    import time
    bf = time.time()
    count = 100
    for i in xrange(count):
        test_plot(parsed_table)
    print('Plot time: %s' %(time.time() - bf) / count)

结果是这样的每个图的平均执行时间约为 2.6 秒。R中的图表 要快得多,但我没有测量性能,我不想使用Rpy,所以我相信我的代码效率低下。 在此处输入图像描述

4

1 回答 1

4

此解决方案重用Figure实例并异步保存绘图。您可以将其更改为具有与处理器一样多的图形,异步执行许多绘图,并且它应该会加快速度。实际上,这需要大约 1 秒的时间,低于我机器上的 2.6 秒。

import numpy as np
import datetime
import urllib2
import time
import multiprocessing as mp
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.collections import LineCollection

class AsyncPlotter():
    def __init__(self, processes=mp.cpu_count()):
        self.manager = mp.Manager()
        self.nc = self.manager.Value('i', 0)
        self.pids = []
        self.processes = processes

    def async_plotter(self, nc, fig, filename, processes):
        while nc.value >= processes:
            time.sleep(0.1)
        nc.value += 1
        print "Plotting " + filename
        fig.savefig(filename)
        plt.close(fig)
        nc.value -= 1

    def save(self, fig, filename):
        p = mp.Process(target=self.async_plotter,
                       args=(self.nc, fig, filename, self.processes))
        p.start()
        self.pids.append(p)

    def join(self):
        for p in self.pids:
            p.join()

class FinanceChart():
    def __init__(self, async_plotter):
        self.async_plotter = async_plotter
        self.bar_width = 1.3
        self.date_offset = 0.5
        self.fig = plt.figure(figsize=(50, 20), facecolor='w')
        self.ax = self.fig.add_subplot(1, 1, 1)
        self.labels = self.ax.get_xmajorticklabels()
        setp(self.labels, rotation=0)
        line_hl = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))
        line_op = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))
        line_cl = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))

        self.lines_hl = self.ax.add_collection(line_hl,  autolim=True)
        self.lines_op = self.ax.add_collection(line_cl,  autolim=True)
        self.lines_cl = self.ax.add_collection(line_op,  autolim=True)

        self.ax.set_title('EOD test plot')
        self.ax.set_xlabel('Date')
        self.ax.set_ylabel('Price , $')

        month = MonthLocator()
        day   = DayLocator()
        timeFmt = DateFormatter('%Y-%m-%d')
        self.ax.xaxis.set_major_locator(month)
        self.ax.xaxis.set_major_formatter(timeFmt)
        self.ax.xaxis.set_minor_locator(day)

    def test_plot(self, OHLCV, i):
        colormap = OHLCV[:,1] < OHLCV[:,4]
        color = np.zeros(colormap.__len__(), dtype = np.dtype('|S5'))
        color[:] = 'red'
        color[np.where(colormap)] = 'green'
        dates = date2num( OHLCV[:,0])
        date_array = np.array(dates)
        xmin = min(dates)
        xmax = max(dates)
        ymin = min(OHLCV[:,1])
        ymax = max(OHLCV[:,1])

        self.lines_hl.set_segments( zip(zip(dates, OHLCV[:,2]), zip(dates, OHLCV[:,3])))
        self.lines_hl.set_color(color)
        self.lines_hl.set_linewidth(self.bar_width)
        self.lines_op.set_segments( zip(zip((date_array - self.date_offset).tolist(), OHLCV[:,1]), zip(date_array.tolist(), OHLCV[:,1])))
        self.lines_op.set_color(color)
        self.lines_op.set_linewidth(self.bar_width)
        self.lines_cl.set_segments( zip(zip((date_array + self.date_offset).tolist(), OHLCV[:,4]), zip(date_array.tolist(), OHLCV[:,4])))
        self.lines_cl.set_color(color)
        self.lines_cl.set_linewidth(self.bar_width)

        self.ax.set_xlim(xmin,xmax)
        self.ax.set_ylim(ymin,ymax)

        self.ax.xaxis.grid(True, 'major')
        self.ax.grid(True)
        self.async_plotter.save(self.fig, '%04i.png'%i)

if __name__=='__main__':
    print "Starting"
    data_table = urllib2.urlopen(r"http://ichart.finance.yahoo.com/table.csv?s=IBM&a=00&b=1&c=2012&d=00&e=15&f=2013&g=d&ignore=.csv").readlines()[1:][::-1]
    parsed_table = []
    #Format:  Date, Open, High, Low, Close, Volume
    dtype = (lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(),float, float, float, float, int)

    for row in data_table:
        field = row.strip().split(',')[:-1]
        data_tmp = [i(j) for i,j in zip(dtype, field)]
        parsed_table.append(data_tmp)

    parsed_table = np.array(parsed_table)
    import time
    bf = time.time()
    count = 10

    a = AsyncPlotter()
    _chart = FinanceChart(a)

    print "Done with startup tasks"
    for i in xrange(count):
        _chart.test_plot(parsed_table, i)

a.join()
print('Plot time: %.2f' %(float(time.time() - bf) / float(count)))
于 2013-01-16T01:53:07.617 回答