0

我正在尝试使用 TensorFlow Probabilities Dense Flip-Out 层构建贝叶斯 Softmax 回归模型。该模型正在 MNIST 数据集上进行训练。

TensorFlow 返回错误:

没有为任何变量提供梯度,请检查您的图表中不支持梯度的操作,在变量 [] 和 0x7f07fc127840 处的损失之间

我认为这是由于随机变量不是微分的,因此不存在梯度。但是,Tensorflow 在其网站上提供了此代码的清晰演示 -这里

有没有人对为什么会发生这种情况有一个全面的答案?

我的代码如下:

import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import numpy as np

mnist = tf.keras.datasets.mnist
tfe = tf.contrib.eager
ed = tfp.edward2

tf.reset_default_graph()
tf.enable_eager_execution()

(x_train, y_train), (x_test, y_test) = mnist.load_data()

pix_w = 28
pix_h = 28

def vis_pix(image):
    plt.imshow(image, cmap='Greys')
    plt.show()

def scale(x, min_val=0.0, max_val=255.0):
    x = tf.cast(x, tf.float32)
    return tf.div(tf.subtract(x, min_val), tf.subtract(max_val, min_val))

def create_dataset(x, y):
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    dataset = dataset.map(lambda x, y: (scale(x), tf.one_hot(y, 10)))
    dataset = dataset.map(lambda x, y: (tf.reshape(x, [pix_w * pix_h]), y))
    dataset = dataset.shuffle(10000).batch(30)
    return dataset

train_ds = create_dataset(x_train, y_train)

model = tf.keras.Sequential([
    tfp.layers.DenseFlipout(10)
])

optimizer = tf.train.AdamOptimizer()

def loss_fn(model, x, y):
    logits = model(x)
    neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(labels=y, 
logits=logits)
    kl = sum(model.losses)
    elbow_loss = neg_log_likelihood + kl
    return elbow_loss

def get_accuracy(model, x, y):
    logits = model(x)
    yhat = tf.argmax(logits, 1)
    is_correct = tf.equal(yhat, tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
    return accuracy

epochs = 1000
for (batch, (x, y)) in enumerate(train_ds):
    optimizer.minimize(lambda: loss_fn(model, x, y), 
global_step=tf.train.get_or_create_global_step())
    if batch % 10 == 0:
      acc = get_accuracy(model, x, y).numpy() * 100
      loss = loss_fn(model, x, y).numpy().mean()
      print("Iteration {}, loss: {:.3f}, train accuracy: 
{:.2f}%".format(batch, loss, acc))
    if batch > epochs:
        break
4

0 回答 0