训练神经网络时会出现停止迭代错误...
这是适合模型的代码:
model.fit(
train_generator,
steps_per_epoch = num_train_samples // batch_size,
epochs = 10,
validation_data = validation_generator,
validation_steps = num_val_samples // batch_size)
这就是错误:
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-33-d4541a7a4ae1> in <module>()
4 epochs = 10,
5 validation_data = validation_generator,
----> 6 validation_steps = num_val_samples // batch_size)
3 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
1145 use_multiprocessing=use_multiprocessing,
1146 shuffle=shuffle,
-> 1147 initial_epoch=initial_epoch)
1148
1149 # Case 2: Symbolic tensors or Numpy array-like.
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1730 use_multiprocessing=use_multiprocessing,
1731 shuffle=shuffle,
-> 1732 initial_epoch=initial_epoch)
1733
1734 @interfaces.legacy_generator_methods_support
/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
183 batch_index = 0
184 while steps_done < steps_per_epoch:
--> 185 generator_output = next(output_generator)
186
187 if not hasattr(generator_output, '__len__'):
StopIteration:
我为另一个模型运行了代码,它可以工作......更多信息,num_val_samples
是一个数字,它定义为:
num_val_samples = len(val_samples)
编辑:这里是train_generator和validation_generator的定义:
batch_size = 32
train_generator = data_generator(train_samples, batch_size=32)
validation_generator = data_generator(val_samples, batch_size=32)
此外:
train_samples = load_samples(train_data_path)
val_samples = load_samples(val_data_path)
以及data_generator的定义:
def data_generator(samples, batch_size, shuffle_data = True, resize=224):
num_samples = len(samples)
while True:
random.shuffle(samples)
for offset in range(0, num_samples, batch_size):
batch_samples = samples[offset: offset + batch_size]
X_train = []
y_train = []
for batch_sample in batch_samples:
img_name = batch_sample[0]
label = batch_sample[1]
img = cv2.imread(os.path.join(root_dir, img_name))
#img, label = preprocessing(img, label, new_height=224, new_width=224, num_classes=37)
img = preprocessing(img, new_height=224, new_width=224)
label = label
X_train.append(img)
y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
yield X_train, y_train