1.另一种方法
我不确定这是否是最好的方法,但它更快。您可以使用tf.boolean_mask
而不是tf.map_fn
.
import tensorflow as tf
import numpy as np
a = tf.cast(tf.constant(np.reshape(np.arange(60), (3,2,2,5))), tf.int32)
idx2 = tf.constant([0, 1, 0])
fn = lambda i: a[i,:,:][:,:,idx2[i]]
idx = tf.range(tf.shape(a)[0])
masks = tf.map_fn(fn, idx)
# new method
idx = tf.one_hot(idx2,depth=a.shape[-1])
masks2 = tf.boolean_mask(tf.transpose(a,[0,3,1,2]), idx)
with tf.Session() as sess:
print('tf.map_fn version:\n',sess.run(masks))
print('tf.boolean_mask version:\n',sess.run(masks2))
# print
tf.map_fn version:
[[[ 0 5]
[10 15]]
[[21 26]
[31 36]]
[[40 45]
[50 55]]]
tf.boolean_mask version:
[[[ 0 5]
[10 15]]
[[21 26]
[31 36]]
[[40 45]
[50 55]]]
2.性能对比
向量化方法 1000 次迭代0.07s
和tf.map_fn
方法 1000 次迭代占用0.85s
我的 8GB GPU 内存。矢量化方法将明显快于tf.map_fn()
.
import datetime
...
with tf.Session() as sess:
start = datetime.datetime.now()
for _ in range(1000):
sess.run(masks)
end = datetime.datetime.now()
print('tf.map_fn version cost time(seconds) : %.2f' % ((end - start).total_seconds()))
start = datetime.datetime.now()
for _ in range(1000):
sess.run(masks2)
end = datetime.datetime.now()
print('tf.boolean_mask version cost time(seconds) : %.2f' % ((end - start).total_seconds()))
# print
tf.map_fn version cost time(seconds) : 0.85
tf.boolean_mask version cost time(seconds) : 0.07
我相信随着形状的a
增加,性能差异会变得更加明显。