1

我在 python 环境中的 Tensorflow 中创建了一个 UNet 模型,并通过 tfjs.converters.save_keras_model() 使用 tfjs 保存了它

模型总结如下。

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
sequential_43 (Sequential)      (None, 128, 128, 64) 3072        input_4[0][0]                    
__________________________________________________________________________________________________
sequential_44 (Sequential)      (None, 64, 64, 128)  131584      sequential_43[1][0]              
__________________________________________________________________________________________________
sequential_45 (Sequential)      (None, 32, 32, 256)  525312      sequential_44[0][0]              
__________________________________________________________________________________________________
sequential_46 (Sequential)      (None, 16, 16, 512)  2099200     sequential_45[0][0]              
__________________________________________________________________________________________________
sequential_47 (Sequential)      (None, 8, 8, 512)    4196352     sequential_46[0][0]              
__________________________________________________________________________________________________
sequential_48 (Sequential)      (None, 4, 4, 512)    4196352     sequential_47[0][0]              
__________________________________________________________________________________________________
sequential_49 (Sequential)      (None, 2, 2, 512)    4196352     sequential_48[0][0]              
__________________________________________________________________________________________________
sequential_50 (Sequential)      (None, 1, 1, 512)    4196352     sequential_49[0][0]              
__________________________________________________________________________________________________
sequential_51 (Sequential)      (None, 2, 2, 512)    4196352     sequential_50[0][0]              
__________________________________________________________________________________________________
concatenate_16 (Concatenate)    (None, 2, 2, 1024)   0           sequential_51[0][0]              
                                                                 sequential_49[0][0]              
__________________________________________________________________________________________________
sequential_52 (Sequential)      (None, 4, 4, 512)    8390656     concatenate_16[0][0]             
__________________________________________________________________________________________________
concatenate_17 (Concatenate)    (None, 4, 4, 1024)   0           sequential_52[0][0]              
                                                                 sequential_48[0][0]              
__________________________________________________________________________________________________
sequential_53 (Sequential)      (None, 8, 8, 512)    8390656     concatenate_17[0][0]             
__________________________________________________________________________________________________
concatenate_18 (Concatenate)    (None, 8, 8, 1024)   0           sequential_53[0][0]              
                                                                 sequential_47[0][0]              
__________________________________________________________________________________________________
sequential_54 (Sequential)      (None, 16, 16, 512)  8390656     concatenate_18[0][0]             
__________________________________________________________________________________________________
concatenate_19 (Concatenate)    (None, 16, 16, 1024) 0           sequential_54[0][0]              
                                                                 sequential_46[0][0]              
__________________________________________________________________________________________________
sequential_55 (Sequential)      (None, 32, 32, 256)  4195328     concatenate_19[0][0]             
__________________________________________________________________________________________________
concatenate_20 (Concatenate)    (None, 32, 32, 512)  0           sequential_55[0][0]              
                                                                 sequential_45[0][0]              
__________________________________________________________________________________________________
sequential_56 (Sequential)      (None, 64, 64, 128)  1049088     concatenate_20[0][0]             
__________________________________________________________________________________________________
concatenate_21 (Concatenate)    (None, 64, 64, 256)  0           sequential_56[0][0]              
                                                                 sequential_44[0][0]              
__________________________________________________________________________________________________
sequential_57 (Sequential)      (None, 128, 128, 64) 262400      concatenate_21[0][0]             
__________________________________________________________________________________________________
concatenate_22 (Concatenate)    (None, 128, 128, 128 0           sequential_57[0][0]              
                                                                 sequential_43[1][0]              
__________________________________________________________________________________________________
conv2d_transpose_26 (Conv2DTran (None, 256, 256, 3)  6147        concatenate_22[0][0]             
==================================================================================================
Total params: 54,425,859
Trainable params: 54,414,979
Non-trainable params: 10,880

当我检查 model.input_shape 它输出 - (None, 256, 256, 3)

因此,模型定义了输入形状,但是在将其加载到 tfjs 时,错误表明 Sequential 的第一层应该具有 input_shape 或 batch_input_shape

当我从 tfjs 转换后检查 model.json 时,它具有以下第一个输入层的配置

"config": 
    {
        "batch_input_shape": [null, 256, 256, 3],
        ...

那么,我应该怎么做才能在浏览器上加载 TensorFlowJS 中的模型呢?

4

0 回答 0