1

我在 Keras 中训练了一个 yolov3 模型,它有 3 个输出(3D 张量)。那里没问题。然后我将此模型转换为 tfjs 以在浏览器中运行它。如果我将模型截断为这个特定的输出,我可以很容易地得到每个输出的内容。但我无法一次获得全部(3 个输出)。我想知道这是否可能?

这是我想做的(它不起作用,它挂起):

const myTracker = await tf.loadLayersModel(v3_model); 
const prediction = tf.tidy(() => {

      // prepare inputs tensor
      const inputs = tf.browser.fromPixels(canvas, 3).expandDims(0).toFloat().div(tf.scalar(255));
      console.log("============ inputs tensor shape:" + inputs.shape);  //--> 1,416,416,3  

      // get all 3 outputs
      const outputs = myTracker.predict(inputs).arraySync(); 
      outputs.print();
      return outputs; 
});            

如果我只对 1 个输出感兴趣,在确定 3 个输出名称后,这就是有效的方法。

const myTracker = await tf.loadLayersModel(v3_model); 
const prediction = tf.tidy(() => {

      // prepare inputs tensor
      const inputs = tf.browser.fromPixels(canvas, 3).expandDims(0).toFloat().div(tf.scalar(255));
      console.log("============ inputs tensor shape:" + inputs.shape);  //--> 1,416,416,3  

      // get the full object of interesting layers
      const layer3 = myTracker.getLayer('conv2d_3');
      const layer8 = myTracker.getLayer('conv2d_8'); 
      const layer13= myTracker.getLayer('conv2d_13'); 

      // get 1 specific output from a new model build from the original
      const myTracker_truncated = tf.model({inputs: myTracker.inputs, outputs: layer3.output});
      const outputs = myTracker_truncated.predict(inputs).arraySync();  
      return outputs; 
});

如果我检查 (console.log(myTracker)) 我会为我的输出得到这个结构:

outputNames: (3) […]
0: "conv2d_3
1: "conv2d_8"
2: "conv2d_13"
length: 3

outputs: (3) […]
0: Object { dtype: "float32", id: 662, originalName: "conv2d_3/conv2d_3", … }
1: Object { dtype: "float32", id: 665, originalName: "conv2d_8/conv2d_8", … }
2: Object { dtype: "float32", id: 668, originalName: "conv2d_13/conv2d_13", … }
length: 3

有人知道我想要实现的目标是否可行(在原始 keras 模型中没有将 3 个输出连接成 1 个)?

4

1 回答 1

0

在使用以下方法创建模型时,可以通过使用输出数组来获得这些输出tf.model

tf.model({inputs: myTracker.inputs, outputs: [layer3.output, layer8.output, layer13.output, ...]});

prediction = myTracker_truncated.predict(inputs).arraySync()

现在预测将是一个包含三个值的数组,分别对应于定义的层outputs

于 2020-02-20T13:14:21.100 回答