我尝试为 CIFAR10 模型修改 inception_export.py,但出现错误:

raise type(e)(node_def, op, message) tensorflow.python.framework.errors.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [18,384] rhs shape= [2304,384]
     [[Node: save/Assign_5 = Assign[T=DT_FLOAT, _class=["loc:@local3/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](local3/weights, save/restore_slice_5)]] Caused by op u'save/Assign_5', defined at:  


EDIT1:这是我的代码。我还没有安装 tensorflow 服务,所以相关的块被注释掉了。我还将 image_size 更改为 24 以适应 CIFAR10 模型。

#!/usr/bin/env python2.7
"""Modified for CIFAR10 model from https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/inception_export.py

import os.path
import sys

# This is a placeholder for a Google-internal import.

import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
#from inception import inception_model

#from tensorflow_serving.session_bundle import exporter

tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
                           """Directory where to read training checkpoints.""")
tf.app.flags.DEFINE_string('export_dir', '/tmp/cifar10_export',
                           """Directory where to export inference model.""")
tf.app.flags.DEFINE_integer('image_size', 24,
                            """Needs to provide same value as in training.""")
FLAGS = tf.app.flags.FLAGS


WORKING_DIR = os.path.dirname(os.path.realpath(__file__))
SYNSET_FILE = os.path.join(WORKING_DIR, 'imagenet_lsvrc_2015_synsets.txt')
METADATA_FILE = os.path.join(WORKING_DIR, 'imagenet_metadata.txt')

def export():
  """can be deleted if my simply define the constant string manually below?
  # Create index->synset mapping
  synsets = []
  with open(SYNSET_FILE) as f:
    synsets = f.read().splitlines()
  # Create synset->metadata mapping
  texts = {}
  with open(METADATA_FILE) as f:
    for line in f.read().splitlines():
      parts = line.split('\t')
      assert len(parts) == 2
      texts[parts[0]] = parts[1]
  with tf.Graph().as_default():
    # Build inference model.
    # Please refer to Tensorflow inception model for details.

    # Input transformation.
    # TODO(b/27776734): Add batching support.
    jpegs = tf.placeholder(tf.string, shape=(1))
    image_buffer = tf.squeeze(jpegs, [0])
    # Decode the string as an RGB JPEG.
    # Note that the resulting image contains an unknown height and width
    # that is set dynamically by decode_jpeg. In other words, the height
    # and width of image is unknown at compile-time.
    image = tf.image.decode_jpeg(image_buffer, channels=3)
    # After this point, all image pixels reside in [0,1)
    # until the very end, when they're rescaled to (-1, 1).  The various
    # adjust_* ops all require this range for dtype float.
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    # Crop the central region of the image with an area containing 87.5% of
    # the original image.
    image = tf.image.central_crop(image, central_fraction=0.875)
    # Resize the image to the original height and width.
    image = tf.expand_dims(image, 0)
    image = tf.image.resize_bilinear(image,
                                     [FLAGS.image_size, FLAGS.image_size],
    image = tf.squeeze(image, [0])
    # Finally, rescale to [-1,1] instead of [0, 1)
    image = tf.sub(image, 0.5)
    image = tf.mul(image, 2.0)
    images = tf.expand_dims(image, 0)

    # Run inference.
    logits = cifar10.inference(images)

    # Transform output to topK result.
    values, indices = tf.nn.top_k(logits, NUM_TOP_CLASSES)

    # Create a constant string Tensor where the i'th element is
    # the human readable class description for the i'th index.
    # Note that the 0th index is an unused background class
    # (see inception model definition code).
    class_descriptions =  ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # for s in synsets:
      # class_descriptions.append(texts[s])
    class_tensor = tf.constant(class_descriptions)

    classes = tf.contrib.lookup.index_to_string(tf.to_int64(indices),

    # Restore variables from training checkpoint.
    variable_averages = tf.train.ExponentialMovingAverage(
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    with tf.Session() as sess:
      # Restore variables from training checkpoints.
      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        # Assuming model_checkpoint_path looks something like:
        #   /my-favorite-path/imagenet_train/model.ckpt-0,
        # extract global_step from it.
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        print('Successfully loaded model from %s at step=%s.' %
              (ckpt.model_checkpoint_path, global_step))
        print('No checkpoint file found at %s' % FLAGS.checkpoint_dir)
      """ Not exporting yet because I haven't installed tensorflow serving
      # Export inference model.
      init_op = tf.group(tf.initialize_all_tables(), name='init_op')
      model_exporter = exporter.Exporter(saver)
      signature = exporter.classification_signature(
          input_tensor=jpegs, classes_tensor=classes, scores_tensor=values)
      model_exporter.init(default_graph_signature=signature, init_op=init_op)
      model_exporter.export(FLAGS.export_dir, tf.constant(global_step), sess)
      print('Successfully exported model to %s' % FLAGS.export_dir)

def main(unused_argv=None):

if __name__ == '__main__':

