当前位置: 首页 > news >正文

使用Vision Transformer进行图像分类

一、背景

在上一篇文章《下载数据集用于图像分类并自动分为训练集和测试集方法》,我们已经下载好了花分类数据集,并自动将其分为训练集和测试集。接下来我们训练Vision Transformer使其可以正确分类我们的花数据集。

二、环境配置

系统:Windows 11

使用上一篇文章创建的conda环境,conda名称为Vit。
Anaconda3
python3.8
pycharm(IDE)

然后在该conda环境下安装pytorch以及相关依赖,具体操作如下:

1、安装PytTorch
(1)浏览器打开PytTorch网站:https://pytorch.org/

鼠标往下翻,翻到安装界面,选择安装之前的pytorch版本。
在这里插入图片描述

因为windows系统下安装cuda以及cudnn太麻烦,所以我选择GPU版本,如下图所示:

在这里插入图片描述

# 打开Anaconda Prompt
conda activate Vit
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cpuonly -c pytorch

注意:作者WZMIAOMIAO要求的版本为pytorch 1.10 (pip package)、torchvision 0.11.1 (pip package)、tensorflow 2.4.1 (pip package),我因为没找到torchvision 0.11.1,所以安装了torchvision==0.11.2 ,经验证,程序依然跑的通,问题不大。

(2)打开pycharm,新建项目deep-learning-for-image-processing-master,并设置环境为我们创建的Vit conda环境。

pycharm设置conda环境可参考我之前的博客《Windows系统配置Anaconda虚拟环境,并安装Numpy、Scipy和Matplotlib等模块方法》

在这里插入图片描述

(3)若后续运行train.py程序时报错AttributeError: module ‘distutils’ has no attribute ‘version’.

原因:setuptools版本过高
解决办法:安装低版本setuptools

conda activate Vit
pip uninstall setuptools
pip install setuptools==59.5.0

如果通过上面一系列更新/固定setuptools包版本,依然报错:AttributeError: module ‘distutils’ has no attribute ‘version’. 则进行以下操作,亲测有效,那就是修改init.py文件,参考《AttributeError: module ‘distutils‘ has no attribute ‘version‘》

修改init.py文件的方法如下:
init.py文件位置:C:\Users\13679\.conda\envs\Vit\Lib\site-packages\torch\utils\tensorboard
a.将from setuptools import distutils替换成from distutils.version import LooseVersion
b.注释LooseVersion = distutils.version.LooseVersio
c.注释del distutils
因此,修改后完整的init.py文件如下:

import tensorboard
# from setuptools import distutils
from distutils.version import LooseVersion# LooseVersion = distutils.version.LooseVersionif not hasattr(tensorboard, '__version__') or LooseVersion(tensorboard.__version__) < LooseVersion('1.15'):raise ImportError('TensorBoard logging requires TensorBoard version 1.15 or above')# del distutils
del LooseVersion
del tensorboardfrom .writer import FileWriter, SummaryWriter  # noqa: F401
from tensorboard.summary.writer.record_writer import RecordWriter  # noqa: F401
(4)若后续运行train.py程序时报错ModuleNotFoundError: No module named ‘tqdm’

tqdm是一个快速、可扩展的Python进度条库,用于展示迭代器的长循环执行进度。
解决办法:通过以下命令安装

conda install tqdm
(5)若后续运行train.py程序时报错ModuleNotFoundError: No module named ‘matplotlib’

解决办法:通过以下命令安装

conda install matplotlib

三、训练花分类数据集

经过一系列环境配置,及bug调试,我们终于能成功运行训练花分类数据集的train.py文件了,这里记录下详细过程,以及注意事项。

千万注意:不要着急运行train.py,不要着急运行train.py,不要着急运行train.py,因为训练新的数据集,之前训练的权重文件model-9.pth以及分类文件class_indices.json会被覆盖,如果没有提前保存,之前花费很长时间训练的结果则会丢失。

因此,训练结束后,我会把权重文件model-9.pth和分类文件class_indices.json另存在weightsSave和jsonSave文件夹中。

总结写在前面:使用train.py训练新的数据集只需要修改两个地方:①数据集路径–data-path以及②Vit的预训练权重–weights(预训练权重第一次设置完成,后面也很少修改),所以通常只需修改数据集路径即可。具体步骤如下:

1 在train.py脚本中将–data-path设置成解压后的flower_photos文件夹绝对路径
 parser.add_argument('--data-path', type=str,default="E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/data_set/flower_data/flower_photos")

注意:因为是在windows系统运行的,所以路径的层级结构使用/符号,而不是\符号,不然会报错。

2 下载预训练权重,在vit_model.py文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重

我这里考虑训练时间的问题,选择了最基础的预训练权重vit_base_patch16_224_in21k,下载链接:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth,下载完成后,需要将其重命名为vit_base_patch16_224_in21k.pth

3 在train.py脚本中将–weights参数设成下载好的预训练权重路径
    parser.add_argument('--weights', type=str, default='E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/pytorch_classification/vision_transformer/vit_base_patch16_224_in21k.pth',help='initial weights path')

设置好数据集的路径–data-path以及预训练权重的路径–weights就能使用train.py脚本开始训练了(训练过程中会自动生成权重文件model-0.pth-model-9.pth以及class_indices.json文件)

因为每个epoch训练都会生成一个权重文件,代码使用了10个epoch,所以生成了从model-0.pth到model-9.pth的10个权重文件。

下图是训练后的结果。

在这里插入图片描述

为了防止自己把train.py修改的面目全非,这里保存下train.py的完整代码

import os
import math
import argparseimport torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes, has_logits=False).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)# 删除不需要的权重del_keys = ['head.weight', 'head.bias'] if model.has_logits \else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']for k in del_keys:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head, pre_logits外,其他权重全部冻结if "head" not in name and "pre_logits" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=10)parser.add_argument('--batch-size', type=int, default=8)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default="E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/data_set/flower_data/flower_photos")#parser.add_argument('--data-path', type=str,#default="E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/data_set/HumanIntention_data/Marker_photos")parser.add_argument('--model-name', default='', help='create model name')# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/pytorch_classification/vision_transformer/vit_base_patch16_224_in21k.pth',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=True)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

四、使用predict.py脚本测试训练效果

写在前面:使用predict.py脚本测试图片分类效果只需要修改三个地方:①在predict.py脚本中导入和训练脚本中同样的模型,即from vit_model import vit_base_patch16_224_in21k as create_model,但这一步基本不用修改。②设置权重文件的路径model_weight_path。③设置需要测试的图片的绝对路径img_path。

通常只修改②权重文件和③需要预测图片的路径即可,具体步骤如下:

1 在predict.py脚本中导入和训练脚本中同样的模型,并将model_weight_path设置成训练好的模型权重路径(默认保存在weights文件夹下)
model_weight_path = "./weightsSave/model-9Flower.pth"

针对花分类权重,我这里改名为model-9Flower.pth并另外在weightsSave文件夹,防止训练其他数据集时被覆盖。

2 在predict.py脚本中将img_path设置成你自己需要预测的图片绝对路径
img_path = "E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/data_set/flower_data/flower_photos/roses/159079265_d77a9ac920_n.jpg"

设置好权重路径model_weight_path和预测的图片路径img_path就能使用predict.py脚本进行预测了。

同样为了防止自己将原始文件修改的面目全非,这里保存在predict.py的完整代码

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom vit_model import vit_base_patch16_224_in21k as create_modeldef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])# load imageimg_path = "E:/manipulator_programming/ViT/deep-learning-for-image-processing-master/data_set/flower_data/flower_photos/roses/159079265_d77a9ac920_n.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './jsonSave/class_indicesFlower.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = create_model(num_classes=5, has_logits=False).to(device)# load model weightsmodel_weight_path = "./weightsSave/model-9Flower.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

五、预测结果展示

1、蒲公英dandelion,60%的概率是。
在这里插入图片描述
2、玫瑰roses,94.9%的概率是,效果显著。

在这里插入图片描述

六、总结

如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),
并且将训练以及预测脚本中的num_classes设置成你自己数据的类别数。

小Tip:圈数字
① ② ③ ④ ⑤ ⑥ ⑦ ⑧ ⑨ ⑩
⑪ ⑫ ⑬ ⑭ ⑮ ⑯ ⑰ ⑱ ⑲ ⑳


http://www.mrgr.cn/news/61166.html

相关文章:

  • SAP消息号 V1599 对于项目 000010 无法确定业务区域
  • 【CSS】HTML页面定位CSS - position 属性 relative 、absolute、fixed 、sticky
  • 计算机网络 (42)远程终端协议TELNET
  • Windows 安装 Docker 和 Docker Compose
  • 【python】OpenCV—Local Translation Warps
  • docker安装mysql详细教程
  • Vue.js(2) 入门指南:从基础知识到核心功能
  • 【动态三维重建】MonST3R:运动中的几何估计
  • 【专题】2024中国B2B市场营销现况白皮书报告汇总PDF洞察(附原数据表)
  • DevEco Studio的使用 习题答案<HarmonyOS第一课>
  • 【射频器件】QPM1017 QPM2102 QPC1031D QPC7333 QPC7339 QM45398 QM14068- Qorvo特点、及应用
  • RTT工具学习
  • signal() -函数的详细使用说明
  • 树莓集团:以数字化平台为基,构建智慧园区生态体系
  • UI设计软件全景:13款工具助力创意实现
  • Spring Boot 经典九设计模式全览
  • AI绘画:SD3.5来了,Flux又不行了?
  • 【JAVA】第四张_Eclipse创建Maven项目
  • 【PUCCH——Format和资源集】
  • SpringBoot 定时任务 @Scheduled 详细解析
  • CentOS7安装Docker-2024
  • 软考信息系统监理师 高分必背
  • 3. 无重复字符的最长子串
  • 虚拟现实辅助工程技术助力航空航天高端制造业破局
  • 3211、生成不含相邻零的二进制字符串-cangjie
  • 富格林:曝光可信经验击败陷阱