1. 实战简介
使用 MMClassification 训练花卉图片分类模型。
基于 MMClassification 提供的预训练模型,在 flowers 数据集上完成分类模型的微调训练。
1. 整理 flower 数据集
数据集介绍
flower 数据集包含 5 种类别的花卉图像:雏菊 daisy 588 张,蒲公英 dandelion 556 张,玫瑰 rose 583 张,向日葵 sunflower 536 张,郁金香 tulip 585 张。
数据集下载链接:
对数据集进行划分
将数据集按照 8:2 的比例划分成训练和验证子数据集,并将数据集整理成 ImageNet 的格式。
将训练子集和验证子集放到 train 和 val 文件夹下。
文件结构如下:
flower_dataset
|--- classes.txt
|--- train.txt
|--- val.txt
| |--- train
| | |--- daisy
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- dandelion
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- rose
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- sunflower
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- tulip
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| |--- val
| | |--- daisy
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- dandelion
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- rose
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- sunflower
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
| | |--- tulip
| | | |--- NAME1.jpg
| | | |--- NAME2.jpg
| | | |--- ...
复制代码
创建并编辑标注文件将所有类别的名称写到 classes.txt
中,每行代表一个类别。
生成训练(可选)和验证子集标注列表 train.txt
和 val.txt
,每行应包含一个文件名和其对应的标签。样例:
...
daisy/NAME**.jpg 0
daisy/NAME**.jpg 0
...
dandelion/NAME**.jpg 1
dandelion/NAME**.jpg 1
...
rose/NAME**.jpg 2
rose/NAME**.jpg 2
...
sunflower/NAME**.jpg 3
sunflower/NAME**.jpg 3
...
tulip/NAME**.jpg 4
tulip/NAME**.jpg 4
复制代码
整理完成后,将处理好的数据集迁移到 mmclassification/data
文件夹下。
2. 构建模型微调的配置文件
使用 _base_
继承机制构建用于微调的配置文件,可以继承任何 MMClassification 提供的基于 ImageNet 的配置文件并进行修改。
对于新手不推荐这种方式,我后面会直接在文件上修改。
2. 环境搭建
由于我使用的是 Mac 电脑,内存不是很多,因此我将项目创建在一个固态的移动硬盘(T7)中。
前置环境:
Anaconda
Python:3.10.9
PyTorch:nightly
基础环境搭建
创建虚拟环境 MMCV
conda create -p /Volumes/T7/CodeSpace/OpenMMLab/mmcv_env python=3.10.9
复制代码
激活环境,并安装 PyTorch
conda activate /Volumes/T7/CodeSpace/OpenMMLab/mmcv_env
conda install pytorch torchvision torchaudio -c pytorch
复制代码
为了后续的操作方便我们再安装一个 jupyter-lab
为了后面代码的管理方便,我在当前目录下创建一个保存工程的目录 WorkSpace
mkdir WorkSpace
cd WorkSpace
复制代码
启动 jupyterlab
会自动打开一个浏览器,千万不要关闭当前这个命令行窗口,否则就会将 jupyter-lab 的环境结束。
OpenMMLab(mmcv)相关环境搭建
我们创建一个 Notebook
这个工具对后面很重要
在 Notebook 中执行命令,要在命令前面加!
-i 是切换为国内镜像进行安装,这样安装速度会快一些。
3. 开始真正的训练
mmclassification 已经为我们提供了很多训练好的模型,我们只要基于这些模型进行微调就可以完成我们这次的实战内容。
我们这次选择的预训练模型是MobileNetV2
下载预训练模型的相关配置文件等
既然要下载我们要知道,下载文件的名字,可以点击上面的链接进入到 GitHub
这里可以看到一些关于这个模型的信息,我们点击 config 就可以看到这个模型的名字。注意:下载时不需要带上.py 后缀。
下载命令如下:
mim download mmcls --config mobilenet-v2_8xb32_in1k --dest .
复制代码
下面这两个文件就是我们下载好的预训练模型(.pth)以及配置文件(.py)
新建一个 data 的文件,并将数据集保存在文件夹中,并解压。
flower_dataset 文件夹就是我们解压好的文件夹,接下来我们需要对这里的数据进行预处理,分成训练集(80%)和测试集(20%)
数据分割的 Python 脚本如下:
import os
import sys
import shutil
import numpy as np
def load_data(data_path):
count = 0
data = {}
for dir_name in os.listdir(data_path):
dir_path = os.path.join(data_path, dir_name)
if not os.path.isdir(dir_path):
continue
data[dir_name] = []
for file_name in os.listdir(dir_path):
file_path = os.path.join(dir_path, file_name)
if not os.path.isfile(file_path):
continue
data[dir_name].append(file_path)
count += len(data[dir_name])
print("{} :{}".format(dir_name, len(data[dir_name])))
print("total of image : {}".format(count))
return data
def copy_dataset(src_img_list, data_index, target_path):
target_img_list = []
for index in data_index:
src_img = src_img_list[index]
img_name = os.path.split(src_img)[-1]
shutil.copy(src_img, target_path)
target_img_list.append(os.path.join(target_path, img_name))
return target_img_list
def write_file(data, file_name):
if isinstance(data, dict):
write_data = []
for lab, img_list in data.items():
for img in img_list:
write_data.append("{} {}".format(img, lab))
else:
write_data = data
with open(file_name, "w") as f:
for line in write_data:
f.write(line + "\n")
print("{} write over!".format(file_name))
def split_data(src_data_path, target_data_path, train_rate=0.8):
src_data_dict = load_data(src_data_path)
classes = []
train_dataset, val_dataset = {}, {}
train_count, val_count = 0, 0
for i, (cls_name, img_list) in enumerate(src_data_dict.items()):
img_data_size = len(img_list)
random_index = np.random.choice(img_data_size, img_data_size,
replace=False)
train_data_size = int(img_data_size * train_rate)
train_data_index = random_index[:train_data_size]
val_data_index = random_index[train_data_size:]
train_data_path = os.path.join(target_data_path, "train", cls_name)
val_data_path = os.path.join(target_data_path, "val", cls_name)
os.makedirs(train_data_path, exist_ok=True)
os.makedirs(val_data_path, exist_ok=True)
classes.append(cls_name)
train_dataset[i] = copy_dataset(img_list, train_data_index,
train_data_path)
val_dataset[i] = copy_dataset(img_list, val_data_index, val_data_path)
print("target {} train:{}, val:{}".format(cls_name,
len(train_dataset[i]), len(val_dataset[i])))
train_count += len(train_dataset[i])
val_count += len(val_dataset[i])
print("train size:{}, val size:{}, total:{}".format(train_count, val_count,
train_count + val_count))
write_file(classes, os.path.join(target_data_path, "classes.txt"))
write_file(train_dataset, os.path.join(target_data_path, "train.txt"))
write_file(val_dataset, os.path.join(target_data_path, "val.txt"))
def main():
src_data_path = sys.argv[1]
target_data_path = sys.argv[2]
split_data(src_data_path, target_data_path, train_rate=0.8)
if __name__ == '__main__':
main()
复制代码
脚本运行方式:python split_data.py data/flower_dataset 目标文件夹
运行后我们看到在 data 目录下生成了两个文件夹,分别是训练集(train), 测试集(val), 以及两个标注文件(train.txt, val.txt)和一个分类信息文件(classes.txt)
修改配置文件
我们将之前下载好的配置文件(mobilenet-v2_8xb32_in1k.py)
复制一个出来(mobilenet-v2_8xb32_flower.py)
修改文件
我们的类别总共有 5 种,因此将之前的 1000 修改问 5
11 行~38 行,我们目前不需要,可以将其删除
修改训练集,测试集,验证集的文件路径
修改训练参数:
lr:预训练模型是在 8 卡上跑的,我们相应的改成一个的值,它影响模型训练的精度
max_epochs: 训练的轮次,由于我们的数据集不大,因此先跑 5 轮看看情况
interval:日志信息每 5 轮保存一次
load_from:就是我们之前下载好的预训练模型(注意路径)
开始训练
我们修改好配置文件后就可以开始训练了,运行命令如下:
mim train mmcls mobilenet-v2_8xb32_flower.py --work-dir work-dirs --gpus 0
复制代码
--work-dir: 运行结果保存的路径,不存在会自动创建
--gpus :0 代表不使用 GPU
我们可以看到下面的运行信息,这方便我们调试。对于 Warning 信息,我们可以忽略,但是报错信息我们需要调整,比如上面的长度不匹配,说明我们配置文件中有一个地方没有修改正确。
由于我们使用的是自定义数据集,因此需要将 ImageNet 改为 CustomDataset
我们再次运行一下
可以看到训练已经开始运行了
随着训练的次数的增加,我们可以看到 top1 的精度是有所提升的。
最终训练好的模型会保存为 latest.pth 文件。
4. 使用训练的模型进行验证
from mmcls.apis import init_model, inference_model
model = init_model("mobilenet-v2_8xb32_flower.py", "work-dirs/latest.pth", device="cpu")
result = inference_model(model, "ata/val/sunflower/1008566138_6927679c8a.jpg")
复制代码
今天的内容就是这些,如果对你有所帮助欢迎转发给你的朋友。
我是 Tango,一个热爱分享技术的程序猿,我们下期见
评论