我创建了一些类来使用 Java 读取MNIST 手写数字数据集。这些类可以在从下载站点上可用的文件中解压缩(解压缩)文件后读取这些文件。允许读取原始(压缩)文件的类是小型MnistReader项目的一部分。
以下这些类是独立的(意味着它们不依赖于第三方库)并且本质上位于公共域中 - 这意味着它们可以被复制到自己的项目中。(署名将不胜感激,但不是必需的):
MnistDecompressedReader
班级:
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.Objects;
import java.util.function.Consumer;
/**
* A class for reading the MNIST data set from the <b>decompressed</b>
* (unzipped) files that are published at
* <a href="http://yann.lecun.com/exdb/mnist/">
* http://yann.lecun.com/exdb/mnist/</a>.
*/
public class MnistDecompressedReader
{
/**
* Default constructor
*/
public MnistDecompressedReader()
{
// Default constructor
}
/**
* Read the MNIST training data from the given directory. The data is
* assumed to be located in files with their default names,
* <b>decompressed</b> from the original files:
* extension) :
* <code>train-images.idx3-ubyte</code> and
* <code>train-labels.idx1-ubyte</code>.
*
* @param inputDirectoryPath The input directory
* @param consumer The consumer that will receive the resulting
* {@link MnistEntry} instances
* @throws IOException If an IO error occurs
*/
public void readDecompressedTraining(Path inputDirectoryPath,
Consumer<? super MnistEntry> consumer) throws IOException
{
String trainImagesFileName = "train-images.idx3-ubyte";
String trainLabelsFileName = "train-labels.idx1-ubyte";
Path imagesFilePath = inputDirectoryPath.resolve(trainImagesFileName);
Path labelsFilePath = inputDirectoryPath.resolve(trainLabelsFileName);
readDecompressed(imagesFilePath, labelsFilePath, consumer);
}
/**
* Read the MNIST training data from the given directory. The data is
* assumed to be located in files with their default names,
* <b>decompressed</b> from the original files:
* extension) :
* <code>t10k-images.idx3-ubyte</code> and
* <code>t10k-labels.idx1-ubyte</code>.
*
* @param inputDirectoryPath The input directory
* @param consumer The consumer that will receive the resulting
* {@link MnistEntry} instances
* @throws IOException If an IO error occurs
*/
public void readDecompressedTesting(Path inputDirectoryPath,
Consumer<? super MnistEntry> consumer) throws IOException
{
String testImagesFileName = "t10k-images.idx3-ubyte";
String testLabelsFileName = "t10k-labels.idx1-ubyte";
Path imagesFilePath = inputDirectoryPath.resolve(testImagesFileName);
Path labelsFilePath = inputDirectoryPath.resolve(testLabelsFileName);
readDecompressed(imagesFilePath, labelsFilePath, consumer);
}
/**
* Read the MNIST data from the specified (decompressed) files.
*
* @param imagesFilePath The path of the images file
* @param labelsFilePath The path of the labels file
* @param consumer The consumer that will receive the resulting
* {@link MnistEntry} instances
* @throws IOException If an IO error occurs
*/
public void readDecompressed(Path imagesFilePath, Path labelsFilePath,
Consumer<? super MnistEntry> consumer) throws IOException
{
try (InputStream decompressedImagesInputStream =
new FileInputStream(imagesFilePath.toFile());
InputStream decompressedLabelsInputStream =
new FileInputStream(labelsFilePath.toFile()))
{
readDecompressed(
decompressedImagesInputStream,
decompressedLabelsInputStream,
consumer);
}
}
/**
* Read the MNIST data from the given (decompressed) input streams.
* The caller is responsible for closing the given streams.
*
* @param decompressedImagesInputStream The decompressed input stream
* containing the image data
* @param decompressedLabelsInputStream The decompressed input stream
* containing the label data
* @param consumer The consumer that will receive the resulting
* {@link MnistEntry} instances
* @throws IOException If an IO error occurs
*/
public void readDecompressed(
InputStream decompressedImagesInputStream,
InputStream decompressedLabelsInputStream,
Consumer<? super MnistEntry> consumer) throws IOException
{
Objects.requireNonNull(consumer, "The consumer may not be null");
DataInputStream imagesDataInputStream =
new DataInputStream(decompressedImagesInputStream);
DataInputStream labelsDataInputStream =
new DataInputStream(decompressedLabelsInputStream);
int magicImages = imagesDataInputStream.readInt();
if (magicImages != 0x803)
{
throw new IOException("Expected magic header of 0x803 "
+ "for images, but found " + magicImages);
}
int magicLabels = labelsDataInputStream.readInt();
if (magicLabels != 0x801)
{
throw new IOException("Expected magic header of 0x801 "
+ "for labels, but found " + magicLabels);
}
int numberOfImages = imagesDataInputStream.readInt();
int numberOfLabels = labelsDataInputStream.readInt();
if (numberOfImages != numberOfLabels)
{
throw new IOException("Found " + numberOfImages
+ " images but " + numberOfLabels + " labels");
}
int numRows = imagesDataInputStream.readInt();
int numCols = imagesDataInputStream.readInt();
for (int n = 0; n < numberOfImages; n++)
{
byte label = labelsDataInputStream.readByte();
byte imageData[] = new byte[numRows * numCols];
read(imagesDataInputStream, imageData);
MnistEntry mnistEntry = new MnistEntry(
n, label, numRows, numCols, imageData);
consumer.accept(mnistEntry);
}
}
/**
* Read bytes from the given input stream, filling the given array
*
* @param inputStream The input stream
* @param data The array to be filled
* @throws IOException If the input stream does not contain enough bytes
* to fill the array, or any other IO error occurs
*/
private static void read(InputStream inputStream, byte data[])
throws IOException
{
int offset = 0;
while (true)
{
int read = inputStream.read(
data, offset, data.length - offset);
if (read < 0)
{
break;
}
offset += read;
if (offset == data.length)
{
return;
}
}
throw new IOException("Tried to read " + data.length
+ " bytes, but only found " + offset);
}
}
MnistEntry
班级:
import java.awt.image.BufferedImage;
import java.awt.image.DataBuffer;
import java.awt.image.DataBufferByte;
/**
* An entry of the MNIST data set. Instances of this class will be passed
* to the consumer that is given to the {@link MnistCompressedReader} and
* {@link MnistDecompressedReader} reading methods.
*/
public class MnistEntry
{
/**
* The index of the entry
*/
private final int index;
/**
* The class label of the entry
*/
private final byte label;
/**
* The number of rows of the image data
*/
private final int numRows;
/**
* The number of columns of the image data
*/
private final int numCols;
/**
* The image data
*/
private final byte[] imageData;
/**
* Default constructor
*
* @param index The index
* @param label The label
* @param numRows The number of rows
* @param numCols The number of columns
* @param imageData The image data
*/
MnistEntry(int index, byte label, int numRows, int numCols,
byte[] imageData)
{
this.index = index;
this.label = label;
this.numRows = numRows;
this.numCols = numCols;
this.imageData = imageData;
}
/**
* Returns the index of the entry
*
* @return The index
*/
public int getIndex()
{
return index;
}
/**
* Returns the class label of the entry. This is a value in [0,9],
* indicating which digit is shown in the entry
*
* @return The class label
*/
public byte getLabel()
{
return label;
}
/**
* Returns the number of rows of the image data.
* This will usually be 28.
*
* @return The number of rows
*/
public int getNumRows()
{
return numRows;
}
/**
* Returns the number of columns of the image data.
* This will usually be 28.
*
* @return The number of columns
*/
public int getNumCols()
{
return numCols;
}
/**
* Returns a <i>reference</i> to the image data. This will be an array
* of length <code>numRows * numCols</code>, containing values
* in [0,255] indicating the brightness of the pixels.
*
* @return The image data
*/
public byte[] getImageData()
{
return imageData;
}
/**
* Creates a new buffered image from the image data that is stored
* in this entry.
*
* @return The image
*/
public BufferedImage createImage()
{
BufferedImage image = new BufferedImage(getNumCols(),
getNumRows(), BufferedImage.TYPE_BYTE_GRAY);
DataBuffer dataBuffer = image.getRaster().getDataBuffer();
DataBufferByte dataBufferByte = (DataBufferByte) dataBuffer;
byte data[] = dataBufferByte.getData();
System.arraycopy(getImageData(), 0, data, 0, data.length);
return image;
}
@Override
public String toString()
{
String indexString = String.format("%05d", index);
return "MnistEntry["
+ "index=" + indexString + ","
+ "label=" + label + "]";
}
}
阅读器可用于读取未压缩的文件。结果将MnistEntry
是传递给消费者的实例:
MnistDecompressedReader mnistReader = new MnistDecompressedReader();
mnistReader.readDecompressedTraining(Paths.get("./data"), mnistEntry ->
{
System.out.println("Read entry " + mnistEntry);
BufferedImage image = mnistEntry.createImage();
...
});
MnistReader项目包含几个示例,说明如何使用这些类读取压缩或未压缩数据,或从 MNIST 条目生成 PNG 图像。