我正在尝试为基于策略的 RL 创建一个神经网络。我编写了类来构建网络并生成如下操作:
class Oracle(object):
def __init__(self, input_dim, output_dim, hidden_dims=None):
if hidden_dims is None:
hidden_dims = [32, 32]
self.input_dim = input_dim
self.output_dim = output_dim
self.__build_network(input_dim,output_dim,hidden_dims)
self.__build_train_fn()
def __build_network(self,input_dim, output_dim, hidden_dims):
"""Create a base network"""
inputs = Input(shape=(input_dim,))
net = inputs
# a layer instance is callable on a tensor, and returns a tensor
for h_dim in hidden_dims:
net = Dense(h_dim, activation='relu',kernel_initializer='RandomNormal',bias_initializer='zeros')(net)
net = Dense(output_dim, activation='softmax',kernel_initializer='RandomNormal',bias_initializer='zeros')(net)
# This creates a model that includes
# the Input layer and three Dense layers
self.model = Model(inputs=inputs, outputs=net)
return self.model
def __build_train_fn(self):
"""Create a train function
It replaces `model.fit(X, y)` because we use the output of model and use it for training.
For example, we need action placeholder
called `action_one_hot` that stores, which action we took at state `s`.
Hence, we can update the same action.
This function will create
`self.train_fn([state, action_one_hot, discount_reward])`
which would train the model.
"""
action_prob_placeholder = self.model.output
action_onehot_placeholder = K.placeholder(shape=(None, self.output_dim),
name="action_onehot")
discount_reward_placeholder = K.placeholder(shape=(None,),
name="discount_reward")
action_prob = K.sum(action_prob_placeholder * action_onehot_placeholder, axis=1)
log_action_prob = K.log(action_prob)
loss = - log_action_prob * discount_reward_placeholder
loss = K.mean(loss)
adam = optimizers.Adam()
updates = adam.get_updates(params=self.model.trainable_weights,
constraints=[],
loss=loss)
self.train_fn = K.function(inputs=[self.model.input,
action_onehot_placeholder,
discount_reward_placeholder],
outputs=[],
updates=updates)
def get_action(self, state):
"""Returns an action at given `state`
Args:
state (1-D or 2-D Array): It can be either 1-D array of shape (state_dimension, )
or 2-D array shape of (n_samples, state_dimension)
Returns:
action: an integer action value ranging from 0 to (n_actions - 1)
"""
shape = state.shape
if len(shape) == 1:
assert shape == (self.input_dim,), "{} != {}".format(shape, self.input_dim)
state = np.expand_dims(state, axis=0)
elif len(shape) == 2:
assert shape[1] == (self.input_dim), "{} != {}".format(shape, self.input_dim)
else:
raise TypeError("Wrong state shape is given: {}".format(state.shape))
action_prob = np.squeeze(self.model.predict(state))
assert len(action_prob) == self.output_dim, "{} != {}".format(len(action_prob), self.output_dim)
print(state)
print(state.shape)
weights = self.model.get_weights()
print(weights)
return np.random.choice(np.arange(self.output_dim), p=action_prob)
我想在基于策略的 RL 中使用它。问题是即使我将权重初始化为Random normal
(或其他初始化程序),权重输出也有很多 nan。此外,action_prob
也以nan的身份出现。下面给出了权重的代表性输出。谁能告诉我如何解决这个问题?
[array([[ 1.97270699e-02, nan, -1.53264655e-02,
nan, nan, 9.83271226e-02,
nan, 1.67111661e-02, nan,
-5.40489666e-02, nan, -3.19434591e-02,
nan, -8.62319861e-03, nan,
3.90832238e-02, nan, nan,
nan, -3.34417708e-02, nan,
4.17598374e-02, 1.23961531e-02, 1.13383524e-01,
1.52971387e-01, -7.35234842e-02, 4.81316447e-03,
nan, nan, 9.02018696e-02,
-5.64984754e-02, nan],
[ 3.42946462e-02, nan, -2.32576765e-02,
nan, nan, -1.62454545e-02,
nan, 7.62931630e-02, nan,
7.09382221e-02, nan, -9.45277140e-02,
nan, 6.81431815e-02, nan,
5.43346964e-02, nan, nan,
nan, -5.25366806e-04, nan,
-3.03930230e-02, 1.90449376e-02, -6.84814155e-02,
-4.24950942e-02, -4.82842028e-02, 3.00289365e-03,
nan, nan, 1.14762083e-01,
-1.53483404e-02, nan],
[ 1.11763954e-01, nan, -2.40741558e-02,
nan, nan, -2.25515720e-02,
nan, 8.37199837e-02, nan,
8.01791809e-03, nan, 4.11959179e-02,
nan, -8.09677169e-02, nan,
1.09827537e-02, nan, nan,
nan, 3.24306265e-03, nan,
-4.61481474e-02, -4.44600247e-02, 5.97798042e-02,
-2.80357362e-03, 4.99138907e-02, -3.16888206e-02,
nan, nan, 4.79343869e-02,
-3.04902103e-02, nan],
[ 9.96000832e-04, nan, 7.03881904e-02,
nan, nan, 3.29129435e-02,
nan, 2.59399302e-02, nan,
3.94702554e-02, nan, 5.41977606e-05,
nan, -8.05872083e-02, nan,
7.35593066e-02, nan, nan,
nan, -3.20138596e-02, nan,
-4.88653146e-02, -3.05510052e-02, 1.61004122e-02,
3.60239707e-02, -2.89578568e-02, -8.55704099e-02,
nan, nan, -4.69469689e-02,
5.44301942e-02, nan],
[ 2.39880346e-02, nan, 1.02485856e-02,
nan, nan, -3.28975841e-02,
nan, 3.20423655e-02, nan,
7.26358453e-03, nan, -3.04405931e-02,
nan, 1.31638274e-02, nan,
-6.58982694e-02, nan, nan,
nan, -8.48279800e-03, nan,
5.07000796e-02, -3.43187563e-02, 1.69583317e-02,
5.02665602e-02, 6.59292564e-02, 5.91163523e-03,
nan, nan, 1.64841004e-02,
1.03674673e-01, nan],
[ 2.22617369e-02, nan, -9.83130708e-02,
nan, nan, -8.62144455e-02,
nan, -1.24993315e-03, nan,
-3.39315496e-02, nan, -3.71638462e-02,
nan, -2.51251217e-02, nan,
-3.30121554e-02, nan, nan,
nan, 6.95239231e-02, nan,
3.96330692e-02, -7.67886639e-02, 3.19798961e-02,
-7.02575818e-02, 5.36917103e-03, -7.84784183e-02,
nan, nan, -1.12238321e-02,
5.90852983e-02, nan],
[ -1.23783462e-02, nan, 8.54373630e-03,
nan, nan, 2.71492247e-02,
nan, -4.39056493e-02, nan,
1.54177221e-02, nan, 8.08294937e-02,
nan, -2.47991290e-02, nan,
-4.90374281e-04, nan, nan,
nan, -2.03785431e-02, nan,
-2.94432435e-02, -4.85701524e-02, -5.98664656e-02,
5.03640659e-02, -1.06101505e-01, -5.01858108e-02,
nan, nan, 1.59794372e-02,
-5.52875735e-03, nan],
[ -6.50038645e-02, nan, -2.88410280e-02,
nan, nan, 5.70952846e-03,
nan, 2.29494330e-02, nan,
2.96308636e-03, nan, -1.30019784e-02,
nan, 1.38891954e-02, nan,
9.82243866e-02, nan, nan,
nan, -4.53725718e-02, nan,
7.28782360e-03, -1.97060239e-02, 1.30356764e-02,
-1.77630689e-02, -5.27498014e-02, -5.70283793e-02,
nan, nan, -4.40920331e-03,
-8.47700890e-03, nan],
[ -7.09274644e-03, nan, -2.85792332e-02,
nan, nan, 1.90456193e-02,
nan, 2.33339947e-02, nan,
-7.10851625e-02, nan, -2.07360443e-02,
nan, -8.23910628e-03, nan,
1.53461788e-02, nan, nan,
nan, 8.74896254e-03, nan,
-1.04130013e-02, -8.23952537e-03, 3.29020806e-02,
-8.53802171e-03, -5.38858548e-02, 2.94392351e-02,
nan, nan, 2.28152424e-03,
3.86046581e-02, nan],
[ 6.32084534e-02, nan, 1.79775548e-03,
nan, nan, -5.96092641e-02,
nan, 1.74504239e-03, nan,
9.05414373e-02, nan, -3.55534554e-02,
nan, -3.89753282e-02, nan,
8.71098042e-03, nan, nan,
nan, 7.47531727e-02, nan,
5.26362322e-02, 1.46157984e-02, 3.21042910e-03,
-7.87475239e-03, 4.22325032e-03, 1.58537421e-02,
nan, nan, 3.45352525e-03,
9.88092553e-03, nan],
[ 8.60697851e-02, nan, 7.76077956e-02,
nan, nan, 1.35996595e-01,
nan, 7.12691769e-02, nan,
-2.70256456e-02, nan, 9.95257962e-03,
nan, -2.21844148e-02, nan,
4.18028049e-02, nan, nan,
nan, 6.15538433e-02, nan,
-3.34422104e-02, 7.96959698e-02, 3.36392457e-03,
-9.79953539e-03, 1.52911739e-02, -9.56133530e-02,
nan, nan, 3.26185785e-02,
-5.18142292e-03, nan],
[ -7.14878365e-02, nan, 3.30364555e-02,
nan, nan, -7.56359026e-02,
nan, -8.38122815e-02, nan,
3.50784622e-02, nan, 6.51308149e-02,
nan, -8.44882503e-02, nan,
1.97267421e-02, nan, nan,
nan, -4.02851999e-02, nan,
-3.84002179e-02, 3.23568434e-02, 9.30055231e-03,
2.97283176e-02, -3.93995969e-03, 1.24160219e-02,
nan, nan, -5.86424842e-02,
-5.61306179e-02, nan],
[ 5.52838258e-02, nan, -2.10575890e-02,
nan, nan, -1.46265700e-02,
nan, -6.19944222e-02, nan,
-4.26368900e-02, nan, -1.77203845e-02,
nan, 7.23404884e-02, nan,
1.19749429e-02, nan, nan,
nan, -1.97013188e-02, nan,
-9.93668661e-03, -1.43543081e-02, -1.89676192e-02,
-3.46484780e-02, -2.41095871e-02, 2.64016148e-02,
nan, nan, 3.39512643e-03,
-2.40868814e-02, nan],
[ 4.85769324e-02, nan, -2.96661835e-02,
nan, nan, -1.16411140e-02,
nan, -9.32439044e-03, nan,
-2.47888379e-02, nan, -2.11149845e-02,
nan, 1.55771989e-02, nan,
-3.60703245e-02, nan, nan,
nan, -8.21380615e-02, nan,
7.12675974e-02, 3.52902263e-02, 5.15214726e-03,
4.55725230e-02, -3.67484652e-02, -1.13544762e-02,
nan, nan, -3.86700444e-02,
-3.91620398e-02, nan],
[ -5.83947077e-03, nan, 5.90741597e-02,
nan, nan, -4.57256138e-02,
nan, -8.41458961e-02, nan,
-7.60969743e-02, nan, 2.50754189e-02,
nan, 2.75974572e-02, nan,
2.27455739e-02, nan, nan,
nan, -1.64209884e-02, nan,
-2.64473110e-02, -1.31150903e-02, 3.04512922e-02,
-5.81411598e-03, 1.68283712e-02, -1.44851422e-02,
nan, nan, -2.56322809e-02,
1.11139610e-01, nan],
[ 8.34780037e-02, nan, 6.61360845e-03,
nan, nan, -1.08085848e-01,
nan, -1.87303626e-03, nan,
-2.97805574e-02, nan, -4.96098958e-02,
nan, -2.47526560e-02, nan,
5.78494631e-02, nan, nan,
nan, 9.74192936e-03, nan,
-4.88330796e-02, 1.02368537e-02, -2.99407393e-02,
-3.94638889e-02, -1.45375028e-01, -8.38985574e-03,
nan, nan, -2.59864815e-02,
-5.39724007e-02, nan],
[ 2.34477259e-02, nan, 6.47758618e-02,
nan, nan, -2.06562635e-02,
nan, -1.50227742e-02, nan,
-4.99106087e-02, nan, -8.75398964e-02,
nan, -1.91738885e-02, nan,
9.81663391e-02, nan, nan,
nan, 8.30503032e-02, nan,
-6.02204986e-02, -5.43463342e-02, -2.73545366e-02,
-3.97464111e-02, -1.08450698e-03, 1.27358735e-02,
nan, nan, -6.65350258e-02,
-7.63151273e-02, nan],
[ -1.75849702e-02, nan, 5.18983677e-02,
nan, nan, 2.52664816e-02,
nan, -7.14112073e-03, nan,
2.89890468e-02, nan, -3.46427821e-02,
nan, 1.85990240e-02, nan,
-4.50296048e-03, nan, nan,
nan, -5.50862215e-02, nan,
1.02454759e-01, 9.34040993e-02, 1.45452050e-02,
2.90963929e-02, 3.19026299e-02, 1.89037640e-02,
nan, nan, -1.68684160e-03,
9.94853582e-03, nan],
[ -9.39413719e-03, nan, -3.46053950e-03,
nan, nan, 3.13128680e-02,
nan, -2.45536752e-02, nan,
4.08208035e-02, nan, 2.67537422e-02,
nan, 8.34849998e-02, nan,
-2.65908819e-02, nan, nan,
nan, -2.63154972e-03, nan,
4.54281829e-02, 1.24697601e-02, 5.25561944e-02,
5.75856939e-02, -8.61058664e-03, 2.86082458e-02,
nan, nan, -4.48538922e-02,
6.58497736e-02, nan],
[ -4.35961820e-02, nan, 5.22863083e-02,
nan, nan, -8.59688129e-03,
nan, -5.25927730e-02, nan,
7.24843144e-02, nan, -4.00458984e-02,
nan, -2.85069328e-02, nan,
2.43122727e-02, nan, nan,
nan, 1.57326814e-02, nan,
4.99758229e-04, 1.23931235e-02, 1.90575924e-02,
-4.64425469e-03, 5.54191284e-02, 2.38004271e-02,
nan, nan, -7.39056617e-03,
3.59723084e-02, nan],
[ 6.80808276e-02, nan, -1.49172200e-02,
nan, nan, -1.84247848e-02,
nan, 7.11160824e-02, nan,
4.74170335e-02, nan, -8.48565064e-03,
nan, 6.96734041e-02, nan,
1.07453577e-01, nan, nan,
nan, 3.21782194e-02, nan,
3.53086367e-02, -2.57775784e-02, -3.70149538e-02,
8.49922895e-02, 4.88188267e-02, 4.43161186e-03,
nan, nan, 7.35458219e-03,
-4.75145914e-02, nan],
[ -1.23953104e-01, nan, -4.27762084e-02,
nan, nan, 2.04169434e-02,
nan, 5.78987077e-02, nan,
-6.60712123e-02, nan, -2.07597148e-02,
nan, 3.00809499e-02, nan,
1.40863642e-01, nan, nan,
nan, -4.05914113e-02, nan,
-4.87232655e-02, 1.49445562e-02, 3.01859360e-02,
2.01087426e-02, 7.96428975e-03, 2.58545913e-02,
nan, nan, -3.26734572e-03,
2.30945610e-02, nan]], dtype=float32), array([ 0., nan, 0., nan, nan, 0., nan, 0., nan, 0., nan,
0., nan, 0., nan, 0., nan, nan, nan, 0., nan, 0.,
0., 0., 0., 0., 0., nan, nan, 0., 0., nan], dtype=float32), array([[ nan, nan, nan, ..., nan,
nan, 0.08562656],
[ nan, nan, nan, ..., nan,
nan, -0.03227361],
[ nan, nan, nan, ..., nan,
nan, -0.1371294 ],
...,
[ nan, nan, nan, ..., nan,
nan, 0.01600872],
[ nan, nan, nan, ..., nan,
nan, -0.0156843 ],
[ nan, nan, nan, ..., nan,
nan, -0.036583 ]], dtype=float32), array([ nan, nan, nan, nan, nan, nan, 0., 0., nan, 0., 0.,
0., 0., 0., nan, nan, nan, 0., nan, 0., 0., 0.,
nan, 0., nan, nan, nan, nan, nan, nan, nan, 0.], dtype=float32), array([[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan]], dtype=float32), array([ nan, nan, nan], dtype=float32)]