我确实有一个对于大型数据集来说很慢的 UDF,我尝试通过利用pandas_udfs来提高执行时间和可扩展性,所有搜索和官方文档都更加关注我已经使用的标量和映射方法,但我确实未能扩展到系列或熊猫数据框方法,你能指出我正确的方向吗?
我确实想并行执行,并且当前的 UDF 方法非常慢,因为按顺序执行记录,而我确实拥有的其他解决方案在 koalas 中,但我宁愿将其作为 pyspark 管道中自定义转换器的一部分包含在内:
下面列出的UDF方法(工作一种):
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql import DataFrame
from pyspark.sql.types import ArrayType, StringType
import pyspark.sql.functions as F
from pyspark.sql.functions import PandasUDFType
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import udf
from pyspark.sql.types import *
def ngrams_udf(string, n=3):
"""Takes an input string, cleans it and converts to ngrams.
This script is focussed on cleaning UK company names but can be made generic by removing lines below"""
string = str(string)
string = string.lower() # lower case
string = string.encode("ascii", errors="ignore").decode() #remove non ascii chars
chars_to_remove = [")","(",".","|","[","]","{","}","'","-"]
rx = '[' + re.escape(''.join(chars_to_remove)) + ']' #remove punc, brackets etc...
string = re.sub(rx, '', string)
string = string.replace('&', 'and')
string = string.replace('limited', 'ltd')
string = string.replace('public limited company', 'plc')
string = string.replace('united states of america', 'usa')
string = string.replace('community interest company', 'cic')
string = string.title() # normalise case - capital at start of each word
string = re.sub(' +',' ',string).strip() # get rid of multiple spaces and replace with a single
string = ' '+ string +' ' # pad names for ngrams...
ngrams = zip(*[string[i:] for i in range(n)])
return [''.join(ngram) for ngram in ngrams]
# # register UDF
dummy_ngram_udf = udf(ngrams_udf, ArrayType(StringType()))
# # call udf on string column and returns array type.
df.withColumn(out_col, dummy_ngram_udf(col(in_col)))
我尝试了以下但没有映射到系列输入和输出......所以输入向量和输出向量有不同的大小......:
from pandas import Series
import pandas as pd
from pyspark.sql.functions import col, pandas_udf, struct
@pandas_udf("string")
def ngrams_udf(string: pd.Series , n=3) -> pd.Series:
"""Takes an input string, cleans it and converts to ngrams.
This script is focussed on cleaning UK company names but can be made generic by removing lines below"""
string = str(string)
string = string.lower() # lower case
string = string.encode("ascii", errors="ignore").decode() #remove non ascii chars
chars_to_remove = [")","(",".","|","[","]","{","}","'","-"]
rx = '[' + re.escape(''.join(chars_to_remove)) + ']' #remove punc, brackets etc...
string = re.sub(rx, '', string)
string = string.replace('&', 'and')
string = string.replace('limited', 'ltd')
string = string.replace('public limited company', 'plc')
string = string.replace('united states of america', 'usa')
string = string.replace('community interest company', 'cic')
string = string.title() # normalise case - capital at start of each word
string = re.sub(' +',' ',string).strip() # get rid of multiple spaces and replace with a single
string = ' '+ string +' ' # pad names for ngrams...
ngrams = zip(*[string[i:] for i in range(n)])
return [''.join(ngram) for ngram in ngrams]