0

我想通过不同的梯度下降优化方法来动画化寻找函数最小点的过程。为此,我使用 matplotlib 和赛璐珞包。问题是无法修复动画中的情节图例,并且在每个循环中,都会在前一个图例下方添加一个新图例,如下图所示。有没有办法修复传说并避免这个问题?

from celluloid import Camera
fig,ax = plt.subplots(1, 1,figsize=(10, 10))
camera = Camera(fig)
for i in range(path1.shape[1])
  ax.contour(x_mesh, y_mesh, z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap=plt.cm.jet)
  ax.plot(*minima_, 'r*', markersize=18)

  line, = ax.plot([], [], 'k', label='Simple SGD', lw=2)
  point, = ax.plot([], [], 'ko')
  line.set_data(path1[::,:i])
  point.set_data(path1[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with momentum', lw=2)
  point, = ax.plot([], [], 'ro')
  line.set_data(*path2[::,:i])
  point.set_data(*path2[::,i-1:i])

  line, = ax.plot([], [], 'g', label='SGD with Nesterov', lw=2)
  point, = ax.plot([], [], 'go')
  line.set_data(*path3[::,:i])
  point.set_data(*path3[::,i-1:i])

  line, = ax.plot([], [], 'b', label='SGD with Adagrad', lw=2)
  point, = ax.plot([], [], 'bo')
  line.set_data(*path4[::,:i])
  point.set_data(*path4[::,i-1:i])

  line, = ax.plot([], [], 'c', label='SGD with Adadelta', lw=2)
  point, = ax.plot([], [], 'co')
  line.set_data(*path5[::,:i])
  point.set_data(*path5[::,i-1:i]) 

  line, = ax.plot([], [], 'm', label='SGD with RMSprob', lw=2)
  point, = ax.plot([], [], 'mo')
  line.set_data(*path6[::,:i])
  point.set_data(*path6[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adam', lw=2)
  point, = ax.plot([], [], 'yo')
  line.set_data(*path7[::,:i])
  point.set_data(*path7[::,i-1:i])

  line, = ax.plot([], [], 'y', label='SGD with Adamax', lw=2)
  point, = ax.plot([], [], 'y*')
  line.set_data(*path8[::,:i])
  point.set_data(*path8[::,i-1:i])

  line, = ax.plot([], [], 'k', label='SGD with Nadam', lw=2)
  point, = ax.plot([], [], 'kp')
  line.set_data(*path9[::,:i])
  point.set_data(*path9[::,i-1:i])

  line, = ax.plot([], [], 'r', label='SGD with AMSGrad', lw=2)
  point, = ax.plot([], [], 'rD')
  line.set_data(*path10[::,:i])
  point.set_data(*path10[::,i-1:i])

  ax.legend(loc='upper left') 
  camera.snap()
animation = camera.animate()
animation.save('2D_animation_overlap.gif', writer='imagemagick')

在此处输入图像描述

4

1 回答 1

0

这里的最佳做法是创建自定义图例而不是自动生成图例,在这种情况下,可以通过

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

labels = ['Single SGD', 'SGD with momentum', 'SGD with Nesterov', 
          'SGD with Adagrad', 'SGD with Adadelta', 'SGD with RMSprob', 'SGD with Adam', 
          'SGD with Adamax', 'SGD with Nadam', 'SGD with AMSgrad']
colors = ['k', 'r', 'g', 'b', 'c', 'm', 'y', 'y', 'k', 'r']
handles = []
for c, l in zip(colors, labels):
    handles.append(Line2D([0], [0], color = c, label = l))

plt.legend(handles = handles, loc = 'upper left')

这会给你一个这样的传说:

在此处输入图像描述

您不需要在循环中包含任何这些,您可以在之前或之后执行它,它仍然可以工作。它也可以在循环中工作,但每次都不需要重新绘制图例。

使用 if 语句简单地保护图例创建而不是手动创建图例也足够了。IE

    # ...
    if i == 0:
        ax.legend(loc = 'upper left')

但我建议不要鼓励自动生成图例,而是直接创建图例。

于 2020-01-24T08:16:37.240 回答