7

如上。我尝试了那些无济于事:

tf.random.shuffle( (a,b) )
tf.random.shuffle( zip(a,b) )

我曾经连接它们并进行改组,然后取消连接/解包。但现在我处于(a)是 4D 秩张量而(b)是 1D 的情况,所以,没有办法连接。

我还尝试将种子参数提供给 shuffle 方法,以便它重现相同的 shuffle 并且我使用它两次 => 失败。还尝试用随机打乱的数字范围对自己进行改组,但 TF 在花式索引和东西 ==> 失败时不如 numpy 灵活。

我现在正在做的是,将所有内容转换回 numpy 然后使用 sklearn 中的 shuffle 然后通过重铸回到张量。这纯粹是愚蠢的方式。这应该发生在图表内。

4

1 回答 1

15

您可以只打乱索引,然后用于tf.gather()提取与这些打乱索引相对应的值:

TF2.x(更新)

import tensorflow as tf
import numpy as np

x = tf.convert_to_tensor(np.arange(5))
y = tf.convert_to_tensor(['a', 'b', 'c', 'd', 'e'])

indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices)

shuffled_x = tf.gather(x, shuffled_indices)
shuffled_y = tf.gather(y, shuffled_indices)

print('before')
print('x', x.numpy())
print('y', y.numpy())

print('after')
print('x', shuffled_x.numpy())
print('y', shuffled_y.numpy())
# before
# x [0 1 2 3 4]
# y [b'a' b'b' b'c' b'd' b'e']
# after
# x [4 0 1 2 3]
# y [b'e' b'a' b'b' b'c' b'd']

TF1.x

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, (None, 1, 1, 1))
y = tf.placeholder(tf.int32, (None))

indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices)

shuffled_x = tf.gather(x, shuffled_indices)
shuffled_y = tf.gather(y, shuffled_indices)

确保您在同一会话运行中进行了shuffled_x计算。shuffled_y否则他们可能会得到不同的索引排序。

# Testing
x_data = np.concatenate([np.zeros((1, 1, 1, 1)),
                         np.ones((1, 1, 1, 1)),
                         2*np.ones((1, 1, 1, 1))]).astype('float32')
y_data = np.arange(4, 7, 1)

print('Before shuffling:')
print('x:')
print(x_data.squeeze())
print('y:')
print(y_data)

with tf.Session() as sess:
  x_res, y_res = sess.run([shuffled_x, shuffled_y],
                          feed_dict={x: x_data, y: y_data})
  print('After shuffling:')
  print('x:')
  print(x_res.squeeze())
  print('y:')
  print(y_res)
Before shuffling:
x:
[0. 1. 2.]
y:
[4 5 6]
After shuffling:
x:
[1. 2. 0.]
y:
[5 6 4]
于 2019-06-13T08:55:47.340 回答