目录
在C++环境中加载一个TORCHSCRIP
Step1: 将pytorch模型转为torch scrip类型的模型
1.1、基于Tracing的方法来转换为Torch Script
1.2、基于Annotating(Script)的方法来转换为Torch Script
Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件
Step3: 在libtorch中加载ScriptModule模型
总结
在C++环境中加载一个TORCHSCRIP
一般地,类似python的脚本语言可用于算法快速实现、验证;但在产品化过程中,一般采用效率更高的C++语言,下面的工作就是将模型从python环境中移植到c++环境。
Step1: 将pytorch模型转为torch scrip类型的模型
通过TorchSript,我们可将pytorch模型从python转为c++。那么,什么是TorchScript呢?其实,它也是Pytorch模型的一种,这种模型能够被TorchScript的编译器识别读取、序列化。一般地,在处理模型过程中,我们都会先将模型转为torch script格式,例如:”.pt” -> “yolov5x.torchscript.pt”。转为torchscript格式有两种方法:一是函数torch.jit.trace;二是函数torch.jit.script。
torch.jit.trace原理:基于跟踪机制,需要输入一张图(0矩阵、张量亦可),模型会对输入的tensor进行处理,并记录所有张量的操作,torch::jit::trace能够捕获模型的结构、参数并保存。由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。
torch.jit.script原理:需要开发者先定义好神经网络模型结构,即:提前写好classMyModule(torch.nn.Module),这样TorchScript可以根据定义好的MyModule来解析网络结构。
1.1、基于Tracing的方法来转换为Torch Script
如下代码,给torch.jit.trace 函数输入一个指定size的随机张量、ResNet18的网络模型,得到一个类型为 torch.jit.ScriptModule 的对象,即:traced_script_module
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
经过上述处理,traced_script_module变量已经包含网络的结构和参数,可以直接用于推理,如下代码:
1 In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
2 In[2]: output[0, :5]
3 Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=)
1.2、基于Annotating(Script)的方法来转换为Torch Script
如果你的模型中有类似于控制流操作(例如:if or for循环),基于上述tracing的方式不再适用,这种方式会排上用场,下面以vanilla模型为例子,注:下面网络结构中有个if判断。
# 定义一个vanilla模型
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
这里调用 torch.jit.script 来获取torch.jit.ScriptModule 类型的对象,即:sm
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件
注:上述的tacing和script方法都将得到一个类型为torch.jit.ScriptModule的对象(这里简单记为:ScriptModule
),该对象就是常规的前向传播模块。不管是哪一种方法,此时,只需要将ScriptModule进行序列化保存就行。这里保存的是上述基于Tracing得到的ResNet推理模块traced_script_module。
traced_script_module.save("traced_resnet_model.pt") # 序列化,保存
# 保存后可用工具:https://netron.app/ 进行可视化
同理,如下是保存基于Annotating得到推理模块my_module后续,在libtorch中加载上述保存的模型文件就行,不再依赖任何python包。
my_module.save("my_module_model.pt") # 为什么不是sm
Step3: 在libtorch中加载ScriptModule模型
如何配置libtorh?,我这里仅贴下vs环境下的属性表:
1 include:
2 D:ThirdPartylibtorch-win-shared-with-deps-1.7.1+cu110libtorchinclude
4 D:ThirdPartylibtorch-win-shared-with-deps-1.7.1+cu110libtorchincludetorchcsrcapiinclude
7 lib:
8 D:ThirdPartylibtorch-win-shared-with-deps-1.7.1+cu110libtorchlib
9
11 链接器:
12 c10.lib
13 c10_cuda.lib
14 torch.lib
15 torch_cpu.lib
16 torch_cuda.lib
17
18 环境变量:
19 D:ThirdPartylibtorch-win-shared-with-deps-1.7.1+cu110libtorchlib
以下c++代码加载上述模型文件
#include
#include
#include
#include
int main()
{
torch::jit::script::Module module;
std::string str = "traced_resnet_model.pt";
try
{
module = torch::jit::load(str);
}
catch (const c10::Error& e)
{
std::cerr inputs;
inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
// 推理
at::Tensor output = module.forward(inputs).toTensor();
std::c服务器托管网out
总结
python模型的序列化、保存代码:
import torchvision
import torch
model = torchvision.models.resnet18()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1, 3, 224, 224))
#traced_script_module.save("traced_resnet_model.pt") # 和下面等价,格式名称不同,仅此而已,在libtorch中是一样的
traced_script_module.save("traced_resnet_model.torchscript.pt")
print()
libtorch的模型加载,推理代码:
#include
#include
#include
#include
int main()
{
torch::jit::script::Module module;
std::string str = "traced_resnet_model.pt";
//std::string str = "traced_resnet_model.torchscript.pt服务器托管网"; // 和上面等价,模型格式而已
try
{
module = torch::jit::load(str);
}
catch (const c10::Error& e)
{
std::cerr inputs;
inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
// 推理
at::Tensor output = module.forward(inputs).toTensor();
std::cout
服务器托管,北京服务器托管,服务器租用 http://www.fwqtg.net
相关推荐: Spring FrameWork从入门到NB – 定制Bean
Customizing the Nature of a Bean,最早准备跳过这部分内容,但是觉得这部分内容是Spring Bean生命周期中的一个重要部分,跳过了可能会影响通往NB之路,所以还是要认真学习一下。 Spring通过三种类型的接口实现对Bean行…