0

我正在尝试使用 GradientTape 在没有 softmax 的情况下执行多类逻辑回归。似乎模型没有收敛,因此我的准确度为 0。它是标准的 MNIST 数据集,具有以下参数和代码:

num_classes = 10
num_features = 784

learning_rate = 0.0001
training_steps = 1000
batch_size = 256
display_step = 50

from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, y_train = map(list, zip(*[(x, y) for x, y in zip(x_train, y_train) if y in range(0, num_classes)]))
x_test, y_test = map(list, zip(*[(x, y) for x, y in zip(x_test, y_test) if y in range(0, num_classes)]))

x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)

x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])

x_train, x_test = x_train/255., x_test/255.

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

b = tf.Variable(tf.ones((num_features, )) * 0.000001, name = "weight")
b0 = tf.Variable(0., name = "bias”)

def logistic_regression(x, b, b0):
  return 1. / (1. + tf.exp(-tf.reduce_sum(tf.multiply(x, b), axis = 1) - b0))

def loglikelihood(p, y_true):
  y_true = tf.cast(y_true, tf.float32)
  return y_true * tf.math.log(p) + (1 - y_true) * tf.math.log(1 - p)

def accuracy(y_pred, y_true):
  correct_prediction = tf.equal(tf.round(y_pred), tf.cast(y_true, tf.float32))
  return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
  with tf.GradientTape() as g:
    g.watch([b, b0])
    p = logistic_regression(batch_x, b, b0)
    ll = loglikelihood(p, batch_y)
    ll_sum = tf.reduce_mean(ll)
  grad_b, grad_b0 = g.gradient(ll_sum, [b, b0])
  b = b + learning_rate * grad_b
  b0 = b0 + learning_rate * grad_b0

  if step % display_step == 0:
    p = logistic_regression(batch_x, b, b0)
    acc = accuracy(p, batch_y)
    print("step: %i, accuracy: %f" % (step, acc))

p = logistic_regression(x_test, b, b0)
print("Test Accuracy: %f" % accuracy(p, y_test))

谁能解释我可能是什么原因?

4

0 回答 0