0

我很难实现 keras 调谐器来找到这个 keras 功能 api 模型的层数、神经元和最佳学习率。

我的代码:

from sklearn.model_selection import train_test_split
import numpy as np
from keras.utils.vis_utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
import pandas as pb
from keras.layers.merge import concatenate


df = pb.read_csv('data3.csv')
training_data = df
print(training_data.values[0:30])

# split into input (x) and output (y) variables #min max scaling
x1 = list(df[df.columns[5]])  # output temperature
x2 = list(df[df.columns[4]])  # output flow rate
X_train = list(zip(x1, x2))

X1 = np.array(x1)
X2 = np.array(x2)

y1 = list(df[df.columns[0]])  # Cold temperature
y2 = list(df[df.columns[1]])  # Hot temperature
y3 = list(df[df.columns[2]])  # Cold flow rate
y4 = list(df[df.columns[3]])  # Hot flow rate

y_train = list(zip(y1, y2, y3, y4))

Y1 = np.array(y1)
Y2 = np.array(y2)
Y3 = np.array(y3)
Y4 = np.array(y4)

inputs = np.array(X_train)
outputs = np.array(y_train)


inputs, inputs_test, outputs, outputs_test = train_test_split(inputs, outputs, test_size=0.2, 
random_state=0)

# def build_model(hp):
input1 = Input(1, name='output_temp')
input2 = Input(1, name='output_flow')
merge = concatenate([input1, input2], axis=-1)
hidden1 = Dense(12, activation='relu')(merge)
hidden2 = Dense(8, activation='relu')(hidden1)
output1 = Dense(1, activation='linear', name='cold_temp')(hidden2)
output2 = Dense(1, activation='linear', name='hot_temp')(hidden2)
output3 = Dense(1, activation='linear', name='cold_flow')(hidden2)
output4 = Dense(1, activation='linear', name='hot_flow')(hidden2)

model = Model([input1, input2], [output1, output2, output3, output4])

model.compile(optimizer='adam',
          loss={'cold_temp': 'mse', 'hot_temp': 'mse', 'cold_flow': 'mse', 'hot_flow': 'mse'})

model.fit({'output_temp': X1, 'output_flow': X2}, {'cold_temp': Y1, 'hot_temp': Y2, 'cold_flow': Y3, 'hot_flow': Y4},
      epochs=10, batch_size=1)


print((model.summary()))
plot_model(model, to_file='multiple_inputs.png')
print(plot_model)
4

0 回答 0