给定数据集中的一些数据(或张量)
tensor = tf.constant([1, 2, 3, 4, 5, 6, 7])
我需要通过绘制(比如)替换来创建N
一批元组。一个示例小批量可能是M
4 x 3
[[1 2 3], [3, 4 5], [2, 3, 4], [5, 6, 7]]
目的是避免以这种形式创建数据集
[[1, 2, 3]
[2, 3, 4]
[4, 5, 6]
]
因为大量的冗余。当我将新的小批量输入到训练过程中时,应该即时创建这些批次。
给定数据集中的一些数据(或张量)
tensor = tf.constant([1, 2, 3, 4, 5, 6, 7])
我需要通过绘制(比如)替换来创建N
一批元组。一个示例小批量可能是M
4 x 3
[[1 2 3], [3, 4 5], [2, 3, 4], [5, 6, 7]]
目的是避免以这种形式创建数据集
[[1, 2, 3]
[2, 3, 4]
[4, 5, 6]
]
因为大量的冗余。当我将新的小批量输入到训练过程中时,应该即时创建这些批次。
我在这里找到了一种方法,您认为这是最佳方法吗?还是以某种方式直接部署队列更好?
此代码基于上面的链接
import tensorflow as tf
import numpy as np
def gen_batch():
# compute number of batches to emit
num_of_batches = round(((len(sequence) - batch_size) / stride))
# emit batches
for i in range(0, num_of_batches * stride, stride):
result = np.array(sequence[i:i + batch_size])
yield result
sequence = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
batch_size = 3
stride = 1
ds = tf.data.Dataset.from_generator(gen_batch, tf.float64)
ds = ds.shuffle(100)
ds_out = ds.make_one_shot_iterator().get_next()
sess = tf.Session()
print(sess.run(ds_out))
print(sess.run(ds_out))
print(sess.run(ds_out))
print(sess.run(ds_out))
print(sess.run(ds_out))
印刷:
[3. 4. 5.]
[1. 2. 3.]
[2. 3. 4.]
[4. 5. 6.]
[5. 6. 7.]