我在 Google Colab 中运行了一些代码。我定义了我自己的模型“MyModel()”和一些函数(因为太长而没有显示),它们继承自“tf.keras.Model”。
'''
save_model_path='./models' # path to save trained model
save_mat_folder='./results' # path to save reconstruction examples
log_path='./tensorboard_log' # path to log training process
load_model_path = save_model_path
model = MyModel()
summary_writer = tf.summary.create_file_writer(log_path)
tf.summary.trace_on(graph = True,profiler = False)
variables = [model.phi1,model.phi2] # write variables in a list
# define optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate= 1e-3)
for i in tf.range(50):
# print(i)
# below for TF 1.x:
# loss,summary,_=sess.run([L,merged,train_op],feed_dict) #run(fetches, feed_dict=None, options=None, run_metadata=None)
# model1_writer.add_summary(summary,global_step = i)
# below for TF2.x:
with tf.GradientTape() as tape:
# loss function
loss = model.call(Ein)
# The tape is automatically erased immediately after you call its gradient() method
grads = tape.gradient(loss, variables) ## auto-differentiation,powerful !!
# TensorFlow will update parameters automatically
optimizer.apply_gradients(grads_and_vars=zip(grads, variables))
# train_op = optimizer.minimize(L) # calculates gradients automatically
with summary_writer.as_default():
tf.summary.scalar('loss', loss, step = tf.cast(i,tf.int64))
if i % 10 == 0:
print(loss)
# export trace
with summary_writer.as_default():
tf.summary.trace_export(name ='model_trace',step=0 ) #, profiler_outdir = log_path)
tf.saved_model.save(model, save_model_path)
# save_path=saver.save(sess,save_model_path)
'''
代码看起来有效,但收到了意外警告。谁能告诉我警告的来源?
以下是运行输出:**
tf.Tensor(-8.2480165e-06, shape=(), dtype=float32)
tf.Tensor(-8.653108e-06, shape=(), dtype=float32)
tf.Tensor(-9.343687e-06, shape=(), dtype=float32)
tf.Tensor(-1.0216764e-05, shape=(), dtype=float32)
tf.Tensor(-1.1233077e-05, shape=(), dtype=float32)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.MyModel object at 0x7fea4a9e9e48>, because it is not built.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./models/assets
**