我想从 tensorflow 自定义 GRU-RNN 单元。但我不知道我需要从 tensorflow 的标准 GRU 更改哪个功能。我想修改 GRU 单元以实现优化门控循环单元 (OGRU)架构。该论文与标准 GRU 的变化是输入 t 乘以Reset Gate的更新门。这里还有另一篇使用 OGRU 方法的论文。
我试图修改kaustubhhiware/LSTM-GRU-from-scratch代码,它适用于作为示例默认数据集的图像数据集。但是,我需要使用 OGRU 方法来预测不同形状输入的时间序列数据。我知道时间序列数据集的 _input 有问题,但我不知道最佳实践如何。
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def get_states(model, processed_input, initial_hidden):
all_hidden_states = tf.scan(model, processed_input, initializer=initial_hidden, name='states')
all_hidden_states = all_hidden_states[:, 0, :, :]
return all_hidden_states
def get_output(Wo, bo, hidden_state):
output = tf.nn.relu(tf.matmul(hidden_state, Wo) + bo)
return output
class OGRU_cell(object):
def __init__(self, input_nodes, hidden_unit, output_nodes):
self.input_nodes = input_nodes
self.hidden_unit = hidden_unit
self.output_nodes = output_nodes
#weight and bias initialization
self.Wx = tf.Variable(tf.zeros([self.input_nodes, self.hidden_unit]))
self.Wr = tf.Variable(tf.zeros([self.input_nodes, self.hidden_unit]))
self.br = tf.Variable(tf.truncated_normal([self.hidden_unit], mean=1))
self.Wz = tf.Variable(tf.zeros([self.input_nodes, self.hidden_unit]))
self.bz = tf.Variable(tf.truncated_normal([self.hidden_unit], mean=1))
self.Wh = tf.Variable(tf.zeros([self.hidden_unit, self.hidden_unit]))
self.Wo = tf.Variable(tf.truncated_normal([self.hidden_unit, self.output_nodes], mean=1, stddev=.01))
self.bo = tf.Variable(tf.truncated_normal([self.output_nodes], mean=1, stddev=.01))
self._inputs = tf.placeholder(tf.float32,shape=[None, None, self.input_nodes], name='inputs')
batch_input_ = tf.transpose(self._inputs, perm=[2, 0, 1])
self.processed_input = tf.transpose(batch_input_)
self.initial_hidden = self._inputs[:, 0, :]
self.initial_hidden = tf.matmul(self.initial_hidden, tf.zeros([input_nodes, hidden_unit]))
def Gru(self, previous_hidden_state, x):
r = tf.sigmoid(tf.matmul(x, self.Wr) + self.br)
z = tf.sigmoid(tf.multiply(tf.matmul(x, self.Wz), r) + self.bz)
h_ = tf.tanh(tf.matmul(x, self.Wx) +
tf.matmul(previous_hidden_state, self.Wh) * r)
current_hidden_state = tf.multiply( (1 - z), h_) + tf.multiply(previous_hidden_state, z)
return current_hidden_state
def get_states(self):
all_hidden_states = tf.scan(self.Gru, self.processed_input, initializer=self.initial_hidden, name='states')
return all_hidden_states
def get_output(self, hidden_state):
output = tf.nn.relu(tf.matmul(hidden_state, self.Wo) + self.bo)
return output
def get_outputs(self):
all_hidden_states = self.get_states()
all_outputs = tf.map_fn(self.get_output, all_hidden_states)
return all_outputs
这是我的输入时间序列形状。批量大小 = 12
trainX shape (3214, 14, 3)
trainY shape (3214, 1)
testX shape (794, 14, 3)
testY shape (794, 1)
batch_x shape: (267, 12, 14, 3)
batch_y shape: (267, 12, 1)
这里是图像数据集的输入形状。批量大小 = 100
trainX shape (60000, 28, 28)
trainY shape (60000, 10)
testX shape (10000, 28, 28)
testY shape (10000, 10)
batch_x shape: (600, 100, 28, 28)
batch_y shape: (600, 100, 10)
这里是我的培训代码
def start_training(train, test, hidden_unit, model, alpha=learning_rate, isTrain=False, num_iterations=num_iterations, batch_size=100, opt = 'sgd'):
np.random.seed(seed)
tf.set_random_seed(seed)
# tf.random.set_seed(seed)
(trainX, trainY) = train
(testX, testY) = test
(n_x, m, m2) = trainX.T.shape
'''
Tensorflow v1.x
'''
Y = tf.placeholder(tf.float32, shape=[None, output_nodes], name='inputs')
'''
Tensorflow v2.x
Y = tf.Variable(tf.ones(shape=[None, output_nodes]), dtype=tf.float32, name='inputs')
'''
if model == 'lstm':
rnn = LSTM_cell(input_nodes, hidden_unit, output_nodes)
elif model == 'gru':
rnn = GRU_cell(input_nodes, hidden_unit, output_nodes)
else:
rnn = OGRU_cell(input_nodes, hidden_unit, output_nodes)
outputs = rnn.get_outputs()
print('Output layer:', outputs)
print('outputs[-1] :', outputs[-1])
prediction = tf.nn.softmax(outputs[-1])
if data_opt == 'image':
cost = -tf.reduce_sum(Y * tf.log(prediction))
elif data_opt == 'timeseries':
cost = -tf.reduce_sum(Y * tf.log(prediction))
saver = tf.train.Saver(max_to_keep=10)
# optimizer = tf.train.GradientDescentOptimizer(alpha).minimize(cost)
if opt == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(alpha).minimize(cost)
elif opt == 'adam':
optimizer = tf.train.AdamOptimizer(learning_rate = alpha).minimize(cost)
print('First cost :', cost)
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = (tf.reduce_mean(tf.cast(correct_prediction, tf.float32))) * 100
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
if not os.path.isdir(os.getcwd() + weights_folder):
print('Missing folder made')
os.makedirs(os.getcwd() + weights_folder)
if isTrain:
num_minibatches = len(trainX) / batch_size
for iteration in range(num_iterations):
iter_cost = 0.
batch_x, batch_y = data.create_batches(trainX, trainY, batch_size=batch_size)
# batch_x, batch_y = train
for (minibatch_X, minibatch_Y) in zip(batch_x, batch_y):
# minibatch_x_size = np.array(minibatch_X)
# minibatch_y_size = np.array(minibatch_Y)
# print('minibatch x:',minibatch_x_size.shape)
# print('minibatch y:',minibatch_y_size.shape)
_, minibatch_cost, acc = sess.run([optimizer, cost, accuracy], feed_dict={rnn._inputs: minibatch_X, Y: minibatch_Y})
iter_cost += minibatch_cost*1.0 / num_minibatches
print("Iteration {iter_num}, Cost: {cost}, Accuracy: {accuracy}".format(iter_num=iteration, cost=iter_cost, accuracy=acc))
# print ppretty(rnn)
Train_accuracy = str(sess.run(accuracy, feed_dict={rnn._inputs: trainX, Y: trainY}))
# Test_accuracy = str(sess.run(accuracy, feed_dict={rnn._inputs: testX, Y: testY}))
save_path = saver.save(sess, "." + weights_folder + data_opt + '/' + "model_" + model + "_" + str(hidden_unit) + ".ckpt")
print("Parameters have been trained and saved!")
print("\rTrain Accuracy: %s" % (Train_accuracy))
else: # test mode
# no need to download weights in this assignment
# check_download_weights(model, hidden_unit)
saver.restore(sess, "." + weights_folder + data_opt + '/' + "model_" + model + "_" +str(hidden_unit) + ".ckpt")
acc = sess.run(accuracy, feed_dict={rnn._inputs: testX, Y: testY})
print("Test Accuracy:"+"{:.3f}".format(acc))
sess.close()
在这里查看完整代码。