0

我正在使用 tensorflow 进行测试。我将数据集放入两个文件夹中。我配置了batch_size,heightwidthtrain_data但是我看不到它们matplotlib或在模型中使用它。

#Import dataset
import pathlib
import os

data_dir = pathlib.Path(r'C:\Users\vion1\Ele\Engie\Exercices\DL\Pikachu\dataset')
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
#374

batch_size = 32
img_height = 256
img_width = 256

train_data = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=42,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  )

class_names = train_data.class_names
print(train_data)
#Found 374 files belonging to 2 classes.
#Using 300 files for training.
#<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

plt.figure(figsize=(10, 10))
for images, labels in train_data.take(1):
  for i in range(3):
    ax = plt.subplot(1, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

错误是:

InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
     [[{{node decode_image/DecodeImage}}]] [Op:IteratorGetNext]

我认为这train_date.take(1)不会占用文件,但我不明白为什么以及如何修复它,知道吗?

4

1 回答 1

1

您提到的代码看起来很正确,失败的主要原因可能是错误是您中的一个或多个文件tf.data.Dataset不属于任何提到的文件扩展名。要检查损坏的文件,您可以参考以下代码。这里我采用文档中提到的示例数据集

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras

import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

roses = list(data_dir.glob('roses/*'))

现在,leet 检查 roses 目录中的唯一文件名。

file_names = [str(i) for i in roses]
unique_files = set(i.split('.')[-1] for i in file_names)
print(unique_files)

Output:
{'jpg'}

在输出目录中,如果您获得允许的文件类型以外的任何文件类型,则需要重新检查您的数据。否则,您可以按照文档执行相同的程序。

于 2021-08-25T16:08:09.080 回答