我正在尝试将所有数据增强预处理移到我的模型内部,因此,我创建了一个预处理模型并将其合并到我的 Resnet50 中。
问题是,我的tf.data
管道batch_size
将图像输入到模型中,当输入到预处理管道时会生成:batch_size * 54
图像(每张图像 54 个样本),因此,标签信息与生成的图像无关,我得到错误(batch_size = 16 ):
InvalidArgumentError: logits and labels must be broadcastable: logits_size=[864,516] labels_size=[16,516]
[[node categorical_crossentropy_1/softmax_cross_entropy_with_logits (defined at <ipython-input-26-8e524a3a5e0b>:31) ]]
[Op:__inference_train_function_118686]
关于我应该怎么做才能在 GPU 上保持运行数据增强并将标签与相应生成的图像相关联的任何猜测?
辅助代码:
'''
Data augmentation pipeline: (yields 54 images by sample)
Extract 5 random crops + 1 central crop,
Rotate +-45 deg,
Translate in two random directions, then mirror (vertically)
'''
def preprocessing_model():
input = keras.Input(shape=(224, 224, 3), name="input")
rescaling = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(input)
central_crop = tf.keras.layers.experimental.preprocessing.CenterCrop(height=112,width=112)(rescaling)
resized_single_crop = tf.keras.layers.experimental.preprocessing.Resizing(224,224)(central_crop)
random_crop = keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomCrop(height=56,width=74)])
random_crop0 = random_crop(rescaling,training=True)
random_crop1 = random_crop(rescaling,training=True)
random_crop2 = random_crop(rescaling,training=True)
random_crop3 = random_crop(rescaling,training=True)
random_crop4 = random_crop(rescaling,training=True)
crops = tf.keras.layers.concatenate([random_crop0,random_crop1,random_crop2,random_crop3,random_crop4],axis=0)
resized_crops = tf.keras.layers.experimental.preprocessing.Resizing(224,224)(crops)
rotate_1 = keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomRotation(factor=[0.125,0.125])])
rotate_2 = keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomRotation(factor=[-0.125,-0.125])])
rotated_a = rotate_1(rescaling,training=True)
rotated_b = rotate_2(rescaling,training=True)
augmented_images = tf.keras.layers.concatenate([rescaling,resized_crops,resized_single_crop,rotated_a,rotated_b],axis=0)
translate_1 = keras.Sequential([keras.layers.experimental.preprocessing.RandomTranslation(height_factor=(-0.25,0.25),width_factor=(0.25,0.25))])
translate_2 = keras.Sequential([keras.layers.experimental.preprocessing.RandomTranslation(height_factor=(-0.25,0.25),width_factor=(-0.25,-0.25))])
translated_a = translate_1(augmented_images,training=True)
translated_b = translate_2(augmented_images,training=True)
augmented_images = tf.keras.layers.concatenate([augmented_images,translated_a,translated_b],axis=0)
mirrored_versions = keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip('vertical')])
mirrored_images = mirrored_versions(augmented_images,training=True)
augmented_images = tf.keras.layers.concatenate([augmented_images,mirrored_images],axis=0)
model = tf.keras.Model(inputs=input,outputs=augmented_images)
return model
将预处理模型合并到 ResNet50 中:
def load_and_configure_model(optimizer, loss, metrics, path):
model = ResNet50V2(include_top=True, weights='imagenet')
transfer_layer = model.get_layer('avg_pool')
resnet_submodel = Model(inputs=model.input,outputs=transfer_layer.output)
augmentation_pipeline = preprocessing_model()
augmentation_model_cfg = augmentation_pipeline.get_config() # Get layer configuration dictionary.
model_config = resnet_submodel.get_config()
submodel = model_config['layers']
submodel.remove(submodel[0]) # Remove the previous input layer
prepr_model_layers = augmentation_model_cfg['layers']
prepr_model_layers.extend(submodel) # Join both models
# Replace the previous input layer with the output from the preprocessing model
# (Connect the preprocessing model to the resnet)
output_name = prepr_model_layers[len(augmentation_pipeline.get_config()['layers'])-1]['name']
prepr_model_layers[len(augmentation_pipeline.get_config()['layers'])]['inbound_nodes'] = [[[output_name, 0, 0, {}]]]
new_model = augmentation_pipeline.__class__.from_config(augmentation_model_cfg, custom_objects={}) # change custom objects if necessary
# Set back pre-trained weights on new model
weights = [layer.get_weights() for layer in resnet_submodel.layers[1:]]
for layer, weight in zip(new_model.layers[15:], weights):
layer.set_weights(weight)
for layer in new_model.layers[15:]:
layer.trainable = False
for layer in new_model.layers[15:]:
trainable = ('conv5_block3' in layer.name)
layer.trainable = trainable
transfer_layer = new_model.get_layer('avg_pool')
class1 = Dense(1000, activation='softmax',name='class_1')(transfer_layer.output)
class2 = Dense(516, activation='softmax',name='class_2')(transfer_layer.output)
class3 = Dense(124,activation='softmax', name='class_3')(transfer_layer.output)
model = keras.Model(
inputs=[new_model.inputs],
outputs=[class1,class2,class3],
)
if not path == None :
model.load_weights(path)
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
print(model.summary())
return model
tf.data 管道
def train_model(train_path, validation_path, buffer_size, epochs, steps_per_epoch, model):
train_filenames = get_filenames(train_path)
random.shuffle(train_filenames)
validation_filenames = get_filenames(validation_path)
random.shuffle(validation_filenames)
dataset_length = 91758
train_size = dataset_length * 0.7
validation_size = dataset_length - train_size
batch_size = 16
AUTO = tf.data.AUTOTUNE
train_dataset = tf.data.TFRecordDataset(buffer_size=int(1e+8),num_parallel_reads=AUTO,filenames=train_filenames).cache('/cache/train_cache').map(parsing_fn,num_parallel_calls=AUTO)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.prefetch(AUTO)
# Create a validation dataset
validation_dataset = tf.data.TFRecordDataset(num_parallel_reads=AUTO,filenames=validation_filenames).map(parsing_fn,num_parallel_calls=AUTO)
validation_dataset = validation_dataset.batch(batch_size)
validation_dataset = validation_dataset.prefetch(AUTO)
validation_dataset = validation_dataset.repeat(1)
validation_steps = validation_size / batch_size # "This ensures that the same validation samples are used every time"
history = model.fit(x=train_dataset,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_dataset,
validation_steps=validation_steps)
return history