赛道检测模型训练部署全过程讲解

1.环境准备

1.1 conda安装

请参考:https://docs.anaconda.com/free/anaconda/install/linux/

1.2 pytorch安装

请参考:https://pytorch.org/get-started/locally/ 为保证python环境的纯净,请在conda环境中进行pytorch安装 执行命令创建环境

conda create -n pytorch python=3.8 -y 
conda activate pytorch

为保证通用性,此处选择CPU版本,使用conda的安装命令

conda install pytorch torchvision torchaudio cpuonly -c pytorch

pytoch安装验证:

python
import torch
x = torch.rand(5, 3)
print(x)

1.3 ai_toolchain安装

请参考:https://developer.horizon.ai/api/v1/fileData/documents_rdk/quant_toolchain_development/horizon_beginner/env_install.html#id4

#下载资料工具包
wget -c ftp://xj3ftp@vrftp.horizon.ai/ai_toolchain/ai_toolchain.tar.gz --ftp-password=xj3ftp@123$% 
wget -c ftp://xj3ftp@vrftp.horizon.ai/model_convert_sample/horizon_model_convert_sample.tar.gz --ftp-password=xj3ftp@123$%  

#创建conda环境 
conda create -n horizon_bpu python=3.8 -y 
conda activate horizon_bpu  

#安装相关功能包 
tar -xvf horizon_model_convert_sample.tar.gz 
tar -xvf ai_toolchain.tar.gz 
pip install ai_toolchain/h* -i https://mirrors.aliyun.com/pypi/simple 
pip install pycocotools -i https://mirrors.aliyun.com/pypi/simple

ai_toolchain安装验证:

hb_mapper --help  
Usage: hb_mapper [OPTIONS] COMMAND [ARGS]...    
  hb_mapper is an offline model transform tool provided by horizon.  
Options:   
  --version   Show the version and exit.   
  -h, --help  Show this message and exit.  
Commands:   
  checker    check whether the model meet the requirements.   
  infer      inference and dump output feature as float vector.   
  makertbin  transform caffe model to quantization model, generate runtime...

2.数据集采集

采集数据步骤请参考nodehub功能包数据集采集,该功能将获取摄像头图片并保存在本地(板端)。取其中赛道检测的数据集,传输到PC上

scp -r track_image 用户名@ip地址:PC上的存放路径

3.数据集标注

数据集标注请使用labelme

#安装 
sudo apt install labelme  

#运行 
labelme

标定,打开数据集文件夹,在图片上右键,选择Create Point

点击图片中的赛道中点并命名,点击OK——>点击Save——>点击Next Image切换到下一张图片

重复以上步骤,若图片不符合需求,可直接跳过。标注文件会保存在与图片相同的路径下

4.数据集转化

在训练之前需要将数据转化为模型训练需要的数据集格式 将脚本放到与数据集文件夹同级目录下,运行转换脚本(脚本见附件)

#脚本支持设置数据集的读取位置以及转换之后数据集的存放位置 
python3 labelme2resnet.py

转化成功之后,在同级目录下会生成line_follow_dataset文件夹,存放转换后的数据集,文件结构如下

root@root-vpc:~/line_follow_dataset$ tree -L 2 
├── test 
│   ├── image 
│   └── label 
└── train  
    ├── image     
    └── label

5.训练

将转换生成的数据集放到resnet18文件夹下(resnet18文件见附件)

#进入conda环境 
conda activate pytorch  

#安装相关功能包 
cd resnet18 
pip install -r requirements.txt  

#训练 
python3 train.py

训练开始之后会输出训练损失和测试损失,当测试损失有所下降时将保存当前的模型

6. 模型转化

在进行模型量化前需要将模型格式转换为onnx格式

#转化 
python3 export.py

转换成功后会在相同的路径下生成onnx格式的模型

7.模型量化

取数据集中的100张图片放到文件夹model_convert/dataset/image文件夹下,将生成的onnx模型存放到model_convert下(model_convert文件见附件)

7.1 准备校准文件

#进入conda环境 
conda activete pytorch 
cd model_convert  

#运行图片转换脚本
python generate_calibration_data.py --dataset ./dataset -o ./calibration_data

运行成功之后会生成calibration_data文件夹

7.2 模型编译

#进入conda环境 
conda activete horizon_bpu 
cd model_convert  

#编译模型 
hb_mapper makertbin --config resnet18_config.yaml --model-type onnx

编译成功后,会在model_output/model_output路径下生成最终的模型文件

8.部署

请参考NodeHub赛道检测功能,安装相关功能包 将生成的.bin格式的模型拷贝到开发板上

ros2 run racing_track_detection_resnet racing_track_detection_resnet  --ros-args -p model_path:=模型路径

运行成功后如下

model_convert.tar labelme2resnet.py resnet18.tar

验证ai-lan安装的时候输入hb_mapper:未找到命令是怎么回事?

在运行train.py的时候报错,试过修改train.py文件的28行的”.jpg“为”.png“依然是同样的报错

请问避障的也适用吗? 避障的用这个模型转化可以吗?

在进行转化时遇到了错误,请问是为什么呢?

Traceback (most recent call last):

File “/home/ly/anaconda3/envs/pytorch/bin/hb_mapper”, line 5, in

from horizon_tc_ui.hb_mapper import main

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_tc_ui/__init__.py”, line 14, in

from .hb_onnxruntime import HB_ONNXRuntime

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_tc_ui/hb_onnxruntime.py”, line 9, in

from horizon_nn import horizon_onnx

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_nn/__init__.py”, line 6, in

from .build import build, build_caffe, build_onnx, check_caffe, check_onnx # noqa: F401

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_nn/build.py”, line 6, in

from horizon_nn import horizon_onnx

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_nn/horizon_onnx/__init__.py”, line 8, in

from .onnx_pb import *

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_nn/horizon_onnx/onnx_pb.py”, line 8, in

from .onnx_ml_pb2 import * # noqa

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/horizon_nn/horizon_onnx/onnx_ml_pb2.py”, line 32, in

_descriptor.EnumValueDescriptor(

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/google/protobuf/descriptor.py”, line 914, in __new__

_message.Message._CheckCalledFromGeneratedFile()

TypeError: Descriptors cannot be created directly.

If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.

If you cannot immediately regenerate your protos, some other possible workarounds are:

1. Downgrade the protobuf package to 3.20.x or lower.

2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

在pip resquirement.txt时遇到如下报错:

RROR: Exception:

Traceback (most recent call last):

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/urllib3/response.py”, line 438, in _error_catcher

yield

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/urllib3/response.py”, line 561, in read

data = self._fp_read(amt) if not fp_closed else b""

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/urllib3/response.py”, line 527, in _fp_read

return self._fp.read(amt) if amt is not None else self._fp.read()

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/cachecontrol/filewrapper.py”, line 98, in read

data: bytes = self.__fp.read(amt)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/http/client.py”, line 459, in read

n = self.readinto(b)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/http/client.py”, line 503, in readinto

n = self.fp.readinto(b)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/socket.py”, line 669, in readinto

return self._sock.recv_into(b)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/ssl.py”, line 1274, in recv_into

return self.read(nbytes, buffer)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/ssl.py”, line 1132, in read

return self._sslobj.read(len, buffer)

socket.timeout: The read operation timed out

During handling of the above exception, another exception occurred:

Traceback (most recent call last):

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/cli/base_command.py”, line 180, in exc_logging_wrapper

status = run_func(*args)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/cli/req_command.py”, line 245, in wrapper

return func(self, options, args)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/commands/install.py”, line 377, in run

requirement_set = resolver.resolve(

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/resolution/resolvelib/resolver.py”, line 179, in resolve

self.factory.preparer.prepare_linked_requirements_more(reqs)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/operations/prepare.py”, line 552, in prepare_linked_requirements_more

self._complete_partial_requirements(

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/operations/prepare.py”, line 467, in _complete_partial_requirements

for link, (filepath, _) in batch_download:

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/network/download.py”, line 183, in __call__

for chunk in chunks:

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/cli/progress_bars.py”, line 53, in _rich_progress_bar

for chunk in iterable:

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_internal/network/utils.py”, line 63, in response_chunks

for chunk in response.raw.stream(

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/urllib3/response.py”, line 622, in stream

data = self.read(amt=amt, decode_content=decode_content)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/urllib3/response.py”, line 587, in read

raise IncompleteRead(self._fp_bytes_read, self.length_remaining)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/contextlib.py”, line 131, in __exit__

self.gen.throw(type, value, traceback)

File “/home/ly/anaconda3/envs/pytorch/lib/python3.8/site-packages/pip/_vendor/urllib3/response.py”, line 443, in _error_catcher

raise ReadTimeoutError(self._pool, None, “Read timed out.”)

pip._vendor.urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host=‘files.pythonhosted.org’, port=443): Read timed out.

分享:提供另一种训练的思路:

1. 利用mipi摄像头获取640*360的nv12图片,送入X3的另一个VPS,得到224*224的图片,转为bgr存储到png进行打标。

2. 训练代码第46,47行将输入标准化为0~1的实数

x, y = float(value1/224.0), float(value2/224.0)

3. 部署时直接cam → vps → 224*224,这是硬件操作,10微秒即可完成,再送入BPU推理,乘以224即可还原得到坐标。

原因:车道线检测网络的是224*224的,所以精度只有224格,resize等操作可能是软件完成的(BPU张量前处理),所以不如直接送入VPS得到224*224的nv12,前处理几乎不耗时,BPU推理10ms,后处理乘以224即可得到关键点坐标,GC4663可刷满30fps,还可留下约70%的时间给其他操作

分享:若在准备校准数据集时出现如下报错:-
Traceback (most recent call last):

File “generate_calibration_data.py”, line 54, in

main(args)

File “generate_calibration_data.py”, line 42, in main

dataset = datasets.ImageFolder(args.dataset, transform=preprocess)

File “/home/sunrise/.local/lib/python3.8/site-packages/torchvision/datasets/folder.py”, line 309, in __init__

super().__init__(

File “/home/sunrise/.local/lib/python3.8/site-packages/torchvision/datasets/folder.py”, line 144, in __init__

classes, class_to_idx = self.find_classes(self.root)

File “/home/sunrise/.local/lib/python3.8/site-packages/torchvision/datasets/folder.py”, line 218, in find_classes

return find_classes(directory)

File “/home/sunrise/.local/lib/python3.8/site-packages/torchvision/datasets/folder.py”, line 40, in find_classes

classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

FileNotFoundError: [Errno 2] No such file or directory: ‘../dataset’-

可能是教程命令略有错误,使用如下命令即可:-
python generate_calibration_data.py --dataset ./dataset

-o ./calibration_data

分享:若训练自己的数据集时出现如下错误:-
Traceback (most recent call last):

File “train.py”, line 136, in

main()

File “train.py”, line 73, in main

train_loader = torch.utils.data.DataLoader(

File “/home/sunrise/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py”, line 349, in __init__

sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]

File “/home/sunrise/.local/lib/python3.8/site-packages/torch/utils/data/sampler.py”, line 140, in __init__

raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")

ValueError: num_samples should be a positive integer value, but got num_samples=0-

可能的原因是自己的训练集图片是.png格式,需要修改train.py文件的28行的”.jpg“为”.png“:

分享:若在运行resnet的export.py过程中发现报如下错:

Traceback (most recent call last):

File “export.py”, line 23, in

main()

File “export.py”, line 13, in main

torch.onnx.export(model,

File “/home/sunrise/.local/lib/python3.8/site-packages/torch/onnx/utils.py”, line 516, in export

_export(

File “/home/sunrise/.local/lib/python3.8/site-packages/torch/onnx/utils.py”, line 1670, in _export

proto = onnx_proto_utils._add_onnxscript_fn(

File “/home/sunrise/.local/lib/python3.8/site-packages/torch/onnx/_internal/onnx_proto_utils.py”, line 223, in _add_onnxscript_fn

raise errors.OnnxExporterError(“Module onnx is not installed!”) from e

torch.onnx.errors.OnnxExporterError: Module onnx is not installed!-

主要是未安装onnx模块导致,可采用以下命令安装:-
pip install onnx -i https://mirrors.aliyun.com/pypi/simple

你好博主我使用了你的训练代码最终得到的值是需要乘224嘛

您好,是pip install -r requirements.txt吧,然后看起来应该是访问不到,可以给pip换国内源试试哈

是在运行hb_mapper makertbin --config resnet18_config.yaml --model-type onnx命令时遇到的错误

障碍物检测使用的yolov5s模型哈,可以参考YOLOV5 在地平线RDK X3的高效部署 (horizon.cc),当然也可以试试直接标记点让小车避障

请确认一下是否有将转换生成的数据集放到resnet18文件夹下

麻烦添加以下更详细的说明哈

可以升级一下protobuf的版本看看

我发现安装相关安装包的时候有个依赖没法安装