写点什么

UNet 语义分割实战:使用 UNet 实现对人物的抠图

作者:AI浩
  • 2022 年 6 月 03 日
  • 本文字数:7266 字

    阅读完需:约 24 分钟

摘要

在上一篇文章,我总结了一些 UNet 的基础知识,对 UNet 不了解的可以看看,文章链接:https://wanghao.blog.csdn.net/article/details/123714994


我也整理的 UNet 的 pytorch 版本,文章链接:


https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/123280059


今天这篇文章讲解如何使用 UNet 实现图像的二分类分割。


关于二分类一般有两种做法:


第一种输出是单通道,即网络的输出 output 为 [batch_size, 1, height, width] 形状。其中 batch_szie 为批量大小,1 表示输出一个通道,heightwidth 与输入图像的高和宽保持一致。


在训练时,输出通道数是 1,网络得到的 output 包含的数值是任意的数。给定的 target ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output 不断逼近这个标签,首先会让 output 经过一个 sigmoid 函数,使其数值归一化到[0, 1],得到 output1 ,然后让这个 output1target 进行交叉熵计算,得到损失值,反向传播更新网络权重。最终,网络经过学习,会使得 output1 逼近target


训练结束后,网络已经具备让输出的 output 经过转换从而逼近 target 的能力。首先将输出的 output 通过 sigmoid 函数,然后取一个阈值(一般设置为 0.5),大于阈值则取 1 反之则取 0,从而得到预测图 predict。后续则是一些评估相关的计算。


如果网络的最后一层使用 sigmoid,则选用 BCELoss,如果没有则选择用 BCEWithLogitsLoss,例:


最后一层没有 sigmod


output = net(input)  # net的最后一层没有使用sigmoidloss_func1 = torch.nn.BCEWithLogitsLoss()loss = loss_func1(output, target)
复制代码


加上 sigmod


output = net(input)  # net的最后一层没有使用sigmoidoutput = F.sigmoid(output)loss_func1 = torch.nn.BCEWithLoss()loss = loss_func1(output, target)
复制代码


预测的时:


output = net(input)  # net的最后一层没有使用sigmoidoutput = F.sigmoid(output)predict=torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
复制代码


第二种输出是多通道,即网络的输出 output 为 [batch_size, num_class, height, width] 形状。其中 batch_szie 为批量大小,num_class 表示输出的通道数与分类数量一致,heightwidth 与输入图像的高和宽保持一致。


在训练时,输出通道数是 num_class(这里取 2)。给定的 target ,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output 不断逼近这个标签,首先会让 output 经过一个 softmax 函数,使其数值归一化到[0, 1],得到 output1 ,在各通道中,这个数值加起来会等于 1。对于target 他是一个单通道图,首先使用onehot编码,转换成 num_class个通道的图像,每个通道中的取值是根据单通道中的取值计算出来的,例如单通道中的第一个像素取值为 1(0<= 1 <=num_class-1,这里 num_class=2),那么onehot编码后,在第一个像素的位置上,两个通道的取值分别为 0,1。也就是说像素的取值决定了对应序号的通道取 1,其他的通道取 0,这个非常关键。上面的操作执行完后得到target1,让这个 output1target1 进行交叉熵计算,得到损失值,反向传播更新网路权重。最终,网络经过学习,会使得 output1 逼近target1(在各通道层面上)。


训练结束后,网络已经具备让输出的 output 经过转换从而逼近 target 的能力。计算 output 中各通道每一个像素位置上,取值最大的那个对应的通道序号,从而得到预测图 predict


训练选择用的 loss 是加插上损失函数,例:


output = net(input)  # net的最后一层没有使用sigmoidloss_func = torch.nn.CrossEntropyLoss()loss = loss_func(output, target)
复制代码


预测时


output = net(input)  # net的最后一层没有使用sigmoidpredict = output.argmax(dim=1)
复制代码


本次实战选用的第二种做法。


选用的代码地址:milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)


下载代码后,解压到本地,如下图:


数据集

数据集地址:http://www.cse.cuhk.edu.hk/~leojia/projects/automatting/,发布于 2016 年。


数据集包含 2000 张图,训练集 1700 张,测试集 300 张,数据都是来源于 Flickr 的肖像图,图像原始分辨率大小为 600×800,其中 Matting 用 closed-form matting 和 KNN matting 方法生成。


由于肖像分割数据集商业价值较高,因此公开的大规模数据集很少,这个数据集是其中发布较早,使用范围也较广的一个数据集,它有几个比较重要的特点:


(1) 图像分辨率统一,拍摄清晰,质量很高。


(2) 所有图像均为上半身的肖像图,人像区域在长度和宽度均至少占据图像的 2/3。


(3) 人物的姿态变化很小,都为小角度的正面图,背景较为简单。





[1] Shen X, Tao X, Gao H, et al. Deep Automatic Portrait Matting[M]// ComputerVision – ECCV 2016. Springer International Publishing, 2016:92-107.


将数据集下载后放到将训练集放到 data 文件夹中,其中图片放到 imgs 文件夹中,mask 放到 masks 文件夹中,测试集放到 test 文件夹下面:



由于原程序是用于Carvana Image Masking Challenge,所以我们需要修改加载数据集的逻辑,打开 utils/data_loading.py 文件:


class CarvanaDataset(BasicDataset):    def __init__(self, images_dir, masks_dir, scale=1):        super().__init__(images_dir, masks_dir, scale, mask_suffix='_matte')
复制代码


将 mask_suffix 改为“_matte”

训练

打开 train.py,先查看全局参数:


def get_args():    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=300, help='Number of epochs')    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=16, help='Batch size')    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.001,                        help='Learning rate', dest='lr')    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,                        help='Percent of the data that is used as validation (0-100)')    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')    return parser.parse_args()
复制代码


epochs:epoch 的个数,一般设置为 300。


batch-size:批处理的大小,根据显存的大小设置。


learning-rate:学习率,一般设置为 0.001,如果优化器不同,初始的学习率也要做相应的调整。


load:加载模型的路径,如果接着上次的训练,就需要设置上次训练的权重文件路径,如果有预训练权重,则设置预训练权重的路径。


scale:放大的倍数,这里设置为 0.5,把图片大小变为原来的一半。


validation:验证验证集的百分比。


amp:是否使用混合精度?


比较重要的参数是 epochs、batch-size 和 learning-rate,可以反复调整做实验,达到最好的精度。


接下来是设置模型:


net = UNet(n_channels=3, n_classes=2, bilinear=True)    logging.info(f'Network:\n'                 f'\t{net.n_channels} input channels\n'                 f'\t{net.n_classes} output channels (classes)\n'                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')    if args.load:        net.load_state_dict(torch.load(args.load, map_location=device))        logging.info(f'Model loaded from {args.load}')
复制代码


设置 UNet 参数,n_channels 是 imgs 图片的通道数,如果是 rgb 则是 3,如果是黑白图片就是 1,n_classes 设置为 2,在这里把背景也当做一个类别,所以有两个类。


如果设置了权重文件,则加载权重文件,加载权重文件做迁移学习可以加快训练,减少迭代次数,所以如果有还是尽量加载预训练权重。


接下来修改 train_net 函数的逻辑。


try:    dataset = CarvanaDataset(dir_img, dir_mask, img_scale)except (AssertionError, RuntimeError):    dataset = BasicDataset(dir_img, dir_mask, img_scale)
# 2. Split into train / validation partitionsn_val = int(len(dataset) * val_percent)n_train = len(dataset) - n_valtrain_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))# 3. Create data loadersloader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)train_loader = DataLoader(train_set, shuffle=True, **loader_args)val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
复制代码


1、加载数据集。


2、按照比例切分训练集和验证集。


3、将训练集和验证集放入 DataLoader 中。


 # (Initialize logging) experiment = wandb.init(project='U-Net', resume='allow', anonymous='must') experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,                                  val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,                                  amp=amp))
复制代码


设置 wandb,wandb 是一款非常好用的可视化工具。安装和使用方法见:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/116124285


 # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)    criterion = nn.CrossEntropyLoss()    global_step = 0
复制代码


1、设置优化器 optimizer 为 RMSprop,我也尝试了改为 SGD,通常情况下 SGD 的表现好一些。但是在训练时发现,二者最终的结果都差不多。


2、ReduceLROnPlateau 学习率调整策略,和 keras 的类似。本次选择用的是 Dice score,所以将 mode 设置为 max,当得分不再上升时,则降低学习率。


3、设置 loss 为 nn.CrossEntropyLoss()。交叉熵,多分类常用的 loss。


接下来是 train 部分的逻辑,这里需要修改的如下:


 masks_pred = net(images) true_masks = F.one_hot(true_masks.squeeze_(1), net.n_classes).permute(0, 3, 1, 2).float() print(masks_pred.shape) print(true_masks.shape)
复制代码


masks_pred = net(images)计算出来的结果是:[batch, 2, 400, 300],其中 2 代表两个类别。


true_masks.shape 是[batch, 1, 400, 300],所以要对 true_masks 做 onehot 处理。如果直接对 true_masks 做 onehot 处理,你会发现处理后的 shape 是[batch, 1, 400, 300,2],这样就和 masks_pred 对不上了,所以在做 onehot 之前,先将第二维(也就是 1 这一维度)去掉,这样 onehot 后的 shape 是[batch, 400, 300,2],然后调整顺序,和 masks_pred 的维度对上。


接下来就要计算 loss,loss 分为两部分,一部分时交叉熵,另一部分是 dice_loss,这两个 loss 各有优势,组合使用效果更优。dice_loss 在 utils/dice_sorce.py 文件中,代码如下:


import torchfrom torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all batches, or for a single mask assert input.size() == target.size() if input.dim() == 2 and reduce_batch_first: raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') if input.dim() == 2 or reduce_batch_first: inter = torch.dot(input.reshape(-1), target.reshape(-1)) sets_sum = torch.sum(input) + torch.sum(target) if sets_sum.item() == 0: sets_sum = 2 * inter return (2 * inter + epsilon) / (sets_sum + epsilon) else: # compute and average metric for each batch element dice = 0 for i in range(input.shape[0]): dice += dice_coeff(input[i, ...], target[i, ...]) return dice / input.shape[0]
def dice_coeff_1(pred, target): smooth = 1. num = pred.size(0) m1 = pred.view(num, -1) # Flatten m2 = target.view(num, -1) # Flatten intersection = (m1 * m2).sum() return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all classes assert input.size() == target.size() dice = 0 for channel in range(input.shape[1]): dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 assert input.size() == target.size() fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True)
复制代码


导入到 train.py 中,然后和交叉熵组合作为本项目的 loss。


 loss = criterion(masks_pred, true_masks) \        + dice_loss(F.softmax(masks_pred, dim=1).float(),                                       true_masks,                                       multiclass=True)
复制代码


接下来是对 evaluate 函数的逻辑做修改。


 mask_true = mask_true.to(device=device, dtype=torch.long) mask_true = F.one_hot(mask_true.squeeze_(1), net.n_classes).permute(0, 3, 1, 2).float()
复制代码


增加对 mask_trued 的 onehot 逻辑。


修改完上面的逻辑就可以开始训练了。


测试

完成训练后就可以测试了。打开 predict.py,修改全局参数:


def get_args():    parser = argparse.ArgumentParser(description='Predict masks from input images')    parser.add_argument('--model', '-m', default='checkpoints/checkpoint_epoch7.pth', metavar='FILE',                        help='Specify the file in which the model is stored')    parser.add_argument('--input', '-i', metavar='INPUT',default='test/00002.png', nargs='+', help='Filenames of input images')    parser.add_argument('--output', '-o', metavar='INPUT',default='00001.png', nargs='+', help='Filenames of output images')    parser.add_argument('--viz', '-v', action='store_true',                        help='Visualize the images as they are processed')    parser.add_argument('--no-save', '-n', action='store_true',default=False, help='Do not save the output masks')    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,                        help='Minimum probability value to consider a mask pixel white')    parser.add_argument('--scale', '-s', type=float, default=0.5,                        help='Scale factor for the input images')
复制代码


model:设置权重文件路径。这里修改为自己训练的权重文件。


scale:0.5,和训练的参数对应上。


其他的参数,通过命令输入。


def mask_to_image(mask: np.ndarray):    if mask.ndim == 2:        return Image.fromarray((mask * 255).astype(np.uint8))    elif mask.ndim == 3:        img_np=(np.argmax(mask, axis=0) * 255 / (mask.shape[0]-1)).astype(np.uint8)        print(img_np.shape)        print(np.max(img_np))        return Image.fromarray(img_np)
复制代码


img_np=(np.argmax(mask, axis=0) * 255 / (mask.shape[0]-1)).astype(np.uint8)这里的逻辑需要修改。


源代码:


 return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
复制代码


我们增加了一类背景,所以 mask.shape[0]为 2,需要减去背景。


展示结果的方法也需要修改;


def plot_img_and_mask(img, mask):    print(mask.shape)    classes = mask.shape[0] if len(mask.shape) > 2 else 1    fig, ax = plt.subplots(1, classes + 1)    ax[0].set_title('Input image')    ax[0].imshow(img)    if classes > 1:        for i in range(classes):            ax[i + 1].set_title(f'Output mask (class {i + 1})')            ax[i + 1].imshow(mask[i, :, :])    else:        ax[1].set_title(f'Output mask')        ax[1].imshow(mask)    plt.xticks([]), plt.yticks([])    plt.show()
复制代码


将原来的 ax[i + 1].imshow(mask[:, :, i])改为:ax[i + 1].imshow(mask[i, :, :])。


执行命令:


python predict.py -i test/00002.png -o output.png  -v 
复制代码


输出结果:



到这里我们已经实现将人物从背景图片中完整的抠出来了!

总结

本文实现了用 Unet 对图像做分割,通过本文,你可以学习到:


1、如何使用 Unet 对图像对二分类的语义分割。


2、如何使用 wandb 可视化。


3、如何使用交叉熵和 dice_loss 组合。


4、如何实现二分类语义分割的预测。


完整的代码:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85083165

发布于: 刚刚阅读数: 3
用户头像

AI浩

关注

还未添加个人签名 2021.11.08 加入

还未添加个人简介

评论

发布
暂无评论
UNet语义分割实战:使用UNet实现对人物的抠图_图像分割_AI浩_InfoQ写作社区