我正在尝试使用 Dask-ML 训练模型。我的最终目标是对大于内存的数据集进行预测,因此我正在利用 Dask 的 ParallelPostFit 包装器在相对较小的数据集(4 Gb)上训练模型,并期望稍后我将对更大的数据帧进行预测。我正在连接到一个有 50 个工作人员的 Yarn 集群,将我的数据从 parquet 加载到一个 dask 数据框中,创建一个管道并进行培训。训练有效,但是当我尝试对保留的测试集进行评估时,我遇到了问题。当我使用 sklearn 的 LogisticRegression 作为分类器时,训练和预测运行成功。但是,当我使用具有 100 个估计器的 sklearn 随机森林时,训练步骤成功运行,但在预测时出现以下错误。我在预测计算步骤中注意到,在断开连接错误之前,我的本地机器内存使用量开始爆炸。当我将 RF 估计器的数量减少到 10 个时,预测步骤会成功运行。谁能帮我理解发生了什么?
我的代码(精简)
cluster = YarnCluster(environment=path_to_packed_conda_env,
n_workers=50,
worker_vcores=10,
worker_env=worker_env,
worker_restarts=10,
scheduler_memory='10GiB',
scheduler_vcores=5,
worker_memory='20GiB')
cluster.adapt(minimum=50, maximum=100)
# connect client
client = Client(cluster)
# instantiate classifier
clf_rfc = RandomForestClassifier(n_estimators=100,
n_jobs=5,
criterion='gini',
max_features='auto',
min_samples_split = 50,
class_weight='balanced',
verbose=1,
random_state=RANDOM_STATE)
# train/test split
X_train, X_test, y_train, y_test = train_test_split_dd(X, y, train_size = 0.7, random_state=RANDOM_STATE)
# build pipeline
pipe = ParallelPostFit(Pipeline(steps=[
('preprocessor', preprocessor),
('classifier_', clone(clf_rfc))
]))
X_train = X_train.persist()
y_train = y_train.persist()
# Train
pipe.fit(X_train, y_train)
# Evaluate
X_test = X_test.persist()
y_test = y_test.persist()
print('computing ypred')
y_preds = pipe.predict(X_test).compute()
print('computing yprob')
y_probs = pipe.predict_proba(X_test).compute()
输出:
computing ypred
distributed.batched - INFO - Batched Comm Closed: in <closed TCP>: ConnectionResetError: [Errno 104] Connection reset by peer
---------------------------------------------------------------------------
CancelledError Traceback (most recent call last)
<ipython-input-108-23f303f7584c> in <module>
5
6 print('computing ypred')
----> 7 y_preds = [pipe.predict(X_test).compute() for pipe in pipes]
8
9 print('computing yprob')
<ipython-input-108-23f303f7584c> in <listcomp>(.0)
5
6 print('computing ypred')
----> 7 y_preds = [pipe.predict(X_test).compute() for pipe in pipes]
8
9 print('computing yprob')
~/.conda/envs/boa/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
164 dask.base.compute
165 """
--> 166 (result,) = compute(self, traverse=False, **kwargs)
167 return result
168
~/.conda/envs/boa/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
435 keys = [x.__dask_keys__() for x in collections]
436 postcomputes = [x.__dask_postcompute__() for x in collections]
--> 437 results = schedule(dsk, keys, **kwargs)
438 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
439
~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2593 should_rejoin = False
2594 try:
-> 2595 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
2596 finally:
2597 for f in futures.values():
~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
1891 direct=direct,
1892 local_worker=local_worker,
-> 1893 asynchronous=asynchronous,
1894 )
1895
~/.conda/envs/boa/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
778 else:
779 return sync(
--> 780 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
781 )
782
~/.conda/envs/boa/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
346 if error[0]:
347 typ, exc, tb = error[0]
--> 348 raise exc.with_traceback(tb)
349 else:
350 return result[0]
~/.conda/envs/boa/lib/python3.7/site-packages/distributed/utils.py in f()
330 if callback_timeout is not None:
331 future = asyncio.wait_for(future, callback_timeout)
--> 332 result[0] = yield future
333 except Exception as exc:
334 error[0] = sys.exc_info()
~/.conda/envs/boa/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
CancelledError:
distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError