1

I have the label array and logits array as:

label = [1,1,0,1,-1,-1,1,0,-1,0,-1,-1,0,0,0,1,1,1,-1,1]
logits = [0.2,0.3,0.4,0.1,-1.4,-2,0.4,0.5,-0.231,1.9,1.4,-1.456,0.12,-0.45,0.5,0.3,0.4,0.2,1.2,12]

Using Tensorflow, I want to get the values from label and logits where:

1> label is greater than zero
2> label is less than zero
3> label is equals to zero

I am willing to have result something like this:

label1,logits1 = some_Condition_logic_Where(label > 0) _ returns respective labels and logits

Can anyone suggest me how is this achievable?

EDITED:

>>> label = [1,1,0,1,-1,-1,1,0,-1,0,-1,-1,0,0,0,1,1,1,-1,1]
>>> logits = [0.2,0.3,0.4,0.1,-1.4,-2,0.4,0.5,-0.231,1.9,1.4,-1.456,0.12,-0.45,0.5,0.3,0.4,0.2,1.2,12]
>>> label1 = [];logits1 = []
>>> for l1,l2 in zip(label,logits):
...     if(l1>0):
...         label1.append(l1)
...         logits1.append(l2)
...
>>> label1
[1, 1, 1, 1, 1, 1, 1, 1]
>>> logits1
[0.2, 0.3, 0.1, 0.4, 0.3, 0.4, 0.2, 12]

Want this logic to be implemented in Tensorflow same for the values with -1 and 0. How I can achieve this?

4

1 回答 1

1

您可以使用tf.boolean_mask.

import tensorflow as tf

label = tf.constant([1,1,0,1,-1,-1,1,0,-1,0,-1,-1,0,0,0,1,1,1,-1,1],dtype=tf.float32)
logits = tf.constant([0.2,0.3,0.4,0.1,-1.4,-2,0.4,0.5,-0.231,1.9,1.4,-1.456,0.12,-0.45,0.5,0.3,0.4,0.2,1.2,12],dtype=tf.float32)

# label>0
label1 = tf.boolean_mask(label,tf.greater(label,0))
logits1 = tf.boolean_mask(logits,tf.greater(label,0))
# label<0
label2 = tf.boolean_mask(label,tf.less(label,0))
logits2 = tf.boolean_mask(logits,tf.less(label,0))
# label=0
label3 = tf.boolean_mask(label,tf.equal(label,0))
logits3 = tf.boolean_mask(logits,tf.equal(label,0))

with tf.Session() as sess:
    print(sess.run(label1))
    print(sess.run(logits1))
    print(sess.run(label2))
    print(sess.run(logits2))
    print(sess.run(label3))
    print(sess.run(logits3))

[1. 1. 1. 1. 1. 1. 1. 1.]
[ 0.2  0.3  0.1  0.4  0.3  0.4  0.2 12. ]
[-1. -1. -1. -1. -1. -1.]
[-1.4   -2.    -0.231  1.4   -1.456  1.2  ]
[0. 0. 0. 0. 0. 0.]
[ 0.4   0.5   1.9   0.12 -0.45  0.5 ]
于 2019-05-14T08:23:36.063 回答