我scipy.optimize.curve_fit
在 UDF 中有一个可能引发异常的调用。有没有办法处理来自 UDF 外部的异常?
我试图从 UDF 内部处理异常,但collect()
有时没有捕获到异常。
我试过:
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from scipy.optimize import curve_fit
def fsigmoid(x, x0, l, k):
return l / (1.0 + np.exp(-k*(x-x0)))
def curve_fitter(day_0, day_1, day_2, day_3, day_4, day_5, day_6):
try:
# Find sigmoid parameters
x = list(range(7))
y = [day_0, day_1, day_2, day_3, day_4, day_5, day_6]
param_bounds = [[0., 0, -10.], [6., 10., 10.]]
(x_0, l, k), _ = curve_fit(fsigmoid, x, y, method='dogbox', bounds=(param_bounds[0], param_bounds[1]), maxfev=100)
except IOError as e:
(x_0, l, k) = (-1, -1, -1)
return (float(x_0), float(l), float(k))
# Define UDF
udf_return_schema = StructType([
StructField("x_0", FloatType(), True),
StructField("l", FloatType(), True),
StructField("k", FloatType(), True)
])
udf_curve_fitter = udf(curve_fitter, udf_return_schema)
# Define df and call UDF
data = [(1.6710683580253483, 3.7414496594802005, 5.186749035232343, 8.552623021374485, 0.4000450281109358, 1.7832269020250069, 8.578459510083448)]
df = spark.createDataFrame(data, ['day_' + str(i) for i in range(7)])
df.select([udf_curve_fitter(df['day_0'], df['day_1'], df['day_2'], df['day_3'], df['day_4'], df['day_5'], df['day_6'])]).collect()
我希望udf_curve_fitter
返回(-1, -1, -1)
,而不是我得到:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-1-06f383278c7e> in <module>()
30 data = [(1.6710683580253483, 3.7414496594802005, 5.186749035232343, 8.552623021374485, 0.4000450281109358, 1.7832269020250069, 8.578459510083448)]
31 df = spark.createDataFrame(data, ['day_' + str(i) for i in range(7)])
---> 32 df.select([udf_curve_fitter(df['day_0'], df['day_1'], df['day_2'], df['day_3'], df['day_4'], df['day_5'], df['day_6'])]).collect()
[...]
Py4JJavaError: An error occurred while calling o130.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1151 in stage 0.0 failed 4 times, most recent failure: Lost task 1151.3 in stage 0.0 (TID 1154, ip-10-0-32-85.ec2.internal, executor 8): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/worker.py", line 177, in main
process()
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/worker.py", line 172, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/serializers.py", line 220, in dump_stream
self.serializer.dump_stream(self._batched(iterator), stream)
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/serializers.py", line 138, in dump_stream
for obj in iterator:
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/serializers.py", line 209, in _batched
for item in iterator:
File "<string>", line 1, in <lambda>
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/pyspark-2.2.1-py3.4.egg/pyspark/worker.py", line 69, in <lambda>
return lambda *a: toInternal(f(*a))
File "<ipython-input-1-06f383278c7e>", line 16, in curve_fitter
File "/mnt/eider_environments/EiderPython/local/apollo/env/EiderPython/lib/python3.4/site-packages/scipy/optimize/minpack.py", line 750, in curve_fit
raise RuntimeError("Optimal parameters not found: " + res.message)
RuntimeError: Optimal parameters not found: The maximum number of function evaluations is exceeded.