我正在尝试创建一个文件名张量,以便使用 Dataset API 创建一个数据集。我的代码是tf.convert_to_tensor(file_list)
. 我也尝试tf.constant(file_list)
过类似的结果。在这种情况下,file_list
是一个 Python 的一维文件名字符串列表。这将返回一个形状张量,(N, )
其中N
是文件的数量。np.rank
告诉我它是 dtype 的 0 级张量tf.int32
。即使指定 dtype,也会tf.convert_to_tensor(file_list, dtype=tf.string)
产生相同的结果。
Datset.from_tensor_slices
当我使用and call将它传递给 Dataset 对象时dataset.map
,我收到一个 read_file 错误,指出输入的等级是 rank-1,而它应该是 rank-0。在我看来,这似乎意味着转换为张量会以某种方式创建一个奇怪形状的张量,或者Datset.from_tensor_slices
以意想不到的方式读取数据。
代码:
dataset = Dataset.from_tensor_slices(tf.convert_to_tensor(file_list))
dataset = self.dataset.map(_input_parser)
def _input_parser(filename):
filename
image = tf.image.decode_image(tf.read_file(filename))
return image