Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】

您所在的位置:网站首页 pytorch重写dataset Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】

Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】

2023-05-26 02:30| 来源: 网络整理| 查看: 265

数据载入由什么组成?

数据载入由dataset和dataloader组成。 dataset:提供一种方式去获取数据及其label dataloader: 为后面网络提供不同的数据形式

1. Dataset的功能

dataset主要为了实现两个功能 1.如何获取每个数据及其label 2.告诉我们总共有多少的数据

2. Dataset代码: 1) 查看官方文档解释

首先,在anaconda prompt中输入如下代码,打开jupyter环境

conda activate #激活pytorch环境 jupyter notebook #打开jupyter

然后,在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。

from torch.utils.data import Dataset help(Dataset) 2) pycharm中进行程序编写

(1)下载数据集 蚁蜜蜂分类数据集 https://download.pytorch.org/tutorial/hymenoptera_data.zip (2) 建立dataset文件,并将数据集放入程序下 在这里插入图片描述 (3)编写数据集载入程序,实现dataset两个功能,第一,如何获取每个数据及其label 第二,告诉我们总共有多少的数据。

from torch.utils.data import Dataset from PIL import Image import os from torchvision import transforms class mydata(Dataset): #设置全局参数 def __init__(self,root_dir,label_dir): self.root_dir=root_dir self.label_dir=label_dir # 获得图片的路径地址 self.path=os.path.join(self.root_dir,self.label_dir) #os.path.join()函数用于路径拼接文件路径,可以传入多个路径。如果不存在以’/’开始的参数,则函数会自动加上 # 获得图片的所有列表 self.img_path=os.listdir(self.path) #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。 #获取每一个图片 def __getitem__(self, item): #获取单张图片名称 img_name=self.img_path[item] #获取单张图片相对路径 img_item_path=os.path.join(self.root_dir,self.label_dir,img_name) #读取单张图片 img=Image.open(img_item_path) img=img.resize((256,256),Image.ANTIALIAS) #统一图片尺寸 trans = transforms.ToTensor() #转换为tensor类型 img_tensor = trans(img)#转换为tensor类型 #获取lable label=self.label_dir return img_tensor, label #列表有多长 def __len__(self): return len(self.img_path) 3. Dataloader的功能

dataloader的功能是为了实现从dataset中取数据,例如,每次取多少数据?,数据集是否打乱?,加载过程是单进程还是多进程?,如果最后剩余数据不足一次需要获取数据,剩余数据是否舍弃。

4. Dataloader代码 1) 查看官方文档解释

在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。

from torch.utils.data import DataLoader help(DataLoader) 2) pycharm中进行程序编写

新建.py文件,写入以下内容

from dataset import mydata from torch.utils.data import DataLoader import torch #准备测试数据集 root_dir= "dataset/val" bees_label_dir="bees" test_dataset=mydata(root_dir,bees_label_dir)#输入数据集路径 #数据载入 test_loader=DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False) for data1 in test_loader: imgs,labs=data1 print(imgs.shape) print(labs)

输出如下内容则表示数据载入成功 在这里插入图片描述

感谢: PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】 视频网址:https://www.bilibili.com/video/BV1hE411t7RN?p=15&vd_source=5b6e0605c1ed0f1db9c92503dd5994e0



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3