大家好,我有一个问题,我找不到最好的方法
我有一个后端 restful api,我想在其中使用 tensorflow hub 模块,但我有一个问题,那就是每次我想进行计算时,我必须初始化所有变量和表,并且需要很长时间来处理我的问题是:
有没有一种方法可以在会话中一次性初始化所有变量和表并关闭会话,因为解决此问题的一种方法是保持会话打开并进行计算,但我的解决方案的问题是它占用资源.
我把主要代码和我自己的解决方案都放了,以便更好地理解
加载不同模块的函数
def loading_module(path = None, module_url =
'https://tfhub.dev/google/universal-sentence-encoder/2'):
# Import the Universal Sentence Encoder's TF Hub module
graph = tf.get_default_graph()
if path == None:
embed_object = hub.Module(module_url)
else:
embed_object = hub.Module(hub.load_module_spec(path))
return embed_object
在文本上运行嵌入模块的函数
def run_embedding(embed_object, graph, text):
# Reduce logging output.
tf.logging.set_verbosity(tf.logging.ERROR)
with tf.Session(graph = graph) as sess:
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
encoding_tensor = embed_object(similarity_input_placeholder)
message_embeddings = sess.run(encoding_tensor, feed_dict = {similarity_input_placeholder:text})
return message_embeddings
embed_object = loading_module()
run_embedding(embed_object, ['sth'])
我的解决方案
def loading_module(path = None, module_url = 'https://tfhub.dev/google/universal-sentence-encoder/2'):
# Import the Universal Sentence Encoder's TF Hub module
g = tf.Graph()
with g.as_default():
if path == None:
embed_object = hub.Module(module_url)
else:
embed_object = hub.Module(hub.load_module_spec(path))
sess = tf.InteractiveSession(graph = g)
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
return embed_object, g, sess
def run_embedding(embed_object, graph, sess, text):
# Reduce logging output.
tf.logging.set_verbosity(tf.logging.ERROR)
with graph.as_default():
similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
encoding_tensor = embed_object(similarity_input_placeholder)
message_embeddings = sess.run(encoding_tensor, feed_dict = {similarity_input_placeholder:text})
return message_embeddings