3

So assuming I have this:

TensorShape([Dimension(None), Dimension(32)])

And I use tf.split on this tensor _X with the dimension above:

_X = tf.split(_X, 128, 0) 

What is the shape of this new tensor? The output is a list so its hard to know the shape of this new tensor.

4

3 回答 3

10

tf.split() returns the list of tensor objects. You could know shape of each tensor object as follows

import tensorflow as tf

X = tf.random_uniform([256, 32]);
Y = tf.split(X,128,0)
Y_shape = tf.shape(Y[1])

sess = tf.Session()
X_v,Y_v,Y_shape_v = sess.run([X,Y,Y_shape]) 
# numpy style
print X_v.shape
print len(Y_v)
print Y_v[100].shape
# TF style
print len(Y)
print Y_shape_v

Output :

(256, 32)
128
(2, 32)
128
[ 2 32]

I hope this helps !

于 2017-06-20T00:03:04.403 回答
9

tf.split(X, row = n, column = m) is used to split the data set of the variable into n number of pieces row wise and m numbers of pieces column wise.

For example, we have data_set x of size (10,10), then tf.split(x, 2, 0) will break the data_set of x in 2 set of size (5, 10)

but if we take tf.split(x, 2, 2), then we will get 4 sets of data of size (5, 5).

于 2018-01-30T16:21:31.467 回答
-1

The new version of tensorflow defines split function as follows:

tf.split( value, num_or_size_splits, axis=0, num=None, name='split' )

however, when I try to run it in R:

X = tf$random_uniform(minval=0,
                      maxval=10,shape(256, 32),name = "X");

Y = tf$split(X,num_or_size_splits = 2,axis = 0)

it reports error message:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split. Argument provided: 2.0
于 2019-09-17T02:09:18.967 回答