0

我有一个非常大的数据集,不适合内存我将它拆分为文件,我想在数据生成器中使用它们进行训练我使用以下代码

def csv_image_generator(i,inputPath1, bs, lb, mode="train", aug=None):

    # open the CSV file for reading
    # loop indefinitely
    while True:
        f = open('mnist_1D_train_'+str(i)+'.csv', "r")

        # initialize our batches of images and labels
        print(i)
        print('mnist_1D_train_'+str(i)+'.csv')
        images = []
        labels = []
        # keep looping until we reach our batch size
        while len(images) < bs:
            # attempt to read the next line of the CSV file
            line = f.readline()
            # check to see if the line is empty, indicating we have
            # reached the end of the file
            if line == "":
                # reset the file pointer to the beginning of the file
                # and re-read the line
                f.seek(0)
                line = f.readline()

                # if we are evaluating we should now break from our
                # loop to ensure we don't continue to fill up the
                # batch from samples at the beginning of the file
#               if mode == "eval":
#                   break
            # extract the label and construct the image
            line = line.strip().split(",")
            label = line[0]
            image = np.array([float(x) for x in line[1:]], dtype="float")
            image = image.reshape((1, 28, 28))
            image = image.T

            # update our corresponding batches lists
            images.append(image)
            labels.append(label)
        # one-hot encode the labels
        labels = lb.transform(np.array(labels))
        # if the data augmentation object is not None, apply it
        if aug is not None:
            (images, labels) = next(aug.flow(np.array(images),
                labels, batch_size=bs))
        # yield the batch to the calling function
        yield (np.array(images), labels)

为 one-hot 编码标签创建标签二值化器,然后对测试标签进行编码 构建用于数据增强的训练图像生成器,初始化训练和测试图像生成器

lb = LabelBinarizer()
lb.fit(list(labels))
testLabels = lb.transform(testLabels)

trainGen = csv_image_generator(TRAIN_CSV, BS, lb,   mode="train", aug=aug)
testGen  = csv_image_generator_test(TEST_CSV,  BS, lb,  mode="train", aug=None)

然后我用

H = model.fit_generator(
    trainGen,
    steps_per_epoch=NUM_TRAIN_IMAGES // (BS*2),
    validation_data=testGen,
    validation_steps=NUM_TEST_IMAGES // (BS*2),
    epochs=NUM_EPOCHS)

但 fit_generator 只读取第一个文件

4

0 回答 0