当前位置: 首页 > news >正文

J1:ResNet-50算法实战与解析(鸟类识别)

J1周:ResNet-50算法实战与解析(鸟类识别)

      • **理论背景**☕
        • 1、CNN算法发展
        • 2、ResNet介绍
        • 3、ResNet-50介绍
          • 1、Input->STAGE 0:
          • 2、残差块(STAGE1->STAGE4)
      • **PyTorch实现**
        • 1、导入库并设置GPU
        • 2、导入和检查数据
        • 3、划分数据集
        • 4、构建ResNet-50模型
          • ①先构建基本残差块再整合成残差组
          • ②残差块分类别构建
        • 5、编写训练和测试函数
        • 6、正式训练
        • 7、模型评估
        • 8、预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

🍺本周任务

  • 根据本文Tensorflow代码,编写出相应的Pytorch代码
  • 了解残差结构
  • 是否可以将残差模块融入到C3当中(自由探索)

⛽ 我的环境

  • 语言环境:Python3.10.12
  • 编译器:Google Colab
  • 深度学习环境:
    • torch==2.4.1+cu121
    • torchvision==0.19.1+cu121

⭐参考博客&文章:

  • J1 - ResNet-50实战
  • 基于resnet的鸟类图片分类实验(pytorch版本)
  • 深度学习 Day21——J1ResNet-50算法实战与解析
  • 进击J1:ResNet-50算法实战与解析
  • CNN算法(一)——残差网络ResNet-50
  • ResNet50网络框架图
  • ResNet-50网络理解
  • ResNet50超详细解析!!!
  • Deep Residual Learning for Image Recognition

理论背景

1、CNN算法发展

在这里插入图片描述

首先借用一下来自于网络的插图,在这张图上列出了一些有里程碑意义的、经典卷积神经网络。评估网络的性能,一个维度是识别精度【纵轴】,另一个维度则是网络的复杂度(计算量)【横轴】。从这张图里,我们能看到:

  1. AlexNet:由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年ImageNet图像分类竞赛中提出的一种经典的卷积神经网络。AlexNet是首个深层卷积神经网络,包含5个Conv和3个FC,同时也引入了ReLU激活函数、局部归一化、数据增强和Dropout处理。
  2. VGGNet:由牛津大学VGG(Visual Geometry Group)提出,可以看成是加深版本的AlexNet。通过更小的卷积核(3×3)和更深的网络来提高性能;如VGG-16将深层网络结构分为几个组,每组堆叠数量不等的Conv-ReLu层并在最后一层使用GlobalMaxPooling缩减特征尺寸.
  3. GoogLeNet:加深网络的同时(22层),引入Inception模块,可以在同一层内并行地应用不同大小的卷积核,以捕获不同尺度的特征。随后衍生出V2、V3、V4等一系列网络结构,构成一个家族。
  4. ResNet:RestNet有V1、V2、NeXt等不同的版本,这是一个提出恒等映射概念、具有短路直接路径、模块化的网络结构可以很方便地扩展位18~101层;
  5. DenseNet:由Gao Huang等人在2016年提出,它通过连接每个层到前面所有层来加强特征传播,从而提高性能和参数效率。这是一种具有前级特征重用、层间直连、结构递归扩展等特点的卷积网络

在这里插入图片描述

2、ResNet介绍

深度残差网络ResNet(deep residual network)在2015年由何恺明等人提出,因为它简单与实用并存,随后很多研究都是建立在ResNet-50或者ResNet-101基础上完成。

  • ResNet主要解决卷积网络在深度加深时候出现的“退化”问题。在一般的卷积神经网络中,增大网络深度后带来的第一个问题就是梯度消失、爆炸;

梯度消失:是指在反向传播过程中,随着网络层数的增加,前几层的梯度值变得非常小,接近于零。这会导致权重更新非常慢,从而无法有效训练深层网络;

梯度爆炸:指在反向传播过程中,随着网络层数的增加,梯度值变得非常大,最终导致网络权重的更新幅度过大,进而导致模型无法收敛,甚至出现数值溢出

2015年,Szegedy提出BN层后基本解决了上述问题。BN层能对各层的输出做归一化,这样梯度在反向层层传递后仍能保持大小稳定,不会出现过小或过大的情况。但是作者发现加了BN后再加大深度仍然不容易收敛,其提到了第二个问题–准确率下降问题:层级大到一定程度时准确率就会饱和,然后迅速下降,这种下降即不是梯度消失引起的也不是过拟合造成的,而是由于网络过于复杂,以至于光靠不加约束的放养式的训练很难达到理想的错误率。

这种准确率下降问题不是网络结构本身的问题,而是现有的训练方式不够理想造成的。当前广泛使用的优化器,无论是SGD,还是RMSProp,或是Adam,都无法在网络深度变大后达到理论上最优的收敛结果。
作者在文中证明了只要有合适的网络结构,更深的网络肯定会比较浅的网络效果好。证明过程也很简单:假设在一种网络A的后面添加几层形成新的网络B,如果增加的层级只是对A的输出做了个恒等映射(identity mapping),即A的输出经过新增的层级变成B的输出后没有发生变化,这样网络A和网络B的错误率就是相等的,也就证明了加深后的网络不会比加深前的网络效果差。
在这里插入图片描述

何恺明等提出了一种残差结构来实现上述恒等映射【如上图】:整个模块除了正常的卷积层输出外,还有一个分支把输入直接连到输出上,该分支输出和卷积的输出做算术相加得到最终的输出,用公式表达就是H(x)=F(x)十xx是输入,F(x)是卷积分支的输出,H(x)是整个结构的输出。可以证明如果F(x)分支中所有参数都是0,H(x)就是个恒等映射。**残差结构人为制造了恒等映射,就能让整个结构朝着恒等映射的方向去收敛,确保最终的错误率不会因为深度的变大而越来越差。**如果一个网络通过简单的手工设置参数值就能达到想要的结果,那这种结构就很容易通过训练来收敛到该结果,这是一条设计复杂的网络时通用的规则。

在这里插入图片描述

上图所示,左边的单元为 ResNet 两层的残差单元,两层的残差单元包含两个相同输出的通道数的 3x3 卷积,只是用于较浅的 ResNet 网络,对较深的网络主要使用三层的残差单元。三层的残差单元又称为 bottleneck 结构,先用一个1x1卷积进行降维,然后3x3卷积,最后用1x1升维恢复原有的维度另外,如果有输入输出维度不同的情况,可以对输入做一个线性映射变换维度,再连接后面的层。三层的残差单元对于相同数量的层又减少了参数量,因此可以拓展更深的模型。通过残差单元的组合有经典的ResNet-50,ResNet101等网络结构。

3、ResNet-50介绍

ResNet有多种不同的架构,主要根据网络的深度和使用的残差模块类型进行分类:

  • ResNet-18 和 ResNet-34:这两种架构使用了名为Basic Block的残差模块,适用于较浅的网络结构。Basic Block由两个3x3的卷积层构成,其中第一个卷积层可能包含步长为2的卷积以实现降采样

  • ResNet-50、ResNet-101 和 ResNet-152:这些更深的网络架构使用了Bottleneck Block作为残差模块

  • ResNet-50:作为ResNet系列中的一个中间深度模型,ResNet-50包含50层网络结构,具体来说,它由4组残差块组成,每组分别包含3、4、6、3个残差块,加上最开始的一个单独的卷积层,以及最后的全连接层,总共达到50层。

  • ResNet-50在设计上有一些细节,比如在每个残差块中,第一个卷积层后通常会跟一个批量归一化(Batch Normalization)和ReLU激活函数,而在后续的卷积层中,只有批量归一化,没有激活函数,直到残差块的最后才会再次使用ReLU激活函数。

  • ResNet v2:是ResNet的改进版本,将ReLU激活函数放在残差块的shortcut连接之后,并且每个卷积层后面都跟随一个BN层

在这里插入图片描述

在这里插入图片描述

主要分为以下几个部分:

  • 初始卷积层:7×7卷积核的卷积层,紧接着是批量归一化(Batch Normalization)和最大池化层(Max Pooling)。
  • 四个残差块:每个残差块内包含若干个残差单元,其中包含卷积、批量归一化和ReLU激活函数等。
  • 全局平均池化层:将特征图通过平均池化层缩小到1×1的尺寸。
  • 全连接层:最后通过一个全连接层进行分类。

    这种结构能够有效地提取图像特征,且由于残差块的引入,使得网络可以堆叠更多层而不会出现显著的退化问题。

在这里插入图片描述

上是另一个较完整的ResNet-50框架图【图源:https://blog.csdn.net/wuqitong123/article/details/132725824?spm=1001.2014.3001.5502】

1、Input->STAGE 0:
  • 模型输入【假设为3×224×224】会首先经过一个7×7的卷积核负责进行特征的提取,步长为2,padding为3;所以输出变为【(224-7+2×3)/2+1=112.5】,向下取整得到112【最终输出:64×112×112】
  • maxpool层会改变维度但不影响个数【输出:64×56×56】
2、残差块(STAGE1->STAGE4)

在这里插入图片描述

接下来进入四个残差块组,分别含有3、4、6、3个残差块;其中分成Conv BlockIdentity Block

Identity Block(图中的BTNK2)

BINK2有两个参数:C,W
C:代表输入通道数。
W:代表输入尺寸。
BINK2左侧经过三个卷积快(包括BN,RELU),设其输出为F(x),将F和x相加再经过Relu激活函数得到BINK2的输出。

在这里插入图片描述

Conv Block(即是图中的BTNK1)

BTNK1有四个参数:C,W,C1,S;

S:表示卷积层中的步长,当S为1时,输入尺寸和输出尺寸相同,代表没有进行下采样;

C1:代表卷积层输出的特征图数目,即输出通道数

C:代表输入通道数。C和C1相等说明左侧1×1的卷积层没有减少通道数,后三个stage中C=2*C1说明左侧1×1的卷积层减少了通道数。

W:代表输入尺寸,即长和宽。

BINK1相对于BINK2是输入通道和输出通道不一致的情况,BINK1右侧先经过一个卷积层,改变其输出通道数,设其输出为G(x),G函数起到了和左侧输出通道数匹配的作用,这样将F和G相加再经过Relu激活函数得到BINK1的输出。

ResNet50中在残差结构中引入了Bottleneck结构(瓶颈层),目的是降低参数的数目(多个小尺寸卷积代替一个大尺寸卷积)。

1、直接使用 3x3 的卷积核。256 维的输入直接经过一个 3×3×256 的卷积层,输出一个256维的 feature map,那么参数量为:256×3×3×256 = 589824 。

2、先经过 1x1 的卷积核,再经过 3x3 卷积核,最后经过一个 1x1 卷积核。 256 维的输入先经过一个 1×1×64 的卷积层,再经过一个 3x3x64 的卷积层,最后经过 1x1x256 的卷积层,则总参数量为:256×1×1×64 + 64×3×3×64 + 64×1×1×256 = 69632

通过对比,可以看出第二种的参数量远少于第一种的参数量。

PyTorch实现

1、导入库并设置GPU
#### 导入库和需要的包
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import torch.nn.functional as F
import copy, random, pathlib
import matplotlib.pyplot as plt
from PIL import Image
from torchsummary import summary
import numpy as np
import warnings
warnings.filterwarnings("ignore")### 设置GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
2、导入和检查数据
from google.colab import drive
drive.mount("/content/drive/")
%cd "/content/drive/My Drive/Colab Notebooks/jupyter notebook/data/J1"
Mounted at /content/drive/
/content/drive/My Drive/Colab Notebooks/jupyter notebook/data/J1
data_dir = "./bird_photos/"
data_dir = pathlib.Path(data_dir)
data_paths = list(data_dir.glob("*"))classnames = [str(path).split("/")[1] for path in data_paths]
classnames
['Cockatoo', 'Black Skimmer', 'Bananaquit', 'Black Throated Bushtiti']
img_list = list(data_dir.glob("*/*.jpg"))
count = len(img_list)
print("the total number:",count)
image = Image.open(str(img_list[2]))
print(image.format,image.size,image.mode)
plt.figure(figsize=(2,3))
plt.imshow(image)
plt.axis("off")
plt.show()
the total number: 565
JPEG (224, 224) RGB

在这里插入图片描述

'''可视化数据'''
plt.figure(figsize=(16,10))
for i in range(20):plt.subplot(4,5,i+1)plt.axis("off")image = random.choice(img_list)label_name = image.parts[-2]plt.title(label_name)plt.imshow(Image.open(str(image)))

在这里插入图片描述

3、划分数据集
'''图像数据变换'''
train_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],)
])total_data = datasets.ImageFolder(data_dir,transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)
Dataset ImageFolderNumber of datapoints: 565Root location: bird_photosStandardTransform
Transform: Compose(ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
{'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}
'''划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_ds, test_ds = random_split(total_data,[train_size,test_size])
print("train_dataset={}\ntest_dataset={}".format(train_ds, test_ds))
print("train_size={}\ntest_size={}".format(train_size, test_size))
train_dataset=<torch.utils.data.dataset.Subset object at 0x7de94ee1aa10>
test_dataset=<torch.utils.data.dataset.Subset object at 0x7de881193460>
train_size=452
test_size=113
'''加载数据'''
batch_size=8
train_dl = DataLoader(train_ds,batch_size=batch_size,shuffle = True,num_workers=1)
test_dl = DataLoader(test_ds,batch_size=batch_size,shuffle = True,num_workers=1)
for X,y in test_dl:print("Shape of X [N, C, H, W]:",X.shape)print("Shape of y:",y.shape,y.dtype)break
Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
4、构建ResNet-50模型
  • 代码参考博客:https://blog.csdn.net/weixin_46620278/article/details/139161645?spm=1001.2014.3001.5502

  • 有关ResNet每一层的结构以及各层维度的输入输出计算的详细介绍参考这个博客(有部分错误):https://blog.csdn.net/qq_51256566/article/details/122409854
①先构建基本残差块再整合成残差组

构建基本残差块:

构造 ResNet 的基本残差块,其中包含三层卷积,并通过捷径连接(shortcut)来实现残差学习,从而更有效地训练深度神经网络

import torch.nn.functional as F
#构造ResNet50模型
class ResNetblock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResNetblock, self).__init__()self.blockconv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=1),nn.BatchNorm2d(out_channels * 4))#区分出Convblockif stride != 1 or in_channels != out_channels * 4:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=stride),nn.BatchNorm2d(out_channels * 4))def forward(self, x):residual = xout = self.blockconv(x)if hasattr(self, 'shortcut'):  # 如果self中含有shortcut属性residual = self.shortcut(x)#F(x)+xout += residualout = F.relu(out)return out
#随便拿几组数测试输入输出
#1.模拟经过第一个残差组(STAGE3)的第一个Bottleblock
inputs = torch.zeros((8,512,28,28))
model = ResNetblock(512, 256,stride=2)
outputs = model(inputs)
print(inputs.shape)
print(outputs.shape)
torch.Size([8, 512, 28, 28])
torch.Size([8, 1024, 14, 14])
#2.模拟经过第二个残差组(STAGE2)的第一个Bottleblock,注意这里第一层卷积stride变为2
inputs = torch.zeros((8,256,56,56))
model = ResNetblock(256, 128, stride=2)
outputs = model(inputs)
print(inputs.shape)
print(outputs.shape)
torch.Size([8, 256, 56, 56])
torch.Size([8, 512, 28, 28])

ResNet-50整体构建:

  • 初始化卷积层:7×7的卷积层、BN层、ReLU激活函数层和MaxPool层;
  • 残差层:四个主要的残差组块,从self.layer1到self.layer4:使用make_layer方法创建四个主要残差层,每层输出通道数和残差块数不同;make_layer首先创建一个步幅列表strides,第一个元素为输入步幅,剩余元素为1,然后通过循环创建制定数量残差块并添加到layers列表中,最后返回包含所有残差块的顺序容器nn.Sequential
  • 全局平均池化层
  • 全连接层
class ResNet50(nn.Module):def __init__(self, block, num_classes):super(ResNet50,self).__init__()#进入残差组之前的卷积self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d((3,3),stride=2,padding=1))self.in_channels = 64##ResNet50中的四个残差组,每大组都是由不同数量的ConvBlock与IdentityBlock堆叠而成(每组第一层都是ConvBlock,后续全为Identity)self.layer1 = self.make_layer(ResNetblock, 64, 3, stride=1)self.layer2 = self.make_layer(ResNetblock, 128, 4, stride=2)self.layer3 = self.make_layer(ResNetblock, 256, 6, stride=2)self.layer4 = self.make_layer(ResNetblock, 512, 3, stride=2)self.avgpool = nn.AvgPool2d((7,7))self.fc = nn.Linear(512*4, len(classnames))# 每个大层的定义函数def make_layer(self, block, channels, num_blocks, stride=1):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, channels, stride))self.in_channels = channels * 4return nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
model=ResNet50(block=ResNetblock,num_classes=len(classnames)).to(device)
import torchsummary as ts
ts.summary(model,(3,224,224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0Conv2d-5           [-1, 64, 56, 56]           4,160BatchNorm2d-6           [-1, 64, 56, 56]             128ReLU-7           [-1, 64, 56, 56]               0Conv2d-8           [-1, 64, 56, 56]          36,928BatchNorm2d-9           [-1, 64, 56, 56]             128ReLU-10           [-1, 64, 56, 56]               0Conv2d-11          [-1, 256, 56, 56]          16,640BatchNorm2d-12          [-1, 256, 56, 56]             512Conv2d-13          [-1, 256, 56, 56]          16,640BatchNorm2d-14          [-1, 256, 56, 56]             512ResNetblock-15          [-1, 256, 56, 56]               0Conv2d-16           [-1, 64, 56, 56]          16,448BatchNorm2d-17           [-1, 64, 56, 56]             128ReLU-18           [-1, 64, 56, 56]               0Conv2d-19           [-1, 64, 56, 56]          36,928BatchNorm2d-20           [-1, 64, 56, 56]             128ReLU-21           [-1, 64, 56, 56]               0Conv2d-22          [-1, 256, 56, 56]          16,640BatchNorm2d-23          [-1, 256, 56, 56]             512ResNetblock-24          [-1, 256, 56, 56]               0Conv2d-25           [-1, 64, 56, 56]          16,448BatchNorm2d-26           [-1, 64, 56, 56]             128ReLU-27           [-1, 64, 56, 56]               0Conv2d-28           [-1, 64, 56, 56]          36,928BatchNorm2d-29           [-1, 64, 56, 56]             128ReLU-30           [-1, 64, 56, 56]               0Conv2d-31          [-1, 256, 56, 56]          16,640BatchNorm2d-32          [-1, 256, 56, 56]             512ResNetblock-33          [-1, 256, 56, 56]               0Conv2d-34          [-1, 128, 28, 28]          32,896BatchNorm2d-35          [-1, 128, 28, 28]             256ReLU-36          [-1, 128, 28, 28]               0Conv2d-37          [-1, 128, 28, 28]         147,584BatchNorm2d-38          [-1, 128, 28, 28]             256ReLU-39          [-1, 128, 28, 28]               0Conv2d-40          [-1, 512, 28, 28]          66,048BatchNorm2d-41          [-1, 512, 28, 28]           1,024Conv2d-42          [-1, 512, 28, 28]         131,584BatchNorm2d-43          [-1, 512, 28, 28]           1,024ResNetblock-44          [-1, 512, 28, 28]               0Conv2d-45          [-1, 128, 28, 28]          65,664BatchNorm2d-46          [-1, 128, 28, 28]             256ReLU-47          [-1, 128, 28, 28]               0Conv2d-48          [-1, 128, 28, 28]         147,584BatchNorm2d-49          [-1, 128, 28, 28]             256ReLU-50          [-1, 128, 28, 28]               0Conv2d-51          [-1, 512, 28, 28]          66,048BatchNorm2d-52          [-1, 512, 28, 28]           1,024ResNetblock-53          [-1, 512, 28, 28]               0Conv2d-54          [-1, 128, 28, 28]          65,664BatchNorm2d-55          [-1, 128, 28, 28]             256ReLU-56          [-1, 128, 28, 28]               0Conv2d-57          [-1, 128, 28, 28]         147,584BatchNorm2d-58          [-1, 128, 28, 28]             256ReLU-59          [-1, 128, 28, 28]               0Conv2d-60          [-1, 512, 28, 28]          66,048BatchNorm2d-61          [-1, 512, 28, 28]           1,024ResNetblock-62          [-1, 512, 28, 28]               0Conv2d-63          [-1, 128, 28, 28]          65,664BatchNorm2d-64          [-1, 128, 28, 28]             256ReLU-65          [-1, 128, 28, 28]               0Conv2d-66          [-1, 128, 28, 28]         147,584BatchNorm2d-67          [-1, 128, 28, 28]             256ReLU-68          [-1, 128, 28, 28]               0Conv2d-69          [-1, 512, 28, 28]          66,048BatchNorm2d-70          [-1, 512, 28, 28]           1,024ResNetblock-71          [-1, 512, 28, 28]               0Conv2d-72          [-1, 256, 14, 14]         131,328BatchNorm2d-73          [-1, 256, 14, 14]             512ReLU-74          [-1, 256, 14, 14]               0Conv2d-75          [-1, 256, 14, 14]         590,080BatchNorm2d-76          [-1, 256, 14, 14]             512ReLU-77          [-1, 256, 14, 14]               0Conv2d-78         [-1, 1024, 14, 14]         263,168BatchNorm2d-79         [-1, 1024, 14, 14]           2,048Conv2d-80         [-1, 1024, 14, 14]         525,312BatchNorm2d-81         [-1, 1024, 14, 14]           2,048ResNetblock-82         [-1, 1024, 14, 14]               0Conv2d-83          [-1, 256, 14, 14]         262,400BatchNorm2d-84          [-1, 256, 14, 14]             512ReLU-85          [-1, 256, 14, 14]               0Conv2d-86          [-1, 256, 14, 14]         590,080BatchNorm2d-87          [-1, 256, 14, 14]             512ReLU-88          [-1, 256, 14, 14]               0Conv2d-89         [-1, 1024, 14, 14]         263,168BatchNorm2d-90         [-1, 1024, 14, 14]           2,048ResNetblock-91         [-1, 1024, 14, 14]               0Conv2d-92          [-1, 256, 14, 14]         262,400BatchNorm2d-93          [-1, 256, 14, 14]             512ReLU-94          [-1, 256, 14, 14]               0Conv2d-95          [-1, 256, 14, 14]         590,080BatchNorm2d-96          [-1, 256, 14, 14]             512ReLU-97          [-1, 256, 14, 14]               0Conv2d-98         [-1, 1024, 14, 14]         263,168BatchNorm2d-99         [-1, 1024, 14, 14]           2,048ResNetblock-100         [-1, 1024, 14, 14]               0Conv2d-101          [-1, 256, 14, 14]         262,400BatchNorm2d-102          [-1, 256, 14, 14]             512ReLU-103          [-1, 256, 14, 14]               0Conv2d-104          [-1, 256, 14, 14]         590,080BatchNorm2d-105          [-1, 256, 14, 14]             512ReLU-106          [-1, 256, 14, 14]               0Conv2d-107         [-1, 1024, 14, 14]         263,168BatchNorm2d-108         [-1, 1024, 14, 14]           2,048ResNetblock-109         [-1, 1024, 14, 14]               0Conv2d-110          [-1, 256, 14, 14]         262,400BatchNorm2d-111          [-1, 256, 14, 14]             512ReLU-112          [-1, 256, 14, 14]               0Conv2d-113          [-1, 256, 14, 14]         590,080BatchNorm2d-114          [-1, 256, 14, 14]             512ReLU-115          [-1, 256, 14, 14]               0Conv2d-116         [-1, 1024, 14, 14]         263,168BatchNorm2d-117         [-1, 1024, 14, 14]           2,048ResNetblock-118         [-1, 1024, 14, 14]               0Conv2d-119          [-1, 256, 14, 14]         262,400BatchNorm2d-120          [-1, 256, 14, 14]             512ReLU-121          [-1, 256, 14, 14]               0Conv2d-122          [-1, 256, 14, 14]         590,080BatchNorm2d-123          [-1, 256, 14, 14]             512ReLU-124          [-1, 256, 14, 14]               0Conv2d-125         [-1, 1024, 14, 14]         263,168BatchNorm2d-126         [-1, 1024, 14, 14]           2,048ResNetblock-127         [-1, 1024, 14, 14]               0Conv2d-128            [-1, 512, 7, 7]         524,800BatchNorm2d-129            [-1, 512, 7, 7]           1,024ReLU-130            [-1, 512, 7, 7]               0Conv2d-131            [-1, 512, 7, 7]       2,359,808BatchNorm2d-132            [-1, 512, 7, 7]           1,024ReLU-133            [-1, 512, 7, 7]               0Conv2d-134           [-1, 2048, 7, 7]       1,050,624BatchNorm2d-135           [-1, 2048, 7, 7]           4,096Conv2d-136           [-1, 2048, 7, 7]       2,099,200BatchNorm2d-137           [-1, 2048, 7, 7]           4,096ResNetblock-138           [-1, 2048, 7, 7]               0Conv2d-139            [-1, 512, 7, 7]       1,049,088BatchNorm2d-140            [-1, 512, 7, 7]           1,024ReLU-141            [-1, 512, 7, 7]               0Conv2d-142            [-1, 512, 7, 7]       2,359,808BatchNorm2d-143            [-1, 512, 7, 7]           1,024ReLU-144            [-1, 512, 7, 7]               0Conv2d-145           [-1, 2048, 7, 7]       1,050,624BatchNorm2d-146           [-1, 2048, 7, 7]           4,096ResNetblock-147           [-1, 2048, 7, 7]               0Conv2d-148            [-1, 512, 7, 7]       1,049,088BatchNorm2d-149            [-1, 512, 7, 7]           1,024ReLU-150            [-1, 512, 7, 7]               0Conv2d-151            [-1, 512, 7, 7]       2,359,808BatchNorm2d-152            [-1, 512, 7, 7]           1,024ReLU-153            [-1, 512, 7, 7]               0Conv2d-154           [-1, 2048, 7, 7]       1,050,624BatchNorm2d-155           [-1, 2048, 7, 7]           4,096ResNetblock-156           [-1, 2048, 7, 7]               0AvgPool2d-157           [-1, 2048, 1, 1]               0Linear-158                    [-1, 4]           8,196
================================================================
Total params: 23,542,724
Trainable params: 23,542,724
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 232.38
Params size (MB): 89.81
Estimated Total Size (MB): 322.77
----------------------------------------------------------------
②残差块分类别构建
class ConvBlock(nn.Module):def __init__(self, kernel_size, input_channel, output_channel, hidden_channel, stride=2):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(input_channel, hidden_channel, 1, stride=stride), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(hidden_channel, hidden_channel, kernel_size, padding='same'), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv3 = nn.Sequential(nn.Conv2d(hidden_channel, output_channel, 1), nn.BatchNorm2d(output_channel))self.shortcut = nn.Sequential(nn.Conv2d(input_channel, output_channel,1, stride=stride), nn.BatchNorm2d(output_channel))self.relu = nn.ReLU()def forward(self, inputs):x = self.conv1(inputs)x = self.conv2(x)x = self.conv3(x)x = x + self.shortcut(inputs)x = self.relu(x)return x
class IdentityBlock(nn.Module):def __init__(self, kernel_size, input_channel, hidden_channel):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(input_channel, hidden_channel, 1), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(hidden_channel, hidden_channel, 3, padding='same'), nn.BatchNorm2d(hidden_channel), nn.ReLU())self.conv3 = nn.Sequential(nn.Conv2d(hidden_channel, input_channel, 1), nn.BatchNorm2d(input_channel))self.relu = nn.ReLU()def forward(self, inputs):x = self.conv1(inputs)x = self.conv2(x)x = self.conv3(x)x = inputs + xx = self.relu(x)return x
class ResNet50_2(nn.Module):def __init__(self, input_channel, classes):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(input_channel, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d((3,3),stride=2,padding=1))self.block1 = nn.Sequential(ConvBlock(3, 64, 256, 64, 1),IdentityBlock(3, 256, 64),IdentityBlock(3, 256, 64))self.block2 = nn.Sequential(ConvBlock(3, 256, 512, 128),IdentityBlock(3, 512, 128),IdentityBlock(3, 512, 128),IdentityBlock(3, 512, 128))self.block3 = nn.Sequential(ConvBlock(3, 512, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256),IdentityBlock(3, 1024, 256))self.block4 = nn.Sequential(ConvBlock(3, 1024, 2048, 512),IdentityBlock(3, 2048, 512),IdentityBlock(3, 2048, 512),)self.avgpool = nn.AvgPool2d(7)self.fc = nn.Linear(4*512, 4)self.softmax = nn.Softmax(dim = 1)def forward(self, inputs):x = self.conv1(inputs)x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)x = self.softmax(x)return x
model2=ResNet50_2(3,4).to(device)
ts.summary(model2,(3,224,224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,472BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0Conv2d-5           [-1, 64, 56, 56]           4,160BatchNorm2d-6           [-1, 64, 56, 56]             128ReLU-7           [-1, 64, 56, 56]               0Conv2d-8           [-1, 64, 56, 56]          36,928BatchNorm2d-9           [-1, 64, 56, 56]             128ReLU-10           [-1, 64, 56, 56]               0Conv2d-11          [-1, 256, 56, 56]          16,640BatchNorm2d-12          [-1, 256, 56, 56]             512Conv2d-13          [-1, 256, 56, 56]          16,640BatchNorm2d-14          [-1, 256, 56, 56]             512ReLU-15          [-1, 256, 56, 56]               0ConvBlock-16          [-1, 256, 56, 56]               0Conv2d-17           [-1, 64, 56, 56]          16,448BatchNorm2d-18           [-1, 64, 56, 56]             128ReLU-19           [-1, 64, 56, 56]               0Conv2d-20           [-1, 64, 56, 56]          36,928BatchNorm2d-21           [-1, 64, 56, 56]             128ReLU-22           [-1, 64, 56, 56]               0Conv2d-23          [-1, 256, 56, 56]          16,640BatchNorm2d-24          [-1, 256, 56, 56]             512ReLU-25          [-1, 256, 56, 56]               0IdentityBlock-26          [-1, 256, 56, 56]               0Conv2d-27           [-1, 64, 56, 56]          16,448BatchNorm2d-28           [-1, 64, 56, 56]             128ReLU-29           [-1, 64, 56, 56]               0Conv2d-30           [-1, 64, 56, 56]          36,928BatchNorm2d-31           [-1, 64, 56, 56]             128ReLU-32           [-1, 64, 56, 56]               0Conv2d-33          [-1, 256, 56, 56]          16,640BatchNorm2d-34          [-1, 256, 56, 56]             512ReLU-35          [-1, 256, 56, 56]               0IdentityBlock-36          [-1, 256, 56, 56]               0Conv2d-37          [-1, 128, 28, 28]          32,896BatchNorm2d-38          [-1, 128, 28, 28]             256ReLU-39          [-1, 128, 28, 28]               0Conv2d-40          [-1, 128, 28, 28]         147,584BatchNorm2d-41          [-1, 128, 28, 28]             256ReLU-42          [-1, 128, 28, 28]               0Conv2d-43          [-1, 512, 28, 28]          66,048BatchNorm2d-44          [-1, 512, 28, 28]           1,024Conv2d-45          [-1, 512, 28, 28]         131,584BatchNorm2d-46          [-1, 512, 28, 28]           1,024ReLU-47          [-1, 512, 28, 28]               0ConvBlock-48          [-1, 512, 28, 28]               0Conv2d-49          [-1, 128, 28, 28]          65,664BatchNorm2d-50          [-1, 128, 28, 28]             256ReLU-51          [-1, 128, 28, 28]               0Conv2d-52          [-1, 128, 28, 28]         147,584BatchNorm2d-53          [-1, 128, 28, 28]             256ReLU-54          [-1, 128, 28, 28]               0Conv2d-55          [-1, 512, 28, 28]          66,048BatchNorm2d-56          [-1, 512, 28, 28]           1,024ReLU-57          [-1, 512, 28, 28]               0IdentityBlock-58          [-1, 512, 28, 28]               0Conv2d-59          [-1, 128, 28, 28]          65,664BatchNorm2d-60          [-1, 128, 28, 28]             256ReLU-61          [-1, 128, 28, 28]               0Conv2d-62          [-1, 128, 28, 28]         147,584BatchNorm2d-63          [-1, 128, 28, 28]             256ReLU-64          [-1, 128, 28, 28]               0Conv2d-65          [-1, 512, 28, 28]          66,048BatchNorm2d-66          [-1, 512, 28, 28]           1,024ReLU-67          [-1, 512, 28, 28]               0IdentityBlock-68          [-1, 512, 28, 28]               0Conv2d-69          [-1, 128, 28, 28]          65,664BatchNorm2d-70          [-1, 128, 28, 28]             256ReLU-71          [-1, 128, 28, 28]               0Conv2d-72          [-1, 128, 28, 28]         147,584BatchNorm2d-73          [-1, 128, 28, 28]             256ReLU-74          [-1, 128, 28, 28]               0Conv2d-75          [-1, 512, 28, 28]          66,048BatchNorm2d-76          [-1, 512, 28, 28]           1,024ReLU-77          [-1, 512, 28, 28]               0IdentityBlock-78          [-1, 512, 28, 28]               0Conv2d-79          [-1, 256, 14, 14]         131,328BatchNorm2d-80          [-1, 256, 14, 14]             512ReLU-81          [-1, 256, 14, 14]               0Conv2d-82          [-1, 256, 14, 14]         590,080BatchNorm2d-83          [-1, 256, 14, 14]             512ReLU-84          [-1, 256, 14, 14]               0Conv2d-85         [-1, 1024, 14, 14]         263,168BatchNorm2d-86         [-1, 1024, 14, 14]           2,048Conv2d-87         [-1, 1024, 14, 14]         525,312BatchNorm2d-88         [-1, 1024, 14, 14]           2,048ReLU-89         [-1, 1024, 14, 14]               0ConvBlock-90         [-1, 1024, 14, 14]               0Conv2d-91          [-1, 256, 14, 14]         262,400BatchNorm2d-92          [-1, 256, 14, 14]             512ReLU-93          [-1, 256, 14, 14]               0Conv2d-94          [-1, 256, 14, 14]         590,080BatchNorm2d-95          [-1, 256, 14, 14]             512ReLU-96          [-1, 256, 14, 14]               0Conv2d-97         [-1, 1024, 14, 14]         263,168BatchNorm2d-98         [-1, 1024, 14, 14]           2,048ReLU-99         [-1, 1024, 14, 14]               0IdentityBlock-100         [-1, 1024, 14, 14]               0Conv2d-101          [-1, 256, 14, 14]         262,400BatchNorm2d-102          [-1, 256, 14, 14]             512ReLU-103          [-1, 256, 14, 14]               0Conv2d-104          [-1, 256, 14, 14]         590,080BatchNorm2d-105          [-1, 256, 14, 14]             512ReLU-106          [-1, 256, 14, 14]               0Conv2d-107         [-1, 1024, 14, 14]         263,168BatchNorm2d-108         [-1, 1024, 14, 14]           2,048ReLU-109         [-1, 1024, 14, 14]               0IdentityBlock-110         [-1, 1024, 14, 14]               0Conv2d-111          [-1, 256, 14, 14]         262,400BatchNorm2d-112          [-1, 256, 14, 14]             512ReLU-113          [-1, 256, 14, 14]               0Conv2d-114          [-1, 256, 14, 14]         590,080BatchNorm2d-115          [-1, 256, 14, 14]             512ReLU-116          [-1, 256, 14, 14]               0Conv2d-117         [-1, 1024, 14, 14]         263,168BatchNorm2d-118         [-1, 1024, 14, 14]           2,048ReLU-119         [-1, 1024, 14, 14]               0IdentityBlock-120         [-1, 1024, 14, 14]               0Conv2d-121          [-1, 256, 14, 14]         262,400BatchNorm2d-122          [-1, 256, 14, 14]             512ReLU-123          [-1, 256, 14, 14]               0Conv2d-124          [-1, 256, 14, 14]         590,080BatchNorm2d-125          [-1, 256, 14, 14]             512ReLU-126          [-1, 256, 14, 14]               0Conv2d-127         [-1, 1024, 14, 14]         263,168BatchNorm2d-128         [-1, 1024, 14, 14]           2,048ReLU-129         [-1, 1024, 14, 14]               0IdentityBlock-130         [-1, 1024, 14, 14]               0Conv2d-131          [-1, 256, 14, 14]         262,400BatchNorm2d-132          [-1, 256, 14, 14]             512ReLU-133          [-1, 256, 14, 14]               0Conv2d-134          [-1, 256, 14, 14]         590,080BatchNorm2d-135          [-1, 256, 14, 14]             512ReLU-136          [-1, 256, 14, 14]               0Conv2d-137         [-1, 1024, 14, 14]         263,168BatchNorm2d-138         [-1, 1024, 14, 14]           2,048ReLU-139         [-1, 1024, 14, 14]               0IdentityBlock-140         [-1, 1024, 14, 14]               0Conv2d-141            [-1, 512, 7, 7]         524,800BatchNorm2d-142            [-1, 512, 7, 7]           1,024ReLU-143            [-1, 512, 7, 7]               0Conv2d-144            [-1, 512, 7, 7]       2,359,808BatchNorm2d-145            [-1, 512, 7, 7]           1,024ReLU-146            [-1, 512, 7, 7]               0Conv2d-147           [-1, 2048, 7, 7]       1,050,624BatchNorm2d-148           [-1, 2048, 7, 7]           4,096Conv2d-149           [-1, 2048, 7, 7]       2,099,200BatchNorm2d-150           [-1, 2048, 7, 7]           4,096ReLU-151           [-1, 2048, 7, 7]               0ConvBlock-152           [-1, 2048, 7, 7]               0Conv2d-153            [-1, 512, 7, 7]       1,049,088BatchNorm2d-154            [-1, 512, 7, 7]           1,024ReLU-155            [-1, 512, 7, 7]               0Conv2d-156            [-1, 512, 7, 7]       2,359,808BatchNorm2d-157            [-1, 512, 7, 7]           1,024ReLU-158            [-1, 512, 7, 7]               0Conv2d-159           [-1, 2048, 7, 7]       1,050,624BatchNorm2d-160           [-1, 2048, 7, 7]           4,096ReLU-161           [-1, 2048, 7, 7]               0IdentityBlock-162           [-1, 2048, 7, 7]               0Conv2d-163            [-1, 512, 7, 7]       1,049,088BatchNorm2d-164            [-1, 512, 7, 7]           1,024ReLU-165            [-1, 512, 7, 7]               0Conv2d-166            [-1, 512, 7, 7]       2,359,808BatchNorm2d-167            [-1, 512, 7, 7]           1,024ReLU-168            [-1, 512, 7, 7]               0Conv2d-169           [-1, 2048, 7, 7]       1,050,624BatchNorm2d-170           [-1, 2048, 7, 7]           4,096ReLU-171           [-1, 2048, 7, 7]               0IdentityBlock-172           [-1, 2048, 7, 7]               0AvgPool2d-173           [-1, 2048, 1, 1]               0Linear-174                    [-1, 4]           8,196Softmax-175                    [-1, 4]               0
================================================================
Total params: 23,542,788
Trainable params: 23,542,788
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 274.49
Params size (MB): 89.81
Estimated Total Size (MB): 364.88
----------------------------------------------------------------
5、编写训练和测试函数
'''
编写训练函数
'''
def train(dataloader, model, optimizer, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_acc, train_loss
'''
编写测试函数
'''
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
6、正式训练
'''
设置损失函数和学习率
'''
loss_fn = nn.CrossEntropyLoss()   #交叉熵函数
learn_rate = 1e-5
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.98)
'''
正式训练
'''
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0# 开始训练
for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, optimizer, loss_fn)scheduler.step()model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))print('Done')
Epoch: 1, Train_acc:31.6%, Train_loss:1.404, Test_acc:51.3%, Test_loss:1.302, Lr:1.00E-05
Epoch: 2, Train_acc:49.1%, Train_loss:1.249, Test_acc:50.4%, Test_loss:1.231, Lr:1.00E-05
Epoch: 3, Train_acc:58.4%, Train_loss:1.132, Test_acc:65.5%, Test_loss:1.144, Lr:1.00E-05
Epoch: 4, Train_acc:71.0%, Train_loss:0.920, Test_acc:58.4%, Test_loss:1.009, Lr:1.00E-05
Epoch: 5, Train_acc:70.4%, Train_loss:0.773, Test_acc:69.9%, Test_loss:0.837, Lr:9.80E-06
Epoch: 6, Train_acc:75.7%, Train_loss:0.666, Test_acc:73.5%, Test_loss:0.840, Lr:9.80E-06
Epoch: 7, Train_acc:72.3%, Train_loss:0.697, Test_acc:57.5%, Test_loss:0.998, Lr:9.80E-06
Epoch: 8, Train_acc:79.2%, Train_loss:0.541, Test_acc:66.4%, Test_loss:0.762, Lr:9.80E-06
Epoch: 9, Train_acc:83.0%, Train_loss:0.504, Test_acc:69.9%, Test_loss:0.721, Lr:9.80E-06
Epoch:10, Train_acc:85.4%, Train_loss:0.428, Test_acc:78.8%, Test_loss:0.594, Lr:9.60E-06
Epoch:11, Train_acc:82.7%, Train_loss:0.483, Test_acc:65.5%, Test_loss:0.971, Lr:9.60E-06
Epoch:12, Train_acc:86.1%, Train_loss:0.388, Test_acc:78.8%, Test_loss:0.617, Lr:9.60E-06
Epoch:13, Train_acc:85.8%, Train_loss:0.390, Test_acc:63.7%, Test_loss:1.206, Lr:9.60E-06
Epoch:14, Train_acc:88.1%, Train_loss:0.343, Test_acc:77.9%, Test_loss:0.695, Lr:9.60E-06
Epoch:15, Train_acc:89.6%, Train_loss:0.333, Test_acc:74.3%, Test_loss:0.658, Lr:9.41E-06
Epoch:16, Train_acc:89.8%, Train_loss:0.286, Test_acc:77.0%, Test_loss:0.603, Lr:9.41E-06
Epoch:17, Train_acc:89.4%, Train_loss:0.297, Test_acc:76.1%, Test_loss:0.680, Lr:9.41E-06
Epoch:18, Train_acc:88.9%, Train_loss:0.304, Test_acc:77.0%, Test_loss:0.879, Lr:9.41E-06
Epoch:19, Train_acc:92.9%, Train_loss:0.249, Test_acc:77.0%, Test_loss:0.708, Lr:9.41E-06
Epoch:20, Train_acc:92.3%, Train_loss:0.239, Test_acc:75.2%, Test_loss:0.760, Lr:9.22E-06
Epoch:21, Train_acc:90.7%, Train_loss:0.236, Test_acc:70.8%, Test_loss:1.053, Lr:9.22E-06
Epoch:22, Train_acc:93.4%, Train_loss:0.189, Test_acc:81.4%, Test_loss:0.521, Lr:9.22E-06
Epoch:23, Train_acc:93.6%, Train_loss:0.179, Test_acc:77.9%, Test_loss:0.749, Lr:9.22E-06
Epoch:24, Train_acc:96.7%, Train_loss:0.151, Test_acc:76.1%, Test_loss:0.539, Lr:9.22E-06
Epoch:25, Train_acc:93.1%, Train_loss:0.196, Test_acc:73.5%, Test_loss:0.693, Lr:9.04E-06
Epoch:26, Train_acc:93.8%, Train_loss:0.192, Test_acc:82.3%, Test_loss:0.539, Lr:9.04E-06
Epoch:27, Train_acc:95.8%, Train_loss:0.154, Test_acc:77.0%, Test_loss:0.660, Lr:9.04E-06
Epoch:28, Train_acc:94.9%, Train_loss:0.150, Test_acc:78.8%, Test_loss:0.557, Lr:9.04E-06
Epoch:29, Train_acc:95.8%, Train_loss:0.139, Test_acc:81.4%, Test_loss:0.628, Lr:9.04E-06
Epoch:30, Train_acc:95.1%, Train_loss:0.154, Test_acc:80.5%, Test_loss:1.146, Lr:8.86E-06
Epoch:31, Train_acc:93.6%, Train_loss:0.144, Test_acc:67.3%, Test_loss:1.120, Lr:8.86E-06
Epoch:32, Train_acc:94.9%, Train_loss:0.145, Test_acc:73.5%, Test_loss:0.674, Lr:8.86E-06
Epoch:33, Train_acc:95.1%, Train_loss:0.139, Test_acc:77.0%, Test_loss:0.754, Lr:8.86E-06
Epoch:34, Train_acc:95.8%, Train_loss:0.134, Test_acc:74.3%, Test_loss:0.787, Lr:8.86E-06
Epoch:35, Train_acc:97.1%, Train_loss:0.096, Test_acc:77.9%, Test_loss:0.610, Lr:8.68E-06
Epoch:36, Train_acc:97.6%, Train_loss:0.077, Test_acc:80.5%, Test_loss:0.693, Lr:8.68E-06
Epoch:37, Train_acc:95.6%, Train_loss:0.121, Test_acc:82.3%, Test_loss:0.639, Lr:8.68E-06
Epoch:38, Train_acc:98.2%, Train_loss:0.078, Test_acc:72.6%, Test_loss:0.711, Lr:8.68E-06
Epoch:39, Train_acc:98.2%, Train_loss:0.083, Test_acc:72.6%, Test_loss:0.822, Lr:8.68E-06
Epoch:40, Train_acc:96.5%, Train_loss:0.126, Test_acc:77.9%, Test_loss:1.102, Lr:8.51E-06
Epoch:41, Train_acc:96.9%, Train_loss:0.135, Test_acc:85.0%, Test_loss:0.556, Lr:8.51E-06
Epoch:42, Train_acc:95.8%, Train_loss:0.150, Test_acc:71.7%, Test_loss:1.173, Lr:8.51E-06
Epoch:43, Train_acc:95.1%, Train_loss:0.120, Test_acc:80.5%, Test_loss:0.793, Lr:8.51E-06
Epoch:44, Train_acc:96.2%, Train_loss:0.121, Test_acc:77.0%, Test_loss:0.762, Lr:8.51E-06
Epoch:45, Train_acc:97.8%, Train_loss:0.089, Test_acc:80.5%, Test_loss:0.600, Lr:8.34E-06
Epoch:46, Train_acc:98.2%, Train_loss:0.080, Test_acc:77.9%, Test_loss:0.626, Lr:8.34E-06
Epoch:47, Train_acc:96.9%, Train_loss:0.078, Test_acc:80.5%, Test_loss:0.706, Lr:8.34E-06
Epoch:48, Train_acc:98.5%, Train_loss:0.063, Test_acc:77.0%, Test_loss:0.831, Lr:8.34E-06
Epoch:49, Train_acc:98.0%, Train_loss:0.063, Test_acc:77.9%, Test_loss:0.723, Lr:8.34E-06
Epoch:50, Train_acc:97.1%, Train_loss:0.090, Test_acc:84.1%, Test_loss:0.597, Lr:8.17E-06
Done
7、模型评估
# 结果可视化
import warnings
warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['axes.unicode_minus'] = False    #用来正常显示负号
plt.rcParams['figure.dpi'] = 100              #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

8、预测
classes = list(total_data.class_to_idx)
# 定义反归一化操作
inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],std=[1/0.229, 1/0.224, 1/0.225]
)for images, _ in test_dl:break  # 获取一个batch的图片# 进行预测
with torch.no_grad():outputs = best_model(images.to(device))_, preds = torch.max(outputs, 1)# 可视化
plt.figure(figsize=(10, 4))
for i in range(8):plt.subplot(2, 4, i + 1)img = images[i].cpu()img = inv_normalize(img)  # 应用反归一化img = img.permute(1, 2, 0)  # 转为 [H, W, C] 格式plt.imshow(img)plt.title(classes[preds[i].item()])plt.axis('off')plt.show()

在这里插入图片描述


http://www.mrgr.cn/news/55521.html

相关文章:

  • 结构化分析与设计(绪论)
  • 2024年9月电子学会青少年软件编程Python等级考试(一级)真题试卷
  • 学习eNSP对提升就业竞争力有多大帮助?
  • 【算法题】数组中只出现一次的两个数字
  • Android应用性能优化的方法
  • 正在等待缓存锁:无法获得锁 /var/lib/dpkg/lock-frontend。锁正由进程 5427(unattended-upgr)持有
  • webpack 老项目升级记录:node-sass 规定的 node v8 提升至支持 node v22
  • Selenium自动化测试全攻略:从入门到精通
  • Anchor DETR论文笔记
  • Telink 2.4G proprietary protocol 泰凌2.4G私有协议
  • Windows下安装并使用 NVM(Node Version Manager)
  • 材料研究与应用
  • 高级sql技巧
  • git配置以及如何删除git
  • Python包---numpy1
  • unix系统的终端、进程、进程组、会话、控制终端、作业控制之间的关系
  • Python内置函数classmethod()详解
  • 有没有好用的待办事项清单软件? —— 一文带你了解
  • 企业成本与时间管理新策略 低代码自动化显身手
  • 《深度学习》模型的部署、web框架 服务端及客户端案例
  • 提升小学语文教学效果的思维导图方法
  • 完爆YOLOv10!Transformer+目标检测新算法性能无敌,狠狠拿捏CV顶会!
  • HTML 实例/测验之HTML 基础一口气讲完!(o-ωq)).oO 困
  • 《Frida Android SO逆向深入实践》书评——清华大学出版社
  • Electron兼容win7版本的打包流程
  • 周报 | 24.10.14-24.10.20文章汇总