0

使用 PySpark,我正在尝试向现有数据框中添加一个新列,其中新列中的条目表示最接近现有列的 bin 值。在我将在下面显示的示例中,numpy 数组bucket_array表示箱(桶)。

PySpark 代码的相关部分,我将很快提到其错误,如下所示:

#Function for finding nearest bucket
def find_nearest(value, bin_array):
    bin_array = np.array(list(bin_array))
    value = float(value)
    idx = np.argmin(np.abs(bin_array - value))
    return float(bin_array[idx])
def metric_analyze(entity_peer_labeled_df, metric, delta_weeks, normalize):
    # delta_weeks = 1
    # normalize = True
    # metric : string which denotes column name
    # entity_peer_labeled_df : some Pyspark dataframe which has a column titled "pct_difference"

    bucket_array = np.arange(-1000, 1000, 5)

    udf_nearest_bin = F.udf(find_nearest, T.FloatType())    
    bucket_df = ( entity_pct_metric_df.withColumn("bucket_array", 
                              F.array(*[F.lit(i) for i in bucket_array])) ).withColumn( "pct_diff_{}_bucket".format(metric) , 
                                                                                       udf_nearest_bin("pct_difference", "bucket_array") )

    bucket_df.show() 

当我在 Jupyter notebook 中运行上述代码时,它运行良好,并且我能够看到数据框bucket_df

同样,当我将上述代码保存为单独的 python 函数时,将其导入我的 Jupyter 笔记本,然后最后执行它,我得到错误。我注意到错误发生在行bucket_df.show()。该错误的一部分如下所示:

/mnt1/jupyter/notebooks/username/custom_function.py in metric_analyze(entity_peer_labeled_df, metric, delta_weeks, normalize)
    100                                                                                        udf_nearest_bin("pct_difference", "bucket_array") )
    101 
--> 102     bucket_df.show()

/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    376         """
    377         if isinstance(truncate, bool) and truncate:
--> 378             print(self._jdf.showString(n, 20, vertical))
    379         else:
    380             print(self._jdf.showString(n, int(truncate), vertical))

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

完整的错误可以在这里找到。

当我用 替换该行时bucket.show()print( bucket.count() )我看不到任何错误并且它运行良好(即使我将上述代码用作单独的函数)。

entity_pct_metric_df下面给出一个例子:

+--------------------+----------+-------------------+-------------------+------------------------------+--------------+
|           entity_id|. category|         sampled_ts|         some_score|         some_score_prev_value|pct_difference|
+--------------------+----------+-------------------+-------------------+------------------------------+--------------+
|abccccccccccccccc...|         A|2017-12-03 00:00:00|                192|                           824|        -632.0|
|defffffffffffffff...|         A|2017-12-10 00:00:00|                515|                           192|         323.0|
|ghiiiiiiiiiiiiiii...|         A|2017-12-17 00:00:00|                494|                           515|         -21.0|
+--------------------+----------+-------------------+-------------------+------------------------------+--------------+

如何解决上述错误?

4

0 回答 0