考虑以下代码,它使用ND4J 库来创建“moons”测试数据集的更简单版本:
val n = 100
val n1: Int = n/2
val n2: Int = n-n1
val outerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n1)))
val outerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n1)))
val innerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val innerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val X: INDArray = Nd4j.vstack(
Nd4j.concat(1, outerX, innerX), // 1 x n
Nd4j.concat(1, outerY, innerY) // 1 x n
) // 2 x n
val y: INDArray = Nd4j.hstack(
Nd4j.zeros(n1), // 1 x n1
Nd4j.ones(n2) // 1 x n2
) // 1 x n
println(s"# y shape: ${y.shape().toList}") // 1x100
println(s"# y data length: ${y.data().length()}") // 100
println(s"# X shape: ${X.shape().toList}") // 2x100
println(s"# X row 0 shape: ${X.getRow(0).shape().toList}") // 1x100
println(s"# X row 1 shape: ${X.getRow(1).shape().toList}") // 1x100
println(s"# X row 0 data length: ${X.getRow(0).data().length()}") // 200 <- !
println(s"# X row 1 data length: ${X.getRow(1).data().length()}") // 100
令人惊讶的是,倒数第二行X.getRow(0).data().length()
是 200 而不是 100。经过检查,这是因为返回的结构data()
包含整个矩阵,即两行连接在一起。
如何将 X 矩阵的实际第一行放入 Java(或 Scala)List
?我可以只取 200 个元素的“第一行”中的前 100 个项目,但这似乎不太优雅。