1

我正在使用以下代码从大约 4000 个图像中提取特征,这些图像分为 30 个类别。

 for i, label in enumerate(train_labels):
        cur_path = train_path + "/" + label
        count = 1
        for image_path in glob.glob(cur_path + "/*.jpg"):
            img = image.load_img(image_path, target_size=image_size)
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)
            feature = model.predict(x)
            flat = feature.flatten()
            features.append(flat)
            labels.append(label)
            print ("[INFO] processed - " + str(count))
        count += 1
    print ("[INFO] completed label - " + label)

虽然,我的整个数据集要大得多,多达 80,000 张图像。查看我的 GPU 内存时,这段代码在 Keras (2.1.2) 中适用于 4000 张图像,但几乎占用了我的 Tesla K80 的所有 5gig 视频 RAM。我想知道是否可以通过更改 batch_size 来提高性能,或者这段代码的工作方式对我的 GPU 来说太重了,我应该重写它吗?

谢谢!

4

1 回答 1

1

有两种可能的解决方案。

1)我假设您以 Numpy 数组格式存储图像。这是非常占用内存的。相反,将其存储为普通列表。当应用程序需要将其转换为 numpy 数组时。就我而言,它将内存消耗减少了 10 倍。如果您已经将其存储为列表,那么 2 解决方案可能会解决您的问题。

2)将结果存储在块中,并在将其输入另一个模型时使用生成器。

chunk_of_features=[]
chunk_of_labels=[]
i=0
for i, label in enumerate(train_labels):
        cur_path = train_path + "/" + label
        count = 1
        for image_path in glob.glob(cur_path + "/*.jpg"):
            i+=1
            img = image.load_img(image_path, target_size=image_size)
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)
            feature = model.predict(x)
            flat = feature.flatten()
            chunk_of_features.append(flat)
            chunk_of_labels.append(label)
            if i==4000:
                with open('useSomeCountertoPreventNameConflict','wb') as output_file:
                    pickle.dump(chunk_of_features,output_file)
                with open('useSomeCountertoPreventNameConflict','wb') as output_file:
                    pickle.dump(chunk_of_labels,output_file)
                chunk_of_features=[]
                chunk_of_labels=[]
                i=0

            print ("[INFO] processed - " + str(count))
        count += 1
    print ("[INFO] completed label - " + label)
于 2018-07-16T11:18:30.733 回答