5

我正在为每个元素构建一个具有两个形状 [batch,width,heigh,3] 和 [batch,class] 的张量的数据集。为简单起见,假设类 = 5。

您输入什么形状以dataset.padded_batch(1000,shape)使图像沿宽度/高度/ 3 轴填充?

我尝试了以下方法:

tf.TensorShape([[None,None,None,3],[None,5]])
[tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])]
[[None,None,None,3],[None,5]]
([None,None,None,3],[None,5])
(tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])‌​)

每次引发 TypeError

文档状态:

padded_shapes:tf.TensorShape 或 tf.int64 矢量张量对象的嵌套结构,表示每个输入元素的相应组件在批处理之前应填充到的形状。任何未知维度(例如 tf.TensorShape 中的 tf.Dimension(None) 或类张量对象中的 -1)将被填充到每个批次中该维度的最大大小。

相关代码:

dataset = tf.data.Dataset.from_generator(generator,tf.float32)
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)
4

2 回答 2

8

感谢 mrry 找到解决方案。原来 from_generator 中的类型必须与条目中的张量数相匹配。

新代码:

dataset = tf.data.Dataset.from_generator(generator,(tf.float32,tf.float32))
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)
于 2017-11-06T20:23:01.903 回答
1

TensorShape不接受嵌套列表。 tf.TensorShape([None, None, None, 3, None, 5])TensorShape(None)(注 no [])是合法的。

不过,将这两个张量结合起来对我来说听起来很奇怪。我不确定您要完成什么,但我建议您尝试在不组合不同维度的张量的情况下进行。

于 2017-11-04T06:06:44.717 回答