0

我正在使用 H2O DAI 1.9.0.6。我正在尝试在专家设置上加载自定义配方(使用自定义配方的 BERT 预保留模型)。我正在使用本地文件上传。但是上传没有发生。没有错误,没有进展,什么都没有。在那次活动之后,我无法在 RECIPE 选项卡下看到这个模型。

从下面的 URL 中获取示例食谱并根据我的需要进行修改。感谢创造这个食谱的人。

https://github.com/h2oai/driverlessai-recipes/blob/master/models/nlp/portuguese_bert.py

定制食谱

导入操作系统
进口舒蒂尔
从 urllib.parse 导入 urlparse

导入请求

从 h2oaicore.models 导入 TextBERTModel、CustomModel
从 h2oaicore.systemutils 导入 make_experiment_logger、temporary_files_path、atomic_move、loggerinfo


def is_url(url):
    尝试:
        结果 = urlparse(url)
        全部返回([result.scheme,result.netloc,result.path])
    除了:
        返回假


def may_download_language_model(记录器,
                                  保存目录,
                                  模型链接,
                                  配置链接,
                                  词汇链接):
    model_name = "pytorch_model.bin"
    如果是实例(模型链接,str):
        model_name = model_link.split('/')[-1]
        如果 '.bin' 不在 model_name 中:
            model_name = "pytorch_model.bin"

    也许_下载(url = config_link,
                   dest=os.path.join(save_directory, "config.json"),
                   记录器=记录器)
    也许_下载(网址=词汇链接,
                   dest=os.path.join(save_directory, "vocab.txt"),
                   记录器=记录器)
    也许_下载(url=model_link,
                   dest=os.path.join(save_directory, model_name),
                   记录器=记录器)


def maybe_download(url, dest, logger=None):
    如果不是 is_url(url):
        loggerinfo(logger, f"{url} 不是有效的 URL。")
        返回

    dest_tmp = dest + ".tmp"
    如果 os.path.exists(dest):
        loggerinfo(logger, f"已经下载 {url} -> {dest}")
        返回

    如果 os.path.exists(dest_tmp):
        loggerinfo(logger, f"下载已经开始 {url} -> {dest_tmp}。"
        f"删除 {dest_tmp} 以再次下载文件。")
        返回

    loggerinfo(logger, f"正在下载 {url} -> {dest}")
    url_data = requests.get(url, stream=True)
    如果 url_data.status_code != requests.codes.ok:
        msg = "无法获取 url %s,代码:%s,原因:%s" % (
            str(url),str(url_data.status_code),str(url_data.reason))
        提出 requests.exceptions.RequestException(msg)
    url_data.raw.decode_content = True
    如果不是 os.path.isdir(os.path.dirname(dest)):
        os.makedirs(os.path.dirname(dest),exist_ok=True)
    使用 open(dest_tmp, 'wb') 作为 f:
        shutil.copyfileobj(url_data.raw, f)

    atomic_move(dest_tmp, dest)


def check_correct_name(custom_name):
    allowed_pretrained_models = ['bert','openai-gpt','gpt2','transfo-xl','xlnet','xlm-roberta',
                                 'xlm'、'roberta'、'distilbert'、'卡门贝尔'、'ctrl'、'albert']
    assert len([model_name for model_name in allowed_pretrained_models
                if model_name in custom_name]), f"{custom_name} 需要包含名称" \
        " 的预训练模型架构(例如 bert 或 xlnet)" \
        “能够正确处理模型。”


类 CustomBertModel(TextBERTModel,CustomModel):
    """
    用于使用预训练变压器模型的自定义模型类。
    该类继承:
      - 真正只是一个标签的 CustomModel。它可以确保 DAI 知道它是自定义模型。
      - TextBERTModel 使自定义模型继承所有属性和方法。
    支持的模型架构:
    'bert'、'openai-gpt'、'gpt2'、'transfo-xl'、'xlnet'、'xlm-roberta'、
    'xlm','罗伯塔','distilbert','卡门贝尔','ctrl','阿尔伯特'
    如何使用:
        - 您已经下载了权重、词汇和配置文件:
            - 将 _model_path 设置为存储权重、词汇和配置文件的文件夹。
            - 根据预训练架构设置_model_name(例如,bert-base-uncased)。
        - 你想下载权重、词汇和配置文件:
            - 相应地设置_model_link、_config_link 和_vocab_link。
            - _model_path 是保存权重、词汇和配置文件的文件夹。
            - 根据预训练架构设置_model_name(例如,bert-base-uncased)。
        - 重要的:
          _model_path 需要包含预训练模型架构的名称(例如,bert 或 xlnet)
          才能正确加载模型。
        - 在专家设置中禁用遗传算法。
    """

    # _model_path 是保存权重、词汇和配置的目录的完整路径。
    _model_name = NotImplemented # 将用于创建 MOJO
    _model_path = 未实施

    _model_link = 未实施
    _config_link = 未实施
    _vocab_link = 未实施
    _booster_str = "pytorch 自定义"

    # MOJO创作要求:
    # _model_name 必须是其中之一
    # bert-base-uncased, bert-base-multilingual-cased, xlnet-base-cased, roberta-base, distilbert-base-uncased
    # vocab.txt 需要与 _model_name 中使用的 vocab.txt 相同(尚无自定义词汇表)。
    _mojo = 假

    @静态方法
    def is_enabled():
        return False # Abstract Base 模型不应该出现在模型中。

    def _set_model_name(self, language_detected):
        self.model_path = self.__class__._model_path
        self.model_name = self.__class__._model_name
        check_correct_name(self.model_path)
        check_correct_name(self.model_name)

    def fit(self, X, y, sample_weight=None, eval_set=None, sample_weight_eval_set=None, **kwargs):
        记录器 = 无
        如果 self.context 和 self.context.experiment_id:
            logger = make_experiment_logger(experiment_id=self.context.experiment_id, tmp_dir=self.context.tmp_dir,
                                            Experiment_tmp_dir=self.context.experiment_tmp_dir)
        也许_下载_语言_模型(记录器,
                                      save_directory=self.__class__._model_path,
                                      model_link=self.__class__._model_link,
                                      config_link=self.__class__._config_link,
                                      vocab_link=self.__class__._vocab_link)
        super().fit(X, y, sample_weight, eval_set, sample_weight_eval_set, **kwargs)


类 GermanBertModel(CustomBertModel):
    _model_name = "bert-base-german-dbmdz-uncased"
    _model_path = os.path.join(temporary_files_path, "german_bert_language_model/")

                  
    _model_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/pytorch_model.bin"
    _config_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json"
    _vocab_link = "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"

    _mojo = 真

    @静态方法
    def is_enabled():
        返回真
4

1 回答 1

0

检查您的自定义食谱是否已is_enabled()返回True

    def is_enabled():
        return True
于 2021-02-08T02:28:02.750 回答