我正在尝试按照此示例微调预训练模型。目的是对蛋白质序列进行二元分类。
我的代码如下所示:
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial, Sigmoid
from jax.nn import relu, sigmoid
from jax.experimental.stax import elementwise
import jax
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
import jax.numpy as jnp
import numpy as np
from jax.experimental.optimizers import adam
from jax import grad, jit
from jax_unirep.utils import seq_to_oh, load_params
train_h_seqs = [
"EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWIDPDEGDTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARLASGFRDYWGQGTLVTVSS------------------------------------------------------",
"EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLEWIASISAGGDKTYYGDSVKGRFSISRDNAKTTHYLQMDSLRSEDTATYYCAKTSRVYFDYWGQGVMVTVSS------------------------------------------------------"
]
train_labels = [0, 1]
test_h_seqs = ["QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLEWIGRIDPSDNETHYNQDFKDKVTLTVDKSSSTVYIQLSSLTSEDSAVYYCGRLGYVYGFDYWGQGTTLTVSS-----------------------------------------------------"]
test_labels = [0]
### MODEL ###
init_fun, apply_fun = serial(
AAEmbedding(10),
mLSTM(1900),
mLSTMAvgHidden(),
Dense(512),
elementwise(relu),
Dense(1)
)
batch_size = 1
num_classes = 2
# 173 is length of sequence
# 26 is size of alphabet (sequences are one-hot encoded)
input_shape = (batch_size, 173, 26)
step_size = 0.1
num_steps = 10
params = load_params(paper_weights=1900)
def loss(params, batch):
inputs, targets = batch
logits = apply_fun(params, inputs)
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(1-logits)
res = -targets * log_p - (1. - targets) * log_not_p
return res
# I get an error here during training
# The error can be bypassed by the line below, but that is not correct for batch size = 1
# return res.mean()
def get_batches():
# one-hot encode sequences
oh_seqs = [seq_to_oh(seq) for seq in train_h_seqs]
labels = train_labels
num_batches = len(labels) // batch_size
for i in range(num_batches):
x = np.asarray(oh_seqs[i])
yield x, np.asarray(labels[i*batch_size : (i+1)*batch_size])
opt_init, opt_update, get_params = adam(step_size)
batches = get_batches()
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
l = loss(params, batch)
print(l.shape)
return opt_update(i, grad(loss)(params, batch), opt_state)
opt_state = opt_init(params)
# Optimization
for i in range(num_steps):
opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)
# Testing
test_oh_seqs = [seq_to_oh(seq) for seq in test_h_seqs]
pred = apply_fun(trained_params, test_oh_seqs[0])
print(pred.shape))
print(pred)
如果我理解正确,模型末尾的大小为 1 的密集层应该产生大小为 1 的输出,但是,最后两行的输出是:
(25,)
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
为什么我得到 25 个数字而不是 1 个?我哪里出错了?
在损失函数的训练过程中已经出现了这个问题,它给出了以下错误:
TypeError: Gradient only defined for scalar-output functions. Output had shape: (25,).
的来源jax_unirep
可以在这里找到。
提前感谢您的任何建议。