0

我从MNIST 图像数据集编写 tfrecord并将 tfrecord 转换为 tf.data.dataset。运行“python3 tfrecord1.py”是正常的。但是,运行“mpirun -np 2 python3 tfrecord1.py”时发生了 DataLossError

也许我的代码有问题。

我的计算环境:ubuntu 20.04, tensorflow 2.6.0, horovod 0.23, 32 CPUs, No GPU

threcord1.py

import tensorflow as tf
import horovod.tensorflow as hvd
from PIL import Image
import os, glob

hvd.init()

path = os.path.expanduser('~') # home directory
train_path = os.path.join(path, 'mnist_png/training') # mnist_png/training'
images = glob.glob(train_path + '/*/*.png')
tfrecord_filename = '/disfs/mnist.tfrecord'

# ----------- write tfrecord ---------------

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

writer = tf.io.TFRecordWriter(tfrecord_filename)

for image in images:
    img = open(image, 'rb').read()
    img_ = tf.image.decode_png(img)
    img_ = bytes(img)
    label = int (image.split('/')[5])
    feature = {
            'image' : _bytes_feature(img),
            'label' : _int64_feature(label)
            }
    example = tf.train.Example(features = tf.train.Features(feature=feature))
    writer.write(example.SerializeToString())
writer.close()

# --------- read tfrecord ----------
reader = tf.data.TFRecordDataset(tfrecord_filename)
feature_set = {
        'image' : tf.io.FixedLenFeature([], tf.string),
        'label' : tf.io.FixedLenFeature([], tf.int64)
        }

def _parse_function(exam_proto):
    feature_dict = tf.io.parse_single_example(exam_proto, feature_set)
    raw_image = tf.io.decode_jpeg(feature_dict['image'])
    raw_image = tf.image.resize(raw_image, [28,28])/255.0
    raw_image = tf.reshape(raw_image,[28,28,1])
    label = feature_dict['label']
    return (raw_image, label)

dataset = reader.map(_parse_function)
dataset = dataset.repeat().shuffle(len(images)).batch(128)
    
for batch, (images, labels) in enumerate(dataset.take(5)): # DataLossError
    print(batch,'\n',images,'\n',  labels)

在此处输入图像描述

4

0 回答 0