1

使用 SimpleTransformers 在 Seq2Seq 中训练 MBART,但出现我在 BART 中没有看到的错误:

TypeError: shift_tokens_right() missing 1 required positional argument: 'decoder_start_token_id'

到目前为止,我已经尝试了各种组合

model.decoder_tokenizer.add_special_tokens({"bos_token": "<s>"})

这已经预先设置好了。使用 bos_token 以外的东西表示该令牌不是特殊令牌。

留下以下代码:

from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs


# Model Config
model_args = Seq2SeqArgs()
model_args.do_sample = True
model_args.eval_batch_size = 4  # 64
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 2500
model_args.evaluate_during_training_verbose = True
model_args.fp16 = False  # False
model_args.learning_rate = 5e-5
model_args.max_length = 128
model_args.max_seq_length = 128
model_args.num_beams = 10  # 0
model_args.num_return_sequences = 3
model_args.num_train_epochs = 2
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.save_eval_checkpoints = False
model_args.save_steps = -1
model_args.top_k = 50
model_args.top_p = 0.95
model_args.train_batch_size = 4  # 8
model_args.use_multiprocessing = False

model_ru = Seq2SeqModel(
    encoder_decoder_type="mbart",
    encoder_decoder_name="IlyaGusev/mbart_ru_sum_gazeta",
    args=model_args,
    use_cuda=True
)

# Add custom tokens
model_ru.encoder_tokenizer.add_tokens(["token1", "token2"])

# already set, as seen from: model_ru.decoder_tokenizer.bos_token
model_ru.decoder_tokenizer.add_special_tokens({"bos_token": "<s>"})

model_ru.model.resize_token_embeddings(len(model_ru.encoder_tokenizer))

model_ru.train_model(train, eval_data=dev)

这会引发以下错误:

/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:3407: FutureWarning: 
`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
`__call__` method to prepare your inputs and the tokenizer under the `as_target_tokenizer` context manager to prepare
your targets.

Here is a short example:

model_inputs = tokenizer(src_texts, ...)
with tokenizer.as_target_tokenizer():
    labels = tokenizer(tgt_texts, ...)
model_inputs["labels"] = labels["input_ids"]

See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
For a more complete example, see the implementation of `prepare_seq2seq_batch`.

  warnings.warn(formatted_warning, FutureWarning)

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

/tmp/ipykernel_1538/3709317111.py in <module>
     15 model_ru.model.resize_token_embeddings(len(model_ru.encoder_tokenizer))
     16 
---> 17 model_ru.train_model(train_tydiqa_ru, eval_data=dev_tydiqa_ru)
     18 
     19 # Evaluation and training loss can also be found WandB

5 frames

/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_model.py in train_model(self, train_data, output_dir, show_running_loss, args, eval_data, verbose, **kwargs)
    433         self._move_model_to_device()
    434 
--> 435         train_dataset = self.load_and_cache_examples(train_data, verbose=verbose)
    436 
    437         os.makedirs(output_dir, exist_ok=True)

/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_model.py in load_and_cache_examples(self, data, evaluate, no_cache, verbose, silent)
   1489                 if args.model_type in ["bart", "mbart", "marian"]:
   1490                     return SimpleSummarizationDataset(
-> 1491                         encoder_tokenizer, self.args, data, mode
   1492                     )
   1493                 else:

/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in __init__(self, tokenizer, args, data, mode)
    423             else:
    424                 self.examples = [
--> 425                     preprocess_fn(d) for d in tqdm(data, disable=args.silent)
    426                 ]
    427 

/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in <listcomp>(.0)
    423             else:
    424                 self.examples = [
--> 425                     preprocess_fn(d) for d in tqdm(data, disable=args.silent)
    426                 ]
    427 

/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in preprocess_data_mbart(data)
    359         decoder_input_ids,
    360         tokenizer.pad_token_id,
--> 361         tokenizer.lang_code_to_id[args.tgt_lang],
    362     )
    363 

/usr/local/lib/python3.7/dist-packages/simpletransformers/seq2seq/seq2seq_utils.py in <lambda>(input_ids, pad_token_id, decoder_start_token_id)
     30     shift_tokens_right = (
     31         lambda input_ids, pad_token_id, decoder_start_token_id: _shift_tokens_right(
---> 32             input_ids, pad_token_id
     33         )
     34     )

TypeError: shift_tokens_right() missing 1 required positional argument: 'decoder_start_token_id'
4

0 回答 0