0

我正在尝试使用图形框架和 Pregel API 构建一棵树。

我有以下联系:

  src  |  dst  
   0   |   1
   0   |   7
   1   |   2
   1   |   5
   1   |   8
   2   |   4
   7   |   4
   7   |   5
   8   |   5

并需要以下输出作为数据框:

sequence
[0,1,2,4]
[0,1,8,5]
[0,1,5]
[0,7,4]
[0,7,5]

实际代码如下所示:

#DATA CREATION
raw_data = [
  ("0","1"),
  ("0","7"),
  ("1","2"),
  ("1","5"),
  ("1","8"),
  ("2","4"),
  ("7","4"),
  ("7","5"),
  ("8","5")]

schema = ["src","dst"]
data = spark.createDataFrame(data=raw_data, schema = schema)

from graphframes import GraphFrame

vertices=(data.select("src").union(data.select("dst")).distinct().withColumnRenamed('src', 'id'))
vertices=vertices.union(spark.createDataFrame(["10"], "string").toDF("id"))
edges=data
graph = GraphFrame(vertices, edges)

import pyspark.sql.functions as F

inDegrees=graph.inDegrees
outDegrees=graph.outDegrees

init_vertices=(vertices
  .join(outDegrees,on="id",how="left")
  .join(inDegrees,on="id",how="left")
.withColumn("nodeType",F.when(F.col("inDegree").isNull(),"root").otherwise(F.when(F.col("outDegree").isNull(),"leaf").otherwise("child"))))

gx = GraphFrame(init_vertices, edges)
# PREGEL API
vertColSchema = T.ArrayType( T.ArrayType( T.StringType(), True), True)

def sendMsgToDst(src, dst, dst_id):
  if src:
    src_tuple = [tuple(lst) for lst in src]
    dst_tuple = [tuple(lst) for lst in dst]
    if not set(src_tuple).issubset(set(dst_tuple)):
      return [i + [dst_id] for i in src] 
  return None

def vertexProgram(vd, msg):
  if msg:
    return msg
  return vd

sendMsgToDstUdf = F.udf(
    sendMsgToDst, vertColSchema
)

vertexProgramUdf = F.udf(
    vertexProgram, vertColSchema
)

start = ["0"]
tree=(gx.pregel
    .withVertexColumn("sequence",F.when(F.col("id").isin(start), F.array(F.array(F.col("id")))).otherwise(F.lit(F.array())),
      updateAfterAggMsgsExpr=vertexProgramUdf(
        F.col("sequence"), Pregel.msg()
      )
    )
    .sendMsgToDst(
      sendMsgToDstUdf(
        Pregel.src("sequence"), Pregel.dst("sequence"), Pregel.dst("id")
      )
    )
    .aggMsgs(F.collect_list(Pregel.msg()))
    .setMaxIter(10)
    .setCheckpointInterval(2)
    .run()
)
# RESULT
df = cycles.withColumn("sequence", F.explode("sequence"))
result = df.filter(F.col("nodeType")=="leaf").select("sequence")

sequence
["[[0, 1], 5]"]
["[[0, 7], 5]"]
["[[[0, 1], 8], 5]"]
["[[[0, 1], 2], 4]"]
["[[0, 7], 4]"]

我不知道为什么,但 dst_id 附加没有按预期工作。我尝试了许多不同的方法,但在使用列表列表时仍然出现相同的错误。

4

0 回答 0