0
import tensorflow as tf

new_model = tf.keras.models.load_model('saved_model/my_model_KNOCK_2_RMS')
new_model.get_weights()

在检索权重时如何检索偏差矩阵?还是有不同的方法来获得偏差矩阵?

4

2 回答 2

1

model.get_weights()返回网络的所有变量,包括偏差。

您可以遍历variables属性并过滤名称tf.Variable以仅获取偏差。

biases = [var for var in new_model.variables if "bias" in var.name]
于 2021-05-21T14:32:13.690 回答
0
new_model.variables    

[<tf.Variable 'dense/kernel:0' shape=(4, 4) dtype=float32, numpy=
     array([[ 1.8315854 ,  1.6919162 ,  2.107687  ,  2.1731293 ],
            [-0.4066143 ,  0.24807486, -0.34563315,  0.70929044],
            [ 0.5660119 , -0.39092124,  0.57988596,  0.7534707 ],
            [ 0.38233787,  0.09385393,  0.25826836,  0.28291   ]],
           dtype=float32)>,
     <tf.Variable 'dense/bias:0' shape=(4,) dtype=float32, numpy=array([-0.84321135, -0.93141776, -0.95930505, -0.36669353], dtype=float32)>,
     <tf.Variable 'dense_1/kernel:0' shape=(4, 10) dtype=float32, numpy=
     array([[ 1.7296976 ,  0.68885595,  0.47779882,  1.2458457 ,  1.3748846 ,
              1.0451635 , -0.05860029,  0.2059054 ,  1.1549207 , -0.14830673],
            [ 2.5991514 ,  1.7558911 , -0.65219647,  1.948722  ,  1.2213669 ,
              2.0473976 , -0.57937807,  0.16753212,  1.2601147 , -0.4593185 ],
            [ 1.8368433 ,  1.253897  , -0.5726242 ,  0.83214754,  0.85430264,
              1.3974545 ,  0.8320734 , -0.624181  ,  1.2566972 , -0.26695323],
            [ 0.8506076 ,  0.52599937, -0.31859252,  0.45206892,  0.727149  ,
              0.8375796 , -0.6014804 , -0.963803  ,  0.4475311 , -0.52666175]],
           dtype=float32)>,
     <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=
     array([-0.69184536, -0.64424694,  0.        , -0.7509772 , -0.7106489 ,
            -0.8235373 ,  0.5724167 ,  0.8609912 , -0.5833405 ,  0.        ],
           dtype=float32)>,
     <tf.Variable 'dense_2/kernel:0' shape=(10, 10) dtype=float32, numpy=
     array([[-0.43679273, -0.18084297, -0.19233215,  2.261455  ,  0.0789384 ,
             -0.57109463, -0.18623981, -0.13344003, -0.16805501, -0.08802319],
            [-0.44742152,  0.5294885 ,  0.44007984, -0.8000098 ,  0.43828738,
             -0.14650299,  0.3873879 ,  0.4865123 ,  0.7062375 ,  0.3959973 ],
            [ 0.42817426, -0.25457212, -0.06503531,  0.25967544, -0.36173528,
             -0.28454632,  0.5185325 ,  0.52907014, -0.50651705, -0.01279312],
            [-0.17378198,  0.61804473,  0.40438575, -0.84975696,  0.46335214,
              0.03958785,  0.59150505,  0.3461628 ,  0.5345084 ,  0.64246666],
            [-0.05633026,  0.19354568,  0.14509334, -0.65395653,  0.82471824,
             -0.33441678,  0.45360735,  0.3273876 ,  0.48969913,  0.5674778 ],
            [-0.08906014,  0.35384002,  0.3765939 , -0.55490863,  0.49258858,
             -0.20532486,  0.31845653,  0.23653099,  0.29804358,  0.32997373],
            [-0.44412076, -2.754143  , -3.0303166 , -1.8454257 , -0.30921584,
             -0.338421  , -2.954837  , -2.8979506 , -2.6388538 , -2.6987133 ],
            [-0.35004058, -1.8016477 , -1.896363  , -0.9093097 ,  0.01669006,
             -0.57078224, -2.097838  , -2.196825  , -2.2054377 , -2.0546691 ],
            [ 0.0622101 ,  0.56875587,  0.64062774, -0.8085089 ,  0.2828998 ,
             -0.01109717,  0.42925444,  0.46951735,  0.45354208,  0.44528142],
            [-0.17158392,  0.00582236,  0.05534732,  0.52236724,  0.24274498,
              0.06249171, -0.39602056, -0.49559176,  0.22903812, -0.2540841 ]],
           dtype=float32)>,
     <tf.Variable 'dense_2/bias:0' shape=(10,) dtype=float32, numpy=
     array([-0.4239517 ,  5.001588  ,  4.77233   , -0.1700763 ,  3.3235748 ,
            -0.36088318,  5.0669746 ,  4.534893  ,  5.156966  ,  5.2919803 ],
           dtype=float32)>,
     <tf.Variable 'dense_3/kernel:0' shape=(10, 1) dtype=float32, numpy=
     array([[-0.11622595],
            [ 1.4069948 ],
            [ 1.5280082 ],
            [ 4.962922  ],
            [ 1.3706445 ],
            [-0.09094736],
            [ 1.4304978 ],
            [ 1.7208712 ],
            [ 1.4002761 ],
            [ 1.2989398 ]], dtype=float32)>,
     <tf.Variable 'dense_3/bias:0' shape=(1,) dtype=float32, numpy=array([25.658356], dtype=float32)>]
于 2021-05-21T15:52:07.757 回答