3

我正在构建一个使用 Keras 在不平衡数据集上训练的 CNN 模型。我正在使用 imblearn 提供的 imblearn.keras.balanced_batch_generator 进行数据重新采样。

我的 x_train 数组的形状为 (n_samples, 32, 32, 1),而 balance_batch_generator 的 fit_generator 采用形状为 (n_samples, n_features) 的 x_train 的输入。

如何将图像(32、32、1)的尺寸合并为一个 dem n_features?

train = np.array(train,dtype="float32") #as mnist
train_labels = np.array(train_labels,dtype="float32") #as mnist
train = np.reshape(train,(-1,64,64,1))

输出:

(9098, 64, 64, 1)
(9098, 1)

为 CNN 上传数据:

x_train = np.load(open(r'C:\...\train.npy', 'rb'))
y_train = np.load(open(r'C:\...\train_labels.npy', 'rb'))
y_train = keras.utils.to_categorical(y_train, num_classes = 5)

from imblearn.keras import BalancedBatchGenerator
from imblearn.keras import balanced_batch_generator

from imblearn.under_sampling import NearMiss


training_set_generator = balanced_batch_generator( # Create Training set
        x_train, y_train,
        sampler=NearMiss())



validation_set_generator = balanced_batch_generator( # Create Testing set
        x_valid, y_valid,
        sampler=NearMiss())


#STEP_SIZE_TRAIN=training_set_generator.n//training_set_generator.batch_size
#STEP_SIZE_VALID=validation_set_generator.n//validation_set_generator.batch_size

history = classifier.fit_generator(generator=training_set_generator, # Fit it to the training set and tested it on the testing set
        validation_data=validation_set_generator,
        epochs=10)

错误按摩:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-63-32bb382e1b13> in <module>
     11 training_set_generator = balanced_batch_generator( # Create Training set
     12         x_train, y_train,
---> 13         sampler=NearMiss())
     14 
     15 

~\Anaconda3\lib\site-packages\imblearn\keras\_generator.py in balanced_batch_generator(X, y, sample_weight, sampler, batch_size, keep_sparse, random_state)
    233     return tf_bbg(X=X, y=y, sample_weight=sample_weight,
    234                   sampler=sampler, batch_size=batch_size,
--> 235                   keep_sparse=keep_sparse, random_state=random_state)

~\Anaconda3\lib\site-packages\imblearn\tensorflow\_generator.py in balanced_batch_generator(X, y, sample_weight, sampler, batch_size, keep_sparse, random_state)
    129         if sampler_.__class__.__name__ not in DONT_HAVE_RANDOM_STATE:
    130             set_random_state(sampler_, random_state)
--> 131     sampler_.fit_resample(X, y)
    132     if not hasattr(sampler_, 'sample_indices_'):
    133         raise ValueError("'sampler' needs to have an attribute "

~\Anaconda3\lib\site-packages\imblearn\base.py in fit_resample(self, X, y)
     78         self._deprecate_ratio()
     79 
---> 80         X, y, binarize_y = self._check_X_y(X, y)
     81 
     82         self.sampling_strategy_ = check_sampling_strategy(

~\Anaconda3\lib\site-packages\imblearn\base.py in _check_X_y(X, y)
    136     def _check_X_y(X, y):
    137         y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
--> 138         X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
    139         return X, y, binarize_y
    140 

~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator)
    754                     ensure_min_features=ensure_min_features,
    755                     warn_on_dtype=warn_on_dtype,
--> 756                     estimator=estimator)
    757     if multi_output:
    758         y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False,

~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
    568         if not allow_nd and array.ndim >= 3:
    569             raise ValueError("Found array with dim %d. %s expected <= 2."
--> 570                              % (array.ndim, estimator_name))
    571         if force_all_finite:
    572             _assert_all_finite(array,

ValueError: Found array with dim 4. Estimator expected <= 2.

4

0 回答 0