我正在使用 H2O DAI 1.9.0.6。我正在尝试在专家设置上加载自定义配方(使用自定义配方的 BERT 预保留模型)。我正在使用本地文件上传。但是上传没有发生。没有错误,没有进展,什么都没有。在那次活动之后,我无法在 RECIPE 选项卡下看到这个模型。
从下面的 URL 中获取示例食谱并根据我的需要进行修改。感谢创造这个食谱的人。
https://github.com/h2oai/driverlessai-recipes/blob/master/models/nlp/portuguese_bert.py
定制食谱
导入操作系统 进口舒蒂尔 从 urllib.parse 导入 urlparse 导入请求 从 h2oaicore.models 导入 TextBERTModel、CustomModel 从 h2oaicore.systemutils 导入 make_experiment_logger、temporary_files_path、atomic_move、loggerinfo def is_url(url): 尝试: 结果 = urlparse(url) 全部返回([result.scheme,result.netloc,result.path]) 除了: 返回假 def may_download_language_model(记录器, 保存目录, 模型链接, 配置链接, 词汇链接): model_name = "pytorch_model.bin" 如果是实例(模型链接,str): model_name = model_link.split('/')[-1] 如果 '.bin' 不在 model_name 中: model_name = "pytorch_model.bin" 也许_下载(url = config_link, dest=os.path.join(save_directory, "config.json"), 记录器=记录器) 也许_下载(网址=词汇链接, dest=os.path.join(save_directory, "vocab.txt"), 记录器=记录器) 也许_下载(url=model_link, dest=os.path.join(save_directory, model_name), 记录器=记录器) def maybe_download(url, dest, logger=None): 如果不是 is_url(url): loggerinfo(logger, f"{url} 不是有效的 URL。") 返回 dest_tmp = dest + ".tmp" 如果 os.path.exists(dest): loggerinfo(logger, f"已经下载 {url} -> {dest}") 返回 如果 os.path.exists(dest_tmp): loggerinfo(logger, f"下载已经开始 {url} -> {dest_tmp}。" f"删除 {dest_tmp} 以再次下载文件。") 返回 loggerinfo(logger, f"正在下载 {url} -> {dest}") url_data = requests.get(url, stream=True) 如果 url_data.status_code != requests.codes.ok: msg = "无法获取 url %s,代码:%s,原因:%s" % ( str(url),str(url_data.status_code),str(url_data.reason)) 提出 requests.exceptions.RequestException(msg) url_data.raw.decode_content = True 如果不是 os.path.isdir(os.path.dirname(dest)): os.makedirs(os.path.dirname(dest),exist_ok=True) 使用 open(dest_tmp, 'wb') 作为 f: shutil.copyfileobj(url_data.raw, f) atomic_move(dest_tmp, dest) def check_correct_name(custom_name): allowed_pretrained_models = ['bert','openai-gpt','gpt2','transfo-xl','xlnet','xlm-roberta', 'xlm'、'roberta'、'distilbert'、'卡门贝尔'、'ctrl'、'albert'] assert len([model_name for model_name in allowed_pretrained_models if model_name in custom_name]), f"{custom_name} 需要包含名称" \ " 的预训练模型架构(例如 bert 或 xlnet)" \ “能够正确处理模型。” 类 CustomBertModel(TextBERTModel,CustomModel): """ 用于使用预训练变压器模型的自定义模型类。 该类继承: - 真正只是一个标签的 CustomModel。它可以确保 DAI 知道它是自定义模型。 - TextBERTModel 使自定义模型继承所有属性和方法。 支持的模型架构: 'bert'、'openai-gpt'、'gpt2'、'transfo-xl'、'xlnet'、'xlm-roberta'、 'xlm','罗伯塔','distilbert','卡门贝尔','ctrl','阿尔伯特' 如何使用: - 您已经下载了权重、词汇和配置文件: - 将 _model_path 设置为存储权重、词汇和配置文件的文件夹。 - 根据预训练架构设置_model_name(例如,bert-base-uncased)。 - 你想下载权重、词汇和配置文件: - 相应地设置_model_link、_config_link 和_vocab_link。 - _model_path 是保存权重、词汇和配置文件的文件夹。 - 根据预训练架构设置_model_name(例如,bert-base-uncased)。 - 重要的: _model_path 需要包含预训练模型架构的名称(例如,bert 或 xlnet) 才能正确加载模型。 - 在专家设置中禁用遗传算法。 """ # _model_path 是保存权重、词汇和配置的目录的完整路径。 _model_name = NotImplemented # 将用于创建 MOJO _model_path = 未实施 _model_link = 未实施 _config_link = 未实施 _vocab_link = 未实施 _booster_str = "pytorch 自定义" # MOJO创作要求: # _model_name 必须是其中之一 # bert-base-uncased, bert-base-multilingual-cased, xlnet-base-cased, roberta-base, distilbert-base-uncased # vocab.txt 需要与 _model_name 中使用的 vocab.txt 相同(尚无自定义词汇表)。 _mojo = 假 @静态方法 def is_enabled(): return False # Abstract Base 模型不应该出现在模型中。 def _set_model_name(self, language_detected): self.model_path = self.__class__._model_path self.model_name = self.__class__._model_name check_correct_name(self.model_path) check_correct_name(self.model_name) def fit(self, X, y, sample_weight=None, eval_set=None, sample_weight_eval_set=None, **kwargs): 记录器 = 无 如果 self.context 和 self.context.experiment_id: logger = make_experiment_logger(experiment_id=self.context.experiment_id, tmp_dir=self.context.tmp_dir, Experiment_tmp_dir=self.context.experiment_tmp_dir) 也许_下载_语言_模型(记录器, save_directory=self.__class__._model_path, model_link=self.__class__._model_link, config_link=self.__class__._config_link, vocab_link=self.__class__._vocab_link) super().fit(X, y, sample_weight, eval_set, sample_weight_eval_set, **kwargs) 类 GermanBertModel(CustomBertModel): _model_name = "bert-base-german-dbmdz-uncased" _model_path = os.path.join(temporary_files_path, "german_bert_language_model/") _model_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/pytorch_model.bin" _config_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json" _vocab_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt" _mojo = 真 @静态方法 def is_enabled(): 返回真