您可以只打乱索引,然后用于tf.gather()
提取与这些打乱索引相对应的值:
TF2.x(更新)
import tensorflow as tf
import numpy as np
x = tf.convert_to_tensor(np.arange(5))
y = tf.convert_to_tensor(['a', 'b', 'c', 'd', 'e'])
indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices)
shuffled_x = tf.gather(x, shuffled_indices)
shuffled_y = tf.gather(y, shuffled_indices)
print('before')
print('x', x.numpy())
print('y', y.numpy())
print('after')
print('x', shuffled_x.numpy())
print('y', shuffled_y.numpy())
# before
# x [0 1 2 3 4]
# y [b'a' b'b' b'c' b'd' b'e']
# after
# x [4 0 1 2 3]
# y [b'e' b'a' b'b' b'c' b'd']
TF1.x
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, (None, 1, 1, 1))
y = tf.placeholder(tf.int32, (None))
indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices)
shuffled_x = tf.gather(x, shuffled_indices)
shuffled_y = tf.gather(y, shuffled_indices)
确保您在同一会话运行中进行了shuffled_x
计算。shuffled_y
否则他们可能会得到不同的索引排序。
# Testing
x_data = np.concatenate([np.zeros((1, 1, 1, 1)),
np.ones((1, 1, 1, 1)),
2*np.ones((1, 1, 1, 1))]).astype('float32')
y_data = np.arange(4, 7, 1)
print('Before shuffling:')
print('x:')
print(x_data.squeeze())
print('y:')
print(y_data)
with tf.Session() as sess:
x_res, y_res = sess.run([shuffled_x, shuffled_y],
feed_dict={x: x_data, y: y_data})
print('After shuffling:')
print('x:')
print(x_res.squeeze())
print('y:')
print(y_res)
Before shuffling:
x:
[0. 1. 2.]
y:
[4 5 6]
After shuffling:
x:
[1. 2. 0.]
y:
[5 6 4]