我正在尝试使用 tf.keras.utils.Sequence 对象作为我的 keras 模型的输入,以便我可以使用almentations 库应用在 tensorflow 中不可用的增强。但是这样做时我遇到了错误。(这里提到的图像预处理操作只是为了清楚起见)
import albumentations as A
from tensorflow.keras.utils import Sequence
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Conv2D, Flatten, MaxPool2D, Dropout
from tensorflow.keras.models import Sequential
TRAIN_DIR = os.path.join('..', 'Data', 'PetImages')
def load_data():
list_of_fpaths = glob.glob('../Data/PetImages/Cat/*')
labels = [1] * len(list_of_fpaths)
temp = glob.glob('../Data/PetImages/Dog/*')
list_of_fpaths.extend(temp)
labels.extend([0] * len(temp))
return list_of_fpaths, labels
# Now list of fpaths contain the list of file paths and labels contain
# corresponding labels
class DataSequence(Sequence):
def __init__(self, x_set, y_set, batch_size, augmentations):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.augment = augmentations
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
a = np.array([
self.augment(image=plt.imread(file_name))["image"] for file_name in
batch_x
])
b = np.array(batch_y)
return a,b
def get_model(input_shape):
model = Sequential([
Conv2D(8, 3, activation='relu', input_shape=input_shape),
MaxPool2D(2),
Conv2D(16, 3, activation='relu'),
MaxPool2D(2),
Conv2D(32, 3, activation='relu'),
MaxPool2D(2),
Conv2D(32, 3, activation='relu'),
MaxPool2D(2),
Conv2D(32, 3, activation='relu'),
MaxPool2D(2),
Flatten(),
Dense(1024, activation='relu'),
Dropout(0.3),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
return model
ALBUMENTATIONS_TRAIN = A.Compose([
A.Resize(256, 256),
# A.Resize(512, 512),
A.ToFloat(),
# A.RandomCrop(384, 384, p=0.5),
])
ALBUMENTATIONS_TEST = A.Compose([
A.ToFloat(),
A.Resize(256, 256)
])
X, Y = load_data()
train_gen = DataSequence(X, Y, 16, ALBUMENTATIONS_TRAIN)
model = get_model(input_shape=(256,256,3))
model.fit(train_gen,epochs=100)
我得到的错误是
17/748 [..............................] - ETA: 1:06 - loss: 0.4304 - accuracy: 0.92282020-07-08 13:25:47.751964: W tensorflow/core/framework/op_kernel.cc:1741] Invalid argument: ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\script_ops.py", line 243, in __call__
ret = func(*args)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
return func(*args, **kwargs)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 785, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 801, in wrapped_generator
for data in generator_fn():
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 932, in generator_fn
yield x[i]
File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 40, in __getitem__
a = np.array([
ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):
File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 89, in <module>
model.fit(train_gen,epochs=100)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapper
return method(self, *args, **kwargs)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\training.py", line 848, in fit
tmp_logs = train_function(iterator)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\def_function.py", line 611, in _call
return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 2420, in __call__
return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 1661, in _filtered_call
return self._call_flat(
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 1745, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 593, in call
outputs = execute.execute(
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\script_ops.py", line 243, in __call__
ret = func(*args)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
return func(*args, **kwargs)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 785, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 801, in wrapped_generator
for data in generator_fn():
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 932, in generator_fn
yield x[i]
File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 40, in __getitem__
a = np.array([
ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
[[{{node PyFunc}}]]
[[IteratorGetNext]]
[[IteratorGetNext/_4]]
(1) Invalid argument: ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
Traceback (most recent call last):
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\script_ops.py", line 243, in __call__
ret = func(*args)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 309, in wrapper
return func(*args, **kwargs)
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 785, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 801, in wrapped_generator
for data in generator_fn():
File "C:\Users\aksha\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 932, in generator_fn
yield x[i]
File "D:/ACAD/TENSORFLOW/Rough/data_aug_pipeline.py", line 40, in __getitem__
a = np.array([
ValueError: could not broadcast input array from shape (256,256,3) into shape (256,256)
[[{{node PyFunc}}]]
[[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_1195]
Function call stack:
train_function -> train_function
Process finished with exit code 1
请帮助我理解我犯了什么错误。