我确实有一个对于大型数据集来说很慢的 UDF,我尝试通过利用pandas_udfs来提高执行时间和可扩展性,所有搜索和官方文档都更加关注我已经使用的标量和映射方法,但我确实未能扩展到系列或熊猫数据框方法,你能指出我正确的方向吗?

我确实想并行执行,并且当前的 UDF 方法非常慢,因为按顺序执行记录,而我确实拥有的其他解决方案在 koalas 中,但我宁愿将其作为 pyspark 管道中自定义转换器的一部分包含在内:


from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql import DataFrame
from pyspark.sql.types import ArrayType, StringType
import pyspark.sql.functions as F
from pyspark.sql.functions import PandasUDFType
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import udf
from pyspark.sql.types import *

def ngrams_udf(string, n=3):
    """Takes an input string, cleans it and converts to ngrams. 
    This script is focussed on cleaning UK company names but can be made generic by removing lines below"""
    string = str(string)
    string = string.lower() # lower case
    string = string.encode("ascii", errors="ignore").decode() #remove non ascii chars
    chars_to_remove = [")","(",".","|","[","]","{","}","'","-"]
    rx = '[' + re.escape(''.join(chars_to_remove)) + ']' #remove punc, brackets etc...
    string = re.sub(rx, '', string)
    string = string.replace('&', 'and')
    string = string.replace('limited', 'ltd')
    string = string.replace('public limited company', 'plc')
    string = string.replace('united states of america', 'usa')
    string = string.replace('community interest company', 'cic')
    string = string.title() # normalise case - capital at start of each word
    string = re.sub(' +',' ',string).strip() # get rid of multiple spaces and replace with a single
    string = ' '+ string +' ' # pad names for ngrams...
    ngrams = zip(*[string[i:] for i in range(n)])
    return [''.join(ngram) for ngram in ngrams]
    # # register UDF
    dummy_ngram_udf = udf(ngrams_udf, ArrayType(StringType()))
    # # call udf on string column and returns array type. 
    df.withColumn(out_col, dummy_ngram_udf(col(in_col)))


from pandas import Series
import pandas as pd
from pyspark.sql.functions import col, pandas_udf, struct

def ngrams_udf(string: pd.Series , n=3) -> pd.Series:
    """Takes an input string, cleans it and converts to ngrams. 
    This script is focussed on cleaning UK company names but can be made generic by removing lines below"""
    string = str(string)
    string = string.lower() # lower case
    string = string.encode("ascii", errors="ignore").decode() #remove non ascii chars
    chars_to_remove = [")","(",".","|","[","]","{","}","'","-"]
    rx = '[' + re.escape(''.join(chars_to_remove)) + ']' #remove punc, brackets etc...
    string = re.sub(rx, '', string)
    string = string.replace('&', 'and')
    string = string.replace('limited', 'ltd')
    string = string.replace('public limited company', 'plc')
    string = string.replace('united states of america', 'usa')
    string = string.replace('community interest company', 'cic')
    string = string.title() # normalise case - capital at start of each word
    string = re.sub(' +',' ',string).strip() # get rid of multiple spaces and replace with a single
    string = ' '+ string +' ' # pad names for ngrams...
    ngrams = zip(*[string[i:] for i in range(n)])
    return [''.join(ngram) for ngram in ngrams]

1 回答 1


通常,当您实际上可以仅使用 Spark 内置函数执行相同操作时,请尽量避免使用 python UDF。虽然 pandas_udf 带来了更好的性能,但使用原生 spark 函数应该“总是”性能更快。

现在对于您的问题,pandas_udf 需要 apd.Series因此您需要调整代码,因为您的变量string不再是单个字符串而是一个系列。

from pyspark.sql import functions as F

def ngrams_udf(string: pd.Series, n: pd.Series) -> pd.Series:
   """Takes an input string, cleans it and converts to ngrams.
   This script is focussed on cleaning UK company names but can be made generic by removing lines below"""

   n = n.iloc[0]
   string.str.lower()  # lower case
   string.str.encode("ascii", errors="ignore").str.decode("utf8")  # remove non ascii chars

   chars_to_remove = [")", "(", ".", "|", "[", "]", "{", "}", "'", "-"]
   rx = '[' + re.escape(''.join(chars_to_remove)) + ']'  # remove punc, brackets etc...
   string = string.str.replace(rx, '', regex=True)

   string = string.str.replace('&', 'and')
   string = string.str.replace('limited', 'ltd')
   string = string.str.replace('public limited company', 'plc')
   string = string.str.replace('united states of america', 'usa')
   string = string.str.replace('community interest company', 'cic')
   string = string.str.title()  # normalise case - capital at start of each word

   # get rid of multiple spaces and replace with a single
   string = string.str.replace(r"\s+", '', regex=True).str.strip()

   string = string.str.pad(width=1, side='both')  # pad names for ngrams...

   string = string.apply(lambda x: zip(*[x[i:] for i in range(n)]))
   string = string.apply(lambda x: [''.join(ngram) for ngram in list(x)])

   return string


df.withColumn("ngrams", ngrams_udf(F.col("company"), F.lit(3)))
于 2022-01-26T15:31:11.803 回答