我有一个简单的循环网络示例,其中保存了andtf.Saver
和weight
变量。bias
state
当示例在没有选项的情况下运行时,它将初始化状态向量以包含零,但我想传递一个load_model
选项并将其使用状态向量的最后一个值作为session.run
调用的提要。
我看到的所有文档都坚持必须调用session.run
以从变量中检索存储的值,但在这种情况下,我想检索这些值,以便我可以初始化状态变量。我需要做一个单独的图表来检索初始化值吗?
下面的示例代码:
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3
def batch_vm2(m, x):
[input_size, output_size] = m.get_shape().as_list()
input_shape = tf.shape(x)
batch_rank = input_shape.get_shape()[0].value - 1
batch_shape = input_shape[:batch_rank]
output_shape = tf.concat(0, [batch_shape, [output_size]])
x = tf.reshape(x, [-1, input_size])
y = tf.matmul(x, m)
y = tf.reshape(y, output_shape)
return y
def get_weight_and_biases():
with tf.variable_scope(network_scope, reuse = True) as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
return weights, biases
def get_saver():
with tf.variable_scope('h1') as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
saver = tf.train.Saver([weights, biases, state])
return saver, scope
def load(sess, saver, checkpoint_dir = './'):
print("loading a session")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("no checkpoint found")
return
iteration = None
def iterate_state(prev_state_tuple, input):
with tf.variable_scope(network_scope, reuse = True) as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
print("input: ",input.get_shape())
matmuladd = batch_vm2(weights, input) + biases
matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration)
print("prev state: ",prev_state_tuple.get_shape())
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = 0.99* unpacked_state
prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ")
state = state.assign( prev_state + 0.01*matmulpri )
#output = tf.nn.relu(state)
output = tf.nn.tanh(state)
state = tf.Print(state, [state], message=" state -> ")
output = tf.Print(output, [output], message=" output -> ")
print(" state: ", state.get_shape())
print(" output: ", output.get_shape())
concat_result = tf.concat(0,[state, output])
print (" concat return: ", concat_result.get_shape())
return concat_result
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('load_model', False, 'If true, uses model files '
'to restore.')
network_scope = None
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
iteration = -1
saver, network_scope = get_saver()
initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1))
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
concat_tensor = tf.concat(0,[initial_state, initial_out])
print(" init state: ",initial_state.get_shape())
print(" init out: ",initial_out.get_shape())
print(" concat: ",concat_tensor.get_shape())
scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
print ("scanout shape: ", scanout.get_shape())
state, output = tf.split(1,2,scanout, name='split_scan_output')
print(" end state: ",state.get_shape())
print(" end out: ",output.get_shape())
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt')
tf_weight, tf_bias = get_weight_and_biases()
tf.histogram_summary('weights', tf_weight)
tf.histogram_summary('bias', tf_bias)
tf.histogram_summary('state', state)
tf.histogram_summary('out', output)
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def)
if FLAGS.load_model:
load(sess, saver)
# HOW DO I LOAD restored state values??????
#st = state[BATCH_SIZE - 1,:]
#st = sess.run([state], feed_dict={})
print("LOADED last state vec: ", st)
else:
st = np.array([0.0 , 0.0])
iter_ = data_iter()
for i in xrange(0, 1):
print ("iteration: ",i)
iteration = i
input_data = iter_.next()
out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st })
saver.save(sess, 'my-model', global_step=1+i)
summary_writer.add_summary(summary_str, i)
summary_writer.flush()
print("input vec: ", input_data)
print("state vec: ", st)
st = st[-1]
print("last state vec: ", st)
print("output vec: ", out)
print(" end state (runtime): ",st.shape)
print(" end out (runtime): ",out.shape)
print(" end scanout (runtime): ",so.shape)
请注意第 124-126 行的注释行,说明我尝试初始化提要字典值的方式。它们都不起作用。