0

我正在尝试按照此示例微调预训练模型。目的是对蛋白质序列进行二元分类。
我的代码如下所示:

from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial, Sigmoid
from jax.nn import relu, sigmoid
from jax.experimental.stax import elementwise
import jax
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
import jax.numpy as jnp
import numpy as np
from jax.experimental.optimizers import adam
from jax import grad, jit
from jax_unirep.utils import seq_to_oh, load_params


train_h_seqs = [
    "EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWIDPDEGDTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARLASGFRDYWGQGTLVTVSS------------------------------------------------------",
    "EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLEWIASISAGGDKTYYGDSVKGRFSISRDNAKTTHYLQMDSLRSEDTATYYCAKTSRVYFDYWGQGVMVTVSS------------------------------------------------------"
]
train_labels = [0, 1]

test_h_seqs = ["QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLEWIGRIDPSDNETHYNQDFKDKVTLTVDKSSSTVYIQLSSLTSEDSAVYYCGRLGYVYGFDYWGQGTTLTVSS-----------------------------------------------------"]
test_labels = [0]

### MODEL ###
init_fun, apply_fun = serial(
    AAEmbedding(10),
    mLSTM(1900),
    mLSTMAvgHidden(),
    Dense(512), 
    elementwise(relu),
    Dense(1)
)

batch_size = 1
num_classes = 2
# 173 is length of sequence
# 26 is size of alphabet (sequences are one-hot encoded)
input_shape = (batch_size, 173, 26)
step_size = 0.1
num_steps = 10

params = load_params(paper_weights=1900)

def loss(params, batch):
    inputs, targets = batch
    logits = apply_fun(params, inputs)
    log_p = jax.nn.log_sigmoid(logits)
    log_not_p = jax.nn.log_sigmoid(1-logits)
    res = -targets * log_p - (1. - targets) * log_not_p
    return res
    # I get an error here during training
    # The error can be bypassed by the line below, but that is not correct for batch size = 1
    # return res.mean()


def get_batches():
    # one-hot encode sequences
    oh_seqs = [seq_to_oh(seq) for seq in train_h_seqs]
    labels = train_labels
    num_batches = len(labels) // batch_size
    for i in range(num_batches):
        x = np.asarray(oh_seqs[i]) 
        yield x, np.asarray(labels[i*batch_size : (i+1)*batch_size])

    
opt_init, opt_update, get_params = adam(step_size)
batches = get_batches()

@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    l = loss(params, batch)
    print(l.shape)
    return opt_update(i, grad(loss)(params, batch), opt_state)


opt_state = opt_init(params)

# Optimization
for i in range(num_steps):
    opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)

# Testing
test_oh_seqs = [seq_to_oh(seq) for seq in test_h_seqs]
pred = apply_fun(trained_params, test_oh_seqs[0])

print(pred.shape))
print(pred)


如果我理解正确,模型末尾的大小为 1 的密集层应该产生大小为 1 的输出,但是,最后两行的输出是:

(25,)
DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

为什么我得到 25 个数字而不是 1 个?我哪里出错了?

在损失函数的训练过程中已经出现了这个问题,它给出了以下错误: TypeError: Gradient only defined for scalar-output functions. Output had shape: (25,).

的来源jax_unirep可以在这里找到。

提前感谢您的任何建议。

4

0 回答 0