tflite_convert
是一个 python 脚本,用于调用 TOCO(TensorFlow Lite 优化转换器)将文件从 Tensorflow 的格式转换为与 tflite 兼容的文件。
我正在尝试从我用Estimator
. 训练代码非常简单,我根据定点量化指南中的要求添加了微调模型所需的修改:
def input_fn(mode, num_classes, batch_size=1):
#[...]
return {'images': images}, labels
def model_fn(features, labels, num_classes, mode):
images = features['images']
with tf.contrib.slim.arg_scope(net_arg_scope()):
logits, end_points = build_net(...)
if FLAGS.with_quantization:
tf.logging.info("Applying quantization to the graph.")
if mode == tf.estimator.ModeKeys.EVAL:
tf.contrib.quantize.create_eval_graph()
tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
if FLAGS.with_quantization:
tf.logging.info("Applying quantization to the graph.")
if mode == tf.estimator.ModeKeys.TRAIN:
tf.contrib.quantize.create_training_graph()
# Configure the training op, etc [...]
return tf.estimator.EstimatorSpec(...)
def main(unused_argv):
regex = FINETUNE_LAYER_RE if not FLAGS.with_quantization else '^((?!_quant).)*$'
ws_settings = tf.estimator.WarmStartSettings(FLAGS.pretrained_checkpoint, regex)
# Create the Estimator
estimator = tf.estimator.Estimator(
model_fn=lambda features, labels, mode: model_fn(features, labels, NUM_CLASSES, mode),
model_dir=FLAGS.model_dir,
#config=run_config,
warm_start_from=ws_settings)
# Set up input functions for training and evaluation
train_input_fn = lambda : input_fn(tf.estimator.ModeKeys.TRAIN, NUM_CLASSES, FLAGS.batch_size)
eval_input_fn = lambda : input_fn(tf.estimator.ModeKeys.EVAL, NUM_CLASSES, FLAGS.batch_size)
#[...]
train_spec = tf.estimator.TrainSpec(...)
eval_spec = tf.estimator.EvalSpec(...)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
我遇到的第一个问题是,在添加量化操作后,不可能简单地使用最新的检查点继续训练。这是因为量化添加了在检查点中找不到的额外变量。我解决了编写一个热启动规范,该规范按名称过滤掉所有新变量,并将训练中的最新检查点用作热启动检查点。
现在,我想生成一个评估图以保存(使用相关变量),然后通过tflite_convert
脚本将其提供给 TOCO。我尝试SavedModel
在每次评估后转换导出的 s 之一,但这会引发以下错误:
数组 conv0_bn/FusedBatchNorm 是 Relu 算子的输入,产生输出数组 cell_stem_0/Relu,但缺少量化所需的最小/最大数据。以非量化输出格式为目标,或者更改输入图以包含最小/最大信息,或者如果您不关心结果的准确性,则传递 --default_ranges_min= 和 --default_ranges_max=。中止(核心转储)
我不知道如何获得正确的SavedModel
或一对GraphDef
+检查点(尽管 SavedModel 是可取的)有没有人尝试量化估计器模型?如何生成量化评估图?