我做了一个简单的脚本来尝试用 JAX 进行梯度累积。这个想法是让大批量大小(例如 64)分成适合 GPU 内存的小块(例如 4)。对于每个块,将存储在 pytree 中的结果梯度添加到当前批次梯度中。仅当计算大批量的所有块时才完成更新。在这个特定示例中,我们只是尝试将随机 512 维向量拟合到具有线性层的随机布尔值。这是脚本:
import jax
import jax.numpy as jnp
from jax import jit, random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal, zeros
from typing import Callable
from dataclasses import dataclass
@dataclass
class Jax_model:
init_fun: Callable
apply_fun: Callable
def Dense(input_size: int, output_size: int, init_kernel=normal(), init_bias=zeros):
def init_fun(key):
key, sub_key1, sub_key2 = jax.random.split(key, 3)
params = {
'I': init_kernel(sub_key1, (input_size, output_size) ),
'I_b': init_bias(sub_key2, (1,output_size) ),
}
return params
def apply_fun(params, inputs):
I, I_b, = params['I'], params['I_b']
logits = inputs @ I + I_b
return logits
return Jax_model(init_fun, apply_fun)
def divide_pytree(pytree, div):
for pt in jax.tree_util.tree_leaves(pytree):
pt = pt / div
return pytree
def add_pytrees(pytree1, pytree2):
for pt1, pt2 in zip( jax.tree_util.tree_leaves(pytree1), jax.tree_util.tree_leaves(pytree2) ):
pt1 = pt1 + pt2
return pytree1
rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50
model = Dense(model_dim, 1)
rng_key, sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)
@jit
def update(i, current_opt_state, current_batch):
N = current_batch[0].shape[0]
K = accumulation_size
num_gradients = N//K
accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
value, grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
value = value / num_gradients
grads = divide_pytree(grads, num_gradients)
for k in range(K,N,K):
accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
new_value, new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state), accumulation_batch)
value = value + (new_value / num_gradients)
grads = add_pytrees(grads, divide_pytree(new_grads, num_gradients))
return opt_update(i, grads, current_opt_state), value
def loss_func(current_params, current_batch):
inputs, labels = current_batch
predictions = model.apply_fun(current_params, inputs)
loss = jnp.square(labels-predictions).sum()
return loss
for i in range(n_iter):
rng_key, sub_key1, sub_key2 = random.split(rng_key, 3)
inputs = jax.random.uniform(sub_key1, (batch_size, model_dim))
labels = jax.random.uniform(sub_key2, (batch_size, 1)) > 0.5
batch = inputs, labels
opt_state, batch_loss = update(i, opt_state, batch)
print(i, batch_loss)
divide_pytree
我对and有疑问add_pytrees
。它是否真的修改了当前的批次梯度或者我错过了什么?此外,您是否看到此代码有任何速度问题?特别是,我应该使用jax.lax.fori_loop
in 代替传统的 python for 循环吗?
相关链接: