我正在使用 H2O DAI 1.9.0.6。我正在尝试在专家设置上加载自定义配方(使用自定义配方的 BERT 预保留模型)。我正在使用本地文件上传。但是上传没有发生。没有错误,没有进展,什么都没有。在那次活动之后,我无法在 RECIPE 选项卡下看到这个模型。
从下面的 URL 中获取示例食谱并根据我的需要进行修改。感谢创造这个食谱的人。
https://github.com/h2oai/driverlessai-recipes/blob/master/models/nlp/portuguese_bert.py
定制食谱
导入操作系统
进口舒蒂尔
从 urllib.parse 导入 urlparse
导入请求
从 h2oaicore.models 导入 TextBERTModel、CustomModel
从 h2oaicore.systemutils 导入 make_experiment_logger、temporary_files_path、atomic_move、loggerinfo
def is_url(url):
尝试:
结果 = urlparse(url)
全部返回([result.scheme,result.netloc,result.path])
除了:
返回假
def may_download_language_model(记录器,
保存目录,
模型链接,
配置链接,
词汇链接):
model_name = "pytorch_model.bin"
如果是实例(模型链接,str):
model_name = model_link.split('/')[-1]
如果 '.bin' 不在 model_name 中:
model_name = "pytorch_model.bin"
也许_下载(url = config_link,
dest=os.path.join(save_directory, "config.json"),
记录器=记录器)
也许_下载(网址=词汇链接,
dest=os.path.join(save_directory, "vocab.txt"),
记录器=记录器)
也许_下载(url=model_link,
dest=os.path.join(save_directory, model_name),
记录器=记录器)
def maybe_download(url, dest, logger=None):
如果不是 is_url(url):
loggerinfo(logger, f"{url} 不是有效的 URL。")
返回
dest_tmp = dest + ".tmp"
如果 os.path.exists(dest):
loggerinfo(logger, f"已经下载 {url} -> {dest}")
返回
如果 os.path.exists(dest_tmp):
loggerinfo(logger, f"下载已经开始 {url} -> {dest_tmp}。"
f"删除 {dest_tmp} 以再次下载文件。")
返回
loggerinfo(logger, f"正在下载 {url} -> {dest}")
url_data = requests.get(url, stream=True)
如果 url_data.status_code != requests.codes.ok:
msg = "无法获取 url %s,代码:%s,原因:%s" % (
str(url),str(url_data.status_code),str(url_data.reason))
提出 requests.exceptions.RequestException(msg)
url_data.raw.decode_content = True
如果不是 os.path.isdir(os.path.dirname(dest)):
os.makedirs(os.path.dirname(dest),exist_ok=True)
使用 open(dest_tmp, 'wb') 作为 f:
shutil.copyfileobj(url_data.raw, f)
atomic_move(dest_tmp, dest)
def check_correct_name(custom_name):
allowed_pretrained_models = ['bert','openai-gpt','gpt2','transfo-xl','xlnet','xlm-roberta',
'xlm'、'roberta'、'distilbert'、'卡门贝尔'、'ctrl'、'albert']
assert len([model_name for model_name in allowed_pretrained_models
if model_name in custom_name]), f"{custom_name} 需要包含名称" \
" 的预训练模型架构(例如 bert 或 xlnet)" \
“能够正确处理模型。”
类 CustomBertModel(TextBERTModel,CustomModel):
"""
用于使用预训练变压器模型的自定义模型类。
该类继承:
- 真正只是一个标签的 CustomModel。它可以确保 DAI 知道它是自定义模型。
- TextBERTModel 使自定义模型继承所有属性和方法。
支持的模型架构:
'bert'、'openai-gpt'、'gpt2'、'transfo-xl'、'xlnet'、'xlm-roberta'、
'xlm','罗伯塔','distilbert','卡门贝尔','ctrl','阿尔伯特'
如何使用:
- 您已经下载了权重、词汇和配置文件:
- 将 _model_path 设置为存储权重、词汇和配置文件的文件夹。
- 根据预训练架构设置_model_name(例如,bert-base-uncased)。
- 你想下载权重、词汇和配置文件:
- 相应地设置_model_link、_config_link 和_vocab_link。
- _model_path 是保存权重、词汇和配置文件的文件夹。
- 根据预训练架构设置_model_name(例如,bert-base-uncased)。
- 重要的:
_model_path 需要包含预训练模型架构的名称(例如,bert 或 xlnet)
才能正确加载模型。
- 在专家设置中禁用遗传算法。
"""
# _model_path 是保存权重、词汇和配置的目录的完整路径。
_model_name = NotImplemented # 将用于创建 MOJO
_model_path = 未实施
_model_link = 未实施
_config_link = 未实施
_vocab_link = 未实施
_booster_str = "pytorch 自定义"
# MOJO创作要求:
# _model_name 必须是其中之一
# bert-base-uncased, bert-base-multilingual-cased, xlnet-base-cased, roberta-base, distilbert-base-uncased
# vocab.txt 需要与 _model_name 中使用的 vocab.txt 相同(尚无自定义词汇表)。
_mojo = 假
@静态方法
def is_enabled():
return False # Abstract Base 模型不应该出现在模型中。
def _set_model_name(self, language_detected):
self.model_path = self.__class__._model_path
self.model_name = self.__class__._model_name
check_correct_name(self.model_path)
check_correct_name(self.model_name)
def fit(self, X, y, sample_weight=None, eval_set=None, sample_weight_eval_set=None, **kwargs):
记录器 = 无
如果 self.context 和 self.context.experiment_id:
logger = make_experiment_logger(experiment_id=self.context.experiment_id, tmp_dir=self.context.tmp_dir,
Experiment_tmp_dir=self.context.experiment_tmp_dir)
也许_下载_语言_模型(记录器,
save_directory=self.__class__._model_path,
model_link=self.__class__._model_link,
config_link=self.__class__._config_link,
vocab_link=self.__class__._vocab_link)
super().fit(X, y, sample_weight, eval_set, sample_weight_eval_set, **kwargs)
类 GermanBertModel(CustomBertModel):
_model_name = "bert-base-german-dbmdz-uncased"
_model_path = os.path.join(temporary_files_path, "german_bert_language_model/")
_model_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/pytorch_model.bin"
_config_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json"
_vocab_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"
_mojo = 真
@静态方法
def is_enabled():
返回真