0

我一直在努力解决这个问题 - 任何帮助将不胜感激。我不确定从这里去哪里。

我正在使用 Dask 使用scikit-image 的 MCP 类并行化最低成本路径计算

该类是用 Cython 编写的,由于Dask 期望中间结果是可序列化的,因此我实现了一个 Wrapper,它在反序列化期间“重新创建”MCP 类。

当我在没有 Dask 或使用 Dask 的单线程调度程序的情况下运行代码时,它需要更长的时间,但结果很好。

但是,当我切换到使用进程或线程运行(仍然使用 Dask Distributed)时,我没有收到任何错误,但我的结果中出现了一堆np.inf'。

此外,结果本身与我在单个线程上运行的结果不一致。

在此处添加相关代码片段:

# Create a client locally
if cluster_type == 'local':
        try:
            client = Client('127.0.0.1:8786')
        except:   
            cluster = LocalCluster(n_workers = 8, 
                               processes=True, 
                               threads_per_worker=8, 
                               scheduler_port=8786)

            client = Client(cluster)
## Create wrapper for MCP
# Creates a wrapper for Cython MCP Class
class Wrapper(object):
    def __init__(self, get_mcp):
        self.get_mcp = get_mcp 
        self.mcp = get_mcp()

    def __reduce__(self):
        #https://stackoverflow.com/questions/19855156/whats-the-exact-usage-of-reduce-in-pickler
        # When unpickled, the filter will be reloaded
        return(self.__class__, (self.get_mcp, ))


def load_mcp():
    print("...loading mcp...")
    inR = rasterio.open(friction_raster_path)
    inD = inR.read()[0,:,:] 
    inD = np.array(inD, dtype=np.float128) * 30 # important to specify pixel size in meters here in oder to get correct measurements
    inD = np.array(inD, dtype=np.float32)
    inD = np.nan_to_num(inD)
    mcp = graph.MCP_Geometric(inD)
    return mcp


# Init the wrapper for MCP
wrapper = Wrapper(load_mcp)

# Only reload inR here to do the crs check
inR = rasterio.open(friction_raster_path)
# Get costs from origin to dests
def get_costs_for_origin(wrapper, origin_id:str, origin_coords:tuple, dests:pd.DataFrame):
    # TODO - dests should be a list of tuples only
    res=[]
    origin_coords = [origin_coords]
    ends = dests.MCP_DESTS_COORDS.to_list()
    costs, traceback = wrapper.mcp.find_costs(starts=origin_coords, ends=ends)#ends=destinations.MCP_DESTS_COORDS.to_list())
    for idx, dest in enumerate(dests.to_dict(orient='records')):
        dest_coords = dest['MCP_DESTS_COORDS']
        tt = costs[dest_coords[0], dest_coords[1]]
        if tt > 9999999999:
            print(dest['id'])
            print(tt)
            raise ValueError("INF")
        res.append(
            {"d_id": dest['id'], 
             "d_tt": tt}
        )
            
    return {"o_id": origin_id, "o_tfan": res}
# Run on distributed scheduler using processes
def run_async(wrapper:Wrapper, origins_d:pd.DataFrame, dests_d:pd.DataFrame):
    # Broadcast the wrapper to all nodes
    wrapper = client.scatter(wrapper, broadcast=True)
    wait(wrapper)

    # Broadcast destinations to all nodes.
    dests_d = client.scatter(dests_d, broadcast=True)
    wait(dests_d)

    #https://docs.dask.org/en/latest/futures.html
    tasks = []
    for idx, origin in enumerate(origins_d):
        print(f"Origin {idx} of {len(origins_d)}")
        task = dask.delayed(get_costs_for_origin)(
            wrapper=wrapper,
            origin_id = origin['id'],
            origin_coords = origin['MCP_DESTS_COORDS'],
            dests=dests_d)#client.submit(get_costs_for_origin, wrapper, ogin, dests)
        tasks.append(task)
    #all_res = client.gather(futures)
    all_res_dsk = dask.compute(*tasks)
    all_res_dsk = list(all_res_dsk)
    return all_res_dsk

我假设它与 MCP 类有关,但无法弄清楚可能导致 INF 发生的原因。

提前谢谢大家!

4

0 回答 0