问题标签 [state-dict]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - RuntimeError:为生成器加载 state_dict 时出错:state_dict 中缺少键
gen.state_dict()
我试图使用 MNIST 数据集训练 DCGAN 模型,但在完成训练后无法加载。
我无法在此步骤中加载 gen state_dict:
这是错误:
pytorch - 如何加载多任务模型但只预测其中一项任务?
这是模型结构。为了方便超参数调整,我将几乎所有内容组合在一起。''' 类 MultiTaskDNN(nn.Module):
'''
如果结构看起来正确,也请告诉我。在训练这个具有 10 个任务(头)的多任务模型之后。我只想预测任务 7,即头 7。我应该如何加载模型并进行预测?谢谢你。
python - 使用文件名中的日期时间保存 Pytorch 模型 state_dict
torch.save(agent.qnetwork_local.state_dict(), filename)
我一直在尝试用where保存 Pytorch 模型的 state_dict
filename = datetime.now().strftime('%d-%m-%y-%H:%M_dqnweights.pth')
type(filename)
返回str
不应该有问题的torch.save()
,它应该输出一个非空文件。相反,我得到一个只有日期和时间的空文件,之后什么都没有。将日期和文件名放在中间会导致一个空文件,其中包含截止日期和时间之后的所有内容。
torch.save(agent.qnetwork_local.state_dict(), 'checkpoint1.pth')
并且任何时候我硬编码字符串都可以工作并给我预期的非空文件。
发生了什么事,我该如何解决?
我在 Windows 10 上使用 Pytorch v1.8.1+cpu 在 Python v3.6.8 virtualenv 中运行此代码。
pytorch - RuntimeError:为 DataParallel 加载 state_dict 时出错:state_dict 中出现意外键:“module.scibert_layer.embeddings.position_ids”
尝试加载已保存的模型检查点(.pth
文件)时出现以下错误。
nn.DataParallel
我在(火炬版本)中训练了我的序列标记模型,1.7.0
但我试图在没有nn.DataParallel
(火炬版本1.9.0
)的情况下加载它。目前,我知道不使用nn.DataParallel
导致的问题RuntimeError: Error(s) in loading state_dict for DataParallel:
,但也可能是因为我使用了不同版本的火炬或训练和加载模型检查点?
该模型nn.DataParallel
使用以下代码块进行包装。
这是我的模型。
我应该如何继续正确包装检查点,nn.DataParallel
或者我应该使用可以解决此问题的正确版本的火炬?
我将不胜感激任何帮助或提示。
pytorch - 关于在函数中保存 state_dict/checkpoint (PyTorch)
我正在尝试实现以下函数来保存 model_state 检查点:
以前我只是通过运行一个 for 循环来做同样的事情:
有没有办法在checkpoint={}
每个循环中启动和更新它?或者checkpoint={}
在每个时期都很好,因为模型本身持有state_dict()
. 只是我每次都覆盖检查点。
deep-learning - 类型错误:load_state_dict() 缺少 1 个必需的位置参数:'state_dict'
加载自定义 model.pt 以进行推理时出现错误。错误是 TypeError:load_state_dict() 缺少 1 个必需的位置参数:'state_dict'。
model = get_model(model_path, model_type='UNet',problem_type='parts')
这是 Unet 模型
从火炬导入火炬 从火炬视觉.models.vgg 导入nn 导入vgg16_bn
pytorch - model.load_state_dict 没有给出相同的评估结果
我正在做一个项目,我最初训练模型并存储最佳模型。为了评估经过训练的模型,我尝试使用 model.load_state_dict(torch.load(path)),其中路径指定检查点。当我在 gpu 机器上运行该模型时,该模型没有给出一致的结果,但代码在 cpu 的本地机器上运行良好。以前有人遇到过类似的问题吗?
pytorch - 当来自 ckpt 和模型的 state_dict 的键相同时,使用 load_state_dict() 时出现意外键
我正在尝试加载一个略有不同的预训练 resnet56 模型。pretrainde 模型在构建时是普通的 resnet 模型,而我要阅读的模型将整个阶段分为两部分,其中第一部分是 Sequential,其余部分放入列表中,我将它们命名为 normal resnet 做: 建立网络
现在我得到了这个意外的关键错误: 错误消息
RuntimeError:为 ResNetc 加载 state_dict 时出错:state_dict 中出现意外键:“layer2.5.conv1.weight”、“layer2.5.bn1.weight”、“layer2.5.bn1.bias”、 “layer2.5.bn1.running_mean”、“layer2.5.bn1.running_var”、“layer2.5.bn1.num_batches_tracked”、“layer2.5.conv2.weight”、“layer2.5.bn2.weight”、 “layer2.5.bn2.bias”、“layer2.5.bn2.running_mean”、“layer2.5.bn2.running_var”、“layer2.5.bn2.num_batches_tracked”、“layer2.6.conv1.weight”、 “layer2.6.bn1.weight”、“layer2.6.bn1.bias”、“layer2.6.bn1.running_mean”、“layer2.6.bn1.running_var”、“layer2.6.bn1.num_batches_tracked”、 “layer2.6.conv2.weight”、“layer2.6.bn2.weight”、“layer2.6.bn2.bias”、“layer2.6.bn2.running_mean”、“layer2.6.bn2.running_var”、“layer2.6.bn2.num_batches_tracked”、“layer2.7.conv1.weight”、“layer2.7.bn1.weight”、 “layer2.7.bn1.bias”、“layer2.7.bn1.running_mean”、“layer2.7.bn1.running_var”、“layer2.7.bn1.num_batches_tracked”、“layer2.7.conv2.weight”、 “layer2.7.bn2.weight”、“layer2.7.bn2.bias”、“layer2.7.bn2.running_mean”、“layer2.7.bn2.running_var”、“layer2.7.bn2.num_batches_tracked”、 “layer2.8.conv1.weight”、“layer2.8.bn1.weight”、“layer2.8.bn1.bias”、“layer2.8.bn1.running_mean”、“layer2.8.bn1.running_var”、 “layer2.8.bn1.num_batches_tracked”,“layer2.8.conv2.weight”,“layer2.8.bn2.weight”,“layer2.8.bn2.bias”、“layer2.8.bn2.running_mean”、“layer2.8.bn2.running_var”、“layer2.8.bn2.num_batches_tracked”、“layer3.5.conv1.weight”、 “layer3.5.bn1.weight”、“layer3.5.bn1.bias”、“layer3.5.bn1.running_mean”、“layer3.5.bn1.running_var”、“layer3.5.bn1.num_batches_tracked”、 “layer3.5.conv2.weight”、“layer3.5.bn2.weight”、“layer3.5.bn2.bias”、“layer3.5.bn2.running_mean”、“layer3.5.bn2.running_var”、 “layer3.5.bn2.num_batches_tracked”、“layer3.6.conv1.weight”、“layer3.6.bn1.weight”、“layer3.6.bn1.bias”、“layer3.6.bn1.running_mean”、 “layer3.6.bn1.running_var”、“layer3.6.bn1.num_batches_tracked”、“layer3.6.conv2.weight”、“layer3.6.bn2.weight”、“layer3.6.bn2.bias”、“layer3.6.bn2.running_mean”、“layer3.6.bn2.running_var”、“layer3.6.bn2.num_batches_tracked”、 “layer3.7.conv1.weight”、“layer3.7.bn1.weight”、“layer3.7.bn1.bias”、“layer3.7.bn1.running_mean”、“layer3.7.bn1.running_var”、 “layer3.7.bn1.num_batches_tracked”、“layer3.7.conv2.weight”、“layer3.7.bn2.weight”、“layer3.7.bn2.bias”、“layer3.7.bn2.running_mean”、 “layer3.7.bn2.running_var”、“layer3.7.bn2.num_batches_tracked”、“layer3.8.conv1.weight”、“layer3.8.bn1.weight”、“layer3.8.bn1.bias”、 “layer3.8.bn1.running_mean”、“layer3.8.bn1.running_var”、“layer3.8.bn1.num_batches_tracked”、“layer3.8.conv2.weight”、“layer3.8.bn2.weight”、“layer3.8.bn2.bias”、“layer3.8.bn2.running_mean”、“layer3.8.bn2.running_var”、“ layer3.8.bn2.num_batches_tracked”。
我还尝试使用以下代码确保两个 state_dict 的键相同:
然后我得到了结果:
不在模型中时从 state_dict 打印键----------------------------
state_dict 中键的长度:344
模型中键的长度:344