我正在学习 Jax,但遇到了一个奇怪的问题。如果我使用如下代码,
import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
net_init, net_apply = stax.serial(
Dense(40), Relu,
Dense(40), Relu,
Dense(40), Relu,
Dense(1)
)
rng = random.PRNGKey(0)
in_shape = (-1, 1,)
out_shape, params = net_init(rng, in_shape)
def loss(params, X, Y):
predictions = net_apply(params, X)
return jnp.mean((Y - predictions)**2)
@jit
def step(i, opt_state, x1, y1):
p = get_params(opt_state)
val, g = value_and_grad(loss)(p, x1, y1)
return val, opt_update(i, g, opt_state)
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)
val_his = []
for i in range(1000):
val, opt_state = step(i, opt_state, xrange_inputs, targets)
val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)
xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.cos(xrange_inputs)
predictions = vmap(partial(net_apply, params))(xrange_inputs)
losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()
神经网络可以cos(x)
很好地逼近函数。
但是如果我自己重写神经网络部分如下
import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
import numpy as np
from jax.experimental import optimizers
from jax.tree_util import tree_multimap
def initialize_NN(layers, key):
params = []
num_layers = len(layers)
keys = random.split(key, len(layers))
a = jnp.sqrt(0.1)
#params.append(a)
for l in range(0, num_layers-1):
W = xavier_init((layers[l], layers[l+1]), keys[l])
b = jnp.zeros((layers[l+1],), dtype=np.float32)
params.append((W,b))
return params
def xavier_init(size, key):
in_dim = size[0]
out_dim = size[1]
xavier_stddev = jnp.sqrt(2/(in_dim + out_dim))
return random.truncated_normal(key, -2, 2, shape=(out_dim, in_dim), dtype=np.float32)*xavier_stddev
def net_apply(params, X):
num_layers = len(params)
#a = params[0]
for l in range(0, num_layers-1):
W, b = params[l]
X = jnp.maximum(0, jnp.add(jnp.dot(X, W.T), b))
W, b = params[-1]
Y = jnp.dot(X, W.T)+ b
Y = jnp.squeeze(Y)
return Y
def loss(params, X, Y):
predictions = net_apply(params, X)
return jnp.mean((Y - predictions)**2)
key = random.PRNGKey(1)
layers = [1,40,40,40,1]
params = initialize_NN(layers, key)
@jit
def step(i, opt_state, x1, y1):
p = get_params(opt_state)
val, g = value_and_grad(loss)(p, x1, y1)
return val, opt_update(i, g, opt_state)
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)
xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.cos(xrange_inputs)
val_his = []
for i in range(1000):
val, opt_state = step(i, opt_state, xrange_inputs, targets)
val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)
predictions = vmap(partial(net_apply, params))(xrange_inputs)
losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()
我的神经网络将始终收敛到一个常数,这似乎被局部最小值所困。但是同样的神经网络和第一部分一样工作得很好。我真的很困惑。
唯一的区别应该是初始化、神经网络部分和参数的设置params
。我尝试了不同的初始化,这没有区别。不知是不是因为优化的设置params
不对,导致无法收敛。