写点什么

PyTorch 实现 FCN 网络用于图像语义分割

  • 2025-05-24
    上海
  • 本文字数:12760 字

    阅读完需:约 42 分钟

本文主要介绍了如何在昇腾上,使用 pytorch 对图像分割任务的开山之作 FCN 网络在 VOC2012 数据集上进行训练的实战过程讲解。主要内容包括 FCN 网络的创新点分析、FCN 网络架构分析,实战训练代码分析等等。


本文的目录结构安排如下所示:

  • FCN 创新点分析

  • FCN 网络架构分析

  • FCN 网络搭建过程及代码详解

  • 端到端训练 Voc2012 数据集全过程分析

FCN(Fully Convolutional Networks)网络创新点分析

  • 采用全卷积结构替换全连接层


FCN 将传统分类网络中的全连接层替换为卷积层,使得网络能够接受任意尺寸的输入图像,而不需要固定输入大小。这种设计使得 FCN 能够直接对图像进行端到端的像素级预测,适用于语义分割任务。


  • 使用多尺度特征融合的思想


通过改进的上采样和下采样技术及高效的跳跃连接结构,FCN 能够整合网络中不同层的特有特征。采用多尺度特征融合可以显著地提高模型在复杂场景分析中的精度和鲁棒性。


  • 自适应和动态卷积核


FCN 中的卷积核可以根据输入数据的特性动态调整其大小和形状,从而更有效地提取特征。这种自适应能力使得 FCN 在多种不同类型的图像处理任务中都能表现出色。


  • 跳跃连接(跟第二点结合)


FCN 使用了类似 ResNet 的跳跃连接结构,将深层的粗糙语义信息和浅层的精细表征信息融合,以实现更加精细的语义分割。这种跳跃连接结构在上采样过程中融合了不同维度的特征,保留了更多细节,帮助模型更精细地重建图像信息。


  • 使用反卷积操作用于上采样


FCN 使用反卷积层进行上采样,将最后一个卷积层的特征图恢复到与输入图像相同的尺寸。反卷积层通过不同尺度的上采样操作,保留了原始输入图像的空间信息,使得网络对每个像素都能实现高效的类别预测。

FCN 网络架构分析

从网络的总体架构图中可以看出 FCN 的结构非常简单。首先,输入图片经过若干个卷积层实现特征提取后,再通过反卷积操作将图像大小还原到指定大小实现像素级别的类别预测。最后输出的 channel 维度是 21,这是因为论文中使用的 pascal voc 数据集总共 20 个类别,加上背景一起总共 21 个类别。


根据这 21 个值进行 softmax 处理就能得到图像中每个像素属于这 21 个类别的概率值,取最大的那个值作为该像素最终的类别预测结果,这样就可以得到整张图所有像素点的类别预测情况,然后,每个类别用不同的颜色区分从而整张图片的背景与类别被清晰的分割开(例如:图中的猫、狗及背景分别用蓝色、棕色与绿色区分)。

上图通过使用全连接层得到最终的维度为 1000 的向量,由于全连接层要求的输入大小必须是固定的,因此作者将网络中的全连接层转换为卷积层,输入图像的大小可以是任意的。 那么最后的输出就不是一个一维向量了,就变成了(m,n,c),对应每个 channel 就是一个 2D 的数据,可以可视化成一个 heatmap 图。


图中将全连接层全部替换成了卷积层,其中全计算量与卷积层的计算量分别为:全连接是 25088 × 4096 = 102760448,卷积的计算量是 7 × 7 × 512 × 4096 = 102760448,可以看到他们的计算量是一模一样的相当于把全连接的权重进行了 reshape 操作。


论文中 FCN 网络有三种模式的模型,分别是 FCN-32s,16s,8s,其中数字的含义是将最后得到的特征图通过上采样多少倍后能够恢复到原图尺寸的大小。图中省略了卷积层与其他层级信息,只保留了池化层用于展示多尺度特征融合的过程。


整个网络第一步通过将特征图上采样 32 倍得到原图大小的输出,此时得到的是 FCN-32s 模型。然后将该特征图进行 2 倍的上采样与 pool4 层的特征图进行结合得到 FCN-16s 模型,此时的网络能够预测更精细的细节,同时保留高级语义信息。同理,将得到的 FCN-16s 模型进行 2 倍的上采样后与 pool3 的特征图进行融合得到 FCN-8s 模型,该模型可以得到更加精准的预测。


除此以外,从上述的分析可以发现,FCN 网络在结合不同尺度特征信息的过程中,还可以继续往深层次的继续结合得到 FCN-4s,2s 模型,这里可以根据需要结合前面 pool2 与 pool1 层的信息即可。

FCN 网络搭建过程及代码详解

基于 torch 搭建 FCN 网络,需要导入 torch 相关模块,其中 nn.Module 是各个神经网络模型需要继承的基类。

import torchimport torch.nn as nn
复制代码


由于 FCN 网络采用的是全类卷积层操作,论文中分别使用 Alexnet、VGG16 与 GoogleNet 网络作为 backone 后用 VOC 数据集进行微调对比,得到 FCN-VGG16 的 mean IU 最高,IU 与目标检测模型中的 IOU 意思一样,用来反映模型预测与框定的效果好坏。


目前 PyTorch 官方实现中使用 ResNet-50 作为 backbone‌,原始论文中提出的 FCN 使用的是 VGG16 作为 backbone,但在 PyTorch 的官方实现中,由于 ResNet-50 在性能上有更好的表现,因此一般都会选择 ResNet-50 作为 backbone 后用数据做微调。本文的实现不采用任何 backbone,从零到一搭建一个 FCN8s 网络模型。


整个 FCN8s 网络模型通过一个 FCN8 类来实现,其中 FCN8 类中继承了'nn.Module'模块,网络总共包含两部分,前一部分对图像输入进行特征提取并不断降维,后一部分通过对得到的特征图进行不同倍率的上采样,从而融合不同尺度特征得到 FCN8s 模型。


前一部分总共包含 5 个 stage,每个 stage 最后都用一个'Maxpool'操作用于特征图提取与降维,对应类中'nn.MaxPool2d(kernel_size=2,padding=0)'。stage1、stage2 与 stage5 均定义了一层'Conv2d'、'Relu'与'BatchNorm2d'操作组合并结合'Maxpool'操作。stage3 与 stage4 分别定义了三层与两层 Conv2d、Relu 与 BatchNorm2d 操作组合并结合 Maxpool 操作。


后一部分定义了 upsample_2、upsample_4、upsample_81 与 upsample_82,也就是 2、4 与 8 三种不同倍率的下采样。VOC 数据集图片的输入后本文会将其裁剪到 224x224,因此网络的输入 size 是 224x224。


class FCN8(nn.Module):     def __init__(self, num_classes):        # 调用super方法调用父类nn.Module的初始化函数        super(FCN8, self).__init__()          '''        定义stage1, Conv2d中in_channels=3与输入图像3通道相对应,out_channels=96表示输出的通道维度是96,kernel=3表示卷积核大小是3x3,对输入图像padding=1。        根据卷积的size计算公式output= ((i + 2p -k) /s + 1),i表示输入图像的尺寸,p表示padding,k表示卷积核大小,s表示步长。        假设batch=1的情况下输入图像为224x224x3,通过'Conv2d'后输出的size为 (224 + 2 -3)/1 +1 = 224,因此conv2d后输出图像为224x224x96。        BatchNorm2d输入前后不改变图像size大小,通过MaxPool2d操作降维后得到最终输出图像大小为112x112x96。        '''        self.stage1 = nn.Sequential(            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=3, padding=1),            nn.ReLU(),            nn.BatchNorm2d(num_features=96),            nn.MaxPool2d(kernel_size=2, padding=0)        )        '''        定义stage2, Conv2d中in_channels=96与stage1中输出维度相对应,out_channels=256表示输出的通道维度是256。同理通过Conv2d后得到输出size为112x112x256        通过'MaxPool2d'操作后变为56x56x256是stage2的最终输出。        '''        self.stage2 = nn.Sequential(            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=3, padding=1),            nn.ReLU(),            nn.BatchNorm2d(num_features=256),            nn.MaxPool2d(kernel_size=2, padding=0)         )        # 定义stage3, 假设batch=1,input = 56x56x256,则output= 28x28x256。        self.stage3 = nn.Sequential(            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),            nn.ReLU(),            nn.BatchNorm2d(num_features=384),
nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=384),
nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=256),
nn.MaxPool2d(kernel_size=2, padding=0) ) # 定义stage4, 假设batch=1,input = 28x28x256,则output= 14x14x512。 self.stage4 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=512),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=512),
nn.MaxPool2d(kernel_size=2, padding=0) ) # 定义stage5, 假设batch=1,input = 14x14x512,则output= 7x7xnum_classes。 self.stage5 = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(num_features=num_classes), nn.MaxPool2d(kernel_size=2,padding=0) ) ''' 定义2倍率的上采样过程,可以分别得到一个2倍率上采样的特征图便于做特征融合。 上采样过程 out_size = (i -1)*S-2P + k,其中i、S、P与k分别表示输入图像size,步长、padding与卷积核大小。 upsample_2结合的是stage4的输出特征图,因此input = 14x14x512,output = 28x28x512。 ''' self.upsample_2 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, padding= 1, stride=2) ''' 定义4倍率的上采样过程,可以分别得到一个4倍率上采样的特征图便于做特征融合。 upsample_4结合的是stage5的输出特征图,因此input = 7x7xnum_classes,output = 28x28xnum_classes。 ''' self.upsample_4 = nn.ConvTranspose2d(in_channels=num_classes, out_channels=num_classes, kernel_size=4, padding= 0,stride=4)
''' 下述upsample_81与upsample_82的作用是将上述特征融合后的图像分别通过4倍与2的上采样将图像还原成原始输入大小(28 * 4 * 2 = 224)。 其中in_channels与out_channels中512 + num_classes + 256表示三个不同维度的channel进行拼接得到。 ''' self.upsample_81 = nn.ConvTranspose2d(in_channels=512 + num_classes + 256, out_channels=512 + num_classes + 256, kernel_size=4, padding= 0,stride=4) self.upsample_82 = nn.ConvTranspose2d(in_channels=512 + num_classes + 256, out_channels=512 + num_classes + 256, kernel_size=4, padding= 1,stride=2)
# 最后的预测模块,input:224x224x(512 + num_classes + 256), output:224x224xnum_classes。 self.final = nn.Sequential( nn.Conv2d(512 + num_classes + 256, num_classes, kernel_size=7, padding=3), )
def forward(self, x): x = x.float() # conv1->pool1->输出 x = self.stage1(x) # conv2->pool2->输出 x = self.stage2(x) # conv3->pool3->输出, 经过上采样后, 需要用pool3暂存 x = self.stage3(x) pool3 = x # conv4->pool4->输出, 经过上采样后, 需要用pool4暂存 x = self.stage4(x) pool4 = self.upsample_2(x) x = self.stage5(x) conv7 = self.upsample_4(x)
# 对所有上采样过的特征图进行concat, 在channel维度上进行叠加 x = torch.cat([pool3, pool4, conv7], dim = 1)
# 经过一个分类网络,输出结果(这里采样到原图大小,分别一次2倍一次4倍上采样来实现8倍上采样) output = self.upsample_81(x) output = self.upsample_82(output) output = self.final(output)
return output
复制代码


将网络模型结构进行打印,可以看到网络的整体结构与上述描述相一致。至此,FCN8s 网络架构全部搭建完成,接下来将用该网络来介绍如何训练 VOC 数据集。


print(FCN8(21))
复制代码


FCN8(  (stage1): Sequential(    (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU()    (2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (stage2): Sequential(    (0): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU()    (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (stage3): Sequential(    (0): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU()    (2): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (3): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (4): ReLU()    (5): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (6): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (7): ReLU()    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (stage4): Sequential(    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU()    (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (4): ReLU()    (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (stage5): Sequential(    (0): Conv2d(512, 21, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): ReLU()    (2): BatchNorm2d(21, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  )  (upsample_2): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))  (upsample_4): ConvTranspose2d(21, 21, kernel_size=(4, 4), stride=(4, 4))  (upsample_81): ConvTranspose2d(789, 789, kernel_size=(4, 4), stride=(4, 4))  (upsample_82): ConvTranspose2d(789, 789, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))  (final): Sequential(    (0): Conv2d(789, 21, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))  ))
复制代码

端到端训练 Voc2012 数据集全过程分析

VOC 数据集介绍

VOC 数据集,全称 Visual Object Classes,是一个广泛使用的计算机视觉数据集,主要用于目标检测、图像分割和图像分类等任务。该数据集最初由英国牛津大学的计算机视觉小组创建,并在 PASCAL VOC 挑战赛中使用。VOC 数据集包含了大量不同类别的标记图像,每个图像都有与之相关联的边界框(bounding box)和对象类别的标签。


VOC 数据集在类别上可以分为 4 大类,20 小类,涵盖了人、汽车、猫、狗等常见目标类别。此外,VOC 数据集还提供了用于图像分割任务的像素级标注,该数据集分为 21 类,其中 20 类为前景物体,1 类为背景。数据集量级方面,VOC2007 和 VOC2012 是两个最流行的版本,分别包含了约 10000 张和 20000 张标注图像,本文采用 VOC2012 数据集。下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

VOC 数据集结构

VOC 数据集的结构相对复杂,但非常有序。它主要包含以下几个文件夹:


ImageSets:包含三个子文件夹(Layout、Main、Segmentation),用于存放不同数据集划分(训练集、验证集、测试集)的文件名列表。


JPEGImages:存放所有的图片,包括训练、验证和测试用到的所有图片。


SegmentationClass:包含已经标注好的图像。


Annotations:存放每张图片相关的标注信息,以 XML 格式的文件存储。这些文件包含了图像中每个目标的类别、边界框坐标等详细信息。


SegmentationObject:文件夹中包含实例分割用到的标签图像。


其中本文实验需要用到的 3 个文件夹均已标粗。

如图所示,图像分割任务需要将图中的物体与物体间,物体与背景间信息区分开来,不同物体标记不同颜色,本文实验用到的 VOC 数据集总共包含 21 种类别,'VOC_COLORMAP'定义了每一个类别的颜色信息,包含 RGB 三个,'VOC_CLASSES'对应数据集中 21 个类别。

VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],                [0, 64, 128]]
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']
# 定义一个一维向量colormap2label,含有256^3个元素,目的是为了让三通道图像的每一点像素特征都有所对应类别所对应。colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)
# 给包含类别的物体赋予颜色标签,不属于类别内的rgb是全为0,也就是整个图片中除了背景与物体以外的颜色为全黑。for i, colormap in enumerate(VOC_COLORMAP): colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
复制代码

VOC 数据集 patch

为了方便整个训练过程快速有效的进行,我们对输入的数据按照一定的 patch 输入进行训练,因此需要定义一个数据个数转换类'VOCSegDataset',用来生成每一个批次送给网络所需要的数据。


在定义该类以前,我们需要定义一些对于文件及标签处理操作函数,分别是'voc_label_indices'、'read_file_list'与'voc_rand_crop'。


import numpy as npdef voc_label_indices(colormap):    """    convert colormap (PIL image) to colormap2label (uint8 tensor).    """    colormap = np.array(colormap).astype('int32')    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256           + colormap[:, :, 2])    return colormap2label[idx]
复制代码


# 针对于VOC2012数据集读取训练与验证集文件返回训练集或验集所有图片路径及标签def read_file_list(root, is_train=True):    txt_fname = root + '/ImageSets/Segmentation/' + ('train.txt' if is_train else 'val.txt')    with open(txt_fname, 'r') as f:        filenames = f.read().split()    images = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in filenames]    labels = [os.path.join(root, 'SegmentationClass', i + '.png') for i in filenames]    return images, labels  # file list
复制代码


# 对输入的VOC图片进行裁剪到指定的height与widthdef voc_rand_crop(image, label, height, width):    """    Random crop image (PIL image) and label (PIL image).    """    i, j, h, w = transforms.RandomCrop.get_params(        image, output_size=(height, width))
image = transforms.functional.crop(image, i, j, h, w) label = transforms.functional.crop(label, i, j, h, w)
return image, label
复制代码


数转换类'VOCSegDataset',继承'torch.utils.data.Dataset'用于迭代送入图片给模型进行训练及验证,


class VOCSegDataset(torch.utils.data.Dataset):    def __init__(self, is_train, crop_size, voc_root):        """        crop_size: (h, w)        """        self.transform = transforms.Compose([            transforms.ToTensor(),            #transforms.Normalize(mean=self.rgb_mean, std=self.rgb_std)        ])        # (h, w)        self.crop_size = crop_size        images, labels = read_file_list(root=voc_root, is_train=is_train)        # images list        self.images = self.filter(images)         # labels list        self.labels = self.filter(labels)        print('Read ' + str(len(self.images)) + ' valid examples')    # 过滤掉尺寸小于crop_size的图片    def filter(self, imgs):        return [img for img in imgs if (                Image.open(img).size[1] >= self.crop_size[0] and                Image.open(img).size[0] >= self.crop_size[1])]     def __getitem__(self, idx):        image = self.images[idx]        label = self.labels[idx]        image = Image.open(image).convert('RGB')        label = Image.open(label).convert('RGB')        image, label = voc_rand_crop(image, label, *self.crop_size)        image = self.transform(image)        label = voc_label_indices(label)        # float32 tensor, uint8 tensor        return image, label 
def __len__(self): return len(self.images)
复制代码


调用上述定义好的数据格式转化类'VOCSegDataset'生成训练集与验证集集合'voc_train'与'voc_val',从打印可以看出本次实验只读取了 1456 与 1436 张图片用于训练与测试。


import osfrom torchvision import transformsfrom PIL import Image
voc_train = VOCSegDataset(is_train = True, crop_size=(224,224), voc_root = '/home/pengyongrong/workspace/VocData')voc_val = VOCSegDataset(is_train = False, crop_size=(224,224), voc_root = '/home/pengyongrong/workspace/VocData')
复制代码


Read 1456 valid examplesRead 1436 valid examples
复制代码


接下来对训练集中的部分图片进行可视化,通过引入 matplotlib 库来进行可视化,这里展示了 5 张图片及对应标签,如果想要展示更多,可以设置 i 的取值即可。从可视化结果可以看出图中每一个不同类别与背景都被用不同颜色区分开啦,例如图一中的飞机、人与背景分别用暗紫色、黄色与紫色区分开来。


import matplotlib.pyplot as pltfor i, (img, label) in enumerate(voc_train):    plt.figure(figsize=(10,10))    plt.subplot(221)    plt.imshow(img.moveaxis(0,2))    plt.subplot(222)    plt.imshow(label)    plt.show()    plt.close()    if i ==5:        break
复制代码







导入昇腾 npu 相关库 transfer_to_npu、该模块可以使能模型自动迁移至昇腾上。

import torch_npufrom torch_npu.contrib import transfer_to_npu
复制代码


from torch.utils.data import Dataset, DataLoader
#创建dataloader,定义每一批次送入模型进行训练的batch_size这里设置成8,也可以根据需要改成任意>=2的取值。trainloader = DataLoader(voc_train, batch_size = 8, shuffle=True,)testloader = DataLoader(voc_val, batch_size = 4)
复制代码


optim 实现了各种优化算法的库(例如:SGD 与 Adam),在使用 optimizer 时候需要构建一个 optimizer 对象,这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。


# 导入torch及相关模块库,便于后续搭建神经网络模型使用import torch.optim as optimimport torch.nn.functional as F
复制代码


#定义模型训练在哪种类型的设备上跑device = 'npu'# 构建模型,这里VOC数据类别是21,因此入参num_classes=21,若是其他的类别,此处可以根据需要进行设置。net = FCN8(num_classes=21)#将网络模型加载到指定设备上,这里device是昇腾的npunet = net.to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=1.0, weight_decay=5e-4)lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,0.1,steps_per_epoch=len(trainloader),                                                   epochs=150,div_factor=25,final_div_factor=10000,pct_start=0.3)
复制代码


训练模块: 根据传入的迭代次数开始训练网络模型,这里需要在 model 开始前加入 net.train(),使用随机梯度下降算法是将梯度值初始化为 0(zero_grad()),计算梯度、通过梯度下降算法更新模型参数的值以及统计每次训练后的 loss 值(每隔 100 次打印一次)


from tqdm import tqdmdef train(epoch):    net.train()    train_loss = 0.0    epoch_loss = 0.0    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, 0)):        inputs, targets = inputs.to(device), targets.to(device)        optimizer.zero_grad()        outputs = net(inputs)        loss = criterion(outputs, targets)        loss.backward()        optimizer.step()        lr_scheduler.step()
train_loss += loss.item() epoch_loss += loss.item()
if batch_idx % 100 == 99: # 每100次迭代打印一次损失 print(f'[Epoch {epoch + 1}, Iteration {batch_idx + 1}] loss: {train_loss / 100:.3f}') train_loss = 0.0 return epoch_loss / len(trainloader)
复制代码


测试模块: 每训练一轮将会对最新得到的训练模型效果进行测试,使用的是数据集准备时期划分得到的测试集。


def test():    net.eval()    val_loss = 0    val_loss_all=[]    val_num = 0    total = 0    with torch.no_grad():        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):            inputs, targets = inputs.to(device), targets.to(device)            outputs = net(inputs)            out = F.log_softmax(outputs, dim=1)            loss = criterion(out, targets)            val_loss += loss.item() * len(targets)            val_num += len(targets)
# 计算一个epoch在验证集上的损失和精度 val_loss_all.append(val_loss / val_num) return val_loss_all[-1]
复制代码


训练与测试的次数为 2 次,这里用户可以根据需要自行选择设置更高或更低,每个 epoch 的准确率都会被打印出来,如果不需要将代码注释掉即可,这里可以看到两个 epoch 间的 loss 在下降(从 1.94->1.63)。


#开启模型训练与测试过程for epoch in range(2):    epoch_loss = train(epoch)    test_accuray = test()    print(f'Epoch loss for FCN8s at epoch {epoch + 1}: {epoch_loss:.3f}')
复制代码


 55%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                            | 100/182 [01:49<01:28,  1.08s/it][Epoch 1, Iteration 100] loss: 1.940

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 182/182 [03:14<00:00, 1.07s/it]100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [02:36<00:00, 2.29it/s]Epoch loss for FCN8s at epoch 1: 1.825
55%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 100/182 [01:42<01:22, 1.01s/it][Epoch 2, Iteration 100] loss: 1.628
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 182/182 [03:04<00:00, 1.01s/it]100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 359/359 [02:39<00:00, 2.25it/s]Epoch loss for FCN8s at epoch 2: 1.630
复制代码


内存使用情况: 整个训练过程的内存使用情况可以通过"npu-smi info"命令在终端查看,因此本文实验只用到了单个 npu 卡(也就是 chip 0),内存占用约 13G,对内存、精度或性能优化有兴趣的可以自行尝试进行优化。

Reference

[1] Long, Jonathan , E. Shelhamer , and T. Darrell . "Fully Convolutional Networks for Semantic Segmentation." IEEE Transactions on Pattern Analysis and Machine Intelligence 39.4(2015):640-651.

用户头像

还未添加个人签名 2024-12-19 加入

还未添加个人简介

评论

发布
暂无评论
PyTorch 实现FCN网络用于图像语义分割_永荣带你玩转昇腾_InfoQ写作社区