我有一个有很多绘图功能的类。我的意图是使用 matplotlib 的 subplot 方法将所有图分组到一张图像中,具体取决于我调用了多少个函数。
我尝试了类似以下的方法(这是我的程序的简短版本),但我不知道为什么不起作用。
任何帮助表示赞赏。提前致谢。
import itertools
import numpy as np
from matplotlib import pyplot as plt
class Base(object):
def __init__(self, a, multiPlot=True, numColGraph=None, numRowGraph=None,
figSize=None, DPI=None, num=None):
self.a = a
self.x = np.linspace(0, 5)
if multiPlot:
self.nCG = numColGraph
self.nRG = numRowGraph
else:
self.nCG = 1
self.nRG = 1
if figSize and DPI:
self.thePlot = plt.figure(figsize=figSize, dpi=DPI)
if num == 0:
self.plotId = itertools.count(1)
def createPlot1(self):
y = self.x**(a/2)
self.thePlot.add_subplot(self.nRG, self.nCG, next(self.plotId))
plt.plot(self.x, y, label=str(self.a)+'/2')
def createPlot2(self):
y = self.x**a
self.thePlot.add_subplot(self.nRG, self.nCG, next(self.plotId))
plt.plot(self.x, y, label=self.a)
def createPlot3(self):
y = self.x**(2*a)
self.thePlot.add_subplot(self.nRG, self.nCG, next(self.plotId))
plt.plot(self.x, y, label=str(self.a)+'*2')
if __name__ == "__main__":
A = np.linspace(0, 2, 5)
for i, a in enumerate(A):
Instance = Base(a, numColGraph=3, numRowGraph=len(A),
figSize=(12,10), DPI=100, num=i)
Instance.createPlot1()
Instance.createPlot2()
Instance.createPlot3()
plt.show()