由于它的分布式特性,Spark 不能允许如果允许在之前的调用中填充然后使用它,否则调用创建的值。有两种可能的选择。
- 由于您正在应用内部连接并且
players
df 具有所有不同玩家的列表,因此您可以current_team
在应用连接之前将该列添加到此 df。如果players
df 在加入之前被缓存,那么UDF
每个玩家可能只调用一次。请参阅此处的讨论,了解为什么可以为每条记录多次调用 UDF。
- 你可以记
getCurrentTeam
工作示例 - 预填充current_team
from pyspark.sql import functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
events_data = [(1, 1, 1, 10), (1, 2, 1, 20, ), (1, 3, 1, 30, ), (2, 3, 1, 30, ), (2, 1, 1, 10), (2, 2, 1, 20, ), ]
players_data = [(1, "Player1", "Nat", ), (2, "Player2", "Nat", ), (3, "Player3", "Nat", ), ]
events = spark.createDataFrame(events_data, ("event_id", "player_id", "match_id", "impact_score", ), ).repartition(3)
players = spark.createDataFrame(players_data, ("player_id", "player_name", "nationality", ), ).repartition(3)
@udf(StringType())
def getCurrentTeam(player_id):
return f"player_{player_id}_team"
players_with_current_team = players.withColumn("current_team", getCurrentTeam(F.col("player_id"))).cache()
events.join(players_with_current_team, ["player_id"]).show()
输出
+---------+--------+--------+------------+-----------+-----------+-------------+
|player_id|event_id|match_id|impact_score|player_name|nationality| current_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
| 2| 2| 1| 20| Player2| Nat|player_2_team|
| 2| 1| 1| 20| Player2| Nat|player_2_team|
| 3| 2| 1| 30| Player3| Nat|player_3_team|
| 3| 1| 1| 30| Player3| Nat|player_3_team|
| 1| 2| 1| 10| Player1| Nat|player_1_team|
| 1| 1| 1| 10| Player1| Nat|player_1_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
工作示例 - 记忆
我使用了一个 python dict 来模拟缓存并使用一个accumulator
来计算模拟网络调用的数量。
from pyspark.sql import functions as F
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import time
events_data = [(1, 1, 1, 10), (1, 2, 1, 20, ), (1, 3, 1, 30, ), (2, 3, 1, 30, ), (2, 1, 1, 10), (2, 2, 1, 20, ), ]
players_data = [(1, "Player1", "Nat", ), (2, "Player2", "Nat", ), (3, "Player3", "Nat", ), ]
events = spark.createDataFrame(events_data, ("event_id", "player_id", "match_id", "impact_score", ), ).repartition(3)
players = spark.createDataFrame(players_data, ("player_id", "player_name", "nationality", ), ).repartition(3)
players_events_joined = events.join(players, ["player_id"])
memoized_call_counter = spark.sparkContext.accumulator(0)
def memoize_call():
cache = {}
def getCurrentTeam(player_id):
global memoized_call_counter
cached_value = cache.get(player_id, None)
if cached_value is not None:
return cached_value
# sleep to mimic network call
time.sleep(1)
# Increment counter everytime cached value can't be lookedup
memoized_call_counter.add(1)
cache[player_id] = f"player_{player_id}_team"
return cache[player_id]
return getCurrentTeam
getCurrentTeam_udf = udf(memoize_call(), StringType())
players_events_joined.withColumn("current_team", getCurrentTeam_udf(F.col("player_id"))).show()
输出
+---------+--------+--------+------------+-----------+-----------+-------------+
|player_id|event_id|match_id|impact_score|player_name|nationality| current_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
| 2| 2| 1| 20| Player2| Nat|player_2_team|
| 2| 1| 1| 20| Player2| Nat|player_2_team|
| 3| 2| 1| 30| Player3| Nat|player_3_team|
| 3| 1| 1| 30| Player3| Nat|player_3_team|
| 1| 2| 1| 10| Player1| Nat|player_1_team|
| 1| 1| 1| 10| Player1| Nat|player_1_team|
+---------+--------+--------+------------+-----------+-----------+-------------+
>>> memoized_call_counter.value
3
由于总共有 3 个独特的玩家,因此之后的逻辑time.sleep(1)
只被调用了三次。调用次数取决于工作人员的数量,因为缓存不是在工作人员之间共享的。当我在本地模式下运行示例时(有 1 个工作人员),我们看到调用的数量等于工作人员的数量。