此教程参考:赛道检测模型训练部署全过程讲解,https://developer.horizon.cc/forumDetail/185446272545810434
在此篇教程的流程和代码上作了一些补充和说明,若使用02准备的数据集,03标注,04打散,则可以直接使用05训练即可,注意需要将DATASET_NMAE修改一致。
如果是第一次训练会自动下载预训练权重,约40MB,训练结束后会在当前目录下生成一个名为BEST_MODEL_PATH的模型文件,关于BATCH_SIZE和EPOCH参数设定可自行百度。
建议使用conda工具管理Python环境,训练ResNet在Windows环境或者Linux环境下均可,有无GPU影响不大,帖主的4800H训练一个Epoch只需要12秒,100个Epoch就差不多了,用CPU训练约20分钟就能训完。-
参考上文教程安装依赖和训练,基本上可以流畅跑通,在此仅仅对原文作几点补充。
-
代码第97~98行,限制了数据集的图片类型,需要根据自己采集的数据集图片格式修改。
self.image_paths = glob.glob(os.path.join(self.directory + “/image”, ‘*.png’))
-
代码第110~115行,直接读取浮点数存储的数据标签,数据标签采用yolo风格,在部署时无论图片经过拉伸或者放缩,都可以使用图片的长和宽乘以这个浮点数得到点的坐标,更加直观。
def getitem(self, idx):
image_path = self.image_paths[idx]
image = PIL.Image.open(image_path)
with open(os.path.join(self.directory + “/label”,
os.path.splitext(os.path.basename(image_path))[0]+“.txt”), ‘r’) as label_file:
content = label_file.read()
values = content.split()
if len(values) == 2:
value1 = float(values[0]) ## 主要修改在此处
value2 = float(values[1]) ## 主要修改在此处
else:
print(“文件格式不正确”)
x, y = value1, value2 ## 主要修改在此处
代码参考:(附件中可以下载)
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np
from threading import Thread
from time import time, sleep
DATASET_NMAE = "DataSet_3_1119" # 数据集名称
BEST_MODEL_PATH = './model_best.pth' # 最好的训练结果
BATCH_SIZE = 32
NUM_EPOCHS = 300 # 迭代次数
def main(args=None):
best_loss = 1e9
train_image = "./" + DATASET_NMAE + "/train/"
test_image = "./" + DATASET_NMAE + "/test/"
train_dataset = XYDataset(train_image)
test_dataset = XYDataset(test_image)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
# 创建ResNet18模型,这里选用已经预训练的模型,
# 更改fc输出为2,即x、y坐标值
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cpu')
model = model.to(device)
optimizer = optim.Adam(model.parameters())
for epoch in range(NUM_EPOCHS):
epoch_time_begin = time()
model.train()
train_loss = 0.0
for images, labels in iter(train_loader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = F.mse_loss(outputs, labels)
train_loss += float(loss)
loss.backward()
optimizer.step()
train_loss /= len(train_loader)
model.eval()
test_loss = 0.0
for images, labels in iter(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = F.mse_loss(outputs, labels)
test_loss += float(loss)
test_loss /= len(test_loader)
msgStr = "Epoch" + "\033[32;40m" + " %d " % epoch + "\033[0m"
msgStr += "-> time: \033[32;40m%.3f\033[0m s, train_loss: \033[32;40m%f\033[0m, test_loss: \033[32;40m%f\033[0m" % (
time() - epoch_time_begin, train_loss, test_loss)
if test_loss < best_loss:
msgStr += (" \033[31m" + " Saved" + "\033[0m")
torch.save(model.state_dict(), BEST_MODEL_PATH)
best_loss = test_loss
else:
msgStr += " Done"
print(msgStr)
class XYDataset(torch.utils.data.Dataset):
def __init__(self, directory, random_hflips=False):
self.directory = directory
self.random_hflips = random_hflips
self.image_paths = glob.glob(os.path.join(
self.directory + "/image", '*.png'))
self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = PIL.Image.open(image_path)
with open(os.path.join(self.directory + "/label", os.path.splitext(os.path.basename(image_path))[0]+".txt"), 'r') as label_file:
content = label_file.read()
values = content.split()
if len(values) == 2:
value1 = float(values[0])
value2 = float(values[1])
else:
print("文件格式不正确")
x, y = value1, value2
if self.random_hflips:
if float(np.random.rand(1)) > 0.5:
image = transforms.functional.hflip(image)
x = -x
image = self.color_jitter(image)
image = transforms.functional.resize(image, (224, 224))
image = transforms.functional.to_tensor(image)
image = image.numpy().copy()
image = torch.from_numpy(image)
image = transforms.functional.normalize(image,
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return image, torch.tensor([x, y]).float()
if __name__ == '__main__':
main()
觉得还不错就点个赞再走呗~
