2

寻找专业知识来指导我解决以下问题。

背景:

我遇到的问题

  • ALS.Train 脚本运行平稳,并且在 GCP 上可以很好地扩展(轻松超过 100 万客户)。

  • 但是,应用预测:即使用函数“PredictAll”或“recommendProductsForUsers”,根本无法扩展。我的脚本对于一个小数据集(<100 个客户,<100 个产品)运行顺利。但是,当将其扩展到与业务相关的规模时,我无法对其进行扩展(例如,>50k 客户和 >10k 产品)

  • 然后我得到的错误如下:

     16/08/16 14:38:56 WARN org.apache.spark.scheduler.TaskSetManager:
       Lost task 22.0 in stage 411.0 (TID 15139,
       productrecommendation-high-w-2.c.main-nova-558.internal):
       java.lang.StackOverflowError
            at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
            at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
            at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
            at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
            at scala.collection.immutable.$colon$colon.readObject(List.scala:362)
            at sun.reflect.GeneratedMethodAccessor11.invoke(Unknown Source)
            at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
            at java.lang.reflect.Method.invoke(Method.java:498)
            at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058)
            at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1909)
            at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
            at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
            at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
            at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
            at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
            at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
            at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
            at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
            at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
            at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
            at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
            at scala.collection.immutable.$colon$colon.readObject(List.scala:362)
    
  • 我什至得到了一个 300 GB 的集群(1 个 108 GB 的主节点 + 2 个 108 GB RAM 的节点)来尝试运行它;它适用于 50k 客户,但不适用于更多客户

  • 野心是建立一个可以为超过 80 万客户运行的设置

细节

失败的代码行

predictions = model.recommendProductsForUsers(10).flatMap(lambda p: p[1]).map(lambda p: (str(p[0]), str(p[1]), float(p[2])))
pprint.pprint(predictions.take(10))
schema = StructType([StructField("customer", StringType(), True), StructField("sku", StringType(), True), StructField("prediction", FloatType(), True)])
dfToSave = sqlContext.createDataFrame(predictions, schema).dropDuplicates()

你建议如何进行?我觉得脚本末尾的“合并”部分(即当我将其写入 dfToSave 时)会导致错误;有没有办法绕过这个并部分保存?

4

1 回答 1

2

从堆栈跟踪来看,这似乎与使用 ALS 训练时 Spark 给出 StackOverflowError的问题相同

基本上,Spark 递归地表达 RDD 沿袭,因此当在迭代工作负载的过程中没有对事物进行惰性评估时,您最终会得到深度嵌套的对象。调用 sc.setCheckpointDir 并调整检查点间隔将减少此 RDD 沿袭的长度。

于 2016-08-16T18:47:02.960 回答