寻找专业知识来指导我解决以下问题。
背景:
- 我正在尝试使用受此示例启发的基本 PySpark 脚本
- 作为部署基础架构,我使用 Google Cloud Dataproc 集群。
- 我的代码中的基石是此处记录的函数“recommendProductsForUsers” ,它为我提供了模型中所有用户的顶级 X 产品
我遇到的问题
ALS.Train 脚本运行平稳,并且在 GCP 上可以很好地扩展(轻松超过 100 万客户)。
但是,应用预测:即使用函数“PredictAll”或“recommendProductsForUsers”,根本无法扩展。我的脚本对于一个小数据集(<100 个客户,<100 个产品)运行顺利。但是,当将其扩展到与业务相关的规模时,我无法对其进行扩展(例如,>50k 客户和 >10k 产品)
然后我得到的错误如下:
16/08/16 14:38:56 WARN org.apache.spark.scheduler.TaskSetManager: Lost task 22.0 in stage 411.0 (TID 15139, productrecommendation-high-w-2.c.main-nova-558.internal): java.lang.StackOverflowError at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373) at scala.collection.immutable.$colon$colon.readObject(List.scala:362) at sun.reflect.GeneratedMethodAccessor11.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1909) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373) at scala.collection.immutable.$colon$colon.readObject(List.scala:362)
我什至得到了一个 300 GB 的集群(1 个 108 GB 的主节点 + 2 个 108 GB RAM 的节点)来尝试运行它;它适用于 50k 客户,但不适用于更多客户
野心是建立一个可以为超过 80 万客户运行的设置
细节
失败的代码行
predictions = model.recommendProductsForUsers(10).flatMap(lambda p: p[1]).map(lambda p: (str(p[0]), str(p[1]), float(p[2])))
pprint.pprint(predictions.take(10))
schema = StructType([StructField("customer", StringType(), True), StructField("sku", StringType(), True), StructField("prediction", FloatType(), True)])
dfToSave = sqlContext.createDataFrame(predictions, schema).dropDuplicates()
你建议如何进行?我觉得脚本末尾的“合并”部分(即当我将其写入 dfToSave 时)会导致错误;有没有办法绕过这个并部分保存?