/tensorflow/core/user_ops
我已经为我的自定义 Op 实现了一个内核,并将其放入custom_op.cc
. 在 Op 中我做所有的注册工作,比如REGISTER_OP
和REGISTER_KERNEL_BUILDER
.
然后我在 Python 中为这个 Op 实现了渐变,我把它和custom_op_grad.py
. 我也在这里完成了所有注册(@ops.RegisterGradient
)。
我创建了 BUILD 文件,内容如下:
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
name = "custom_op.so",
srcs = ["custom_op.cc"],
)
py_library(
name = "custom_op_grad",
srcs = ["custom_op_grad.py"],
srcs_version = "PY2",
deps = [
":custom_op_grad",
"//tensorflow:tensorflow_py",
],
)
之后,我重建 TensorFlow:
pip uninstall tensorflow
bazel clean
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
cp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl
当我在这一切之后尝试使用我的 Op 时,通过调用tf.user_ops.custom_op
它告诉我该模块没有它。
也许我还需要做一些额外的步骤?BUILD
或者我对文件做错了什么?