写点什么

Swin Transformer 实战: timm 使用、Mixup、Cutout 和评分一网打尽,图像分类任务

作者:AI浩
  • 2022 年 5 月 23 日
  • 本文字数:7032 字

    阅读完需:约 23 分钟

摘要

本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有 12 种类别,演示如何使用 timm 版本的 Swin Transformer 图像分类模型实现分类任务已经对验证集得分的统计,本文实现了多个 GPU 并行训练。


通过本文你和学到:


1、如何从 timm 调用模型、loss 和 Mixup?


2、如何制作 ImageNet 数据集?


3、如何使用 Cutout 数据增强?


4、如何使用 Mixup 数据增强。


5、如何实现多个 GPU 训练和验证。


6、如何使用余弦退火调整学习率?


7、如何使用 classification_report 实现对模型的评价。


8、预测的两种写法。

Swin Transformer 简介

目标检测刷到 58.7 AP!


实例分割刷到 51.1 Mask AP!


语义分割在 ADE20K 上刷到 53.5 mIoU!


今年,微软亚洲研究院的 Swin Transformer 又开启了吊打 CNN 的模式,在速度和精度上都有很大的提高。这篇文章带你实现 Swin Transformer 图像分类。

资料汇总

论文: https://arxiv.org/abs/2103.14030


代码: https://github.com/microsoft/Swin-Transformer


论文翻译:https://wanghao.blog.csdn.net/article/details/120724040


一些大佬的 B 站视频:


1、霹雳吧啦 Wzhttps://www.bilibili.com/video/BV1yg411K7Yc?from=search&seid=18074716460851088132&spm_id_from=333.337.0.0


2、ClimbingVision 社区震惊!这个关于Swin Transformer的论文分享讲得太透彻了!_哔哩哔哩_bilibili


关于 Swin Transformer 的资料有很多,在这里就不一一列举了,我觉得理解这个模型的最好方式:源码+论文。

数据增强 Cutout 和 Mixup

为了提高成绩我在代码中加入 Cutout 和 Mixup 这两种增强方式。实现这两种增强需要安装 torchtoolbox。安装命令:


pip install torchtoolbox
复制代码


Cutout 实现,在 transforms 中。


from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([ transforms.Resize((224, 224)), Cutout(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
复制代码


需要导入包:from timm.data.mixup import Mixup,


定义 Mixup,和 SoftTargetCrossEntropy


  mixup_fn = Mixup(    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,    prob=0.1, switch_prob=0.5, mode='batch',    label_smoothing=0.1, num_classes=12)     criterion_train = SoftTargetCrossEntropy()
复制代码

项目结构

Swin_demo├─data│  ├─Black-grass│  ├─Charlock│  ├─Cleavers│  ├─Common Chickweed│  ├─Common wheat│  ├─Fat Hen│  ├─Loose Silky-bent│  ├─Maize│  ├─Scentless Mayweed│  ├─Shepherds Purse│  ├─Small-flowered Cranesbill│  └─Sugar beet├─mean_std.py├─makedata.py├─train.py├─test1.py└─test.py
复制代码


mean_std.py:计算 mean 和 std 的值。


makedata.py:生成数据集。

计算 mean 和 std

为了使模型更加快速的收敛,我们需要计算出 mean 和 std 的值,新建 mean_std.py,插入代码:


from torchvision.datasets import ImageFolderimport torchfrom torchvision import transforms
def get_mean_and_std(train_data): train_loader = torch.utils.data.DataLoader( train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) mean = torch.zeros(3) std = torch.zeros(3) for X, _ in train_loader: for d in range(3): mean[d] += X[:, d, :, :].mean() std[d] += X[:, d, :, :].std() mean.div_(len(train_data)) std.div_(len(train_data)) return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__': train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor()) print(get_mean_and_std(train_dataset))
复制代码


数据集结构:



运行结果:


([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])


把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的


data├─Black-grass├─Charlock├─Cleavers├─Common Chickweed├─Common wheat├─Fat Hen├─Loose Silky-bent├─Maize├─Scentless Mayweed├─Shepherds Purse├─Small-flowered Cranesbill└─Sugar beet
复制代码


pytorch 和 keras 默认加载方式是 ImageNet 数据集格式,格式是


├─data│  ├─val│  │   ├─Black-grass│  │   ├─Charlock│  │   ├─Cleavers│  │   ├─Common Chickweed│  │   ├─Common wheat│  │   ├─Fat Hen│  │   ├─Loose Silky-bent│  │   ├─Maize│  │   ├─Scentless Mayweed│  │   ├─Shepherds Purse│  │   ├─Small-flowered Cranesbill│  │   └─Sugar beet│  └─train│      ├─Black-grass│      ├─Charlock│      ├─Cleavers│      ├─Common Chickweed│      ├─Common wheat│      ├─Fat Hen│      ├─Loose Silky-bent│      ├─Maize│      ├─Scentless Mayweed│      ├─Shepherds Purse│      ├─Small-flowered Cranesbill│      └─Sugar beet
复制代码


新增格式转化脚本 makedata.py,插入代码:


import globimport osimport shutil
image_list=glob.glob('data1/*/*.png')print(image_list)file_dir='data'if os.path.exists(file_dir): print('true') #os.rmdir(file_dir) shutil.rmtree(file_dir)#删除再建立 os.makedirs(file_dir)else: os.makedirs(file_dir)
from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'val_dir='val'train_root=os.path.join(file_dir,train_dir)val_root=os.path.join(file_dir,val_dir)for file in trainval_files: file_class=file.replace("\\","/").split('/')[-2] file_name=file.replace("\\","/").split('/')[-1] file_class=os.path.join(train_root,file_class) if not os.path.isdir(file_class): os.makedirs(file_class) shutil.copy(file, file_class + '/' + file_name)
for file in val_files: file_class=file.replace("\\","/").split('/')[-2] file_name=file.replace("\\","/").split('/')[-1] file_class=os.path.join(val_root,file_class) if not os.path.isdir(file_class): os.makedirs(file_class) shutil.copy(file, file_class + '/' + file_name)
复制代码

训练

完成上面的步骤后,就开始 train 脚本的编写,新建 train.py.

导入项目使用的库

import torchimport torch.nn as nnimport torch.nn.parallelimport torch.optim as optimimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.datasets as datasetsimport torchvision.transforms as transformsfrom sklearn.metrics import classification_reportfrom timm.data.mixup import Mixupfrom timm.loss import SoftTargetCrossEntropyfrom timm.models import swin_small_patch4_window7_224from torchtoolbox.transform import Cutout
复制代码

设置全局参数

设置学习率、BatchSize、epoch 等参数,判断环境中是否存在 GPU,如果没有则使用 CPU。建议使用 GPU,CPU 太慢了。


# 设置全局参数model_lr = 1e-4BATCH_SIZE = 4EPOCHS = 1000DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
复制代码

图像预处理与增强

数据处理比较简单,加入了 Cutout、做了 Resize 和归一化,定义 Mixup 函数。


# 数据预处理7transform = transforms.Compose([    transforms.Resize((224, 224)),    Cutout(),    transforms.ToTensor(),    transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])
])transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])])mixup_fn = Mixup( mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, prob=0.1, switch_prob=0.5, mode='batch', label_smoothing=0.1, num_classes=12)
复制代码

读取数据

使用 pytorch 默认读取数据的方式,然后将 dataset_train.class_to_idx 打印出来,预测的时候要用到。


# 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)print(dataset_train.class_to_idx)# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
复制代码


class_to_idx 的结果:


{'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}

设置模型

  • 设置 loss 函数,train 的 loss 为:SoftTargetCrossEntropy,val 的 loss:nn.CrossEntropyLoss()。

  • 设置模型为 swin_small_patch4_window7_224,预训练设置为 true,num_classes 设置为 12。

  • 检测可用显卡的数量,如果大于 1,则要用 torch.nn.DataParallel 加载模型,开启多卡训练。

  • 优化器设置为 adam。

  • 学习率调整策略选择为余弦退火。


# 实例化模型并且移动到GPUcriterion_train = SoftTargetCrossEntropy()criterion_val = torch.nn.CrossEntropyLoss()model_ft = swin_small_patch4_window7_224(pretrained=True)print(model_ft)num_ftrs = model_ft.head.in_featuresmodel_ft.head = nn.Linear(num_ftrs, 12)model_ft.to(DEVICE)print(model_ft)
if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model_ft = torch.nn.DataParallel(model_ft)print(model_ft)# 选择简单暴力的Adam优化器,学习率调低optimizer = optim.Adam(model_ft.parameters(), lr=model_lr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)
复制代码

定义训练和验证函数

定义训练函数和验证函数,在一个 epoch 完成后,使用 classification_report 计算详细的得分情况。


# 定义训练过程def train(model, device, train_loader, optimizer, epoch):    model.train()    sum_loss = 0    total_num = len(train_loader.dataset)    print(total_num, len(train_loader))    for batch_idx, (data, target) in enumerate(train_loader):        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)        samples, targets = mixup_fn(data, target)        optimizer.zero_grad()        output = model(samples)        loss = criterion_train(output, targets)        loss.backward()        optimizer.step()        lr = optimizer.state_dict()['param_groups'][0]['lr']        print_loss = loss.data.item()        sum_loss += print_loss        if (batch_idx + 1) % 10 == 0:            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),                       100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))    ave_loss = sum_loss / len(train_loader)    print('epoch:{},loss:{}'.format(epoch, ave_loss))

ACC = 0

# 验证过程def val(model, device, test_loader): global ACC model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) val_list = [] pred_list = [] with torch.no_grad(): for data, target in test_loader: for t in target: val_list.append(t.data.item()) data, target = data.to(device), target.to(device) output = model(data) loss = criterion_val(output, target) _, pred = torch.max(output.data, 1) for p in pred: pred_list.append(p.data.item()) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) if acc > ACC: torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth') ACC = acc return val_list, pred_list

# 训练
for epoch in range(1, EPOCHS + 1): train(model_ft, DEVICE, train_loader, optimizer, epoch) cosine_schedule.step() val_list, pred_list = val(model_ft, DEVICE, test_loader) print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
复制代码


运行结果:


测试

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:


测试集存放的目录如下图:



第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!


第二步 定义 transforms,transforms 和验证集的 transforms 一样即可,别做数据增强。


第三步 加载 model,并将模型放在 DEVICE 里,


第四步 读取图片并预测图片的类别,在这里注意,读取图片用 PIL 库的 Image。不要用 cv2,transforms 不支持。


import torch.utils.data.distributedimport torchvision.transforms as transformsfrom PIL import Imagefrom torch.autograd import Variableimport osclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',           'Common wheat','Fat Hen', 'Loose Silky-bent',           'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')transform_test = transforms.Compose([         transforms.Resize((224, 224)),        transforms.ToTensor(),       transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])]) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = torch.load("model.pth")model.eval()model.to(DEVICE) path='data/test/'testList=os.listdir(path)for file in testList:        img=Image.open(path+file)        img=transform_test(img)        img.unsqueeze_(0)        img = Variable(img).to(DEVICE)        out=model(img)        # Predict        _, pred = torch.max(out.data, 1)        print('Image Name:{},predict:{}'.format(file,classes[pred.data.item()]))
复制代码


运行结果:



第二种 使用自定义的 Dataset 读取图片


import torch.utils.data.distributedimport torchvision.transforms as transformsfrom dataset.dataset import SeedlingDatafrom torch.autograd import Variable classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',           'Common wheat','Fat Hen', 'Loose Silky-bent',           'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')transform_test = transforms.Compose([    transforms.Resize((224, 224)),    transforms.ToTensor(),    transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])]) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = torch.load("model.pth")model.eval()model.to(DEVICE) dataset_test =SeedlingData('data/test/', transform_test,test=True)print(len(dataset_test))# 对应文件夹的label for index in range(len(dataset_test)):    item = dataset_test[index]    img, label = item    img.unsqueeze_(0)    data = Variable(img).to(DEVICE)    output = model(data)    _, pred = torch.max(output.data, 1)    print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()]))    index += 1
复制代码


运行结果:



代码:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/81895764

用户头像

AI浩

关注

还未添加个人签名 2021.11.08 加入

还未添加个人简介

评论

发布
暂无评论
Swin Transformer实战: timm使用、Mixup、Cutout和评分一网打尽,图像分类任务_AI浩_InfoQ写作社区