我正在尝试在 PyTorch 中修剪我的模型torch.nn.utils.prune
,它提供 2 个张量,
- 一个是原始重量和
- 另一个是包含 0 和 1 的掩码,可帮助我们关闭网络中的某些连接。
我已经尝试了这两种解决方案,但都没有提高推理速度:
- 使用剪枝后的网络来推断哪个会先关闭一些与掩码的连接,然后再运行推断。
- 使用掩码将原始权重归零,然后从 state_dict 中删除掩码以进行推断。
有没有办法通过模型张量和掩码来提高速度?与 0 的非零浮点数相乘不会比将 2 个浮点数相乘更快吗?
这是我的修剪功能和修剪速度计算过程:
def prune_net(net):
"""Prune 20% net's weights that have abs(value) approx. 0
Function that will be use when an iteration is reach
Args:
Return:
newnet (nn.Module): a newnet contain mask that help prune network's weight
"""
if not isinstance(net,nn.Module):
print('Invalid input. Must be nn.Module')
return
newnet = copy.copy(net)
modules_list = []
for name, module in newnet.named_modules():
if isinstance(module, torch.nn.Conv2d):
modules_list += [(module,'weight'),(module,'bias')]
if isinstance(module, torch.nn.Linear):
modules_list += [(module,'weight'),(module,'bias')]
prune.global_unstructured(
modules_list,
pruning_method=prune.L1Unstructured,
amount=0.2,)
return newnet
测试推理速度第一种情况:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import time
from torch.autograd import Variable
torch.set_default_tensor_type('torch.cuda.FloatTensor')
old_net = init_your_net()
new_net = prune_net(old_net)
new_net = prune_net(new_net)
old_net.eval()
new_net.eval()
old_net = old_net.cuda()
new_net = new_net.cuda()
dataset = load_your_dataset()
for i in range(100):
x = dataset[i]
x = x.cuda()
y = x.cuda()
#new infer
start_time = time.perf_counter()
detections = new_net(x).data
time_new += time.perf_counter() - start_time
#old infer
start_time = time.perf_counter()
detections = old_net(y).data
time_old += time.perf_counter() - start_time
print('old ',time_old)
print('new ', time_new)
测试推理速度第二种情况:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import time
from torch.autograd import Variable
torch.set_default_tensor_type('torch.cuda.FloatTensor')
old_net = init_your_net()
new_net = prune_net(old_net)
new_net = prune_net(new_net)
# Apply mask to model tensor and remove mask from state_dict
for name, module in new_net.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.remove(module,'weight')
prune.remove(module,'bias')
if isinstance(module, torch.nn.Linear):
prune.remove(module,'weight')
prune.remove(module,'bias')
old_net.eval()
new_net.eval()
old_net = old_net.cuda()
new_net = new_net.cuda()
dataset = load_your_dataset()
for i in range(100):
x = dataset[i]
x = x.cuda()
y = x.cuda()
#new infer
start_time = time.perf_counter()
detections = new_net(x).data
time_new += time.perf_counter() - start_time
#old infer
start_time = time.perf_counter()
detections = old_net(y).data
time_old += time.perf_counter() - start_time
print('old ',time_old)
print('new ', time_new)
更新
我发现torch有一个稀疏模块,如果我们修剪足够的参数,它可以减少内存使用,但它还不支持nn.Module,只有Tensor对象。这里有一些有用的链接:
https ://github.com/pytorch/pytorch/issues/36214#issuecomment-619586452
https://pytorch.org/docs/stable/sparse.html