0

我正在尝试使用 tf.keras.utils.Sequence 对象作为我的 keras 模型的输入,以便我可以使用almentations 库应用在 tensorflow 中不可用的增强。但是这样做时我遇到了错误。(这里提到的图像预处理操作只是为了清楚起见)

import albumentations as A
from tensorflow.keras.utils import Sequence
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPool2D, Dropout
from tensorflow.keras.models import Sequential



TRAIN_DIR = os.path.join('..', 'Data', 'PetImages')


def load_data():
    list_of_fpaths = glob.glob('../Data/PetImages/Cat/*')
    labels = [1] * len(list_of_fpaths)
    temp = glob.glob('../Data/PetImages/Dog/*')
    list_of_fpaths.extend(temp)
    labels.extend([0] * len(temp))
    return list_of_fpaths, labels


# Now list of fpaths contain the list of file paths and labels contain
# corresponding labels

class DataSequence(Sequence):
    def __init__(self, x_set, y_set, batch_size, augmentations):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.augment = augmentations

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        a =  np.array([
            self.augment(image=plt.imread(file_name))["image"] for file_name in
            batch_x
        ])
        b = np.array(batch_y)
        return a,b


def get_model(input_shape):
    model = Sequential([
        Conv2D(8, 3, activation='relu', input_shape=input_shape),
        MaxPool2D(2),
        Conv2D(16, 3, activation='relu'),
        MaxPool2D(2),
        Conv2D(32, 3, activation='relu'),
        MaxPool2D(2),
        Conv2D(32, 3, activation='relu'),
        MaxPool2D(2),
        Conv2D(32, 3, activation='relu'),
        MaxPool2D(2),
        Flatten(),
        Dense(1024, activation='relu'),
        Dropout(0.3),
        Dense(1, activation='sigmoid')

    ])
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy']
                  )
    return model


ALBUMENTATIONS_TRAIN = A.Compose([
    A.Resize(256, 256),
   # A.Resize(512, 512),
    A.ToFloat(),
  #  A.RandomCrop(384, 384, p=0.5),

])

ALBUMENTATIONS_TEST = A.Compose([
    A.ToFloat(),
    A.Resize(256, 256)
])

X, Y = load_data()
train_gen = DataSequence(X, Y, 16, ALBUMENTATIONS_TRAIN)
model = get_model(input_shape=(256,256,3))
model.fit(train_gen,epochs=100)

我得到的错误是

 17/748 [..............................] - ETA: 1:06 - loss: 0.4304 - accuracy: 0.92282020-07-08 13:25:47.751964: W tensorflow/core/framework/op_kernel.cc:1741] Invalid argument: ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\script_ops.py", line 243, in __call__
    ret = func(*args)

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
    return func(*args, **kwargs)

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 785, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 801, in wrapped_generator
    for data in generator_fn():

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 932, in generator_fn
    yield x[i]

  File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 40, in __getitem__
    a =  np.array([

ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)


Traceback (most recent call last):
  File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 89, in <module>
    model.fit(train_gen,epochs=100)
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\training.py", line 848, in fit
    tmp_logs = train_function(iterator)
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\def_function.py", line 611, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 2420, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 1661, in _filtered_call
    return self._call_flat(
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 1745, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 593, in call
    outputs = execute.execute(
  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\script_ops.py", line 243, in __call__
    ret = func(*args)

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
    return func(*args, **kwargs)

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 785, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 801, in wrapped_generator
    for data in generator_fn():

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 932, in generator_fn
    yield x[i]

  File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 40, in __getitem__
    a =  np.array([

ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)


     [[{{node PyFunc}}]]
     [[IteratorGetNext]]
     [[IteratorGetNext/_4]]
  (1) Invalid argument:  ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\script_ops.py", line 243, in __call__
    ret = func(*args)

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
    return func(*args, **kwargs)

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 785, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 801, in wrapped_generator
    for data in generator_fn():

  File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 932, in generator_fn
    yield x[i]

  File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 40, in __getitem__
    a =  np.array([

ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)


     [[{{node PyFunc}}]]
     [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_1195]

Function call stack:
train_function -> train_function


Process finished with exit code 1

请帮助我理解我犯了什么错误。

4

1 回答 1

1

根据错误消息,您的数据集中至少有一张灰度图像已调整为 256x256,因此无法适应您的网络。

于 2020-07-19T07:42:44.837 回答