这是一个等效的子类实现。虽然我没有测试。
import tensorflow as tf
# your config
config = {
'learning_rate': 0.001,
'lstm_neurons':32,
'lstm_activation':'tanh',
'dropout_rate': 0.08,
'batch_size': 128,
'dense_layers':[
{'neurons': 32, 'activation': 'relu'},
{'neurons': 32, 'activation': 'relu'},
]
}
# Subclassed API Model
class MySubClassed(tf.keras.Model):
def __init__(self, output_size):
super(MySubClassed, self).__init__()
self.lstm = tf.keras.layers.LSTM(config['lstm_neurons'],
activation=config['lstm_activation'])
self.bn = tf.keras.layers.BatchNormalization()
if 'dropout_rate' in config:
self.dp1 = tf.keras.layers.Dropout(config['dropout_rate'])
self.dp2 = tf.keras.layers.Dropout(config['dropout_rate'])
self.dp3 = tf.keras.layers.Dropout(config['dropout_rate'])
for layer in config['dense_layers']:
self.dense1 = tf.keras.layers.Dense(layer['neurons'],
activation=layer['activation'])
self.bn1 = tf.keras.layers.BatchNormalization()
self.dense2 = tf.keras.layers.Dense(layer['neurons'],
activation=layer['activation'])
self.bn2 = tf.keras.layers.BatchNormalization()
self.out = tf.keras.layers.Dense(output_size,
activation='sigmoid')
def call(self, inputs, training=True, **kwargs):
x = self.lstm(inputs)
x = self.bn(x)
if 'dropout_rate' in config:
x = self.dp1(x)
x = self.dense1(x)
x = self.bn1(x)
if 'dropout_rate' in config:
x = self.dp2(x)
x = self.dense2(x)
x = self.bn2(x)
if 'dropout_rate' in config:
x = self.dp3(x)
return self.out(x)
# A convenient way to get model summary
# and plot in subclassed api
def build_graph(self, raw_shape):
x = tf.keras.layers.Input(shape=(None, raw_shape),
ragged=True)
return tf.keras.Model(inputs=[x],
outputs=self.call(x))
构建和编译 mdoel
s = MySubClassed(output_size=1)
s.compile(
loss = 'mse',
metrics = ['mse'],
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001))
通过一些张量来创建权重(检查)。
raw_input = (16, 16, 16)
y = s(tf.ones(shape=(raw_input)))
print("weights:", len(s.weights))
print("trainable weights:", len(s.trainable_weights))
weights: 21
trainable weights: 15
总结和情节
总结和可视化模型图。
s.build_graph(16).summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, None, 16)] 0
_________________________________________________________________
lstm (LSTM) (None, 32) 6272
_________________________________________________________________
batch_normalization (BatchNo (None, 32) 128
_________________________________________________________________
dropout (Dropout) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 32) 1056
_________________________________________________________________
batch_normalization_3 (Batch (None, 32) 128
_________________________________________________________________
dropout_1 (Dropout) (None, 32) 0
_________________________________________________________________
dense_3 (Dense) (None, 32) 1056
_________________________________________________________________
batch_normalization_4 (Batch (None, 32) 128
_________________________________________________________________
dropout_2 (Dropout) (None, 32) 0
_________________________________________________________________
dense_4 (Dense) (None, 1) 33
=================================================================
Total params: 8,801
Trainable params: 8,609
Non-trainable params: 192
tf.keras.utils.plot_model(
s.build_graph(16),
show_shapes=True,
show_dtype=True,
show_layer_names=True,
rankdir="TB",
)