2

我正在学习 Jax,但遇到了一个奇怪的问题。如果我使用如下代码,

import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization

net_init, net_apply = stax.serial(
    Dense(40), Relu,
    Dense(40), Relu,
    Dense(40), Relu,
    Dense(1)
)
rng = random.PRNGKey(0)
in_shape = (-1, 1,)
out_shape, params = net_init(rng, in_shape)

def loss(params, X, Y):
    predictions = net_apply(params, X)
    return jnp.mean((Y - predictions)**2)

@jit
def step(i, opt_state, x1, y1):
    p = get_params(opt_state)
    val, g = value_and_grad(loss)(p, x1, y1)
    return val, opt_update(i, g, opt_state)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)

val_his = []
for i in range(1000):
    val, opt_state = step(i, opt_state, xrange_inputs, targets)
    val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)

xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.cos(xrange_inputs)
predictions = vmap(partial(net_apply, params))(xrange_inputs)
losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss

plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

神经网络可以cos(x)很好地逼近函数。

但是如果我自己重写神经网络部分如下

import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
import numpy as np
from jax.experimental import optimizers
from jax.tree_util import tree_multimap

def initialize_NN(layers, key):        
    params = []
    num_layers = len(layers)
    keys = random.split(key, len(layers))
    a = jnp.sqrt(0.1)
    #params.append(a)
    for l in range(0, num_layers-1):
        W = xavier_init((layers[l], layers[l+1]), keys[l])
        b = jnp.zeros((layers[l+1],), dtype=np.float32)
        params.append((W,b))
    return params

def xavier_init(size, key):
    in_dim = size[0]
    out_dim = size[1]      
    xavier_stddev = jnp.sqrt(2/(in_dim + out_dim))
    return random.truncated_normal(key, -2, 2, shape=(out_dim, in_dim), dtype=np.float32)*xavier_stddev
    
def net_apply(params, X):
    num_layers = len(params)
    #a = params[0]
    for l in range(0, num_layers-1):
        W, b = params[l]
        X = jnp.maximum(0, jnp.add(jnp.dot(X, W.T), b))
    W, b = params[-1]
    Y = jnp.dot(X, W.T)+ b
    Y = jnp.squeeze(Y)
    return Y
    
def loss(params, X, Y):
    predictions = net_apply(params, X)
    return jnp.mean((Y - predictions)**2)

key = random.PRNGKey(1)
layers = [1,40,40,40,1]
params = initialize_NN(layers, key)

@jit
def step(i, opt_state, x1, y1):
    p = get_params(opt_state)
    val, g = value_and_grad(loss)(p, x1, y1)
    return val, opt_update(i, g, opt_state)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)

xrange_inputs = jnp.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = jnp.cos(xrange_inputs)

val_his = []
for i in range(1000):
    val, opt_state = step(i, opt_state, xrange_inputs, targets)
    val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)

predictions = vmap(partial(net_apply, params))(xrange_inputs)
losses = vmap(partial(loss, params))(xrange_inputs, targets) # per-input loss

plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

我的神经网络将始终收敛到一个常数,这似乎被局部最小值所困。但是同样的神经网络和第一部分一样工作得很好。我真的很困惑。

唯一的区别应该是初始化、神经网络部分和参数的设置params。我尝试了不同的初始化,这没有区别。不知是不是因为优化的设置params不对,导致无法收敛。

4

1 回答 1

1

提问者似乎已经自己解决了这个问题。但是,我仍然想解释到底发生了什么,因为我遇到了这个提问者所面临的完全相同的问题。

确实,在提问者删除之前神经网络的尴尬行为 ,函数定义中的和Y = jnp.squeeze(Y)的形状实际上具有不同的形状:是一个列向量(大小为),而在“挤压”操作之后,由提问者他/她自己,成为一个行向量(大小)。Ypredictionsloss(params, X, Y)predictions(N, 1)Y(1, N)

在 NumPy 和 JAX 的 NumPy 中(其实 MATLAB 中也有),有一个特性叫做广播,用于数组操作。由于此功能,解释器将进行如下计算

\begin{pmatrix}
a_{1}\\
a_{2}\\
\vdots \\
a_{m}
\end{pmatrix} -\begin{pmatrix}
b_{1} & b_{2} & \cdots  & b_{n}
\end{pmatrix} =\begin{pmatrix}
a_{1} -b_{1} & a_{1} -b_{2} & \cdots  & a_{1} -b_{n}\\
a_{2} -b_{1} & a_{2} -b_{2} & \cdots  & a_{2} -b_{n}\\
\vdots  & \vdots  & \ddots  & \vdots \\
a_{m} -b_{1} & a_{m} -b_{2} & \cdots  & a_{m} -b_{n}
\end{pmatrix}

(这个公式应该被 LaTeX 解释)

因此,在提问者自己修正之前,Y-predictions 实际上是一个形状为 的矩阵(N, N),并对np.means()这个 N*N 矩阵中的所有条目进行平均,这当然不是人们想要计算的期望 MSELoss,并导致了奇怪的收敛行为提问者显示。

于 2022-01-22T09:42:13.143 回答