3

我正在尝试为 Django 提供通用句子编码器。

代码在开始时被初始化为后台进程(通过使用诸如supervisor之类的程序),然后它使用TCP套接字与Django通信并最终返回编码语句。

import socket
from threading import Thread
import tensorflow as tf
import tensorflow_hub as hub
import atexit

# Pre-loading the variables:
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
session = tf.Session()
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
atexit.register(session.close)  # session closes if the script is halted
...
# Converts string to vector embedding:
def initiate_connection(conn):
    data = conn.recv(1024)
    conn.send(session.run(embed([data])))
    conn.close()

# Process in background, waiting for TCP message from views.py
while True:
    conn, addr = _socket.accept()
    _thread = Thread(target=initiate_connection, args=(conn,))  # new thread for each request (could be limited to n threads later)
    _thread.demon = True
    _thread.start()
    conn.close()

但是我在执行时收到以下错误conn.send(session.run(embed([data])))

RuntimeError: 模块必须应用在为其实例化的图中。


我基本上是在尝试在 tensorflow 中预加载表(因为这需要很多时间),但是 tensorflow 不允许我使用预定义的会话。

我怎样才能解决这个问题?有没有办法预加载这些变量?

PS 我相信这个 Github 问题页面可能会解决我的问题,但我不确定如何实施。

4

1 回答 1

3

使用您创建的图形加载模型并在会话中使用它。

graph = tf.Graph()
with tf.Session(graph = graph) as session:
     embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")

并在initial_connection函数中使用与会话相同的图形对象

def initiate_connection(conn):
    data = conn.recv(1024)
    with tf.Session(graph = graph) as session:
        session.run([tf.global_variables_initializer(), tf.tables_initializer()])
        conn.send(session.run(embed([data])))
    conn.close()
于 2019-03-28T09:12:06.400 回答