问题标签 [state-dict]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
1 回答
3851 浏览

python - RuntimeError:为生成器加载 state_dict 时出错:state_dict 中缺少键

gen.state_dict()我试图使用 MNIST 数据集训练 DCGAN 模型,但在完成训练后无法加载。

我无法在此步骤中加载 gen state_dict:

这是错误:

0 投票
0 回答
49 浏览

pytorch - 如何加载多任务模型但只预测其中一项任务?

这是模型结构。为了方便超参数调整,我将几乎所有内容组合在一起。''' 类 MultiTaskDNN(nn.Module):

'''

如果结构看起来正确,也请告诉我。在训练这个具有 10 个任务(头)的多任务模型之后。我只想预测任务 7,即头 7。我应该如何加载模型并进行预测?谢谢你。

0 投票
0 回答
101 浏览

pytorch - 如何找回只有权重字典的 pytorch 模型的架构?

  • 我想使用 多语言代码搜索模型,但首先代码不起作用并输出以下错误,表明它不能仅加载权重:
  • 然后我下载了 pytorch bin 文件,但它只包含权重字典(这里提到的状态字典),这意味着如果我想使用模型,我必须初始化好的架构,然后加载权重。

但是我应该如何找到适合如此复杂模型重量的架构?我看到一些方法可以根据权重字典找到模型,但我没有设法让它们工作(我想在这里输入链接描述)。

如何找回权重字典的架构以使模型工作?甚至可能吗?

0 投票
1 回答
218 浏览

python - 使用文件名中的日期时间保存 Pytorch 模型 state_dict


torch.save(agent.qnetwork_local.state_dict(), filename)我一直在尝试用where保存 Pytorch 模型的 state_dict
filename = datetime.now().strftime('%d-%m-%y-%H:%M_dqnweights.pth')

type(filename)返回str不应该有问题的torch.save(),它应该输出一个非空文件。相反,我得到一个只有日期和时间的空文件,之后什么都没有。将日期和文件名放在中间会导致一个空文件,其中包含截止日期和时间之后的所有内容。

torch.save(agent.qnetwork_local.state_dict(), 'checkpoint1.pth')并且任何时候我硬编码字符串都可以工作并给我预期的非空文件。

发生了什么事,我该如何解决?

我在 Windows 10 上使用 Pytorch v1.8.1+cpu 在 Python v3.6.8 virtualenv 中运行此代码。

0 投票
1 回答
665 浏览

pytorch - RuntimeError:为 DataParallel 加载 state_dict 时出错:state_dict 中出现意外键:“module.scibert_layer.embeddings.position_ids”

尝试加载已保存的模型检查点(.pth文件)时出现以下错误。

nn.DataParallel我在(火炬版本)中训练了我的序列标记模型,1.7.0但我试图在没有nn.DataParallel(火炬版本1.9.0)的情况下加载它。目前,我知道不使用nn.DataParallel导致的问题RuntimeError: Error(s) in loading state_dict for DataParallel:,但也可能是因为我使用了不同版本的火炬或训练和加载模型检查点?

该模型nn.DataParallel使用以下代码块进行包装。

这是我的模型。

我应该如何继续正确包装检查点,nn.DataParallel或者我应该使用可以解决此问题的正确版本的火炬?

我将不胜感激任何帮助或提示。

0 投票
1 回答
32 浏览

pytorch - 关于在函数中保存 state_dict/checkpoint (PyTorch)

我正在尝试实现以下函数来保存 model_state 检查点:

以前我只是通过运行一个 for 循环来做同样的事情:

有没有办法在checkpoint={}每个循环中启动和更新它?或者checkpoint={}在每个时期都很好,因为模型本身持有state_dict(). 只是我每次都覆盖检查点。

0 投票
0 回答
307 浏览

deep-learning - 类型错误:load_state_dict() 缺少 1 个必需的位置参数:'state_dict'

加载自定义 model.pt 以进行推理时出现错误。错误是 TypeError:load_state_dict() 缺少 1 个必需的位置参数:'state_dict'。

model = get_model(model_path, model_type='UNet',problem_type='parts')

这是 Unet 模型

从火炬导入火炬 从火炬视觉.models.vgg 导入nn 导入vgg16_bn

0 投票
0 回答
53 浏览

pytorch - model.load_state_dict 没有给出相同的评估结果

我正在做一个项目,我最初训练模型并存储最佳模型。为了评估经过训练的模型,我尝试使用 model.load_state_dict(torch.load(path)),其中路径指定检查点。当我在 gpu 机器上运行该模型时,该模型没有给出一致的结果,但代码在 cpu 的本地机器上运行良好。以前有人遇到过类似的问题吗?

0 投票
0 回答
42 浏览

pytorch - 当来自 ckpt 和模型的 state_dict 的键相同时,使用 load_state_dict() 时出现意外键

我正在尝试加载一个略有不同的预训练 resnet56 模型。pretrainde 模型在构建时是普通的 resnet 模型,而我要阅读的模型将整个阶段分为两部分,其中第一部分是 Sequential,其余部分放入列表中,我将它们命名为 normal resnet 做: 建立网络

现在我得到了这个意外的关键错误: 错误消息

RuntimeError:为 ResNetc 加载 state_dict 时出错:state_dict 中出现意外键:“layer2.5.conv1.weight”、“layer2.5.bn1.weight”、“layer2.5.bn1.bias”、 “layer2.5.bn1.running_mean”、“layer2.5.bn1.running_var”、“layer2.5.bn1.num_batches_tracked”、“layer2.5.conv2.weight”、“layer2.5.bn2.weight”、 “layer2.5.bn2.bias”、“layer2.5.bn2.running_mean”、“layer2.5.bn2.running_var”、“layer2.5.bn2.num_batches_tracked”、“layer2.6.conv1.weight”、 “layer2.6.bn1.weight”、“layer2.6.bn1.bias”、“layer2.6.bn1.running_mean”、“layer2.6.bn1.running_var”、“layer2.6.bn1.num_batches_tracked”、 “layer2.6.conv2.weight”、“layer2.6.bn2.weight”、“layer2.6.bn2.bias”、“layer2.6.bn2.running_mean”、“layer2.6.bn2.running_var”、“layer2.6.bn2.num_batches_tracked”、“layer2.7.conv1.weight”、“layer2.7.bn1.weight”、 “layer2.7.bn1.bias”、“layer2.7.bn1.running_mean”、“layer2.7.bn1.running_var”、“layer2.7.bn1.num_batches_tracked”、“layer2.7.conv2.weight”、 “layer2.7.bn2.weight”、“layer2.7.bn2.bias”、“layer2.7.bn2.running_mean”、“layer2.7.bn2.running_var”、“layer2.7.bn2.num_batches_tracked”、 “layer2.8.conv1.weight”、“layer2.8.bn1.weight”、“layer2.8.bn1.bias”、“layer2.8.bn1.running_mean”、“layer2.8.bn1.running_var”、 “layer2.8.bn1.num_batches_tracked”,“layer2.8.conv2.weight”,“layer2.8.bn2.weight”,“layer2.8.bn2.bias”、“layer2.8.bn2.running_mean”、“layer2.8.bn2.running_var”、“layer2.8.bn2.num_batches_tracked”、“layer3.5.conv1.weight”、 “layer3.5.bn1.weight”、“layer3.5.bn1.bias”、“layer3.5.bn1.running_mean”、“layer3.5.bn1.running_var”、“layer3.5.bn1.num_batches_tracked”、 “layer3.5.conv2.weight”、“layer3.5.bn2.weight”、“layer3.5.bn2.bias”、“layer3.5.bn2.running_mean”、“layer3.5.bn2.running_var”、 “layer3.5.bn2.num_batches_tracked”、“layer3.6.conv1.weight”、“layer3.6.bn1.weight”、“layer3.6.bn1.bias”、“layer3.6.bn1.running_mean”、 “layer3.6.bn1.running_var”、“layer3.6.bn1.num_batches_tracked”、“layer3.6.conv2.weight”、“layer3.6.bn2.weight”、“layer3.6.bn2.bias”、“layer3.6.bn2.running_mean”、“layer3.6.bn2.running_var”、“layer3.6.bn2.num_batches_tracked”、 “layer3.7.conv1.weight”、“layer3.7.bn1.weight”、“layer3.7.bn1.bias”、“layer3.7.bn1.running_mean”、“layer3.7.bn1.running_var”、 “layer3.7.bn1.num_batches_tracked”、“layer3.7.conv2.weight”、“layer3.7.bn2.weight”、“layer3.7.bn2.bias”、“layer3.7.bn2.running_mean”、 “layer3.7.bn2.running_var”、“layer3.7.bn2.num_batches_tracked”、“layer3.8.conv1.weight”、“layer3.8.bn1.weight”、“layer3.8.bn1.bias”、 “layer3.8.bn1.running_mean”、“layer3.8.bn1.running_var”、“layer3.8.bn1.num_batches_tracked”、“layer3.8.conv2.weight”、“layer3.8.bn2.weight”、“layer3.8.bn2.bias”、“layer3.8.bn2.running_mean”、“layer3.8.bn2.running_var”、“ layer3.8.bn2.num_batches_tracked”。

我还尝试使用以下代码确保两个 state_dict 的键相同:

然后我得到了结果:

不在模型中时从 state_dict 打印键----------------------------

state_dict 中键的长度:344

模型中键的长度:344