TorchSharp 是对 Torch c++的封装,基本继承了c++的全部接口。但使用中会有一些小问题,需要特别注意一些。
- 语义分割(semantic segmentation)神经网络训练
训练的代码可以参考github里的官方代码 https://github.com/pytorch/vision/tree/main/references/segmentation
2.模型输出
官方代码的模型 默认输出是list 虽然可以强制输出script文件,但TorchSharp 调用后会报错”Expected Tensor but got GenericDict”.因此需要修改网络
export代码如下:
import torch
import torchvision
from torch import nn
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self._model = torchvision.models.segmentation.lraspp_mobilenet_v3_large(num_classes=2,
aux_loss=False,
pretrained=False)
#修改了输入,将输入改为单通道图片
#如果输入的是3通道 则不需要修改
for item in self._model.backbone.items():
item[1][0] = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
break
checkpoint = torch.load('model/model_119.pth', map_location='cpu')
self._model.load_state_dict(checkpoint['model'], strict=not True)
self._model.eval()
def forward(self, x):
# 修改了输出,将List修改为Tensor 并进行了 argmax 和 转float操作
result = self._model.forward(x)
return result["out"].argmax(1).flatten().float()
model = MyModel()
x = torch.rand(1,1, 240, 400)
predictions = model(x)
#使用torch.jit.trace 输出 script pt 网络文件
traced_script_module = torch.jit.trace(model, x)
traced_script_module.save('_seg.pt')
3.使用TorchSharp 进行预测
由于测试程序使用了opencvsharp,所以下面的代码使用了opencv来读取图片和简单的数据预处理,并使用opencv可视化输出
//导入网络
torch.jit.ScriptModule torch_model;
torch_model = torch.jit.load("_seg.pt");
//导入图片,并使用BlobFromImage 进行数据类型 维度 和归一化的转换
Mat temp = Cv2.ImRead("Pic_42633.bmp", ImreadModes.Grayscale);
Mat tensor_mat = OpenCvSharp.Dnn.CvDnn.BlobFromImage(temp, 1 / 255.0);
//初始化输入的tensor
float[] data_bytes = new float[tensor_mat.Total()];
Marshal.Copy(tensor_mat.Data, data_bytes, 0, data_bytes.Length);
torch.Tensor x = torch.tensor(data_bytes, torch.ScalarType.Float32);
//维度转换 和 normalize(进行normalize是因为官方代码里有这一步处理)
x = x.reshape(1, 1, 240, 400);
x = TorchSharp.torchvision.transforms.functional.normalize(x, new double[] { 0.485 }, new double[] { 0.229 });
DateTime date1 = DateTime.Now;
//进行预测
torch.Tensor _out = torch_model.forward(x);
DateTime date2 = DateTime.Now;
TimeSpan ts = date2 - date1;
Console.WriteLine("No. of Seconds (Difference) = {0}", ts.TotalMilliseconds);
Console.WriteLine(_out);
//使用opencv 将预测出的tensor输出为可视化的图片
Mat result_mat = new Mat(240, 400, MatType.CV_32FC1, _out.bytes.ToArray());
result_mat.ConvertTo(result_mat, MatType.CV_8UC1);
Cv2.ImWrite("mask.bmp", result_mat);