1

我正在尝试使用 Dask-ML 训练模型。我的最终目标是对大于内存的数据集进行预测,因此我正在利用 Dask 的 ParallelPostFit 包装器在相对较小的数据集(4 Gb)上训练模型,并期望稍后我将对更大的数据帧进行预测。我正在连接到一个有 50 个工作人员的 Yarn 集群,将我的数据从 parquet 加载到一个 dask 数据框中,创建一个管道并进行培训。训练有效,但是当我尝试对保留的测试集进行评估时,我遇到了问题。当我使用 sklearn 的 LogisticRegression 作为分类器时,训练和预测运行成功。但是,当我使用具有 100 个估计器的 sklearn 随机森林时,训练步骤成功运行,但在预测时出现以下错误。我在预测计算步骤中注意到,在断开连接错误之前,我的本地机器内存使用量开始爆炸。当我将 RF 估计器的数量减少到 10 个时,预测步骤会成功运行。谁能帮我理解发生了什么?

我的代码(精简)

cluster = YarnCluster(environment=path_to_packed_conda_env, 
                      n_workers=50, 
                      worker_vcores=10,
                      worker_env=worker_env,
                      worker_restarts=10,
                      scheduler_memory='10GiB',
                      scheduler_vcores=5,
                      worker_memory='20GiB')
cluster.adapt(minimum=50, maximum=100)

# connect client
client = Client(cluster)

# instantiate classifier
clf_rfc = RandomForestClassifier(n_estimators=100, 
                                 n_jobs=5, 
                                 criterion='gini',
                                 max_features='auto',
                                 min_samples_split = 50,
                                 class_weight='balanced', 
                                 verbose=1,
                                 random_state=RANDOM_STATE)

# train/test split
X_train, X_test, y_train, y_test = train_test_split_dd(X, y, train_size = 0.7, random_state=RANDOM_STATE)

# build pipeline
pipe = ParallelPostFit(Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier_', clone(clf_rfc))
]))

X_train = X_train.persist()
y_train = y_train.persist()

# Train
pipe.fit(X_train, y_train)

# Evaluate
X_test = X_test.persist()
y_test = y_test.persist()

print('computing ypred')
y_preds = pipe.predict(X_test).compute()

print('computing yprob')
y_probs = pipe.predict_proba(X_test).compute()

输出:

computing ypred
distributed.batched - INFO - Batched Comm Closed: in <closed TCP>: ConnectionResetError: [Errno 104] Connection reset by peer
---------------------------------------------------------------------------
CancelledError                            Traceback (most recent call last)
<ipython-input-108-23f303f7584c> in <module>
      5 
      6 print('computing ypred')
----> 7 y_preds = [pipe.predict(X_test).compute() for pipe in pipes]
      8 
      9 print('computing yprob')

<ipython-input-108-23f303f7584c> in <listcomp>(.0)
      5 
      6 print('computing ypred')
----> 7 y_preds = [pipe.predict(X_test).compute() for pipe in pipes]
      8 
      9 print('computing yprob')

~/.conda/envs/boa/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    164         dask.base.compute
    165         """
--> 166         (result,) = compute(self, traverse=False, **kwargs)
    167         return result
    168 

~/.conda/envs/boa/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    435     keys = [x.__dask_keys__() for x in collections]
    436     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 437     results = schedule(dsk, keys, **kwargs)
    438     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    439 

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2593                     should_rejoin = False
   2594             try:
-> 2595                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2596             finally:
   2597                 for f in futures.values():

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1891                 direct=direct,
   1892                 local_worker=local_worker,
-> 1893                 asynchronous=asynchronous,
   1894             )
   1895 

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    778         else:
    779             return sync(
--> 780                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    781             )
    782 

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    346     if error[0]:
    347         typ, exc, tb = error[0]
--> 348         raise exc.with_traceback(tb)
    349     else:
    350         return result[0]

~/.conda/envs/boa/lib/python3.7/site-packages/distributed/utils.py in f()
    330             if callback_timeout is not None:
    331                 future = asyncio.wait_for(future, callback_timeout)
--> 332             result[0] = yield future
    333         except Exception as exc:
    334             error[0] = sys.exc_info()

~/.conda/envs/boa/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

CancelledError: 
distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError
4

0 回答 0