以下是我想用我的操作做的事情: - 生成 python 包装器 - 也将操作添加到 pip 包 - 将我的操作链接到 tensorflow,以便 tensorflow-serving 可以执行操作
我将我的操作放在 tensorflow/contrib/foo 中。这是源代码树的样子
.
├── BUILD
├── LICENSE
├── __init__.py
├── foo_op.cc
├── foo_op_gpu.cu.cc
└── foo_op.h
我的__init__.py
文件包含生成的包装器的导入
from tensorflow.contrib.sampling.ops.gen_foo import *
我在tensorflow/contrib/__init__.py
from tensorflow.contrib import foo
这是我的tensorflow/contrib/foo/BUILD
文件:
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
tf_kernel_library(
name = "foo_op_kernels",
prefix = "foo",
alwayslink = 1,
)
tf_gen_op_libs(
op_lib_names = ["foo"],
)
tf_gen_op_wrapper_py(
name = "foo",
visibility = ["//visibility:public"],
deps = [
":foo_op_kernels",
],
)
tf_custom_op_py_library(
name = "foo_py",
srcs = [
"__init__.py",
],
kernels = [
":foo_op_kernels",
],
srcs_version = "PY2AND3",
deps = [
":foo",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:common_shapes",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
)
这是我必须触摸才能使其正常工作的 tensorflow bazel 文件。
tensorflow/contrib/BUILD
- 添加
foo_op_kernels
到contrib_kernels
部门
- 添加
foo_op_lib
到contrib_ops_op_lib
部门
- 添加
foo
到contrib_py
部门
tensorflow/tools/pip_package/BUILD
- 将我的 python 目标添加到
COMMON_PIP_DEPS
tensorflow/core/BUILD
- 将我的内核添加到
all_kernels_statically_linked
. 我可能在这个方面做得过火了,但它奏效了。
以下是 tensorflow 服务 bazel 文件:
WORKSPACE
- 更改
org_tensorflow
为local_repository
指向我的张量流而不是谷歌的tensorflow_http_archive
然后我修改:tensorflow_serving/tools/docker/Dockerfile.devel-gpu
克隆我的 tensorflow 和 tensorflow-serving 版本。