我很困惑标签应该如何在集线器中工作以及在导出时如何使用它们。如何在我的图表的火车部分进行训练并导出服务图表?
我有以下代码:
def user_module_fn(foo, bar):
x = tf.sparse_placeholder(tf.float32, shape[-1, 32], name='name')
y = something(x)
hub.add_signature(name='my_name', input={"x": x}, output={"default", y})
module_spec = hub.create_module_spec(module_spec_fn, tags_and_args=[
(set(), {"foo": foo, "bar": bar}),
({"train"}, {"foo": foo, "bar": baz})
])
m = hub.Module(module_spec, name="my_name", trainable=True, tags={"train"})
hub.register_for_export(m, "my_name")
我的问题如下:由于我将模块实例化为m
with tags={'train'}
,我认为我正在使用正确的模块进行培训。这是否意味着我只导出带有 标记的那个train
?我如何使用train
训练和set()
(默认)服务?