为了训练 CNN,我以 TFrecord 格式编码了一些图像。这是我用来读取 TFrecord 文件和提取图像的功能。我在从 tfrecord 读取标签(由 10000-50000 范围内的 5 位数字组成且稀疏的字符串)并将这些字符串转换为“一个热”编码张量以训练我的分类器时遇到问题。训练是通过使用 tensorflow 的自定义 Estimator 进行的。这是我用来读取 TFRecords 文件的函数片段
def imgs_input_fn(filenames, classes, perform_shuffle=False, repeat_count=1, batch_size=1):
def _parse_function(serialized):
features = \
{
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/colorspace': tf.FixedLenFeature([], tf.string),
'image/class/label': tf.FixedLenFeature([], tf.string),
'image/class/text_label': tf.FixedLenFeature([], tf.string),
'image/filename': tf.FixedLenFeature([], tf.string)
}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.parse_single_example(serialized=serialized,
features=features)
# Get the image as raw bytes.
# in image_shape I can't use parsed_example['image/channels']
# read from file but need to pass 1 to the shape...
# how to get this?
channels = parsed_example['image/channels']
image_shape = tf.stack([parsed_example['image/width'],
parsed_example['image/height'], 1])
image_raw = parsed_example['image/encoded']
# Labels are string representing numbers but are sparse
label = tf.string_to_number(parsed_example['image/class/label'], out_type=tf.int32)
# Check how to pass the value read from the tfrecord file
image = tf.image.decode_image(image_raw)
image = tf.divide(tf.cast(image, tf.float32), tf.constant(255., dtype=tf.float32))
image = tf.reshape(image, image_shape)
num_classes = classes
#The following operation does not give me expecte result
# as as labels are strings like 12345, 34234, 53453,
# and I have only ie 100 classes so tf.one_hot(10000, 100)
# will give me a tensor with only 0s in it
d = dict(zip([input_name], [image])), tf.one_hot(label, num_classes)
return d
dataset = tf.data.TFRecordDataset(filenames=filenames)
# Parse the serialized data in the TFRecords files.
# This returns TensorFlow tensors for the image and labels.
dataset = dataset.map(_parse_function)
if perform_shuffle:
# Randomizes input using a window of 1024 elements (read into memory)
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
那么我如何填写一个像查找表这样的结构,例如使用 tf.contrib.lookup.index_table_from_tensor,直接从 TFRecord 文件中读取信息,因为图像是为训练而读取的,而不是提前提供文件或读取所有 TFRecords预先提取标签?我想利用这样一个事实,即如果查找表的标签是“未知的”,则“index_table_from_tensor”将使用标签的哈希值来给出一致的结果。在定义 tf.estimator.TrainSpec 和 tf.estimator.EvalSpec 并使用 keras 模型后,我编写的函数是从训练循环 tf.estimator.train_and_evaluate 调用的
有没有办法做到这一点?
非常感谢。
塞巴