0

我正在尝试恢复训练monkeyAI pytorch Retinanet。我已经加载了 .pt 文件而不是实际模型。更改在 Monk_Object_Detection/5_pytorch_retinanet/lib/train_detector.py 中进行,在其修改的地方检查“#change”。

def Model(self, model_name="resnet18",gpu_devices=[0]):
    '''
    User function: Set Model parameters

        Available Models
            resnet18
            resnet34
            resnet50
            resnet101
            resnet152

    Args:
        model_name (str): Select model from available models
        gpu_devices (list): List of GPU Device IDs to be used in training

    Returns:
        None
    '''

    num_classes = self.system_dict["local"]["dataset_train"].num_classes();
    if model_name == "resnet18":
        retinanet = model.resnet18(num_classes=num_classes, pretrained=True)
    elif model_name == "resnet34":
        retinanet = model.resnet34(num_classes=num_classes, pretrained=True)
    elif model_name == "resnet50":
        # retinanet = model.resnet50(num_classes=num_classes, pretrained=True)
        # change
        retinanet = torch.load('/content/drive/MyDrive/Object_detection_retinanet/trained_retinanet_40.pt')
    elif model_name == "resnet101":
        retinanet = model.resnet101(num_classes=num_classes, pretrained=True)
    elif model_name == "resnet152":
        retinanet = model.resnet152(num_classes=num_classes, pretrained=True)

    if self.system_dict["params"]["use_gpu"]:
        self.system_dict["params"]["gpu_devices"] = gpu_devices
        if len(self.system_dict["params"]["gpu_devices"])==1:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(self.system_dict["params"]["gpu_devices"][0])
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(id) for id in self.system_dict["params"]["gpu_devices"]])
        self.system_dict["local"]["device"] = 'cuda' if torch.cuda.is_available() else 'cpu'

        # change - added 3 lines below
        if isinstance(retinanet,torch.nn.DataParallel):
            retinanet = retinanet.module
        retinanet.load_state_dict(torch.load('/content/drive/MyDrive/Object_detection_retinanet/trained_retinanet_40.pt'))
        
        retinanet = retinanet.to(self.system_dict["local"]["device"])
        retinanet = torch.nn.DataParallel(retinanet).to(self.system_dict["local"]["device"])

    
    retinanet.training = True
    retinanet.train()
    retinanet.module.freeze_bn()

    self.system_dict["local"]["model"] = retinanet;

当我从主函数调用 Model() 时,出现属性错误,如下所示:

from train_detector import Detector
gtf = Detector() 

#Loading the dataset
root_dir = './'
coco_dir = 'coco_dir'
img_dir = 'images'
set_dir ='train'
gtf.Train_Dataset(root_dir, coco_dir, img_dir, set_dir, batch_size=8, use_gpu=True)

gtf.Model(model_name="resnet50", gpu_devices=[0, 1, 2, 3])

错误:

AttributeError                            Traceback (most recent call last)
<ipython-input-22-1a0c8d446904> in <module>()
      3 if PRE_TRAINED:
      4   #Initialising Model
----> 5   gtf.Model(model_name="resnet50", gpu_devices=[0, 1, 2, 3])
      6   #Setting up hyperparameters
      7   gtf.Set_Hyperparams(lr=0.001, val_interval=1, print_interval=20)

2 frames
/content/Monk_Object_Detection/5_pytorch_retinanet/lib/train_detector.py in Model(self, model_name, gpu_devices)
    245             if isinstance(retinanet,torch.nn.DataParallel):
    246                 retinanet = retinanet.module
--> 247             retinanet.load_state_dict(torch.load('/content/drive/MyDrive/Object_detection_retinanet/trained_retinanet_40.pt'))
    248 
    249             retinanet = retinanet.to(self.system_dict["local"]["device"])

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1453         # copy state_dict so _load_from_state_dict can modify it
   1454         metadata = getattr(state_dict, '_metadata', None)
-> 1455         state_dict = state_dict.copy()
   1456         if metadata is not None:
   1457             # mypy isn't aware that "_metadata" exists in state_dict

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in __getattr__(self, name)
   1176                 return modules[name]
   1177         raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1178             type(self).__name__, name))
   1179 
   1180     def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

AttributeError: 'DataParallel' object has no attribute 'copy'

请帮我解决问题!

4

1 回答 1

0

我通过简单地搜索您的问题发现了这一点:

retinanet.load_state_dict(torch.load('filename').module.state_dict())

讨论的链接在这里

于 2021-12-22T22:43:45.610 回答