写点什么

数据增强(二)-SamplePairing

  • 2022 年 5 月 10 日
  • 本文字数:1507 字

    阅读完需:约 5 分钟

1.背景

在数据增强(一)中介绍了 imgaug 图像增强库,本文介绍 SamplePairing 的数据增强策略。


参考文献:Data Augmentation by Pairing Samples for Images Classification


2. 内容

2.1 增强策略

  • 训练集中任一的图片 A(256x256, 标签为 A),经过普通增强后随机 patch 出 224*224 的区域

  • 随机选择(256x256,标签为 B)的图片 B,也经过普通增强后随机 patch 出 224*224 的区域

  • patch 部分像素平均:(patch A + pathc B)/2

  • 新合成的 A(256x256 标签为 A) 参与训练


关键点:合成后的图像保留标签 A 不变,不考虑标签 B 的部分;但是合成的数据中既有标签 A 也有标签 B 的数据,训练中会存在会存在训练误差和验证误差降低的情况; 所以需要再最后不使用 SamplePairing 增强情况,做 finetune 后的效果要高于训练中没有 SamplePairing 增强的训练;

2.2 训练步骤

  • without SamplePairing: 先按照普通的数据增强策略训练若干 epoch(如 100 个 epoch)

  • enable SamplePairing: 接下来的 8 个 epoch SamplePairing增强+2 个 epoch普通增强, 执行若干个组合 (enable SamplePairing for 8 epochs and disable it for the next 2 epochs)

  • disable the SamplePairing as the fine-tuning: 执行普通的数据增强策略最后训练若干 epoch。

2.3 实现细节

import randomdef patch_range(h, w, patch_h, patch_w):    diff_h = h - patch_h    diff_w = w - patch_w    h_select_range = [i for i in range(0, diff_h)]    w_select_range = [i for i in range(0, diff_w)]    h_idx = random.sample(h_select_range, 1)[0]    w_idx = random.sample(w_select_range, 1)[0]    return h_idx, h_idx + patch_h, w_idx, w_idx + patch_w
def sample_pair_batch(x, y, h=224, w=224, patch_h=196, patch_w=196, class_num=3): """ Data Augmentation by Pairing Samples for Images Classification :param x: [n, h, w, c] :param y: [b, class_num] :param h: input height :param w: input width :param patch_h: patch height ( 3/4 * input height) :param patch_w: patch width ( 3/4 * input width) :param class_num: the number of class :return: """ class_idx = np.arange(0, class_num) class_idx_another = class_idx + 1 class_idx_another[-1] = 0 class_rela_map = dict(zip(class_idx, class_idx_another))
# get different lable index label_index = np.argmax(y, axis=1)
# different class sample index class_label_index_list = [np.where(label_index == i)[0] for i in range(class_num)] pair_sample = [] for idx, sample in enumerate(x): sample_label = label_index[idx] label_index_list = class_label_index_list[class_rela_map[sample_label]].tolist() if len(label_index_list) == 0: pair_sample.append(x[[idx]]) else: pair_sample.append(x[random.sample(label_index_list, 1)])
pair_sample_x = np.concatenate(pair_sample)
begin_h, end_h, begin_w, end_w = patch_range(h, w, patch_h, patch_w) begin_h2, end_h2, begin_w2, end_w2 = patch_range(h, w, patch_h, patch_w)
x[:, begin_h:end_h, begin_w:end_w] = (x[:, begin_h:end_h, begin_w:end_w] + pair_sample_x[:, begin_h2:end_h2,begin_w2:end_w2]) // 2 return x, y
复制代码

3.结语

本实验介绍了 samples pair 的数据增强策略,通过随机叠加两个图片的方式来形成一个强的正则化器, 提高模型的泛化能力。


希望对大家有帮助。




发布于: 刚刚阅读数: 4
用户头像

公众号:人工智能微客(weker) 2019.11.21 加入

人工智能微客(weker)长期跟踪和分享人工智能前沿技术、应用、领域知识,不定期的发布相关产品和应用,欢迎关注和转发

评论

发布
暂无评论
数据增强(二)-SamplePairing_人工智能_AIWeker-人工智能微客_InfoQ写作社区