写点什么

SWA 实战:使用 SWA 进行微调,提高模型的泛化

作者:AI浩
  • 2022 年 5 月 23 日
  • 本文字数:3392 字

    阅读完需:约 11 分钟

摘要

论文链接:https://arxiv.org/abs/1803.05407.pdf


官方代码:https://github.com/timgaripov/swa


论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客


SWA 简单来说就是对训练过程中的多个 checkpoints 进行平均,以提升模型的泛化性能。记训练过程第个 epoch 的 checkpoint 为,一般情况下我们会选择训练过程中最后的一个 epoch 的模型或者在验证集上效果最好的一个模型作为最终模型。但 SWA 一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个 checkpoints 的平均值。


pytorch 使用举例:


from torch.optim.swa_utils import AveragedModel, SWALR# 采用SGD优化器optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)# 随机权重平均SWA,实现更好的泛化swa_model = AveragedModel(model).to(device)# SWA调整学习率swa_scheduler = SWALR(optimizer, swa_lr=1e-6)for epoch in range(1, epoch + 1):    for batch_idx, (data, target) in enumerate(train_loader):           data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)        # 在反向传播前要手动将梯度清零        optimizer.zero_grad()        output = model(data)        #计算losss        loss = train_criterion(output, targets)        # 反向传播求解梯度        loss.backward()        optimizer.step()        lr = optimizer.state_dict()['param_groups'][0]['lr']       swa_model.update_parameters(model)    swa_scheduler.step()# 最后更新BN层参数torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)# 保存结果torch.save(swa_model.state_dict(), "last.pt")
复制代码


上面的代码展示了 SWA 的主要代码,实现的步骤:


1、定义 SGD 优化器。


2、定义 SWA。


3、定义 SWALR,调整模型的学习率。


4、开始训练,等待训练完成。


5、在每个 epoch 中更新模型的参数,更新学习率。


6、等待训练完成后,更新 BN 层的参数。

详细实现过程

环境

pyotrch:1.10

准备

在开始今天的代码前,我们要准备好训练好的模型。然后才能开始今天的代码。

实现过程

定义模型,并将训练好的模型载入,代码如下:


    model_ft = efficientnet_b1(pretrained=True)    print(model_ft)    num_ftrs = model_ft.classifier.in_features    model_ft.classifier = nn.Linear(num_ftrs, classes)    model_ft.to(DEVICE)    model_ft = torch.load(model_path)    print(model_ft)    fine_epoch = 80    fine_tune(model_ft, DEVICE, train_loader, test_loader, criterion_train, criterion_val, fine_epoch, mixup_fn,              use_amp)
复制代码


定义模型为 efficientnet_b1,这里要和训练的模型保持一致。


如果保存的整个模型,则使用 torch.load(model_path)载入模型,如果只保存了权重信息,则要使用 model_ft=load_state_dict(torch.load(model_path)),载入模型。


然后,设置 fine 的 epoch 为 80。


接下来,我们一起去看 fine_tune 函数中的内容。


 # 采用SGD优化器    optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)    if use_amp:        model, optimizer = amp.initialize(model_ft, optimizer, opt_level="O1")  # 这里是“欧一”,不是“零一”
复制代码


定义优化器为 SGD。


如果使用混合精度,则对 amp 初始化。


 # 随机权重平均SWA,实现更好的泛化 swa_model = AveragedModel(model).to(device) # SWA调整学习率 swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
复制代码


初始化 SWA。


使用 SWALR 调整学习率。


接下来循环 epoch,这里都是比较通用的逻辑。


 for epoch in range(1, epoch + 1):        model.train()        train_loss = 0        total_num = len(train_loader.dataset)        print(total_num, len(train_loader))        for batch_idx, (data, target) in enumerate(train_loader):            if len(data) % 2 != 0:                print(len(data))                data = data[0:len(data) - 1]                target = target[0:len(target) - 1]                print(len(data))            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)            samples, targets = mixup_fn(data, target)            output = model(samples)            loss = train_criterion(output, targets)            optimizer.zero_grad()            if use_amp:                with amp.scale_loss(loss, optimizer) as scaled_loss:                    scaled_loss.backward()            else:                loss.backward()            optimizer.step()            lr = optimizer.state_dict()['param_groups'][0]['lr']            print_loss = loss.data.item()            train_loss += print_loss            if (batch_idx + 1) % 10 == 0:                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),                           100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))        swa_model.update_parameters(model)        swa_scheduler.step()
复制代码


主要步骤有:


1、计算 loss。

2、是否使用 amp 混合精度,如果使用混合精度则使用 scaled_loss 反向传播求梯度,否则直接 loss 反向传播求梯度。

3、 swa_model.update_parameters(model)更新 swa_model 的参数。

4、 swa_scheduler.step()更新学习率。


等待所有的 epoch 执行完成后。


torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)torch.save(swa_model.state_dict(), "last.pt")
复制代码


更新 BN 层参数。


然后保存模型的权重。注意:这里只能保存模型的权重,不能保存整个模型。


完成之后就可以测试了,执行代码:


import torch.utils.data.distributedimport torchvision.transforms as transformsfrom PIL import Imagefrom torch.autograd import Variableimport osfrom torchvision.models.mobilenetv3 import mobilenet_v3_largeimport torch.nn as nnfrom torch.optim.swa_utils import AveragedModel, SWALRfrom timm.models.efficientnet import efficientnet_b1import numpy as np
def show_outputs(output):
output_sorted = sorted(output, reverse=True) top5_str = '-----TOP 5-----\n' for i in range(5): value = output_sorted[i] index = np.where(output == value) for j in range(len(index)): if (i + j) >= 5: break if value > 0: topi = '{}: {}\n'.format(index[j], value) else: topi = '-1: 0.0\n' top5_str += topi print(top5_str)
transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = efficientnet_b1(pretrained=True)
num_ftrs = model.classifier.in_featuresmodel.classifier = nn.Linear(num_ftrs, 8)swa_model = AveragedModel(model)swa_model.load_state_dict(torch.load("last.pt"))swa_model.to(DEVICE)swa_model.eval()
path = 'test/'testList = os.listdir(path)for file in testList: img = Image.open(path + file) img = transform_test(img) img.unsqueeze_(0) img = Variable(img).to(DEVICE) out = swa_model(img) out = out.data.cpu().numpy()[0] print(file) show_outputs(out)
复制代码


这里测试代码和以前的写法没有啥区别,唯一不同的地方:


重新定义模型,然后载入权重。运行结果:



完整代码:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85223146

发布于: 刚刚阅读数: 2
用户头像

AI浩

关注

还未添加个人签名 2021.11.08 加入

还未添加个人简介

评论

发布
暂无评论
SWA实战:使用SWA进行微调,提高模型的泛化_AI浩_InfoQ写作社区