0

在训练了 VAE 模型之后,我得到了训练好的模型。而且,保存图像应该在每个 epoch 文件夹中,但图像没有保存。不知道问题出在哪里?以及如何解决?

      ValueError                                Traceback (most recent call last)
      <ipython-input-35-32d496ed4864> in <module>()
          1 with chainer.no_backprop_mode():
          2     x1 = model(x)
    ----> 3 save_images(x.data, 'train')
          4 save_images(x1.data, 'train_reconstructed')


    <ipython-input-31-3c651ae7f2a2> in save_images(x, filename)
         6     fig, ax = plt.subplots(3, 3, figsize=(9, 9), dpi=100)
         7     for ai, xi in zip(ax.flatten(), x):
   ----> 8          ai.imshow(xi.reshape(512,512,4))
         9      fig.savefig(filename)

       ~/anaconda3/envs/chainer_p36/lib/python3.6/site- 
       packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
       1808                         "the Matplotlib list!)" % (label_namer, 
         func.__name__),
         1809                         RuntimeWarning, stacklevel=2)
      -> 1810 return func(ax, *args, **kwargs)
         1811 
         1812         inner.__doc__ = _add_data_doc(inner.__doc__,

         ~/anaconda3/envs/chainer_p36/lib/python3.6/site- 
         packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, 
         interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, 
         filterrad, imlim, resample, url, **kwargs)
         5492                               resample=resample, **kwargs)
         5493 
      -> 5494im.set_data(X)
         5495         im.set_alpha(alpha)
         5496         if im.get_clip_path() is None:

         ~/anaconda3/envs/chainer_p36/lib/python3.6/site- 
       packages/matplotlib/image.py in set_data(self, A)
          628         A : array-like
          629         """
      --> 630  self._A = cbook.safe_masked_invalid(A, copy=True)
          631 
          632         if (self._A.dtype != np.uint8 and

       ~/anaconda3/envs/chainer_p36/lib/python3.6/site- 
            packages/matplotlib/cbook/__init__.py in 
       safe_masked_invalid(x, copy)
         782 
         783 def safe_masked_invalid(x, copy=False):
     --> 784x = np.array(x, subok=True, copy=copy)
         785     if not x.dtype.isnative:
         786         # Note that the argument to `byteswap` is 'inplace',

         ValueError: object __array__ method not producing an array

保存图像的功能

      def save_image(x, filename):

         fig, ax = plt.subplots(3, 3, figsize=(9, 9), dpi=100)

         for ai, xi in zip(ax.flatten(), x):

            ai.imshow(xi.reshape(512,512,4))

         fig.savefig(filename)

在编码和解码后,我将输出排除为相同的图像。

4

0 回答 0