1

我无法正确导入 mnist 数据集。你能帮我弄清楚出了什么问题吗?“input_data.py”被正确放置和调用。

>>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

>>> trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'NoneType' object has no attribute 'train'

>>> print(mnist)
None
4

1 回答 1

0

read_data_sets 方法首先从http://yann.lecun.com/exdb/mnist/下载数据,然后提取

local_file = may_download(TRAIN_IMAGES, train_dir)
train_images = extract_images(local_file)

这对你来说工作正常。但之后返回 DataSet 集合的对象为空。由于它对我来说工作正常并且我无法重现错误,您可以在方法中运行调用并提供失败的位置。像这样的东西...

>>> local_file = input_data.maybe_download('train-labels-idx1-ubyte.gz', 'MNIST_data/')
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
>>> train_labels = input_data.extract_labels(local_file, one_hot=True)
Extracting MNIST_data/train-labels-idx1-ubyte.gz
>>> local_file = input_data.maybe_download('train-images-idx3-ubyte.gz', 'MNIST_data/')
>>> train_images = input_data.extract_images(local_file)
Extracting MNIST_data/train-images-idx3-ubyte.gz
>>> local_file = input_data.maybe_download('t10k-images-idx3-ubyte.gz', 'MNIST_data/')
>>> test_images = input_data.extract_images(local_file)
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
>>> local_file = input_data.maybe_download('t10k-labels-idx1-ubyte.gz', 'MNIST_data/')
>>> test_labels = input_data.extract_labels(local_file,one_hot=True)
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
>>> VALIDATION_SIZE = 5000
>>> validation_images = train_images[:VALIDATION_SIZE]
>>> validation_labels = train_labels[:VALIDATION_SIZE]
>>> train_images = train_images[VALIDATION_SIZE:]
>>> train_labels = train_labels[VALIDATION_SIZE:]

>>> dtype = 'float32'
>>> data_set_train = input_data.DataSet(train_images, train_labels, dtype=dtype)
>>> data_set_validation = input_data.DataSet(validation_images, validation_labels, dtype=dtype)
>>> data_set_test = input_data.DataSet(test_images, test_labels, dtype=dtype)  
>>> trX = data_set_train.images
>>> print(data_set_train)
<tensorflow.examples.tutorials.mnist.input_data.DataSet object at 0x10508ff98>
>>> print(trX)
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
于 2016-02-12T01:26:28.520 回答