PyTorch DataLoader详解:如何高效加载和处理大规模数据集?

您所在的位置:网站首页 如何添加数据分析库的数据 PyTorch DataLoader详解:如何高效加载和处理大规模数据集?

PyTorch DataLoader详解:如何高效加载和处理大规模数据集?

2024-07-17 16:50| 来源: 网络整理| 查看: 265

一、DataLoader概述

DataLoader是深度学习中重要的数据处理工具之一,旨在有效加载、处理和管理大规模数据集,用于训练和测试机器学习和深度学习模型。

DataLoader主要用于两个关键任务:数据加载和批次处理

数据加载:DataLoader可以从不同来源加载数据,如硬盘上的文件、数据库、网络等。它能够自动将数据集划分为小批次,从而减小内存需求,确保数据的高效加载。数据批次处理:每个批次由多个样本组成,可以并行地进行数据预处理和数据增强。这有助于提高模型训练的效率,同时确保每个批次的数据都经过适当的处理。 二、DataLoader读取数据流程

DataLoader读取数据的详细流程如下:

1、根据dataset和sampler,生成数据索引。

2、根据这些索引,从dataset中读取指定数量的数据,并对其进行预处理(例如归一化、裁剪 等)。

3、如果设置了collate_fn,则将处理后的数据打包成批次数据。

4、如果设置了num_workers > 0,则将数据加载任务分配给多个子进程并行完成。

5、在模型训练时,每个epoch从DataLoader中获取一个批次的数据,作为模型的输入。

三、DataLoader使用 def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None, batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False, pin_memory_device: str = ""):

上述为DataLoader所能使用的所有参数,下述为参数的相关说明:

dataset:加载的数据集对象,类型为Dataset[T_co],其中T_co是数据集中元素的数据类型。batch_size:每个batch的大小,类型为Optional[int],默认为1。shuffle:是否打乱数据集,类型为Optional[bool],默认为None,表示不打乱。sampler:样本采样器,用于定义数据采样策略,类型为Union[Sampler, Iterable, None],默认为None。batch_sampler:批量采样器,用于定义批量采样策略,类型为Union[Sampler[Sequence], Iterable[Sequence], None],默认为None。num_workers:数据加载时的并发数,类型为int,默认为0。collate_fn:将一个list的sample组成一个mini-batch的函数,类型为Optional[_collate_fn_t],默认为None。pin_memory:如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中,类型为bool,默认为False。drop_last:如果设置为True,对于最后的未完成的batch,会被扔掉,类型为bool,默认为False。timeout:如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了,类型为float,默认为0。worker_init_fn:用于初始化worker进程的函数,类型为Optional[_worker_init_fn_t],默认为None。multiprocessing_context:用于创建进程的上下文,类型为Optional[MultiprocessingContext],默认为None。generator:用于生成数据的生成器函数或对象,类型为Optional[Generator],默认为None。prefetch_factor:用于设置预取因子的参数,类型为int,默认为2。persistent_workers:是否使用持久化的worker进程,类型为bool,默认为False。pin_memory_device:用于设置固定内存的设备名称,类型为str,默认为空字符串。 四、代码实践应用 training_dataloader = DataLoader(training_data,batch_size=64) test_dataloader = DataLoader(test_data,batch_size=64)

在这段代码中,创建了两个DataLoader对象,一个用于训练数据('training_dataloader'),另一个用于测试数据('test_dataloader')。'batch_size=64'表示每个批次中包含的样本数量是64。

具体来说,'training_dataloader = DataLoader(training_data,batch_size=64)'这行代码将'training_data'(训练数据集)加载到DataLoader中,并设置每个批次的样本数为64。同样,'test_dataloader = DataLoader(test_data,batch_size=64)'这行代码将'test_data'(测试数据集)加载到DataLoader中,并设置每个批次的样本数为64。

在实际使用时,每次迭代(或循环)会从对应的dataloader中获取一个批次的数据,然后将这个批次的数据用于模型的训练或测试。



【本文地址】


今日新闻


推荐新闻


    CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3