1

你好吗?我正在尝试使用 sklearn.svm 在 keras 成本函数中实现 SVM。但是,我总是遇到错误。我相信问题是将 y_true 和 y_pred 张量转换为要在 sklearn.svm 中使用的 numpy 数组。然后我需要将预测结果转换为张量以用于 keras 的成本函数(categorical_hinge)。

有谁能够帮我?

model_input = Input(shape = (img_width, img_height, channel_axis))
x = Convolution2D_bn(model_input, 32, 3, 3, strides=(2, 2), padding='valid')
x = Convolution2D_bn(x, 32, 3, 3, padding='valid')
x = Convolution2D_bn(x, 64, 3, 3)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)

x = Convolution2D_bn(x, 80, 1, 1, padding='valid')
x = Convolution2D_bn(x, 192, 3, 3, padding='valid')
more model
# Classification block
x = GlobalAveragePooling2D()(x)
x = Dense(4096, kernel_regularizer=l2(1e-4), name='Dense_1')(x)
x = Activation('relu', name='relu1')(x)
x = Dropout(DROPOUT)(x)
x = Dense(4096, kernel_regularizer=l2(1e-4), name='Dense_2')(x)
x = Activation('relu', name='relu2')(x)
model_output = Dropout(DROPOUT)(x)
model = Model(model_input, model_output)
model.summary()

import tensorflow as tf
from keras import backend as K
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from keras.losses import categorical_hinge
def custom_loss_value(y_true, y_pred):
    X = K.eval(y_pred)
    print(X)
    Y = np.ravel(K.eval(y_true))
    Predict = []
    Prob = []
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    param_grid = {'C': [0.1, 1, 8, 10], 'gamma': [0.001, 0.01, 0.1, 1]}
    SVM = GridSearchCV(SVC(kernel='rbf',probability=True), cv=3, param_grid=param_grid, scoring='auc', verbose=1)
    SVM.fit(X, Y)
    Final_Model = SVM.best_estimator_
    Predict = Final_Model.predict(X)
    Prob = Final_Model.predict_proba(X)
    return categorical_hinge(tf.convert_to_tensor(Y, dtype=tf.float32), tf.convert_to_tensor(Predict, dtype=tf.float32))

sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True) 
model.compile(loss=custom_loss_value, optimizer=sgd, metrics=['accuracy'])
4

1 回答 1

0

尝试这个

y_test = np.argmax(y_test , axis=1)
y_pred = np.argmax(y_pred , axis=1)
于 2020-05-30T13:26:21.730 回答