我在 tensorflow2 中实现了一个函数 aproximator(有效)。但是,我没有意识到我应该在 tensorflow 1.14 中实现它。所以,现在我正处于十字路口,因为版本完全不同,我找不到像 Tensorflow2 官方教程那样详细的教程。
我想知道是否有人可以帮助我翻译这个非常简单的示例和/或为我指明一个好的详细教程的方向。
'''
import numpy as np
import tensorflow as tf
from algo import ValueFunctionWithApproximation
class ValueFunctionWithNN(ValueFunctionWithApproximation):
def __init__(self,
state_dims):
"""
state_dims: the number of dimensions of state space
"""
self.model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(32, input_shape=(2,), activation='relu'),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1)
]
)
self.flag_first_iter = True
def __call__(self,s)->float:
s = self.state_formatter(s) # My input is a numpy array
return float( self.v_hat(s) )
def update(self,alpha,G,s_tau):
s_tau = self.state_formatter(s_tau) # My input is a numpy array
if self.flag_first_iter == True:
self.flag_first_iter == False
# Instantiate an optimizer.
self.optimizer = tf.keras.optimizers.Adam(
learning_rate=alpha,
beta_1 = 0.9,
beta_2 = 0.999,
epsilon = 1e-07,
amsgrad = False,
name ="Adam"
)
# Instantiate a loss function.
self.loss_fn = tf.keras.losses.MeanSquaredError()
self.trainANN(G, s_tau)
return None
@tf.function
def v_hat(self, s):
return self.model( s )
@tf.function
def trainANN(self, G, s_tau):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
v_hat = self.model(s_tau, training=True) # estimate of the value function
# Compute the loss value.
loss_value = self.loss_fn(G, v_hat)
loss_value /= 2
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, self.model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
return None
def state_formatter(self,s):
return s.reshape((-1,len(s)))
'''
谢谢