0

我尝试运行以下代码,但失败了。我该如何解决?顺便说一句,它在 Google Colab 上运行。我使用python 3.9。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import NMF
from sklearn.datasets import fetch_olivetti_faces
from sklearn.model_selection import train_test_split
faces, _ = fetch_olivetti_faces(return_X_y=True, shuffle=True)

train_faces, test_faces = train_test_split(faces, test_size=0.1)


def show_faces(faces):
    plt.figure()
    num_rows, num_cols= 2, 3
    for i in range(num_rows * num_cols):
        plt.subplot(num_rows, num_cols, i + 1)
        plt.imshow(np.reshape(faces[i], (64, 64)), cmap=plt.cm.gray)

        
damaged_faces =[]
for face in test_faces:
    idx = np.random.choice(range(64 * 64), size=1024)
    damaged_face = face.copy()
    damaged_face[idx] = 0.
    damaged_faces.append(damaged_face)
    
    
nmf = NMF(n_components=10)
nmf.fit(train_faces)    
matrix1 = nmf.transform(damaged_faces)
matrix2 = nmf.components_
show_faces(matrix1 @ matrix2)

错误如下

TypeError Traceback(最近一次通话最后)/var/folders/c_/xfclsffs46x_pr2y70z2qfsm0000gn/T/ipykernel_7259/2889088273.py in ----> 1 matrix1 = nmf.transform(damaged_faces) 2 matrix2 = nmf.components_ 3 show_faces(matrix1 @矩阵2)

/opt/homebrew/Caskroom/miniforge/base/envs/ps/lib/python3.9/site-packages/sklearn/decomposition/ nmf.py in transform(self, X) 1688 1689 with config_context(assume_finite=True): - > 1690 W,* = 自身。fit_transform(X, H=self.components , update_H=False) 1691 1692 返回 W

/opt/homebrew/Caskroom/miniforge/base/envs/ps/lib/python3.9/site-packages/sklearn/decomposition/_nmf.py in _fit_transform(self, X, y, W, H, update_H) 1595 1596
初始化或检查 W 和 H -> 1597 W, H = self._check_w_h(X, W, H, update_H) 1598 1599 # 缩放正则化项

/opt/homebrew/Caskroom/miniforge/base/envs/ps/lib/python3.9/site-packages/sklearn/decomposition/_nmf.py in _check_w_h(self, X, W, H, update_H) 1470
_check_init(H, (self._n_components, n_features), "NMF (input H)") 1471 if H.dtype != X.dtype: -> 1472 raise TypeError( 1473 "H 应该与 X 具有相同的 dtype。得到 H.dtype = { }.".format(1474 H.dtype

TypeError:H 应该与 X 具有相同的 dtype。得到 H.dtype = float32。

4

0 回答 0