0

我浏览了 一个 tf2.0 的基本示例

包含非常简单的代码

from __future__ import absolute_import, division, print_function, unicode_literals
import os

import tensorflow as tf

import cProfile

# Fetch and format the mnist data
(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()

dataset = tf.data.Dataset.from_tensor_slices(
  (tf.cast(mnist_images[...,tf.newaxis]/255, tf.float32),
   tf.cast(mnist_labels,tf.int64)))
dataset = dataset.shuffle(1000).batch(32)

# Build the model
mnist_model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(16,[3,3], activation='relu',
                         input_shape=(None, None, 1)),
  tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(10)
])

for images,labels in dataset.take(1):
    print("Logits: ", mnist_model(images[0:1]).numpy())

optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

loss_history = []


def train_step(model, images, labels):

    with tf.GradientTape() as tape:
        logits = model(images, training=True)

        # Add asserts to check the shape of the output.
        tf.debugging.assert_equal(logits.shape, (32, 10))

        loss_value = loss_object(labels, logits)

    loss_history.append(loss_value.numpy().mean())
    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))


def train(epochs):
  for epoch in range(epochs):
    for (batch, (images, labels)) in enumerate(dataset):
      train_step(mnist_model, images, labels)
    print ('Epoch {} finished'.format(epoch))

我通过以下方式对其进行了训练并在前后保存了 trainable_variables


t0=mnist_model.trainable_variables  
train(epochs = 3)
t1=mnist_model.trainable_variables
diff = tf.reduce_mean(tf.abs(t0[0] - t1[0])) 
# whethere indexing [0] or [1] etc. gets the same outcome of diff
print(diff.numpy())

他们是一样的!!!那么我检查的地方不正确吗?如果是这种情况,我怎样才能正确观察这些更新的变量?

4

1 回答 1

0

您不是在创建新的变量数组,而是在同一个对象上创建 2 个指针尝试这样做

t0 = np.array(mnist_model.trainable_variables)
于 2019-12-24T00:52:17.403 回答