0

我正在尝试将 keras 的模型转换为 TensorRTengine。我的模型包含 CNN 层和 GRU 层。我使用tensorrt python API来做到这一点。同时,我成功地构建了引擎。但是准确性下降了很多。我找不到是什么原因造成的。

通过测试,发现GRU层之前的层的输出和keras的输出是一致的。但是加入GRU层之后,输出就大不一样了。所以我可以确认问题是GRU层引起的。但是找不到具体原因。这个问题困扰了我很久。

keras 代码如下:</p>

from keras.layers import Bidirectional, GRU
gru_layer = Bidirectional(GRU(256, return_sequences=True), name='gru_layer')(input)

tensorrt代码如下:</p>

def add_bigru(self,input_layer,layer_name,layer_index,nhidden=1):
  gru_1 = self.network.add_rnn_v2(input_layer.get_output(0), layer_count=1, hidden_size=256,
  max_seq_length=input_layer.get_output(0).shape[1], op=trt.RNNOperation.GRU)
  gru_1.direction = trt.RNNDirection.BIDIRECTION
  # forward
  forward_kernel_weight_unit, forward_recurrent_kernel_weight_unit, forward_bias_weight_unit = self.split_gru_weight(
      [self.pb_weight[layer_name+"/forward_gru_"+str(layer_index)+"/kernel"], self.pb_weight[layer_name+"/forward_gru_"+str(layer_index)+"/recurrent_kernel"],
       self.pb_weight[layer_name+"/forward_gru_"+str(layer_index)+"/bias"]], 3)

  # backward
  backward_kernel_weight_unit, backward_recurrent_kernel_weight_unit, backward_bias_weight_unit = self.split_gru_weight(
      [self.pb_weight[layer_name+"/backward_gru_"+str(layer_index)+"/kernel"], self.pb_weight[layer_name+"/backward_gru_"+str(layer_index)+"/recurrent_kernel"],
       self.pb_weight[layer_name+"/backward_gru_"+str(layer_index)+"/bias"]], 3)

  for i in range(2):

      if i == 0:
          kernel_weight_unit, recurrent_kernel_weight_unit, bias_weight_unit = forward_kernel_weight_unit, forward_recurrent_kernel_weight_unit, forward_bias_weight_unit
      else:
          kernel_weight_unit, recurrent_kernel_weight_unit, bias_weight_unit = backward_kernel_weight_unit, backward_recurrent_kernel_weight_unit, backward_bias_weight_unit

      layer_index = i
      # update gate
      gru_1.set_weights_for_gate(layer_index=layer_index, gate=trt.RNNGateType.UPDATE, is_w=True,
                                 weights=kernel_weight_unit[str(0)])
      gru_1.set_weights_for_gate(layer_index=layer_index, gate=trt.RNNGateType.UPDATE, is_w=False,
                                 weights=recurrent_kernel_weight_unit[str(0)])
      gru_1.set_bias_for_gate(layer_index=layer_index, gate=trt.RNNGateType.UPDATE, is_w=True,
                              bias=bias_weight_unit[str(0)])
      gru_1.set_bias_for_gate(layer_index=layer_index, gate=trt.RNNGateType.UPDATE, is_w=False,
                              bias=np.zeros(256).astype(np.float32))

      # reset gate
      gru_1.set_weights_for_gate(layer_index=layer_index, gate=trt.RNNGateType.RESET, is_w=True,
                                 weights=kernel_weight_unit[str(1)])
      gru_1.set_weights_for_gate(layer_index=layer_index, gate=trt.RNNGateType.RESET, is_w=False,
                                 weights=recurrent_kernel_weight_unit[str(1)])

      gru_1.set_bias_for_gate(layer_index=layer_index, gate=trt.RNNGateType.RESET, is_w=True,
                              bias=bias_weight_unit[str(1)])
      gru_1.set_bias_for_gate(layer_index=layer_index, gate=trt.RNNGateType.RESET, is_w=False,
                              bias=np.zeros(256).astype(np.float32))

      # hidden gate
      gru_1.set_weights_for_gate(layer_index=layer_index, gate=trt.RNNGateType.HIDDEN, is_w=True,
                                 weights=kernel_weight_unit[str(2)])
      gru_1.set_weights_for_gate(layer_index=layer_index, gate=trt.RNNGateType.HIDDEN, is_w=False,
                                 weights=recurrent_kernel_weight_unit[str(2)])

      gru_1.set_bias_for_gate(layer_index=layer_index, gate=trt.RNNGateType.HIDDEN, is_w=True,
                              bias=bias_weight_unit[str(2)])
      gru_1.set_bias_for_gate(layer_index=layer_index, gate=trt.RNNGateType.HIDDEN, is_w=False,
                              bias=np.zeros(256).astype(np.float32))
  return  gru_1

环境:</p>

TensorRT Version: 7.0.0.11
NVIDIA GPU: Tesla P100
NVIDIA Driver Version: 440.118.02
CUDA Version:
CUDNN Version: cudatoolkit 10.0.130
Operating System: centos
Python Version (if applicable): 3.6.13
Tensorflow Version (if applicable): 1,13,1
Keras Version (if applicable): 2.1.6

任何人都可以帮我解决它吗?

4

0 回答 0