使用 SimpleTransformers 在 Seq2Seq 中训练 MBART,但出现我在 BART 中没有看到的错误:
TypeError: shift_tokens_right() missing 1 required positional argument: 'decoder_start_token_id'
到目前为止,我已经尝试了各种组合
model.decoder_tokenizer.add_special_tokens({"bos_token": "<s>"})
这已经预先设置好了。使用 bos_token 以外的东西表示该令牌不是特殊令牌。
留下以下代码:
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs
# Model Config
model_args = Seq2SeqArgs()
model_args.do_sample = True
model_args.eval_batch_size = 4 # 64
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 2500
model_args.evaluate_during_training_verbose = True
model_args.fp16 = False # False
model_args.learning_rate = 5e-5
model_args.max_length = 128
model_args.max_seq_length = 128
model_args.num_beams = 10 # 0
model_args.num_return_sequences = 3
model_args.num_train_epochs = 2
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.save_eval_checkpoints = False
model_args.save_steps = -1
model_args.top_k = 50
model_args.top_p = 0.95
model_args.train_batch_size = 4 # 8
model_args.use_multiprocessing = False
model_ru = Seq2SeqModel(
encoder_decoder_type="mbart",
encoder_decoder_name="IlyaGusev/mbart_ru_sum_gazeta",
args=model_args,
use_cuda=True
)
# Add custom tokens
model_ru.encoder_tokenizer.add_tokens(["token1", "token2"])
# already set, as seen from: model_ru.decoder_tokenizer.bos_token
model_ru.decoder_tokenizer.add_special_tokens({"bos_token": "<s>"})
model_ru.model.resize_token_embeddings(len(model_ru.encoder_tokenizer))
model_ru.train_model(train, eval_data=dev)
这会引发以下错误:
/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:3407: FutureWarning:
`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
`__call__` method to prepare your inputs and the tokenizer under the `as_target_tokenizer` context manager to prepare
your targets.
Here is a short example:
model_inputs = tokenizer(src_texts, ...)
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt_texts, ...)
model_inputs["labels"] = labels["input_ids"]
See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
For a more complete example, see the implementation of `prepare_seq2seq_batch`.
warnings.warn(formatted_warning, FutureWarning)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_1538/3709317111.py in <module>
15 model_ru.model.resize_token_embeddings(len(model_ru.encoder_tokenizer))
16
---> 17 model_ru.train_model(train_tydiqa_ru, eval_data=dev_tydiqa_ru)
18
19 # Evaluation and training loss can also be found WandB
5 frames
/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_model.py in train_model(self, train_data, output_dir, show_running_loss, args, eval_data, verbose, **kwargs)
433 self._move_model_to_device()
434
--> 435 train_dataset = self.load_and_cache_examples(train_data, verbose=verbose)
436
437 os.makedirs(output_dir, exist_ok=True)
/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_model.py in load_and_cache_examples(self, data, evaluate, no_cache, verbose, silent)
1489 if args.model_type in ["bart", "mbart", "marian"]:
1490 return SimpleSummarizationDataset(
-> 1491 encoder_tokenizer, self.args, data, mode
1492 )
1493 else:
/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in __init__(self, tokenizer, args, data, mode)
423 else:
424 self.examples = [
--> 425 preprocess_fn(d) for d in tqdm(data, disable=args.silent)
426 ]
427
/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in <listcomp>(.0)
423 else:
424 self.examples = [
--> 425 preprocess_fn(d) for d in tqdm(data, disable=args.silent)
426 ]
427
/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in preprocess_data_mbart(data)
359 decoder_input_ids,
360 tokenizer.pad_token_id,
--> 361 tokenizer.lang_code_to_id[args.tgt_lang],
362 )
363
/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in <lambda>(input_ids, pad_token_id, decoder_start_token_id)
30 shift_tokens_right = (
31 lambda input_ids, pad_token_id, decoder_start_token_id: _shift_tokens_right(
---> 32 input_ids, pad_token_id
33 )
34 )
TypeError: shift_tokens_right() missing 1 required positional argument: 'decoder_start_token_id'