0

我有一个 LSH 表构建器实用程序类,如下所示(引用自此处):

class BuildLSHTable:
    def __init__(self, hash_size=8, dim=2048, num_tables=10, lsh_file="lsh_table.pkl"):
        self.hash_size = hash_size
        self.dim = dim
        self.num_tables = num_tables
        self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
        self.embedding_model = embedding_model
        self.lsh_file = lsh_file

    def train(self, training_files):
        for id, training_file in enumerate(training_files):
            image, label = training_file
            if len(image.shape) < 4:
                image = image[None, ...]
            features = self.embedding_model.predict(image)
            self.lsh.add(id, features, label)
        
        with open(self.lsh_file, "wb") as handle:
            pickle.dump(self.lsh, 
                        handle, protocol=pickle.HIGHEST_PROTOCOL)    

然后我执行以下命令来构建我的 LSH 表:

training_files = zip(images, labels)
lsh_builder = BuildLSHTable()
lsh_builder.train(training_files)

现在,当我尝试通过 Apache Beam(下面的代码)执行此操作时,它会抛出:

TypeError: can't pickle tensorflow.python._pywrap_tf_session.TF_Operation objects

用于 Beam 的代码:

def generate_lsh_table(args):
    options = beam.options.pipeline_options.PipelineOptions(**args)
    args = namedtuple("options", args.keys())(*args.values())

    with beam.Pipeline(args.runner, options=options) as pipeline:
        (
            pipeline
            | 'Build LSH Table' >> beam.Map(
                args.lsh_builder.train, args.training_files)
        )

这就是我调用光束运行器的方式:

args = {
    "runner": "DirectRunner",
    "lsh_builder": lsh_builder,
    "training_files": training_files
}

generate_lsh_table(args)
4

1 回答 1

0

Apache Beam 管道应在执行之前转换为标准(例如,proto)格式。作为其中的一部分,某些管道对象(例如DoFns)会被序列化(选择)。如果您DoFn的 s 具有无法序列化的实例变量,则此过程无法继续。

解决此问题的一种方法是在执行期间加载/定义此类实例对象或模块,而不是在管道提交期间创建和存储此类对象。这可能需要调整您的管道。

于 2021-07-20T06:33:14.700 回答