0

我运行以下代码来生成一个图表,其中根据https://gpflow.readthedocs.io/en/stable/notebooks/basics/regression.html绘制平均函数、95% 置信区间和来自后验的 10 个样本:

# GPflow
Y = np.asarray(p.metrics, dtype=float).reshape(-1, 1)
X = np.asarray(p.x, dtype=float).reshape(-1, 1)

 

# scaling X and Y to magnify effects of variance
X *= .01
Y *= .01

 

plt.plot(X, Y, 'kx', mew=2)

 

k = gpflow.kernels.Matern52(lengthscales=0.05)
# k = gpflow.kernels.Polynomial(degree=4, variance=1)

 

m = gpflow.models.GPR((X, Y), k)
m.likelihood.variance.assign(0.01)

 

# code below is from the docs:
# https://gpflow.readthedocs.io/en/docupdate/notebooks/regression.html
def plot(m):
    xx = np.linspace(min(X), max(X), 150)#[:,None]
    mean, var = m.predict_y(xx)
    plt.figure(figsize=(12, 6))
    plt.plot(X, Y, 'kx', mew=2)
    plt.plot(xx, mean, 'b', lw=2)
    plt.fill_between(xx[:,0], mean[:,0] - 2*np.sqrt(var[:,0]), mean[:,0] + 2*np.sqrt(var[:,0]), color='blue', alpha=0.2)

 

plot(m)
plt.show()

 


## generate test points for prediction
xx = np.linspace(min(X), max(X), 100).reshape(-1, 1)  # test points must be of shape (N, D)

 

## predict mean and variance of latent GP at test points
mean, var = m.predict_f(xx)

 

## generate 10 samples from posterior
tf.random.set_seed(1)  # for reproducibility
samples = m.predict_f_samples(xx, 10)  # shape (10, 100, 1)

 


## plot
plt.figure(figsize=(12, 6))
plt.plot(X, Y, "kx", mew=2)
plt.plot(xx, mean, "C0", lw=2)
plt.fill_between(
    xx[:, 0],
    mean[:, 0] - 1.96 * np.sqrt(var[:, 0]),
    mean[:, 0] + 1.96 * np.sqrt(var[:, 0]),
    color="C0",
    alpha=0.2,
)

 

plt.plot(xx, samples[:, :, 0].numpy().T, "C0", linewidth=0.5)
plt.show()

运行它时,我得到以下错误:

AttributeError                            Traceback (most recent call last)
<ipython-input-28-e1bf5ef03fcf> in <module>
     22 plt.figure(figsize=(12, 6))
     23 plt.plot(X, Y, "kx", mew=2)
---> 24 plt.plot(xx, mean, "C0", lw=2)
     25 plt.fill_between(
     26     xx[:, 0],

~/miniconda3/lib/python3.7/site-packages/matplotlib/pyplot.py in plot(scalex, scaley, data, *args, **kwargs)
   2787     return gca().plot(
   2788         *args, scalex=scalex, scaley=scaley, **({"data": data} if data
-> 2789         is not None else {}), **kwargs)
   2790 
   2791 

~/miniconda3/lib/python3.7/site-packages/matplotlib/axes/_axes.py in plot(self, scalex, scaley, data, *args, **kwargs)
   1663         """
   1664         kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D._alias_map)
-> 1665         lines = [*self._get_lines(*args, data=data, **kwargs)]
   1666         for line in lines:
   1667             self.add_line(line)

~/miniconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in __call__(self, *args, **kwargs)
    223                 this += args[0],
    224                 args = args[1:]
--> 225             yield from self._plot_args(this, kwargs)
    226 
    227     def get_next_color(self):

~/miniconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in _plot_args(self, tup, kwargs)
    385         if len(tup) == 2:
    386             x = _check_1d(tup[0])
--> 387             y = _check_1d(tup[-1])
    388         else:
    389             x, y = index_of(tup[-1])

~/miniconda3/lib/python3.7/site-packages/matplotlib/cbook/__init__.py in _check_1d(x)
   1400     else:
   1401         try:
-> 1402             ndim = x[:, None].ndim
   1403             # work around https://github.com/pandas-dev/pandas/issues/27775
   1404             # which mean the shape is not as expected. That this ever worked

AttributeError: 'Tensor' object has no attribute 'ndim'

X 和 Y 是两个给定的数组,例如:

Y = array([[14.13],
       [14.01],
       [13.59],
       [12.63],
       [11.44],
       [10.34],
       [ 9.3 ],
       [ 8.38],
       [ 7.49],
       [ 6.9 ],
       [ 6.72],
       [ 6.87],
       [ 7.07],
       [ 7.24],
       [ 8.36],
       [ 9.78],
       [10.64],
       [12.21],
       [12.88],
       [13.37],
       [13.57],
       [13.44],
       [13.2 ]])

X = array([[0.01],
       [0.02],
       [0.03],
       [0.04],
       [0.05],
       [0.06],
       [0.07],
       [0.08],
       [0.09],
       [0.1 ],
       [0.11],
       [0.12],
       [0.13],
       [0.14],
       [0.15],
       [0.16],
       [0.17],
       [0.18],
       [0.19],
       [0.2 ],
       [0.21],
       [0.22],
       [0.23]])
4

0 回答 0