我一直在玩 tensorflow (CPU) 和一些语言建模 - 到目前为止它一直很成功 - 一切都很好。
但是在看到我的旧 CPU 慢慢地从所有模型训练中被杀死之后 - 我决定是时候最终从我的 RTX 2080 中得到一些使用了。我一直在遵循华盛顿大学的指南:。很快我让 tensorflow-gpu 运行起来,在一些轻量级预测和类似的东西上运行它。
但是当我开始运行 GPT2 语言模型时,我遇到了一些小问题。我首先对数据进行标记:
from tokenizers.models import BPE
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.normalizers import NFKC, Sequence
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
class BPE_token(object):
def __init__(self):
self.tokenizer = Tokenizer(BPE())
self.tokenizer.normalizer = Sequence([
NFKC()
])
self.tokenizer.pre_tokenizer = ByteLevel()
self.tokenizer.decoder = ByteLevelDecoder()
def bpe_train(self, paths):
trainer = BpeTrainer(vocab_size=50000, show_progress=True, inital_alphabet=ByteLevel.alphabet(), special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>"
])
self.tokenizer.train(trainer, paths)
def save_tokenizer(self, location, prefix=None):
if not os.path.exists(location):
os.makedirs(location)
self.tokenizer.model.save(location, prefix)
# ////////// TOKENIZE DATA ////////////
from pathlib import Pa th
import os# the folder 'text' contains all the files
paths = [str(x) for x in Path("./da_corpus/").glob("**/*.txt")]
tokenizer = BPE_token()# train the tokenizer model
tokenizer.bpe_train(paths)# saving the tokenized data in our specified folder
save_path = 'tokenized_data'
tokenizer.save_tokenizer(save_path)
上面的代码完美地工作并标记数据 - 就像使用 tensorflow (CPU) 一样。在对我的数据进行标记后,我开始训练我的模型 - 但在它开始之前,我得到以下 ImportError:
from transformers import GPT2Config, TFGPT2LMHeadModel, GPT2Tokenizer # loading tokenizer from the saved model path
ImportError: cannot import name 'TFGPT2LMHeadModel' from 'transformers' (unknown location)
Transformers 包似乎已正确安装在 site-packages 库中,我似乎能够使用其他变压器 - 但不是TFGPT2LMHeadModel 我已经阅读了 google 和hugging.co上的所有内容- 尝试了不同版本的 tensorflow-gpu、变压器、标记器和许多其他软件包 - 遗憾的是没有任何帮助。
套餐:
- 蟒蛇,3.7.1
- 张量流 2.1.0
- TensorFlow-GPU 2.1.0
- 基于 TensorFlow 的 2.1.0
- 张量流估计器 2.1.0
- 变形金刚 4.2.2
- 标记器 0.9.4
- cudnn 7.6.5
- cudatoolkit 10.1.243