几天都无法处理这个问题,因为我是 NLP 新手,实际解决方案可能非常简单
class QAModel(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.model = MT5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
def forward(self, input_ids, attention_mask, labels=None):
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('train_loss', loss, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('val_loss', loss, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
loss, outputs = self(input_ids, attention_mask, labels)
self.log('test_loss', loss, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)
model = QAModel()
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
dirpath='/content/checkpoints',
filename='best-checkpoint',
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
)
trainer = pl.Trainer(
checkpoint_callback=checkpoint_callback,
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
运行此代码给我 AttributeError: 'QAModel' object has no attribute 'automatic_optimization' after fit() function 可能,问题出在 MT5ForConditionalGeneration 中,因为在将其传递给 funtion() 之后,我们遇到了同样的错误