写点什么

OpenMMLab 图像分类实战代码演示

作者:IT蜗壳-Tango
  • 2023-02-10
    江苏
  • 本文字数:4499 字

    阅读完需:约 15 分钟

OpenMMLab图像分类实战代码演示

1. 实战简介

使用 MMClassification 训练花卉图片分类模型。

基于 MMClassification 提供的预训练模型,在 flowers 数据集上完成分类模型的微调训练。

1. 整理 flower 数据集

数据集介绍

flower 数据集包含 5 种类别的花卉图像:雏菊 daisy 588 张,蒲公英 dandelion 556 张,玫瑰 rose 583 张,向日葵 sunflower 536 张,郁金香 tulip 585 张。


数据集下载链接:

  • 国际网:https://www.dropbox.com/s/snom6v4zfky0flx/flower_dataset.zip?dl=0

  • 国内网:https://pan.baidu.com/s/1RJmAoxCD_aNPyTRX6w97xQ 提取码: 9x5u

对数据集进行划分

  1. 将数据集按照 8:2 的比例划分成训练和验证子数据集,并将数据集整理成 ImageNet 的格式。

  2. 将训练子集和验证子集放到 train 和 val 文件夹下。


文件结构如下:

flower_dataset|--- classes.txt|--- train.txt|--- val.txt|    |--- train|    |    |--- daisy|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- dandelion|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- rose|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- sunflower|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- tulip|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |--- val|    |    |--- daisy|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- dandelion|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- rose|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- sunflower|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...|    |    |--- tulip|    |    |    |--- NAME1.jpg|    |    |    |--- NAME2.jpg|    |    |    |--- ...
复制代码


  1. 创建并编辑标注文件将所有类别的名称写到 classes.txt 中,每行代表一个类别。

  2. 生成训练(可选)和验证子集标注列表 train.txtval.txt ,每行应包含一个文件名和其对应的标签。样例:


...daisy/NAME**.jpg 0daisy/NAME**.jpg 0...dandelion/NAME**.jpg 1dandelion/NAME**.jpg 1...rose/NAME**.jpg 2rose/NAME**.jpg 2...sunflower/NAME**.jpg 3sunflower/NAME**.jpg 3...tulip/NAME**.jpg 4tulip/NAME**.jpg 4
复制代码

整理完成后,将处理好的数据集迁移到 mmclassification/data 文件夹下。

2. 构建模型微调的配置文件

使用 _base_ 继承机制构建用于微调的配置文件,可以继承任何 MMClassification 提供的基于 ImageNet 的配置文件并进行修改。

对于新手不推荐这种方式,我后面会直接在文件上修改。

2. 环境搭建

由于我使用的是 Mac 电脑,内存不是很多,因此我将项目创建在一个固态的移动硬盘(T7)中。

前置环境:

  • Anaconda

  • Python:3.10.9

  • PyTorch:nightly

基础环境搭建

创建虚拟环境 MMCV

conda create -p /Volumes/T7/CodeSpace/OpenMMLab/mmcv_env python=3.10.9
复制代码



激活环境,并安装 PyTorch

conda activate /Volumes/T7/CodeSpace/OpenMMLab/mmcv_envconda install pytorch torchvision torchaudio -c pytorch
复制代码




为了后续的操作方便我们再安装一个 jupyter-lab

conda install jupyterlab
复制代码


为了后面代码的管理方便,我在当前目录下创建一个保存工程的目录 WorkSpace

mkdir WorkSpacecd WorkSpace
复制代码


启动 jupyterlab

jupyter-lab
复制代码


会自动打开一个浏览器,千万不要关闭当前这个命令行窗口,否则就会将 jupyter-lab 的环境结束。

OpenMMLab(mmcv)相关环境搭建

我们创建一个 Notebook


  • 安装 mim 工具

这个工具对后面很重要

!pip install -U openmim
复制代码

在 Notebook 中执行命令,要在命令前面加!

-i 是切换为国内镜像进行安装,这样安装速度会快一些。

  • 安装 mmcv

mim install mmcv
复制代码


  • 安装 mmclassification 相关工具(mmcls)

!mim install mmcls
复制代码


3. 开始真正的训练

mmclassification 已经为我们提供了很多训练好的模型,我们只要基于这些模型进行微调就可以完成我们这次的实战内容。

我们这次选择的预训练模型是MobileNetV2

下载预训练模型的相关配置文件等

既然要下载我们要知道,下载文件的名字,可以点击上面的链接进入到 GitHub

这里可以看到一些关于这个模型的信息,我们点击 config 就可以看到这个模型的名字。注意:下载时不需要带上.py 后缀。


下载命令如下:

mim download mmcls --config mobilenet-v2_8xb32_in1k --dest .
复制代码


下面这两个文件就是我们下载好的预训练模型(.pth)以及配置文件(.py)

新建一个 data 的文件,并将数据集保存在文件夹中,并解压。

flower_dataset 文件夹就是我们解压好的文件夹,接下来我们需要对这里的数据进行预处理,分成训练集(80%)和测试集(20%)

数据分割的 Python 脚本如下:

import osimport sysimport shutilimport numpy as np

def load_data(data_path): count = 0 data = {} for dir_name in os.listdir(data_path): dir_path = os.path.join(data_path, dir_name) if not os.path.isdir(dir_path): continue
data[dir_name] = [] for file_name in os.listdir(dir_path): file_path = os.path.join(dir_path, file_name) if not os.path.isfile(file_path): continue data[dir_name].append(file_path)
count += len(data[dir_name]) print("{} :{}".format(dir_name, len(data[dir_name])))
print("total of image : {}".format(count)) return data

def copy_dataset(src_img_list, data_index, target_path): target_img_list = [] for index in data_index: src_img = src_img_list[index] img_name = os.path.split(src_img)[-1] shutil.copy(src_img, target_path) target_img_list.append(os.path.join(target_path, img_name)) return target_img_list

def write_file(data, file_name): if isinstance(data, dict): write_data = [] for lab, img_list in data.items(): for img in img_list: write_data.append("{} {}".format(img, lab)) else: write_data = data with open(file_name, "w") as f: for line in write_data: f.write(line + "\n") print("{} write over!".format(file_name))

def split_data(src_data_path, target_data_path, train_rate=0.8): src_data_dict = load_data(src_data_path)
classes = [] train_dataset, val_dataset = {}, {} train_count, val_count = 0, 0 for i, (cls_name, img_list) in enumerate(src_data_dict.items()): img_data_size = len(img_list) random_index = np.random.choice(img_data_size, img_data_size, replace=False) train_data_size = int(img_data_size * train_rate) train_data_index = random_index[:train_data_size] val_data_index = random_index[train_data_size:] train_data_path = os.path.join(target_data_path, "train", cls_name) val_data_path = os.path.join(target_data_path, "val", cls_name) os.makedirs(train_data_path, exist_ok=True) os.makedirs(val_data_path, exist_ok=True) classes.append(cls_name) train_dataset[i] = copy_dataset(img_list, train_data_index, train_data_path)
val_dataset[i] = copy_dataset(img_list, val_data_index, val_data_path)
print("target {} train:{}, val:{}".format(cls_name, len(train_dataset[i]), len(val_dataset[i])))
train_count += len(train_dataset[i])
val_count += len(val_dataset[i])
print("train size:{}, val size:{}, total:{}".format(train_count, val_count, train_count + val_count))
write_file(classes, os.path.join(target_data_path, "classes.txt")) write_file(train_dataset, os.path.join(target_data_path, "train.txt")) write_file(val_dataset, os.path.join(target_data_path, "val.txt"))

def main(): src_data_path = sys.argv[1] target_data_path = sys.argv[2] split_data(src_data_path, target_data_path, train_rate=0.8)

if __name__ == '__main__': main()
复制代码


脚本运行方式:python split_data.py data/flower_dataset 目标文件夹

运行后我们看到在 data 目录下生成了两个文件夹,分别是训练集(train), 测试集(val), 以及两个标注文件(train.txt, val.txt)和一个分类信息文件(classes.txt)

修改配置文件

我们将之前下载好的配置文件(mobilenet-v2_8xb32_in1k.py)

复制一个出来(mobilenet-v2_8xb32_flower.py)

修改文件

我们的类别总共有 5 种,因此将之前的 1000 修改问 5

11 行~38 行,我们目前不需要,可以将其删除

修改训练集,测试集,验证集的文件路径


修改训练参数:


lr:预训练模型是在 8 卡上跑的,我们相应的改成一个的值,它影响模型训练的精度

max_epochs: 训练的轮次,由于我们的数据集不大,因此先跑 5 轮看看情况

interval:日志信息每 5 轮保存一次

load_from:就是我们之前下载好的预训练模型(注意路径)

开始训练

我们修改好配置文件后就可以开始训练了,运行命令如下:

mim train mmcls mobilenet-v2_8xb32_flower.py --work-dir work-dirs --gpus 0
复制代码

--work-dir: 运行结果保存的路径,不存在会自动创建

--gpus :0 代表不使用 GPU


我们可以看到下面的运行信息,这方便我们调试。对于 Warning 信息,我们可以忽略,但是报错信息我们需要调整,比如上面的长度不匹配,说明我们配置文件中有一个地方没有修改正确。

由于我们使用的是自定义数据集,因此需要将 ImageNet 改为 CustomDataset

我们再次运行一下

可以看到训练已经开始运行了


随着训练的次数的增加,我们可以看到 top1 的精度是有所提升的。


最终训练好的模型会保存为 latest.pth 文件。

4. 使用训练的模型进行验证

from mmcls.apis import init_model, inference_modelmodel = init_model("mobilenet-v2_8xb32_flower.py", "work-dirs/latest.pth", device="cpu")result = inference_model(model, "ata/val/sunflower/1008566138_6927679c8a.jpg")
复制代码

今天的内容就是这些,如果对你有所帮助欢迎转发给你的朋友。

我是 Tango,一个热爱分享技术的程序猿,我们下期见

发布于: 2023-02-10阅读数: 30
用户头像

一个日语专业的程序猿。 2017-09-10 加入

【坐标】无锡 【软件技能】Java,C#,Python 【爱好】炉石传说 【称号】InfoQ年度人气作者,Intel OpenVINO领航者联盟成员 【B站】https://space.bilibili.com/397260706/ 【个人站】www.it-worker.club

评论

发布
暂无评论
OpenMMLab图像分类实战代码演示_CV_IT蜗壳-Tango_InfoQ写作社区