我正在基于 cifar10 示例构建 tfx 管道:[https://github.com/tensorflow/tfx/tree/master/tfx/examples/cifar10]
不同之处在于我不想将其转换为 tf_lite 模型,而是使用基于 keras 的常规 tensorflow 模型。
一切都按预期工作,直到我到达 Evaluator 组件,因为它失败并出现以下错误:
ValueError: Missing data for input "input_1". You passed a data dictionary with keys ['image_xf']. Expected the following keys: ['input_1']
[while running 'Run[Trainer]']
不知道我做错了什么,但到目前为止我调试/修改了代码如下:
[1] preprocessing_fn 输出正在输出密钥image_xf
:
_IMAGE_KEY = 'image'
_LABEL_KEY = 'label'
def _transformed_name(key):
return key + '_xf'
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
outputs = {}
# tf.io.decode_png function cannot be applied on a batch of data.
# We have to use tf.map_fn
image_features = tf.map_fn(
lambda x: tf.io.decode_png(x[0], channels=3),
inputs[_IMAGE_KEY],
dtype=tf.uint8)
# image_features = tf.cast(image_features, tf.float32)
image_features = tf.image.resize(image_features, [224, 224])
image_features = tf.keras.applications.mobilenet.preprocess_input(
image_features)
outputs[_transformed_name(_IMAGE_KEY)] = image_features
#outputs["input_1"] = image_features
# TODO(b/157064428): Support label transformation for Keras.
# Do not apply label transformation as it will result in wrong evaluation.
outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY]
return outputs
[2] 当我构建模型时,我使用了带有同名 inputLayer 的迁移学习image_xf
。
def _build_keras_model() -> tf.keras.Model:
"""Creates a Image classification model with MobileNet backbone.
Returns:
The image classifcation Keras Model and the backbone MobileNet model
"""
# We create a MobileNet model with weights pre-trained on ImageNet.
# We remove the top classification layer of the MobileNet, which was
# used for classifying ImageNet objects. We will add our own classification
# layer for CIFAR10 later. We use average pooling at the last convolution
# layer to get a 1D vector for classifcation, which is consistent with the
# origin MobileNet setup
base_model = tf.keras.applications.MobileNet(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet',
pooling='avg')
base_model.input_spec = None
# We add a Dropout layer at the top of MobileNet backbone we just created to
# prevent overfiting, and then a Dense layer to classifying CIFAR10 objects
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(
input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)),
base_model,
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(10, activation='softmax')
])
[3] 相应地创建模型签名:
def _get_serve_image_fn(model, tf_transform_output):
"""Returns a function that feeds the input tensor into the model."""
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function
def serve_image_fn(serialized_tf_examples):
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(_LABEL_KEY)
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_image_fn
def run_fn(fn_args: FnArgs):
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
signatures = {
'serving_default':
_get_serve_image_fn(model,tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None],
dtype=tf.string,
name=_IMAGE_KEY))
}
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir)
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
现在,我怀疑 tensorflow 没有正确保存模型,因为当我导出保存的模型时,输入层input_1
不是image_xf
.
import tensorflow as tf
import numpy as np
import tensorflow.python.ops.numpy_ops.np_config as np_config
np_config.enable_numpy_behavior()
path = './model/Format-Serving/'
imported = tf.saved_model.load(path)
model = tf.keras.models.load_model(path)
print(model.summary())
print(list(imported.signatures.keys()))
print(model.get_layer('mobilenet_1.00_224').layers[0].name)
这里要注意的是(1)我在Sequential
上面的模型中添加的输入层丢失了,(2)mobilenet 第一层是input_1
,所以我得到一个不匹配的原因是有道理的。
2021-10-15 08:33:40.683034: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
mobilenet_1.00_224 (Function (None, 1024) 3228864
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense (Dense) (None, 10) 10250
=================================================================
Total params: 3,239,114
Trainable params: 1,074,186
Non-trainable params: 2,164,928
_________________________________________________________________
None
['serving_default']
input_1
那么我怎样才能让模型正确保存正确的输入呢?
这是完整的代码:
管道.py
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CIFAR10 image classification example using TFX.
This example demonstrates how to do data augmentation, transfer learning,
and inserting TFLite metadata with TFX.
The trained model can be pluged into MLKit for object detection.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
from typing import List, Text
import absl
from tfx import v1 as tfx
import tensorflow_model_analysis as tfma
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import ImportExampleGen
from tfx.components import Pusher
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.dsl.components.common import resolver
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.proto import example_gen_pb2
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing
_pipeline_name = 'cifar10_native_keras'
# This example assumes that CIFAR10 train set data is stored in
# ~/cifar10/data/train, test set data is stored in ~/cifar10/data/test, and
# the utility function is in ~/cifar10. Feel free to customize as needed.
_cifar10_root = os.path.join(os.getcwd())
_data_root = os.path.join(_cifar10_root, 'data')
# Python module files to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_module_file = os.path.join(_cifar10_root, 'cifar10_utils_native_keras.py')
# Path which can be listened to by the model server. Pusher will output the
# trained model here.
_serving_model_dir_lite = os.path.join(_cifar10_root, 'serving_model_lite',
_pipeline_name)
# Directory and data locations. This example assumes all of the images,
# example code, and metadata library is relative to $HOME, but you can store
# these files anywhere on your local filesystem.
_tfx_root = os.path.join(os.getcwd(), 'tfx')
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
# Sqlite ML-metadata db path.
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
'metadata.db')
# Path to labels file for mapping model outputs.
_labels_path = os.path.join(_data_root, 'labels.txt')
# Pipeline arguments for Beam powered Components.
_beam_pipeline_args = [
'--direct_running_mode=multi_processing',
'--direct_num_workers=0',
]
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir_lite: Text,
metadata_path: Text,
labels_path: Text,
beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
"""Implements the CIFAR10 image classification pipeline using TFX."""
# This is needed for datasets with pre-defined splits
# Change the pattern argument to train_whole/* and test_whole/* to train
# on the whole CIFAR-10 dataset
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='train/*'),
example_gen_pb2.Input.Split(name='eval', pattern='test/*')
])
# Brings data into the pipeline.
example_gen = ImportExampleGen(
input_base=data_root, input_config=input_config)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
# Performs anomaly detection based on statistics and data schema.
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
# Performs transformations and feature engineering in training and serving.
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=module_file)
model_resolver = resolver.Resolver(
#instance_name='latest_model_resolver',
strategy_class=tfx.dsl.experimental.LatestArtifactStrategy,
model=Channel(type=Model)).with_id('latest_blessed_model_resolver')
# Uses user-provided Python function that trains a model.
# When traning on the whole dataset, use 18744 for train steps, 156 for eval
# steps. 18744 train steps correspond to 24 epochs on the whole train set, and
# 156 eval steps correspond to 1 epoch on the whole test set. The
# configuration below is for training on the dataset we provided in the data
# folder, which has 128 train and 128 test samples. The 160 train steps
# correspond to 40 epochs on this tiny train set, and 4 eval steps correspond
# to 1 epoch on this tiny test set.
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
base_model=model_resolver.outputs['model'],
train_args=trainer_pb2.TrainArgs(num_steps=160),
eval_args=trainer_pb2.EvalArgs(num_steps=4),
custom_config={'labels_path': labels_path})
# Get the latest blessed model for model validation.
# model_resolver = resolver.Resolver(
# strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
# model=Channel(type=Model),
# model_blessing=Channel(
# type=ModelBlessing)).with_id('latest_blessed_model_resolver')
# Uses TFMA to compute evaluation statistics over features of a model and
# perform quality validation of a candidate model (compare to a baseline).
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='label')],
slicing_specs=[tfma.SlicingSpec()],
metrics_specs=[
tfma.MetricsSpec(metrics=[
tfma.MetricConfig(
class_name='SparseCategoricalAccuracy',
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.55}),
# Change threshold will be ignored if there is no
# baseline model resolved from MLMD (first run).
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-3})))
])
])
# Uses TFMA to compute the evaluation statistics over features of a model.
# We evaluate using the materialized examples that are output by Transform
# because
# 1. the decoding_png function currently performed within Transform are not
# compatible with TFLite.
# 2. MLKit requires deserialized (float32) tensor image inputs
# Note that for deployment, the same logic that is performed within Transform
# must be reproduced client-side.
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
#baseline_model=model_resolver.outputs['model'],
eval_config=eval_config)
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir_lite)))
components = [
example_gen, statistics_gen, schema_gen, example_validator, transform,
trainer, model_resolver, evaluator, pusher
]
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=components,
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
beam_pipeline_args=beam_pipeline_args)
# To run this pipeline from the python CLI:
# $python cifar_pipeline_native_keras.py
if __name__ == '__main__':
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
logger.setLevel(logging.INFO)
logging.getLogger().setLevel(logging.INFO)
absl.logging.set_verbosity(absl.logging.FATAL)
BeamDagRunner().run(
_create_pipeline(
pipeline_name=_pipeline_name,
pipeline_root=_pipeline_root,
data_root=_data_root,
module_file=_module_file,
serving_model_dir_lite=_serving_model_dir_lite,
metadata_path=_metadata_path,
labels_path=_labels_path,
beam_pipeline_args=_beam_pipeline_args))
实用程序文件:
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python source file includes CIFAR10 utils for Keras model.
The utilities in this file are used to build a model with native Keras.
This module file will be used in Transform and generic Trainer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from typing import List, Text
import absl
import tensorflow as tf
import tensorflow_transform as tft
from tfx.components.trainer.fn_args_utils import DataAccessor
from tfx.components.trainer.fn_args_utils import FnArgs
from tfx.components.trainer.rewriting import converters
from tfx.components.trainer.rewriting import rewriter
from tfx.components.trainer.rewriting import rewriter_factory
from tfx.dsl.io import fileio
from tfx_bsl.tfxio import dataset_options
# import flatbuffers
# from tflite_support import metadata_schema_py_generated as _metadata_fb
# from tflite_support import metadata as _metadata
# When training on the whole dataset use following constants instead.
# This setting should give ~91% accuracy on the whole test set
# _TRAIN_DATA_SIZE = 50000
# _EVAL_DATA_SIZE = 10000
# _TRAIN_BATCH_SIZE = 64
# _EVAL_BATCH_SIZE = 64
# _CLASSIFIER_LEARNING_RATE = 3e-4
# _FINETUNE_LEARNING_RATE = 5e-5
# _CLASSIFIER_EPOCHS = 12
_TRAIN_DATA_SIZE = 128
_EVAL_DATA_SIZE = 128
_TRAIN_BATCH_SIZE = 32
_EVAL_BATCH_SIZE = 32
_CLASSIFIER_LEARNING_RATE = 1e-3
_FINETUNE_LEARNING_RATE = 7e-6
_CLASSIFIER_EPOCHS = 30
_IMAGE_KEY = 'image'
_LABEL_KEY = 'label'
_TFLITE_MODEL_NAME = 'tflite'
def _transformed_name(key):
return key + '_xf'
def _get_serve_image_fn(model, tf_transform_output):
"""Returns a function that feeds the input tensor into the model."""
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function
def serve_image_fn(serialized_tf_examples):
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(_LABEL_KEY)
parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_image_fn
def _image_augmentation(image_features):
"""Perform image augmentation on batches of images .
Args:
image_features: a batch of image features
Returns:
The augmented image features
"""
batch_size = tf.shape(image_features)[0]
image_features = tf.image.random_flip_left_right(image_features)
image_features = tf.image.resize_with_crop_or_pad(image_features, 250, 250)
image_features = tf.image.random_crop(image_features,
(batch_size, 224, 224, 3))
return image_features
def _data_augmentation(feature_dict):
"""Perform data augmentation on batches of data.
Args:
feature_dict: a dict containing features of samples
Returns:
The feature dict with augmented features
"""
image_features = feature_dict[_transformed_name(_IMAGE_KEY)]
image_features = _image_augmentation(image_features)
feature_dict[_transformed_name(_IMAGE_KEY)] = image_features
return feature_dict
def _input_fn(file_pattern: List[Text],
data_accessor: DataAccessor,
tf_transform_output: tft.TFTransformOutput,
is_train: bool = False,
batch_size: int = 200) -> tf.data.Dataset:
"""Generates features and label for tuning/training.
Args:
file_pattern: List of paths or patterns of input tfrecord files.
data_accessor: DataAccessor for converting input to RecordBatch.
tf_transform_output: A TFTransformOutput.
is_train: Whether the input dataset is train split or not.
batch_size: representing the number of consecutive elements of returned
dataset to combine in a single batch
Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
dataset = data_accessor.tf_dataset_factory(
file_pattern,
dataset_options.TensorFlowDatasetOptions(
batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)),
tf_transform_output.transformed_metadata.schema)
# Apply data augmentation. We have to do data augmentation here because
# we need to apply data agumentation on-the-fly during training. If we put
# it in Transform, it will only be applied once on the whole dataset, which
# will lose the point of data augmentation.
if is_train:
dataset = dataset.map(lambda x, y: (_data_augmentation(x), y))
return dataset
def _freeze_model_by_percentage(model: tf.keras.Model, percentage: float):
"""Freeze part of the model based on specified percentage.
Args:
model: The keras model need to be partially frozen
percentage: the percentage of layers to freeze
Raises:
ValueError: Invalid values.
"""
if percentage < 0 or percentage > 1:
raise ValueError('Freeze percentage should between 0.0 and 1.0')
if not model.trainable:
raise ValueError(
'The model is not trainable, please set model.trainable to True')
num_layers = len(model.layers)
num_layers_to_freeze = int(num_layers * percentage)
for idx, layer in enumerate(model.layers):
if idx < num_layers_to_freeze:
layer.trainable = False
else:
layer.trainable = True
def _build_keras_model() -> tf.keras.Model:
"""Creates a Image classification model with MobileNet backbone.
Returns:
The image classifcation Keras Model and the backbone MobileNet model
"""
# We create a MobileNet model with weights pre-trained on ImageNet.
# We remove the top classification layer of the MobileNet, which was
# used for classifying ImageNet objects. We will add our own classification
# layer for CIFAR10 later. We use average pooling at the last convolution
# layer to get a 1D vector for classifcation, which is consistent with the
# origin MobileNet setup
base_model = tf.keras.applications.MobileNet(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet',
pooling='avg')
base_model.input_spec = None
# We add a Dropout layer at the top of MobileNet backbone we just created to
# prevent overfiting, and then a Dense layer to classifying CIFAR10 objects
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(
input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)),
base_model,
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(10, activation='softmax')
])
# Freeze the whole MobileNet backbone to first train the top classifer only
_freeze_model_by_percentage(base_model, 1.0)
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.RMSprop(lr=_CLASSIFIER_LEARNING_RATE),
metrics=['sparse_categorical_accuracy'])
model.summary(print_fn=absl.logging.info)
return model, base_model
# TFX Transform will call this function.
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
outputs = {}
# tf.io.decode_png function cannot be applied on a batch of data.
# We have to use tf.map_fn
image_features = tf.map_fn(
lambda x: tf.io.decode_png(x[0], channels=3),
inputs[_IMAGE_KEY],
dtype=tf.uint8)
# image_features = tf.cast(image_features, tf.float32)
image_features = tf.image.resize(image_features, [224, 224])
image_features = tf.keras.applications.mobilenet.preprocess_input(
image_features)
outputs[_transformed_name(_IMAGE_KEY)] = image_features
#outputs["input_1"] = image_features
# TODO(b/157064428): Support label transformation for Keras.
# Do not apply label transformation as it will result in wrong evaluation.
outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY]
return outputs
# TFX Trainer will call this function.
def run_fn(fn_args: FnArgs):
"""Train the model based on given args.
Args:
fn_args: Holds args used to train the model as name/value pairs.
Raises:
ValueError: if invalid inputs.
"""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
baseline_path = fn_args.base_model
if baseline_path is not None:
model = tf.keras.models.load_model(os.path.join(baseline_path))
else:
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
is_train=True,
batch_size=_TRAIN_BATCH_SIZE)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
is_train=False,
batch_size=_EVAL_BATCH_SIZE)
model, base_model = _build_keras_model()
absl.logging.info('Tensorboard logging to {}'.format(fn_args.model_run_dir))
# Write logs to path
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=fn_args.model_run_dir, update_freq='batch')
# Our training regime has two phases: we first freeze the backbone and train
# the newly added classifier only, then unfreeze part of the backbone and
# fine-tune with classifier jointly.
steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE)
total_epochs = int(fn_args.train_steps / steps_per_epoch)
if _CLASSIFIER_EPOCHS > total_epochs:
raise ValueError('Classifier epochs is greater than the total epochs')
absl.logging.info('Start training the top classifier')
model.fit(
train_dataset,
epochs=_CLASSIFIER_EPOCHS,
steps_per_epoch=steps_per_epoch,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps,
callbacks=[tensorboard_callback])
absl.logging.info('Start fine-tuning the model')
# Unfreeze the top MobileNet layers and do joint fine-tuning
_freeze_model_by_percentage(base_model, 0.9)
# We need to recompile the model because layer properties have changed
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.RMSprop(lr=_FINETUNE_LEARNING_RATE),
metrics=['sparse_categorical_accuracy'])
model.summary(print_fn=absl.logging.info)
model.fit(
train_dataset,
initial_epoch=_CLASSIFIER_EPOCHS,
epochs=total_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps,
callbacks=[tensorboard_callback])
# Prepare the TFLite model used for serving in MLKit
signatures = {
'serving_default':
_get_serve_image_fn(model,tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None],
dtype=tf.string,
name=_IMAGE_KEY))
}
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir)
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
# tfrw = rewriter_factory.create_rewriter(
# rewriter_factory.TFLITE_REWRITER,
# name='tflite_rewriter')
# converters.rewrite_saved_model(temp_saving_model_dir,
# fn_args.serving_model_dir, tfrw,
# rewriter.ModelType.TFLITE_MODEL)
# # Add necessary TFLite metadata to the model in order to use it within MLKit
# # TODO(dzats@): Handle label map file path more properly, currently
# # hard-coded.
# tflite_model_path = os.path.join(fn_args.serving_model_dir,
# _TFLITE_MODEL_NAME)
# # TODO(dzats@): Extend the TFLite rewriter to be able to add TFLite metadata
# #@ to the model.
# _write_metadata(
# model_path=tflite_model_path,
# label_map_path=fn_args.custom_config['labels_path'],
# mean=[127.5],
# std=[127.5])
# fileio.rmtree(temp_saving_model_dir)