我正在使用眼底图像数据进行糖尿病视网膜病变的图像分类任务。有5个班。数据分布为 1805 张图像(第 1 类)、370 张图像(第 2 类)、999 张图像(第 3 类)、193 张图像(第 4 类)、295 张图像(第 5 类)。以下是我尝试运行的步骤:
- 预处理(调整大小 224 * 224)
- 训练数据和测试数据的比例为 85% : 15%
x_train, xtest, y_train, ytest = train_test_split(
x_train, y_train,
test_size = 0.15,
random_state=SEED,
stratify = y_train
)
- 数据分析
ImageDataGenerator(
zoom_range=0.15,
fill_mode='constant',
cval=0.,
horizontal_flip=True,
vertical_flip=True,
)
- 使用 ResNet-50 模型和交叉验证进行训练
def getResNet():
modelres = ResNet50(weights=None, include_top=False, input_shape= (IMAGE_HEIGHT,IMAGE_HEIGHT, 3))
x = modelres.output
x = GlobalAveragePooling2D()(x)
x = Dense(5, activation= 'softmax')(x)
model = Model(inputs = modelres.input, outputs = x)
return model
num_folds = 5
skf = StratifiedKFold(n_splits = 5, shuffle=True, random_state=2021)
cvscores = []
fold = 1
for train, val in skf.split(x_train, y_train.argmax(1)):
print('Fold: ', fold)
Xtrain = x_train[train]
Xval = x_train[val]
Ytrain = y_train[train]
Yval = y_train[val]
data_generator = create_datagen().flow(Xtrain, Ytrain, batch_size=32, seed=2021)
model = getResNet()
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=0.0001),
metrics=['accuracy'])
with tf.compat.v1.device('/device:GPU:0'):
model_train = model.fit(data_generator,
validation_data=(Xval, Yval),
epochs=30, batch_size = 32, verbose=1)
model_name = 'cnn_keras_aug_Fold_'+str(fold)+'.h5'
model.save(model_name)
scores = model.evaluate(xtest, ytest, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
cvscores.append(scores[1] * 100)
fold = fold +1
我从这种方法中得到的最大结果是训练准确率为 81.2%,验证准确率为 72.2%,测试准确率为 70.73%。谁能给我一个改进模型的想法,以便我可以将测试准确率提高到 90% 以上?稍后,我将使用这个模型作为预训练模型来训练糖尿病视网膜病变数据,但来自其他来源。
顺便说一句,我尝试用这种方法替换我的预处理:
def preprocessing(path):
image = cv2.imread(path)
image = crop_image_from_gray(image)
green = image[:,:,1]
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
cl = clahe.apply(green)
image[:,:,0] = image[:,:,0]
image[:,:,2] = image[:,:,2]
image[:,:,1] = cl
image = cv2.resize(image, (224,224))
return image
我还尝试用 VGG16 EfficientNetB0 替换我的模型。然而,这些都对我的结果没有太大影响。我仍然坚持大约 70% 的准确率。请帮助我想出一些想法来改善我的建模结果。我希望。