1
import tensorflow as tf

cluster_size = tf.constant(6) # size of the cluster
m = tf.constant(6) # number of contigs (column size)
n = tf.constant(3) # number of points in a single contigs (column size)
contigs_index = tf.reshape(tf.range(0, m, 1, dtype=tf.int32), [1, -1])
contigs = tf.constant(
  [[1.1, 2.2, 3.3], [6.6, 5.5, 4.4], [7.7, 8.8, 9.9], [11.1, 22.2, 33.3],
    [66.6, 55.5, 44.4], [77.7, 88.8, 99.9]])

# pad zeo to the right till fixed length
def rpad_with_zero(points):
  points = tf.slice(tf.pad(points, tf.reshape(tf.concat(
    [tf.zeros([1, 2], tf.int32), tf.add(
      tf.zeros([1, 2], tf.int32),
      tf.subtract(cluster_size, tf.size(points)))], 0), [2, -1]), "CONSTANT"),
                    (0, tf.subtract(cluster_size, tf.size(points))),
                    (1, cluster_size))
  return points

#calculate pearson correlation coefficient r value
def calculate_pcc(row, contigs):
  r = tf.divide(tf.subtract(
      tf.multiply(tf.to_float(n), tf.reduce_sum(tf.multiply(row, contigs), 1)),
      tf.multiply(tf.reduce_sum(row, 1), tf.reduce_sum(contigs, 1))),
                tf.multiply(
      tf.sqrt(tf.subtract(
        tf.multiply(tf.to_float(n), tf.reduce_sum(tf.square(row), 1)), 
        tf.square(tf.reduce_sum(row, 1)))), 
      tf.sqrt(tf.subtract(tf.multiply(
        tf.to_float(n), tf.reduce_sum(tf.square(contigs), 1)),
        tf.square(tf.reduce_sum(contigs, 1)))
      )))
  return r

#slice first row from contigs
row = tf.slice(contigs, (0, 0), (1, 3))
#calculate pcc
r = calculate_pcc(row, contigs)
#cluster member index whose r value is greater than 0.90, then casting to
# int32,
members0_index = tf.cast(tf.reshape(tf.where(tf.greater(r, 0.90)), [1, -1]),
                         tf.int32)
#members = index <intersection> members, padding the members index with
# zeros at right, to keep the fixed cluster length
members0_index = rpad_with_zero(
  tf.reshape(tf.sets.set_intersection(contigs_index, members0_index).values,
             [1, -1]))
#update index with the rest element index from contigs, and padding
contigs_index = rpad_with_zero(
  tf.reshape(tf.sets.set_difference(contigs_index, members0_index).values,
             [1, -1]))

#def condition(contigs, contigs_index, members0_index):
def condition(contigs_index, members0_index):
  return tf.greater(tf.count_nonzero(contigs_index),
                    0) # iterate until there is a contig

#def body(contigs, contigs_index, members0_index):
def body(contigs_index, members0_index):
  i = tf.reshape(tf.slice(contigs_index, [0, 0], [1, 1]),
                 []) #the first element in the contigs_index
  row = tf.slice(contigs, (i, 0),
                 (1, 3)) #slice the ith contig from contigs
  r = calculate_pcc(row, contigs)
  members_index = tf.cast(tf.reshape(tf.where(tf.greater(r, 0.90)), [1, -1]),
                          tf.int32)
  members_index = rpad_with_zero(rpad_with_zero(
    tf.reshape(tf.sets.set_intersection(contigs_index, members_index).values,
               [1, -1])))
  members0_index = tf.concat([members0_index, members_index], 0)
  contigs_index = rpad_with_zero(
    tf.reshape(tf.sets.set_difference(contigs_index, members_index).values,
               [1, -1]))
  #return [contigs, contigs_index, members0_index]
  return [contigs_index, members0_index]

sess = tf.Session()
sess.run(tf.while_loop(condition, body,
   #loop_vars=[contigs, contigs_index, members0_index],
   loop_vars=[contigs_index, members0_index],
   #shape_invariants=[contigs.get_shape(), contigs_index.get_shape(), 
   # tf.TensorShape([None, 6])]))
   shape_invariants=[contigs_index.get_shape(), tf.TensorShape([None, 6])]))

错误是:

ValueError:while_12/Merge:0 的形状不是循环的不变量。它以形状 (1, 6) 进入循环,但在一次迭代后具有形状 (?, ?)。shape_invariants在循环变量上使用 tf.while_loop 或 set_shape()的参数提供形状不 变量。

似乎变数

contigs_index

是负责任的,但我真的不知道为什么!我展开循环执行每个语句,但找不到任何形状不匹配!

4

1 回答 1

1

shape_invariants=[contigs_index.get_shape(), tf.TensorShape([None, 6])]))应该变成shape_invariants=[tf.TensorShape([None, None]), tf.TensorShape([None, 6])])),以允许contigs_index变量的形状变化(在rpad_with_zero调用中)。

于 2017-11-21T17:26:59.170 回答