写点什么

Transforms 预处理

作者:测试人
  • 2025-05-19
    北京
  • 本文字数:1526 字

    阅读完需:约 5 分钟

Transforms

  • 为什么要预处理

  • 什么是预处理

  • Transforms 简介

  • Transforms 案例

为什么要预处理

  • 神经网络模型接收的数据类型是 Tensor,而不是 PIL 对象,因此需要对数据进行预处理操作。

  • 图像分析中,图像质量的好坏直接影响识别算法的设计与效果的精度,因此在图像分析前,需要进行图像预处理操作。


  • 适应神经网络结构

  • 对训练样本进行提纯

  • 进行数据增强

  • 数据归一化

  • 压缩数据体积

  • 图像预处理的主要目的是消除图像中无关的信息,恢复有用的真实信息,即抑制不想要的变形或者增强某些对于后续处理重要的图像特征。通过增强有关信息的可检测性、最大限度地简化数据,从而改进特征提取、图像分割、匹配和识别的可靠性。

  • 预处理不会增加图像的信息量,一般会减少信息量。




什么是预处理

  • 预处理方法分为四类:

  • 像素亮度变换——亮度矫正要考虑该像素原来的亮度和其在图像中的位置

  • 几何变换——可以消除图像获取时所出现的几何变形

  • 局部邻域预处理——使用像素的小邻域来产生输出图像中新的亮度数值

  • 图像复原——旨在利用有关退化性质知识来抑制退化

Transforms 简介:Transforms 中常用的图像预处理方法

裁剪

  • 中心裁剪:transforms.CenterCrop

  • 随机裁剪:transforms.RandomCrop

  • 随机长宽比裁剪:transforms.RandomResizedCrop

翻转和旋转

  • 随机旋转:transforms.RandomRotation

图像变换

  • 标准化:transforms.Normalize

  • 转为 tensor,并归一化至[0-1]:transforms.ToTensor

  • 修改亮度、对比度和饱和度:transforms.ColorJitter

  • 转灰度图:transforms.Grayscale

  • 线性变换: transforms.LinearTransformation()

Transforms 简介:Transforms 的机制


Transforms 简介: Transforms 案例

  • 在 pytorch 中,图像的预处理过程中常常需要对图片的格式、尺寸等做一系列的变化,这就需要借助 transforms。

  • torchvision.transforms 模块主要用于对图像进行转换等一系列预处理操作,其主要目的是对图像数据进行增强,进而提高模型的泛化能力。

  • 主要包括对 Tensor 及 PIL Image 对象的操作,例如随机切割、旋转、数据类型转换等。


  • 所有 TorchVision 数据集都有两个参数——用于修改要素的 transform 和用于修改标注的 target _ transform。它们接受包含变换逻辑的调用。

  • torchvision.transforms 模块提供了一些现成的常用转换。

  • FashionMNIST 特征是 PIL 图像格式,标签是整数。对于训练,我们需要归一化张量形式的特征,以及 one-hot 编码张量形式的标签。为了进行这些转换,我们使用 ToTensor 和 Lambda。

ToTensor()

ToTensor 将 PIL 图像或数字数组转换为浮点型。并将图像的像素亮度值缩放到范围[0., 1.]

  • transforms.ToTensor 的作用是将一个 PIL Image 格式的图片或者是取值范围为[0,255],形状为[H×W×C]的 numpy.ndarray 的数组转换为取值范围为[0.0,1.0],形状为[C×H×W]的 tensor 格式图片

Lambda Transforms

Lambda 转换应用任何用户定义的 lambda 函数。这里,我们定义了一个函数,将整数转换成 one-hot 张量。它首先创建一个大小为 10(我们数据集中标签的数量)的零张量,并调用 scatter_在标签 y 给出的索引上赋值=1。

Transforms

import torchfrom torchvision import datasetsfrom torchvision.transforms import ToTensor, Lambda

def test_demo(): test = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), target_transform=Lambda(lambda y: torch.zeros( 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) )#lambda表达式,输入一个y,创建一个torch,10的zeros,进行原地的scatter,在0个维度,y个位置数据设置value为1

t = test[0]#输出代表了张量形式的特征,one hot张量形式的标签 print(t)
if __name__ == '__main__': test_demo()
复制代码


用户头像

测试人

关注

专注于软件测试开发 2022-08-29 加入

霍格沃兹测试开发学社,测试人社区:https://ceshiren.com/t/topic/22284

评论

发布
暂无评论
Transforms预处理_人工智能_测试人_InfoQ写作社区