Pytorch 基础 - 张量基本操作
授人以鱼不如授人以渔,原汁原味的知识才更富有精华,本文只是对张量基本操作知识的理解和学习笔记,看完之后,想要更深入理解,建议去 pytorch 官方网站,查阅相关函数和操作,英文版在这里,中文版在这里。本文的代码是在
pytorch1.7
版本上测试的,其他版本一般也没问题。
一,张量的基本操作
Pytorch
中,张量的操作分为结构操作和数学运算,其理解就如字面意思。结构操作就是改变张量本身的结构,数学运算就是对张量的元素值完成数学运算。
常使用的张量结构操作:维度变换(
tranpose
、view
等)、合并分割(split
、chunk
等)、索引切片(index_select
、gather
等)。常使用的张量数学运算:标量运算、向量运算、矩阵运算。
二,维度变换
2.1,squeeze vs unsqueeze 维度增减
squeeze()
:对 tensor 进行维度的压缩,去掉维数为1
的维度。用法:torch.squeeze(a)
将 a 中所有为 1 的维度都删除,或者a.squeeze(1)
是去掉a
中指定的维数为1
的维度。unsqueeze()
:对数据维度进行扩充,给指定位置加上维数为1
的维度。用法:torch.unsqueeze(a, N)
,或者a.unsqueeze(N)
,在a
中指定位置N
加上一个维数为1
的维度。
squeeze
用例程序如下:
程序输出结果如下:
torch.Size([3, 3])torch.Size([1, 3, 3])
unsqueeze
用例程序如下:
程序输出结果如下:
torch.Size([1, 3, 3])torch.Size([1, 3, 3])
2.2,transpose vs permute 维度交换
torch.transpose()
只能交换两个维度,而 .permute()
可以自由交换任意位置。函数定义如下:
在 CNN
模型中,我们经常遇到交换维度的问题,举例:四个维度表示的 tensor:[batch, channel, h, w]
(nchw
),如果想把 channel
放到最后去,形成[batch, h, w, channel]
(nhwc
),如果使用 torch.transpose()
方法,至少要交换两次(先 1 3
交换再 1 2
交换),而使用 .permute()
方法只需一次操作,更加方便。例子程序如下:
三,索引切片
3.1,规则索引切片方式
张量的索引切片方式和 numpy
、python 多维列表几乎一致,都可以通过索引和切片对部分元素进行修改。切片时支持缺省参数和省略号。实例代码如下:
以上切片方式相对规则,对于不规则的切片提取,可以使用 torch.index_select
, torch.take
, torch.gather
, torch.masked_select
。
3.2,gather 和 torch.index_select 算子
gather
算子的用法比较难以理解,在翻阅了官方文档和网上资料后,我有了一些自己的理解。
1,gather
是不规则的切片提取算子(Gathers values along an axis specified by dim. 在指定维度上根据索引 index 来选取数据)。函数定义如下:
参数解释:
input
(Tensor) – the source tensor.dim
(int) – the axis along which to index.index
(LongTensor) – the indices of elements to gather.
gather
算子的注意事项:
输入
input
和索引index
具有相同数量的维度,即input.shape = index.shape
对于任意维数,只要
d != dim
,index.size(d) <= input.size(d),即对于可以不用索引维数d
上的全部数据。输出
out
和 索引index
具有相同的形状。输入和索引不会相互广播。
对于 3D tensor,output
值的定义如下:gather
的官方定义如下:
通过理解前面的一些定义,相信读者对 gather
算子的用法有了一个基本了解,下面再结合 2D 和 3D tensor 的用例来直观理解算子用法。(1),对于 2D tensor 的例子:
output
值定义如下:
(2),索引更复杂的 2D tensor 例子:
output
值的计算如下:
总结:可以看到 gather
是通过将索引在指定维度 dim
上的值替换为 index
的值,但是其他维度索引不变的情况下获取 tensor
数据。直观上可以理解为对矩阵进行重排,比如对每一行(dim=1)的元素进行变换,比如 torch.gather(a, 1, torch.tensor([[1,2,0], [1,2,0]]))
的作用就是对 矩阵 a
每一行的元素,进行 permtute(1,2,0)
操作。2,理解了 gather
再看 index_select
就很简单,函数作用是返回沿着输入张量的指定维度的指定索引号进行索引的张量子集。函数定义如下:
函数返回一个新的张量,它使用数据类型为 LongTensor
的 index
中的条目沿维度 dim
索引输入张量。返回的张量具有与原始张量(输入)相同的维数。 维度尺寸与索引长度相同; 其他尺寸与原始张量中的尺寸相同。实例代码如下:
四,合并分割
4.1,torch.cat 和 torch.stack
可以用 torch.cat
方法和 torch.stack
方法将多个张量合并,也可以用 torch.split
方法把一个张量分割成多个张量。torch.cat
和 torch.stack
有略微的区别,torch.cat
是连接,不会增加维度,而 torch.stack
是堆叠,会增加一个维度。两者函数定义如下:
torch.cat
和 torch.stack
用法实例代码如下:
4.2,torch.split 和 torch.chunk
torch.split()
和 torch.chunk()
可以看作是 torch.cat()
的逆运算。split()
作用是将张量拆分为多个块,每个块都是原始张量的视图。split()
函数定义如下:
chunk()
作用是将 tensor
按 dim
(行或列)分割成 chunks
个 tensor
块,返回的是一个元组。chunk()
函数定义如下:
实例代码如下:
五,卷积相关算子
5.1,上采样方法总结
上采样大致被总结成了三个类别:
基于线性插值的上采样:最近邻算法(
nearest
)、双线性插值算法(bilinear
)、双三次插值算法(bicubic
)等,这是传统图像处理方法。基于深度学习的上采样(转置卷积,也叫反卷积
Conv2dTranspose2d
等)Unpooling
的方法(简单的补零或者扩充操作)
计算效果:最近邻插值算法 < 双线性插值 < 双三次插值。计算速度:最近邻插值算法 > 双线性插值 > 双三次插值。
5.2,F.interpolate 采样函数
Pytorch 老版本有
nn.Upsample
函数,新版本建议用torch.nn.functional.interpolate
,一个函数可实现定制化需求的上采样或者下采样功能,。
F.interpolate()
函数全称是 torch.nn.functional.interpolate()
,函数定义如下:
参数解释如下:
input
(Tensor):输入张量数据;size
: 输出的尺寸,数据类型为 tuple: ([optional D_out], [optional H_out], W_out),和scale_factor
二选一。scale_factor
:在高度、宽度和深度上面的放大倍数。数据类型既可以是 int——表明高度、宽度、深度都扩大同一倍数;也可是
tuple`——指定高度、宽度、深度等维度的扩大倍数。mode
: 上采样的方法,包括最近邻(nearest
),线性插值(linear
),双线性插值(bilinear
),三次线性插值(trilinear
),默认是最近邻(nearest
)。align_corners
: 如果设为True
,输入图像和输出图像角点的像素将会被对齐(aligned),这只在mode = linear, bilinear, or trilinear
才有效,默认为False
。
例子程序如下:
5.3,nn.ConvTranspose2d 反卷积
转置卷积(有时候也称为反卷积,个人觉得这种叫法不是很规范),它是一种特殊的卷积,先 padding
来扩大图像尺寸,紧接着跟正向卷积一样,旋转卷积核 180 度,再进行卷积计算。
参考资料
版权声明: 本文为 InfoQ 作者【嵌入式视觉】的原创文章。
原文链接:【http://xie.infoq.cn/article/37d461197620dbf627df4aed2】。文章转载请联系作者。
评论