我正在构建一个神经网络来预测成对比较的结果。在合并和计算下游部分的结果之前,相同的编码器网络应用于两个输入。在我的用例中,我正在计算给定元素集的所有成对预测,因此预测的数量增长得非常快,因此我有兴趣加快预测过程。
天真地进行完整的成对预测涉及一遍又一遍地计算编码器网络在每个元素上的结果。由于编码器网络大于下游部分(合并+下游层),我认为在每个输入元素上预先计算编码器网络的结果,然后仅根据这些编码值计算下游将导致显着加速。然而,这并不是我在实践中发现的。对于下面在 Colab (CPU) 和我的机器 (CPU) 上的示例,我可以节省 10-15% 的运行时间,而如果您从层的角度考虑,我预计会节省 50%,如果您考虑到更多考虑参数。
我觉得我错过了一些东西,无论是在实现中还是 tensorflow/keras 已经做了某种魔法(缓存?)给定网络的结构,从而导致较小的收益?
import numpy as np # numpy will be used for mgrid to compute all the pairs of the input
import tensorflow as tf
# Encoder Network
input_a = tf.keras.Input(shape=(10,4))
x = tf.keras.layers.Flatten()(input_a)
x = tf.keras.layers.Dense(100, activation='relu')(x)
x = tf.keras.layers.Dense(20, activation='relu')(x)
x = tf.keras.layers.Dense(10, activation='relu')(x)
upstream_network = tf.keras.Model(input_a, x)
# Downstream network, from merge to final prediction
input_downstream_a = tf.keras.Input(shape = upstream_network.layers[-1].output_shape[1:])
input_downstream_b = tf.keras.Input(shape = upstream_network.layers[-1].output_shape[1:])
x = tf.keras.layers.subtract([input_downstream_a, input_downstream_b])
x = tf.keras.layers.Dense(20, activation='relu')(x)
x = tf.keras.layers.Dense(1, activation='sigmoid')(x)
downstream_network = tf.keras.Model((input_downstream_a, input_downstream_b), x)
# Full network
input_full_a = tf.keras.Input(shape=(10,4))
input_full_b = tf.keras.Input(shape=(10,4))
intermed_a = upstream_network(input_full_a)
intermed_b = upstream_network(input_full_b)
res = downstream_network([intermed_a, intermed_b])
full_network = tf.keras.Model([input_full_a, input_full_b], res)
full_network.compile(loss='binary_crossentropy')
# Experiment
population = np.random.random((300, 10, 4))
# %%timeit 10
# 1.9s on Colab CPU
indices = np.mgrid[range(population.shape[0]), range(population.shape[0])].reshape(2, -1)
full_network.predict([population[indices[0]], population[indices[1]]])
# %%timeit 10
# 1.7s on Colab CPU
out = upstream_network.predict(population)
indices = np.mgrid[range(population.shape[0]), range(population.shape[0])].reshape(2, -1)
downstream_network.predict([out[indices[0]], out[indices[1]]])