0

TensorFlow Swift 中是否有等效的tf.repeat。我似乎无法在主要 api 或 .Raw 函数的任何地方找到它。

谢谢

4

2 回答 2

0

如果我没记错的话,tiled(multiples:)张量应该复制tf.repeat的功能。例如,

import TensorFlow
let tensor = Tensor<Float>([1.0, 2.0, 3.0])
let repeatTensor = tensor.tiled(multiples: [2])

将产生一个repeatTensor[1.0, 2.0, 3.0, 1.0, 2.0, 3.0]

于 2020-08-25T19:56:49.060 回答
0

到目前为止我发现的最好方法是使用 _Raw.raggedRange。

func rowIdsFrom(lengths: Tensor<Int32>) -> Tensor<Int32> {
        let starts = Tensor<Int32>(rangeFrom: Int32(0), to: lengths.shape[0], stride: 1)
        let (_, indices1) : (Tensor<Int32>, Tensor<Int32>) = _Raw.raggedRange(starts: starts, limits: starts + lengths, deltas: Tensor.one)
        let (_, indices2) : (Tensor<Int32>, Tensor<Int32>) = _Raw.raggedRange(starts: Tensor.zero, limits: lengths, deltas: Tensor.one)
        return indices1 - indices2
}
func repeat(values: Tensor<Scalar>, lengths: Tensor<Int32>) -> Tensor<Scalar> {
        return values.gathering(atIndices: rowIdsFrom(lengths: lengths))
}
于 2020-09-04T18:49:17.790 回答