我有一个 LSH 表构建器实用程序类,如下所示(引用自此处):
class BuildLSHTable:
def __init__(self, hash_size=8, dim=2048, num_tables=10, lsh_file="lsh_table.pkl"):
self.hash_size = hash_size
self.dim = dim
self.num_tables = num_tables
self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
self.embedding_model = embedding_model
self.lsh_file = lsh_file
def train(self, training_files):
for id, training_file in enumerate(training_files):
image, label = training_file
if len(image.shape) < 4:
image = image[None, ...]
features = self.embedding_model.predict(image)
self.lsh.add(id, features, label)
with open(self.lsh_file, "wb") as handle:
pickle.dump(self.lsh,
handle, protocol=pickle.HIGHEST_PROTOCOL)
然后我执行以下命令来构建我的 LSH 表:
training_files = zip(images, labels)
lsh_builder = BuildLSHTable()
lsh_builder.train(training_files)
现在,当我尝试通过 Apache Beam(下面的代码)执行此操作时,它会抛出:
TypeError: can't pickle tensorflow.python._pywrap_tf_session.TF_Operation objects
用于 Beam 的代码:
def generate_lsh_table(args):
options = beam.options.pipeline_options.PipelineOptions(**args)
args = namedtuple("options", args.keys())(*args.values())
with beam.Pipeline(args.runner, options=options) as pipeline:
(
pipeline
| 'Build LSH Table' >> beam.Map(
args.lsh_builder.train, args.training_files)
)
这就是我调用光束运行器的方式:
args = {
"runner": "DirectRunner",
"lsh_builder": lsh_builder,
"training_files": training_files
}
generate_lsh_table(args)