JMU黑客松 - 05_train.py 开始训练激动人心的ResNet18!

此教程参考:赛道检测模型训练部署全过程讲解,https://developer.horizon.cc/forumDetail/185446272545810434

在此篇教程的流程和代码上作了一些补充和说明,若使用02准备的数据集,03标注,04打散,则可以直接使用05训练即可,注意需要将DATASET_NMAE修改一致。

如果是第一次训练会自动下载预训练权重,约40MB,训练结束后会在当前目录下生成一个名为BEST_MODEL_PATH的模型文件,关于BATCH_SIZE和EPOCH参数设定可自行百度。

建议使用conda工具管理Python环境,训练ResNet在Windows环境或者Linux环境下均可,有无GPU影响不大,帖主的4800H训练一个Epoch只需要12秒,100个Epoch就差不多了,用CPU训练约20分钟就能训完。-

参考上文教程安装依赖和训练,基本上可以流畅跑通,在此仅仅对原文作几点补充。

  1. 代码第97~98行,限制了数据集的图片类型,需要根据自己采集的数据集图片格式修改。

    self.image_paths = glob.glob(os.path.join(self.directory + “/image”, ‘*.png’))

  2. 代码第110~115行,直接读取浮点数存储的数据标签,数据标签采用yolo风格,在部署时无论图片经过拉伸或者放缩,都可以使用图片的长和宽乘以这个浮点数得到点的坐标,更加直观。

    def getitem(self, idx):
    image_path = self.image_paths[idx]
    image = PIL.Image.open(image_path)
    with open(os.path.join(self.directory + “/label”,
    os.path.splitext(os.path.basename(image_path))[0]+“.txt”), ‘r’) as label_file:
    content = label_file.read()
    values = content.split()
    if len(values) == 2:
    value1 = float(values[0]) ## 主要修改在此处
    value2 = float(values[1]) ## 主要修改在此处
    else:
    print(“文件格式不正确”)
    x, y = value1, value2 ## 主要修改在此处

代码参考:(附件中可以下载)

import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np
from threading import Thread
from time import time, sleep

DATASET_NMAE = "DataSet_3_1119"   # 数据集名称
BEST_MODEL_PATH = './model_best.pth'  # 最好的训练结果
BATCH_SIZE = 32
NUM_EPOCHS = 300            # 迭代次数


def main(args=None):
    best_loss = 1e9
    train_image = "./" + DATASET_NMAE + "/train/"
    test_image = "./" + DATASET_NMAE + "/test/"
    train_dataset = XYDataset(train_image)
    test_dataset = XYDataset(test_image)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )

    # 创建ResNet18模型,这里选用已经预训练的模型,
    # 更改fc输出为2,即x、y坐标值
    model = models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, 2)
    device = torch.device('cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters())

    for epoch in range(NUM_EPOCHS):
        epoch_time_begin = time()
        model.train()
        train_loss = 0.0
        for images, labels in iter(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = F.mse_loss(outputs, labels)
            train_loss += float(loss)
            loss.backward()
            optimizer.step()
        train_loss /= len(train_loader)

        model.eval()
        test_loss = 0.0
        for images, labels in iter(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = F.mse_loss(outputs, labels)
            test_loss += float(loss)
        test_loss /= len(test_loader)
        msgStr = "Epoch" + "\033[32;40m" + " %d " % epoch + "\033[0m"
        msgStr += "-> time: \033[32;40m%.3f\033[0m s,  train_loss: \033[32;40m%f\033[0m,  test_loss: \033[32;40m%f\033[0m" % (
            time() - epoch_time_begin, train_loss, test_loss)

        if test_loss < best_loss:
            msgStr += (" \033[31m" + " Saved" + "\033[0m")
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            best_loss = test_loss
        else:
            msgStr += " Done"
        print(msgStr)


class XYDataset(torch.utils.data.Dataset):
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(
            self.directory + "/image", '*.png'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = PIL.Image.open(image_path)
        with open(os.path.join(self.directory + "/label", os.path.splitext(os.path.basename(image_path))[0]+".txt"), 'r') as label_file:
            content = label_file.read()
            values = content.split()
            if len(values) == 2:
                value1 = float(values[0])
                value2 = float(values[1])
            else:
                print("文件格式不正确")
        x, y = value1, value2

        if self.random_hflips:
            if float(np.random.rand(1)) > 0.5:
                image = transforms.functional.hflip(image)
                x = -x

        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy().copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image,
                                                [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return image, torch.tensor([x, y]).float()


if __name__ == '__main__':
    main()

觉得还不错就点个赞再走呗~

05_train.py

佬,如果是height=224,width=640的话,训练的代码需要修改吗?

佬如何把你归一化后的代码反归一化呢‘

乘以224

+1

修改一下尾部全连接应该就行了