假设我们有一个这样的数据集:
+---+-----+
| id|label|
+---+-----+
| 0| 0.0|
| 1| 1.0|
| 2| 0.0|
| 3| 1.0|
| 4| 0.0|
| 5| 1.0|
| 6| 0.0|
| 7| 1.0|
| 8| 0.0|
| 9| 1.0|
+---+-----+
这个数据集是完美平衡的,但这种方法也适用于不平衡的数据。
现在,让我们用额外的信息来扩充这个 DataFrame,这些信息对于决定哪些行应该去训练集很有用。步骤如下:
- 确定每个标签的多少示例应该是给定一些的训练集的一部分
ratio
。
- 打乱 DataFrame 的行。
- 使用窗口函数对 DataFrame 进行分区和排序
label
,然后使用 对每个标签的观察结果进行排名row_number()
。
我们最终得到以下数据框:
+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
| 6| 0.0| 1|
| 2| 0.0| 2|
| 0| 0.0| 3|
| 4| 0.0| 4|
| 8| 0.0| 5|
| 9| 1.0| 1|
| 5| 1.0| 2|
| 3| 1.0| 3|
| 1| 1.0| 4|
| 7| 1.0| 5|
+---+-----+----------+
注意:行被打乱(参见:id
列中的随机顺序),按标签分区(参见:label
列)并排名。
假设我们想要进行 80% 的拆分。在这种情况下,我们希望四个1.0
标签和四个0.0
标签用于训练数据集,一个1.0
标签和一个0.0
标签用于测试数据集。我们在row_number
列中有这些信息,所以现在我们可以简单地在用户定义的函数中使用它(如果row_number
小于或等于四,则示例转到训练集)。
应用UDF后,得到的数据帧如下:
+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
| 6| 0.0| 1| true|
| 2| 0.0| 2| true|
| 0| 0.0| 3| true|
| 4| 0.0| 4| true|
| 8| 0.0| 5| false|
| 9| 1.0| 1| true|
| 5| 1.0| 2| true|
| 3| 1.0| 3| true|
| 1| 1.0| 4| true|
| 7| 1.0| 5| false|
+---+-----+----------+----------+
现在,要获得训练/测试数据,必须做:
val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)
对于一些非常大的数据集,这些排序和分区步骤可能会让人望而却步,因此我建议首先尽可能过滤数据集。物理计划如下:
== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
+- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(label#5, 200)
+- *(1) Project [id#4, label#5]
+- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
+- LocalTableScan [id#4, label#5, _nondeterministic#9
这是完整的工作示例(使用 Spark 2.3.0 和 Scala 2.11.12 测试):
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}
class StratifiedTrainTestSplitter {
def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
df.groupBy(label).count().createOrReplaceTempView("labelCounts")
val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
import ss.implicits._
ss.sql(query)
.select("ratioLabel", "trainExamples")
.map((r: Row) => r.getDouble(0) -> r.getLong(1))
.collect()
.toMap
}
def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
val w = Window.partitionBy(col(label)).orderBy(col(label))
val rowNumPartitioner = row_number().over(w)
val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")
dfRowNum.show()
val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)
val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))
dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
}
}
object StratifiedTrainTestSplitter {
def getDf(ss: SparkSession): DataFrame = {
val data = Seq(
(0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
)
ss.createDataFrame(data).toDF("id", "label")
}
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.config(new SparkConf().setMaster("local[1]"))
.getOrCreate()
val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)
df.cache()
df.where(col("isTrainSet") === true).show()
df.where(col("isTrainSet") === false).show()
}
}
注意:Double
在这种情况下,标签是 s。如果您的标签是String
s,您将不得不在这里和那里切换类型。