Pytorch 数据加载
DataLoader
常用参数说明:
dataset: Dataset 类 ( 详见下文数据集构建 ),可以自定义数据集或者读取 pytorch 自带数据集
batch_size: 每个 batch 加载多少个样本, 默认 1
shuffle: 是否顺序读取,True 表示随机打乱,默认 False
sampler:定义从数据集中提取样本的策略。如果指定,则忽略 shuffle 参数。
batch_sampler: 定义一个按照 batch_size 大小返回索引的采样器。采样器详见下文 Batch_Sampler
num_workers: 数据读取进程数量, 默认 0
collate_fn: 自定义一个函数,接收一个 batch 的数据,进行自定义处理,然后返回处理后这个 batch 的数据。例如改变数据类型:
pin_memory:设置 pin_memory=True,则意味着生成的 Tensor 数据最开始是属于内存中的锁页内存,这样将内存的 Tensor 转义到 GPU 的显存就会更快一些。默认为 False.
主机中的内存,有两种,一种是锁页,一种是不锁页。锁页内存存放的内容在任何情况下都不会与主机的虚拟内存 (硬盘)进行交换,而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。注意显卡中的显存全部都是锁业内存。如果计算机内存充足的话,设置为 True 可以加快数据交换顺序。
drop_last:默认 False, 最后剩余数据量不够 batch_size 时候,是否丢弃。
timeout: 设置数据读取的时间限制,超过限制时间还未完成数据读取则报错。数值必须大于等于 0
数据集构建
自定义数据集
自定义数据集,需要继承torch.utils.data.Dataset
,然后在__getitem__()
中,接受一个索引,返回一个样本, 基本流程,首先在__init__()
加载数据以及做一些处理,在__getitem__()
中返回单个数据样本,在__len__()
中,返回样本数量
torchvision 数据集
pytorch 自带 torchvision 库可以帮助我们方便快捷的读取和加载数据
TensorDataset
注意这里的 tensor 必须是一维度的数据。
从文件夹中加载数据集
如果想要加载自己的数据集可以这样,用猫狗数据集举例,根目录下 ( "data/train" )
,分别放置两个文件夹,dog 和 cat,这样使用 ImageFolder 函数就可以自动的将猫狗照片自动的按照文件夹定义为猫狗两个标签
数据集操作
数据拼接
连接不同的数据集以构成更大的新数据集。
class torch.utils.data.ConcatDataset( [datasets, ... ] )
数据切分
方法一: class torch.utils.data.Subset(dataset, indices)
取指定一个索引序列对应的子数据集。
方法二:torch.utils.data.random_split(dataset, lengths)
采样器
所有采样器都在 torch.utils.data
中,采样器会根据该有的策略返回一组索引,在 DataLoader 中设定了采样器之后,会根据索引读取相应的样本, 不同采样器生成的索引不一样,从而实现不同的采样目的。
Sampler
所有采样器的基类,自定义采样器的时候需要实现 __iter__()
函数
RandomSampler
RandomSampler,当 DataLoader 的shuffle
参数为 True 时,系统会自动调用这个采样器,实现打乱数据。默认的是采用 SequentialSampler,它会按顺序一个一个进行采样。
SequentialSampler
按顺序采样,当 DataLoader 的shuffle
参数为 False 时,使用的就是 SequentialSampler。
SubsetRandomSampler
输入一个列表,按照这个列表采样。也可以通过这个采样器来分割数据集。
BatchSampler
参数:sampler, batch_size, drop_last
每此返回batch_size
数量的采样索引,通过设置sampler
参数来使用不同的采样方法。
WeightedRandomSampler
参数:weights, num_samples, replacement
它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。通过weights
设定样本权重,权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。num_samples
为返回索引的数量,replacement
表示是否是放回抽样,如果为 True,表示可以重复采样,默认为 True
自定义采样器
集成 Sampler 类,然后实现__iter__()
方法,比如,下面实现一个 SequentialSampler 类
评论