使用one-click-mlflow安装 MLFlow 后,我使用在用户指南中找到的默认命令保存 pytorch 模型。您可以在下面找到命令:
mlflow.pytorch.log_model(net, artifact_path="model", pickle_module=pickle)
保存的神经网络很简单,这基本上是一个以Xavier初始化和双曲正切为激活函数的两层神经网络。
class Net(T.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = T.nn.Linear(n_features, 10)
self.hid2 = T.nn.Linear(10, 10)
self.oupt = T.nn.Linear(10, 1)
T.nn.init.xavier_uniform_(self.hid1.weight)
T.nn.init.zeros_(self.hid1.bias)
T.nn.init.xavier_uniform_(self.hid2.weight)
T.nn.init.zeros_(self.hid2.bias)
T.nn.init.xavier_uniform_(self.oupt.weight)
T.nn.init.zeros_(self.oupt.bias)
def forward(self, x):
z = T.tanh(self.hid1(x))
z = T.tanh(self.hid2(z))
z = self.oupt(z)
return z
Jupyter Notebook 中的每件事都运行良好。我可以记录指标和其他工件,但是当我保存模型时,我收到以下错误消息:
2021/10/13 09:21:00 WARNING mlflow.utils.requirements_utils: Found torch version (1.9.0+cu111) contains a local version label (+cu111). MLflow logged a pip requirement for this package as 'torch==1.9.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
2021/10/13 09:21:00 WARNING mlflow.utils.requirements_utils: Found torchvision version (0.10.0+cu111) contains a local version label (+cu111). MLflow logged a pip requirement for this package as 'torchvision==0.10.0' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`.
2021/10/13 09:21:01 ERROR mlflow.utils.environment: Encountered an unexpected error while inferring pip requirements (model URI: /tmp/tmpnl9dsoye/model/data, flavor: pytorch)
Traceback (most recent call last):
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/environment.py", line 212, in infer_pip_requirements
return _infer_requirements(model_uri, flavor)
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/requirements_utils.py", line 263, in _infer_requirements
modules = _capture_imported_modules(model_uri, flavor)
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/requirements_utils.py", line 221, in _capture_imported_modules
_run_command(
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/requirements_utils.py", line 173, in _run_command
raise MlflowException(msg)
mlflow.exceptions.MlflowException: Encountered an unexpected error while running ['/home/ucsky/.virtualenv/mymodel/bin/python', '/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/_capture_modules.py', '--model-path', '/tmp/tmpnl9dsoye/model/data', '--flavor', 'pytorch', '--output-file', '/tmp/tmplyj0w2fr/imported_modules.txt', '--sys-path', '["/home/ucsky/project/ofi-ds-research/incubator/ofi-pe-fr/notebook/guillaume-simon/06-modelisation-pytorch", "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/git/ext/gitdb", "/usr/lib/python39.zip", "/usr/lib/python3.9", "/usr/lib/python3.9/lib-dynload", "", "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages", "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/IPython/extensions", "/home/ucsky/.ipython", "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/gitdb/ext/smmap"]']
exit status: 1
stdout:
stderr: Traceback (most recent call last):
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/_capture_modules.py", line 125, in <module>
main()
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/_capture_modules.py", line 118, in main
importlib.import_module(f"mlflow.{flavor}")._load_pyfunc(model_path)
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/pytorch/__init__.py", line 723, in _load_pyfunc
return _PyTorchWrapper(_load_model(path, **kwargs))
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/pytorch/__init__.py", line 626, in _load_model
return torch.load(model_path, **kwargs)
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/torch/serialization.py", line 607, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/torch/serialization.py", line 882, in _load
result = unpickler.load()
File "/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/torch/serialization.py", line 875, in find_class
return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/home/ucsky/.virtualenv/mymodel/lib/python3.9/site-packages/mlflow/utils/_capture_modules.py'>
有人可以解释一下有什么问题吗?