当前位置: 首页 > news >正文

深度学习Python基础(2)

二 数据处理

一般来说PyTorch中深度学习训练的流程是这样的:

1. 创建Dateset 

2. Dataset传递给DataLoader

3. DataLoader迭代产生训练数据提供给模型

对应的一般都会有这三部分代码

# 创建Dateset(可以自定义)

    dataset = face_dataset # Dataset部分自定义过的face_dataset

# Dataset传递给DataLoader

    dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle=False,num_workers=8)

# DataLoader迭代产生训练数据提供给模型

    for i in range(epoch):

        for index,(img,label) in enumerate(dataloader):

            pass

到这里应该就PyTorch的数据集和数据传递机制应该就比较清晰明了了。Dataset负责建立索引到样本的映射DataLoader负责以特定的方式从数据集中迭代的产生一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset的getitem方法。其中,还会涉及数据的变化形式。

1.数据收集

找数据集,注意数据集格式.

Dataset是DataLoader实例化的一个参数。

CIFAR10是CV训练中经常使用到的一个数据集,在PyTorch中CIFAR10是一个写好的Dataset,我们使用时只需以下代码:

data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)

datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

用自己在一个文件夹中的数据作为数据集时可以使用ImageFolder这个方便的API。

FaceDataset = datasets.ImageFolder('./data', transform=img_transform)

如何自定义一个数据集

torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。

所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。

Pytorch提供两种数据集: Map式数据集 Iterable式数据集

Map式数据集

一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map).

这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。

自定义类大致是这样的:

class CustomDataset(data.Dataset):#需要继承data.Dataset

    def __init__(self):

        # TODO

        # 1. Initialize file path or list of file names.

        pass

    def __getitem__(self, index):

        # TODO

        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).

        # 2. Preprocess the data (e.g. torchvision.Transform).

        # 3. Return a data pair (e.g. image and label).

        #这里需要注意的是,第一步:read one data,是一个data

        pass

    def __len__(self):

        # You should change 0 to the total size of your dataset.

        return 0

例子-1: 自己实验中写的一个例子:这里我们的图片文件储存在“./data/faces/”文件夹下,图片的名字并不是从1开始,而是从final_train_tag_dict.txt这个文件保存的字典中读取,label信息也是用这个文件中读取。大家可以照着上面的注释阅读这段代码。

from torch.utils import data

import numpy as np

from PIL import Image

class face_dataset(data.Dataset):

    def __init__(self):

        self.file_path = './data/faces/'

        f=open("final_train_tag_dict.txt","r")

        self.label_dict=eval(f.read()) # eval除了计算,还可以将str转为dict

        f.close()

    def __getitem__(self,index):

        label = list(self.label_dict.values())[index-1]

        img_id = list(self.label_dict.keys())[index-1]

        img_path = self.file_path+str(img_id)+".jpg"

        img = np.array(Image.open(img_path))

        return img,label

    def __len__(self):

        return len(self.label_dict)

Iterable式数据集

一个Iterable(迭代)式数据集是抽象类data.IterableDataset的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。

2.数据划分

数据划分主要是路径的处理。

def makedir(new_dir):

    if not os.path.exists(new_dir):

        os.makedirs(new_dir)

检测路径是否存在,若不存在,则创建此路径。

dataset_dir = os.path.join("..", "..", "data", "RMB_data")

设置路径,将它们组合在一起。相对于Python文件所在位置的相对路径。

    for root, dirs, files in os.walk(dataset_dir):

        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))

            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))

            random.shuffle(imgs)

            img_count = len(imgs)

            train_point = int(img_count * train_pct)

            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):

                if i < train_point:

                    out_dir = os.path.join(train_dir, sub_dir)

                elif i < valid_point:

                    out_dir = os.path.join(valid_dir, sub_dir)

                else:

                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])

                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

os.walk

每一层遍历:

root保存的就是当前遍历的文件夹的绝对路径;

dirs保存当前文件夹下的所有子文件夹的名称(仅一层,孙子文件夹不包括)

files保存当前文件夹下的所有文件的名称

其次,发现它的遍历文件方式,在图的遍历方式中,那可不就是深度遍历嘛!!

  1. os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表

shutil.copy()Python中的方法用于将源文件的内容复制到目标文件或目录。它还会保留文件的权限模式,但不会保留文件的其他元数据(例如文件的创建和修改时间)。源必须代表文件,但目标可以是文件或目录。如果目标是目录,则文件将使用源中的基本文件名复制到目标中。另外,目的地必须是可写的。如果目标是文件并且已经存在,则将其替换为源文件,否则将创建一个新文件。

3.图像预处理-transforms

3.1 图像标准化

transforms.Normalize(mean,std,inplace)

逐通道的标准化,每个通道先求出平均值和标准差,然后标准化。Inplace表示是否原地操作。

3.2 图像裁剪

train_transform = transforms.Compose([

    transforms.Resize((32, 32)),

    transforms.RandomCrop(32, padding=4),

    transforms.ToTensor(),

    transforms.Normalize(norm_mean, norm_std),

])

(1)transforms.CenterCrop(size)

从图片中心截取size大小的图片。

(2)transforms.RandomCrop(size,padding,padding_mode)

随机裁剪区域。

(3)transforms.RandomResizedCrop(size,scale,ratio)

随机大小,随机长宽比的裁剪。

3.3图像旋转

(1)transforms.RandomHorizationalFlip(p)

依据概率p水平翻转。

(2)transforms.RandomVerticalFlip(p)

依据概率p垂直翻转。

(3)transforms.RandomRotation(degrees,resample,expand)

transforms方法

Transforms Methods

一、裁剪

1. transforms.CenterCrop

2. transforms.RandomCrop

3. transforms.RandomResizedCrop

4. transforms.FiveCrop

5. transforms.TenCrop

二、翻转和旋转

1. transforms.RandomHorizontalFlip

2. transforms.RandomVerticalFlip

3. transforms.RandomRotation

三、图像变换

• 1. transforms.Pad

• 2. transforms.ColorJitter

• 3. transforms.Grayscale

• 4. transforms.RandomGrayscale

• 5. transforms.RandomAffine

• 6. transforms.LinearTransformation

• 7. transforms.RandomErasing

• 8. transforms.Lambda

• 9. transforms.Resize

• 10. transforms.Totensor

• 11. transforms.Normalize

四、transforms的操作

• 1. transforms.RandomChoice

• 2. transforms.RandomApply

• 3. transforms.RandomOrder

train_transform = transforms.Compose([

    transforms.Resize((224, 224)),

    # 1 CenterCrop

    # transforms.CenterCrop(512),     # 512

    # 2 RandomCrop

    # transforms.RandomCrop(224, padding=16),

    # transforms.RandomCrop(224, padding=(16, 64)),

    # transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),

    # transforms.RandomCrop(512, pad_if_needed=True),   # pad_if_needed=True

    # transforms.RandomCrop(224, padding=64, padding_mode='edge'),

    # transforms.RandomCrop(224, padding=64, padding_mode='reflect'),

    # transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),

    # 3 RandomResizedCrop

    # transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),

    # 4 FiveCrop

    # transforms.FiveCrop(112),

    # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

    # 5 TenCrop

    # transforms.TenCrop(112, vertical_flip=False),

    # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

    # 1 Horizontal Flip

    # transforms.RandomHorizontalFlip(p=1),

    # 2 Vertical Flip

    # transforms.RandomVerticalFlip(p=0.5),

    # 3 RandomRotation

    # transforms.RandomRotation(90),

    # transforms.RandomRotation((90), expand=True),

    # transforms.RandomRotation(30, center=(0, 0)),

    # transforms.RandomRotation(30, center=(0, 0), expand=True),   # expand only for center rotation

    transforms.ToTensor(),

    transforms.Normalize(norm_mean, norm_std),

])

4.数据加载-DataLoader


http://www.mrgr.cn/news/80994.html

相关文章:

  • Spring Cloud工程完善
  • aio-pika 快速上手(Python 异步 RabbitMQ 客户端)
  • 基于python+Django+mysql鲜花水果销售商城网站系统设计与实现
  • 电脑重启后vscode快捷方式失效,找不到code.exe
  • deepseek本地部署-linux
  • 【免费】2011-2020年各省互联网宽带接入用户数据
  • 移植 OLLVM 到 LLVM18,修复控制流平坦化报错
  • EdgeX Core Service 核心服务之 Meta Data 元数据
  • 精通Redis(一)
  • SpringBoot Redis 消息队列
  • JWT,OAuth 2.0,Apigee的区别与关系
  • MySQL的详细使用教程
  • .NET重点
  • iOS + watchOS Tourism App(含源码可简单复现)
  • 【Lua热更新】上篇
  • Restaurants WebAPI(三)——Serilog/
  • BenchmarkSQL使用教程
  • 使用RTP 协议 对 H264 封包和解包
  • 使用“NodeMCU”、“红外模块”实现空调控制
  • Day12 梯度下降法的含义与公式
  • php各个版本的特性以及绕过方式
  • 基础电路的学习
  • 在VBA中结合正则表达式和查找功能给文档添加交叉连接
  • 分析excel硕士序列数据提示词——包含对特征的筛选
  • k8s迁移——岁月云实战笔记
  • JWT令牌与微服务