当前位置: 首页 > news >正文

第J8周:Inception v1算法实战与解析

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

📌 本周任务:
1了解并学习图2中的卷积层运算量的计算过程(🏐储备知识->卷积层运算量的计算,有我的推导过程,建议先自己手动推导,然后再看)
2了解并学习卷积层的并行结构与1x1卷积核部分内容(重点)
3尝试根据模型框架图写入相应的pytorch代码,并使用Inception v1完成猴痘病识别

🏡 我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Notebook
  • 深度学习环境:Pytorch
    • torch==2.3.1+cu118
    • torchvision==0.18.1+cu118

一、Inception v1

Inception v1 论文:Going deeper with convolutions

1. 理论知识

GoogLeNet首次出现在2014年ILSVRC 比赛中获得冠军。这次的版本通常称其为Inception V1。Inception V1有22层深,参数量为5M。同一时期的VGGNet性能和Inception V1差不多,但是参数量也是远大于Inception V1。

Inception Module是Inception V1的核心组成单元,提出了卷积层的并行结构,实现了在同一层就可以提取不同的特征,如下图。

按照这样的结构来增加网络的深度,虽然可以提升性能,但是还面临计算量大(参数多)的问题。为改善这种现象,Inception Module借鉴Network-in-Network的思想,使用1x1的卷积核实现降维操作(也间接增加了网络的深度),以此来减小网络的参数量与计算量,如上图b所示。

备注举例:假如前一层的输出为100x100x128,经过具有256个5x5卷积核的卷积层之后(stride=1,pad=2),输出数据为100x100x256。其中,卷积层的参数为5x5x128x256+256。假如上一层输出先经过具有32个1x1卷积核的卷积层(1x1卷积降低了通道数,且特征图尺寸不变),再经过具有256个5x5卷积核的卷积层,最终的输出数据仍为为100x100x256,但卷积参数量已经减少为128x1x1x32+32 + 32x5x5x256+256,参数数量减少为原来的约4分之一。其计算量由原先的8.192\times 10^{9} ,降低至 2.048\times 10^{9} ,更详细的计算过程可参考我训练营内发布的“卷积层计算量的计算”一文。

1x1卷积核的作用: 1x1卷积核的最大作用是降低输入特征图的通道数,减小网络的参数量与计算量。

最后Inception Module基本由11卷积,33卷积,55卷积,33最大池化四个基本单元组成,对四个基本单元运算结果进行通道上组合,不同大小的卷积核赋予不同大小的感受野,从而提取到图像不同尺度的信息,进行融合,得到图像更好的表征, 就是Inception Module的核心思想。

2. 算法结构

实现的Inception v1网络结构图如下:

注:另外增加了两个辅助分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用,实际测试时,这两个辅助softmax分支会被去掉,在后续模型的发展中,该方法被采用较少,可以直接绕过,重点学习卷积层的并行结构与1x1卷积核部分的内容即可

详细网络结构图如下:

二、 前期准备

1. 设置GPU

如果设备上支持GPU就使用GPU,否则使用CPU

import warnings
warnings.filterwarnings("ignore")import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device

运行结果:

device(type='cuda')

2. 导入数据

同时查看数据集中图片的数量

import pathlibdata_dir=r'D:\THE MNIST DATABASE\P4-data'
data_dir=pathlib.Path(data_dir)image_count=len(list(data_dir.glob('*/*')))
image_count

运行结果:

2142

3. 查看数据集分类

data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[3] for path in data_paths]
classNames

运行结果:

['Monkeypox', 'Others']

4. 随机查看图片

随机抽取数据集中的20张图片进行查看

import PIL,random
import matplotlib.pyplot as plt
from PIL import Imagedata_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):plt.subplot(2,10,i+1)plt.axis('off')image=random.choice(data_paths2) #随机选择一个图片plt.title(image.parts[-2])  #通过glob对象取出他的文件夹名称,即分类名plt.imshow(Image.open(str(image)))  #显示图片

运行结果:

5. 图片预处理  

import torchvision.transforms as transforms
from torchvision import transforms,datasetstrain_transforms=transforms.Compose([transforms.Resize([224,224]), #将图片统一尺寸transforms.RandomHorizontalFlip(), #将图片随机水平翻转transforms.ToTensor(),  #将图片转换为tensortransforms.Normalize(  #标准化处理-->转换为正态分布,使模型更容易收敛mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])total_data=datasets.ImageFolder(r'D:\THE MNIST DATABASE\P4-data',transform=train_transforms
)
total_data

运行结果:

Dataset ImageFolderNumber of datapoints: 2142Root location: D:\THE MNIST DATABASE\P4-dataStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)RandomHorizontalFlip(p=0.5)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))

将数据集分类情况进行映射输出:

total_data.class_to_idx

运行结果:

{'Monkeypox': 0, 'Others': 1}

6. 划分数据集

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_sizetrain_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size]
)
train_dataset,test_dataset

运行结果:

(<torch.utils.data.dataset.Subset at 0x241dec0e950>,<torch.utils.data.dataset.Subset at 0x241deef32d0>)

查看训练集和测试集的数据数量:

train_size,test_size

运行结果:

(1713, 429)

7. 加载数据集

batch_size=16
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)

查看测试集的情况:

for x,y in train_dl:print("Shape of x [N,C,H,W]:",x.shape)print("Shape of y:",y.shape,y.dtype)break

运行结果:

Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64

二、pytorch环境下复现 Inception V1 模型

参考上方表格中的网络架构和参数搭建 Inception V1 网络模型

1. 构建模型

这里去掉了两个辅助分支,直接复现主支。

定义一个名为 Inception 的类,继承自 nn.Module。inception_block 类包含了 Inception v1 模型的所有层和参数。

import torch.nn as nn
import torch.nn.functional as Fclass inception_block(nn.Module):def __init__(self,in_channels,ch1x1,ch3x3red,ch3x3,ch5x5red,ch5x5,pool_proj):super(inception_block,self).__init__()# 1x1 conv branchself.branch1=nn.Sequential(nn.Conv2d(in_channels,ch1x1,kernel_size=1),nn.BatchNorm2d(ch1x1),nn.ReLU(inplace=True))# 1x1 conv -> 3x3 conv branchself.branch2=nn.Sequential(nn.Conv2d(in_channels,ch3x3red,kernel_size=1),nn.BatchNorm2d(ch3x3red),nn.ReLU(inplace=True),nn.Conv2d(ch3x3red,ch3x3,kernel_size=3,padding=1),nn.BatchNorm2d(ch3x3),nn.ReLU(inplace=True))# 1x1 conv -> 5x5 conv branchself.branch3=nn.Sequential(nn.Conv2d(in_channels,ch5x5red,kernel_size=1),nn.BatchNorm2d(ch5x5red),nn.ReLU(inplace=True),nn.Conv2d(ch5x5red,ch5x5,kernel_size=5,padding=2),nn.BatchNorm2d(ch5x5),nn.ReLU(inplace=True))# 3x3 max pooling -> 1x1 conv branchself.branch4=nn.Sequential(nn.MaxPool2d(kernel_size=3,stride=1,padding=1),nn.Conv2d(in_channels,pool_proj,kernel_size=1),nn.BatchNorm2d(pool_proj),nn.ReLU(inplace=True))def forward(self,x):#compute forward pass through all branches and concatenate the output feature mapsbranch1_output=self.branch1(x)branch2_output=self.branch2(x)branch3_output=self.branch3(x)branch4_output=self.branch4(x)outputs=[branch1_output,branch2_output,branch3_output,branch4_output]return torch.cat(outputs,1)

__init__方法中,定义了四个分支,分别是:

(1)branch1,一个 1x1 卷积层;
(2)branch2,一个 1x1 卷积层接一个 3x3 卷积层;
(3)branch3,一个 1x1 卷积层接一个 5x5 卷积层;
(4)branch4,一个 3x3 最大池化层接一个 1x1 卷积层。
每个分支都包含了一些卷积层、批归一化层和激活函数。这些层都是 PyTorch 中的标准层,我们可以使用 nn.Conv2d、nn.BatchNorm2dnn.ReLU 分别定义卷积层、批归一化层和 ReLU 激活函数。

forward 方法中,我们计算从输入到所有分支的前向传递,并将所有分支的输出特征图拼接在一起。最后,我们返回拼接后的特征图。

接下来,我们定义 Inception v1 模型,使用 nn.ModuleListnn.Sequential 组合多个 Inception 模块和其他层。

class InceptionV1(nn.Module):def __init__(self,num_classes=1000):super(InceptionV1,self).__init__()self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3)self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.conv2=nn.Conv2d(64,64,kernel_size=1,stride=1,padding=0)self.conv3=nn.Conv2d(64,192,kernel_size=3,stride=1,padding=1)self.maxpool2=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.inception3a=inception_block(192,64,96,128,16,32,32)self.inception3b=inception_block(256,128,128,192,32,96,64)self.maxpool3=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.inception4a=inception_block(480,192,96,208,16,48,64)self.inception4b=inception_block(512,160,112,224,24,64,64)self.inception4c=inception_block(512,128,128,256,24,64,64)self.inception4d=inception_block(512,112,114,288,32,64,64)self.inception4e=inception_block(528,256,160,320,32,128,128)self.maxpool4=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.inception5a=inception_block(832,256,160,320,32,128,128)self.inception5b=nn.Sequential(inception_block(832,384,192,384,48,128,128),nn.AvgPool2d(kernel_size=7,stride=1,padding=0),nn.Dropout(0.4))#全连接网络层,用于分类self.classifier=nn.Sequential(nn.Linear(in_features=1024,out_features=1024),nn.ReLU(),nn.Linear(in_features=1024,out_features=num_classes),nn.Softmax(dim=1))def forward(self,x):x=self.conv1(x)x=F.relu(x)x=self.maxpool1(x)x=self.conv2(x)x=F.relu(x)x=self.conv3(x)x=F.relu(x)x=self.maxpool2(x)x=self.inception3a(x)x=self.inception3b(x)x=self.maxpool3(x)x=self.inception4a(x)x=self.inception4b(x)x=self.inception4c(x)x=self.inception4d(x)x=self.inception4e(x)x=self.maxpool4(x)x=self.inception5a(x)x=self.inception5b(x)x=torch.flatten(x,start_dim=1)x=self.classifier(x)return x

2. 输出模型结果

#统计模型参数量以及其他指标
import torchsummary#调用并将模型转移到GPU中
model=InceptionV1(num_classes=2).to(device)#显示网络结构
torchsummary.summary(model,(3,224,224))
print(model)

运行结果:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,472MaxPool2d-2           [-1, 64, 56, 56]               0Conv2d-3           [-1, 64, 56, 56]           4,160Conv2d-4          [-1, 192, 56, 56]         110,784MaxPool2d-5          [-1, 192, 28, 28]               0Conv2d-6           [-1, 64, 28, 28]          12,352BatchNorm2d-7           [-1, 64, 28, 28]             128ReLU-8           [-1, 64, 28, 28]               0Conv2d-9           [-1, 96, 28, 28]          18,528BatchNorm2d-10           [-1, 96, 28, 28]             192ReLU-11           [-1, 96, 28, 28]               0Conv2d-12          [-1, 128, 28, 28]         110,720BatchNorm2d-13          [-1, 128, 28, 28]             256ReLU-14          [-1, 128, 28, 28]               0Conv2d-15           [-1, 16, 28, 28]           3,088BatchNorm2d-16           [-1, 16, 28, 28]              32ReLU-17           [-1, 16, 28, 28]               0Conv2d-18           [-1, 32, 28, 28]          12,832BatchNorm2d-19           [-1, 32, 28, 28]              64ReLU-20           [-1, 32, 28, 28]               0MaxPool2d-21          [-1, 192, 28, 28]               0Conv2d-22           [-1, 32, 28, 28]           6,176BatchNorm2d-23           [-1, 32, 28, 28]              64ReLU-24           [-1, 32, 28, 28]               0inception_block-25          [-1, 256, 28, 28]               0Conv2d-26          [-1, 128, 28, 28]          32,896BatchNorm2d-27          [-1, 128, 28, 28]             256ReLU-28          [-1, 128, 28, 28]               0Conv2d-29          [-1, 128, 28, 28]          32,896BatchNorm2d-30          [-1, 128, 28, 28]             256ReLU-31          [-1, 128, 28, 28]               0Conv2d-32          [-1, 192, 28, 28]         221,376BatchNorm2d-33          [-1, 192, 28, 28]             384ReLU-34          [-1, 192, 28, 28]               0Conv2d-35           [-1, 32, 28, 28]           8,224BatchNorm2d-36           [-1, 32, 28, 28]              64ReLU-37           [-1, 32, 28, 28]               0Conv2d-38           [-1, 96, 28, 28]          76,896BatchNorm2d-39           [-1, 96, 28, 28]             192ReLU-40           [-1, 96, 28, 28]               0MaxPool2d-41          [-1, 256, 28, 28]               0Conv2d-42           [-1, 64, 28, 28]          16,448BatchNorm2d-43           [-1, 64, 28, 28]             128ReLU-44           [-1, 64, 28, 28]               0inception_block-45          [-1, 480, 28, 28]               0MaxPool2d-46          [-1, 480, 14, 14]               0Conv2d-47          [-1, 192, 14, 14]          92,352BatchNorm2d-48          [-1, 192, 14, 14]             384ReLU-49          [-1, 192, 14, 14]               0Conv2d-50           [-1, 96, 14, 14]          46,176BatchNorm2d-51           [-1, 96, 14, 14]             192ReLU-52           [-1, 96, 14, 14]               0Conv2d-53          [-1, 208, 14, 14]         179,920BatchNorm2d-54          [-1, 208, 14, 14]             416ReLU-55          [-1, 208, 14, 14]               0Conv2d-56           [-1, 16, 14, 14]           7,696BatchNorm2d-57           [-1, 16, 14, 14]              32ReLU-58           [-1, 16, 14, 14]               0Conv2d-59           [-1, 48, 14, 14]          19,248BatchNorm2d-60           [-1, 48, 14, 14]              96ReLU-61           [-1, 48, 14, 14]               0MaxPool2d-62          [-1, 480, 14, 14]               0Conv2d-63           [-1, 64, 14, 14]          30,784BatchNorm2d-64           [-1, 64, 14, 14]             128ReLU-65           [-1, 64, 14, 14]               0inception_block-66          [-1, 512, 14, 14]               0Conv2d-67          [-1, 160, 14, 14]          82,080BatchNorm2d-68          [-1, 160, 14, 14]             320ReLU-69          [-1, 160, 14, 14]               0Conv2d-70          [-1, 112, 14, 14]          57,456BatchNorm2d-71          [-1, 112, 14, 14]             224ReLU-72          [-1, 112, 14, 14]               0Conv2d-73          [-1, 224, 14, 14]         226,016BatchNorm2d-74          [-1, 224, 14, 14]             448ReLU-75          [-1, 224, 14, 14]               0Conv2d-76           [-1, 24, 14, 14]          12,312BatchNorm2d-77           [-1, 24, 14, 14]              48ReLU-78           [-1, 24, 14, 14]               0Conv2d-79           [-1, 64, 14, 14]          38,464BatchNorm2d-80           [-1, 64, 14, 14]             128ReLU-81           [-1, 64, 14, 14]               0MaxPool2d-82          [-1, 512, 14, 14]               0Conv2d-83           [-1, 64, 14, 14]          32,832BatchNorm2d-84           [-1, 64, 14, 14]             128ReLU-85           [-1, 64, 14, 14]               0inception_block-86          [-1, 512, 14, 14]               0Conv2d-87          [-1, 128, 14, 14]          65,664BatchNorm2d-88          [-1, 128, 14, 14]             256ReLU-89          [-1, 128, 14, 14]               0Conv2d-90          [-1, 128, 14, 14]          65,664BatchNorm2d-91          [-1, 128, 14, 14]             256ReLU-92          [-1, 128, 14, 14]               0Conv2d-93          [-1, 256, 14, 14]         295,168BatchNorm2d-94          [-1, 256, 14, 14]             512ReLU-95          [-1, 256, 14, 14]               0Conv2d-96           [-1, 24, 14, 14]          12,312BatchNorm2d-97           [-1, 24, 14, 14]              48ReLU-98           [-1, 24, 14, 14]               0Conv2d-99           [-1, 64, 14, 14]          38,464BatchNorm2d-100           [-1, 64, 14, 14]             128ReLU-101           [-1, 64, 14, 14]               0MaxPool2d-102          [-1, 512, 14, 14]               0Conv2d-103           [-1, 64, 14, 14]          32,832BatchNorm2d-104           [-1, 64, 14, 14]             128ReLU-105           [-1, 64, 14, 14]               0inception_block-106          [-1, 512, 14, 14]               0Conv2d-107          [-1, 112, 14, 14]          57,456BatchNorm2d-108          [-1, 112, 14, 14]             224ReLU-109          [-1, 112, 14, 14]               0Conv2d-110          [-1, 114, 14, 14]          58,482BatchNorm2d-111          [-1, 114, 14, 14]             228ReLU-112          [-1, 114, 14, 14]               0Conv2d-113          [-1, 288, 14, 14]         295,776BatchNorm2d-114          [-1, 288, 14, 14]             576ReLU-115          [-1, 288, 14, 14]               0Conv2d-116           [-1, 32, 14, 14]          16,416BatchNorm2d-117           [-1, 32, 14, 14]              64ReLU-118           [-1, 32, 14, 14]               0Conv2d-119           [-1, 64, 14, 14]          51,264BatchNorm2d-120           [-1, 64, 14, 14]             128ReLU-121           [-1, 64, 14, 14]               0MaxPool2d-122          [-1, 512, 14, 14]               0Conv2d-123           [-1, 64, 14, 14]          32,832BatchNorm2d-124           [-1, 64, 14, 14]             128ReLU-125           [-1, 64, 14, 14]               0inception_block-126          [-1, 528, 14, 14]               0Conv2d-127          [-1, 256, 14, 14]         135,424BatchNorm2d-128          [-1, 256, 14, 14]             512ReLU-129          [-1, 256, 14, 14]               0Conv2d-130          [-1, 160, 14, 14]          84,640BatchNorm2d-131          [-1, 160, 14, 14]             320ReLU-132          [-1, 160, 14, 14]               0Conv2d-133          [-1, 320, 14, 14]         461,120BatchNorm2d-134          [-1, 320, 14, 14]             640ReLU-135          [-1, 320, 14, 14]               0Conv2d-136           [-1, 32, 14, 14]          16,928BatchNorm2d-137           [-1, 32, 14, 14]              64ReLU-138           [-1, 32, 14, 14]               0Conv2d-139          [-1, 128, 14, 14]         102,528BatchNorm2d-140          [-1, 128, 14, 14]             256ReLU-141          [-1, 128, 14, 14]               0MaxPool2d-142          [-1, 528, 14, 14]               0Conv2d-143          [-1, 128, 14, 14]          67,712BatchNorm2d-144          [-1, 128, 14, 14]             256ReLU-145          [-1, 128, 14, 14]               0inception_block-146          [-1, 832, 14, 14]               0MaxPool2d-147            [-1, 832, 7, 7]               0Conv2d-148            [-1, 256, 7, 7]         213,248BatchNorm2d-149            [-1, 256, 7, 7]             512ReLU-150            [-1, 256, 7, 7]               0Conv2d-151            [-1, 160, 7, 7]         133,280BatchNorm2d-152            [-1, 160, 7, 7]             320ReLU-153            [-1, 160, 7, 7]               0Conv2d-154            [-1, 320, 7, 7]         461,120BatchNorm2d-155            [-1, 320, 7, 7]             640ReLU-156            [-1, 320, 7, 7]               0Conv2d-157             [-1, 32, 7, 7]          26,656BatchNorm2d-158             [-1, 32, 7, 7]              64ReLU-159             [-1, 32, 7, 7]               0Conv2d-160            [-1, 128, 7, 7]         102,528BatchNorm2d-161            [-1, 128, 7, 7]             256ReLU-162            [-1, 128, 7, 7]               0MaxPool2d-163            [-1, 832, 7, 7]               0Conv2d-164            [-1, 128, 7, 7]         106,624BatchNorm2d-165            [-1, 128, 7, 7]             256ReLU-166            [-1, 128, 7, 7]               0inception_block-167            [-1, 832, 7, 7]               0Conv2d-168            [-1, 384, 7, 7]         319,872BatchNorm2d-169            [-1, 384, 7, 7]             768ReLU-170            [-1, 384, 7, 7]               0Conv2d-171            [-1, 192, 7, 7]         159,936BatchNorm2d-172            [-1, 192, 7, 7]             384ReLU-173            [-1, 192, 7, 7]               0Conv2d-174            [-1, 384, 7, 7]         663,936BatchNorm2d-175            [-1, 384, 7, 7]             768ReLU-176            [-1, 384, 7, 7]               0Conv2d-177             [-1, 48, 7, 7]          39,984BatchNorm2d-178             [-1, 48, 7, 7]              96ReLU-179             [-1, 48, 7, 7]               0Conv2d-180            [-1, 128, 7, 7]         153,728BatchNorm2d-181            [-1, 128, 7, 7]             256ReLU-182            [-1, 128, 7, 7]               0MaxPool2d-183            [-1, 832, 7, 7]               0Conv2d-184            [-1, 128, 7, 7]         106,624BatchNorm2d-185            [-1, 128, 7, 7]             256ReLU-186            [-1, 128, 7, 7]               0inception_block-187           [-1, 1024, 7, 7]               0AvgPool2d-188           [-1, 1024, 1, 1]               0Dropout-189           [-1, 1024, 1, 1]               0Linear-190                 [-1, 1024]       1,049,600ReLU-191                 [-1, 1024]               0Linear-192                    [-1, 2]           2,050Softmax-193                    [-1, 2]               0
================================================================
Total params: 6,945,912
Trainable params: 6,945,912
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 69.48
Params size (MB): 26.50
Estimated Total Size (MB): 96.55
----------------------------------------------------------------
InceptionV1((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(conv3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(inception3a): inception_block((branch1): Sequential((0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(inception3b): inception_block((branch1): Sequential((0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(maxpool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(inception4a): inception_block((branch1): Sequential((0): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(inception4b): inception_block((branch1): Sequential((0): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(inception4c): inception_block((branch1): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(inception4d): inception_block((branch1): Sequential((0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(512, 114, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(114, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(114, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(inception4e): inception_block((branch1): Sequential((0): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(maxpool4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(inception5a): inception_block((branch1): Sequential((0): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(inception5b): Sequential((0): inception_block((branch1): Sequential((0): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch3): Sequential((0): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True))(branch4): Sequential((0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)(1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(3): ReLU(inplace=True)))(1): AvgPool2d(kernel_size=7, stride=1, padding=0)(2): Dropout(p=0.4, inplace=False))(classifier): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): ReLU()(2): Linear(in_features=1024, out_features=2, bias=True)(3): Softmax(dim=1))
)

三、 训练模型

1. 编写训练函数

def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)  #训练集的大小num_batches=len(dataloader)  #批次数目train_loss,train_acc=0,0  #初始化训练损失和正确率for x,y in dataloader:  #获取图片及其标签x,y=x.to(device),y.to(device)#计算预测误差pred=model(x)  #网络输出loss=loss_fn(pred,y)  #计算网络输出和真实值之间的差距,二者差值即为损失#反向传播optimizer.zero_grad()  #grad属性归零loss.backward()  #反向传播optimizer.step()  #每一步自动更新#记录acc与losstrain_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

#测试函数
def test(dataloader,model,loss_fn):size=len(dataloader.dataset) #测试集的大小num_batches=len(dataloader)  #批次数目test_loss,test_acc=0,0#当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)#计算losstarget_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss

3. 正式训练

import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)  #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss()  #创建损失函数 epochs=100train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]best_acc=0  #设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)#保存最佳模型到J8_modelif epoch_test_acc>best_acc:best_acc=epoch_test_accJ8_model=copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)#获取当前学习率lr=optimizer.state_dict()['param_groups'][0]['lr']template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J8_model.pth'
torch.save(model.state_dict(),PATH)
print('Done')

运行结果:

Epoch: 1,Train_acc:66.4%,Train_loss:0.633,Test_acc:61.5%,Test_loss:0.656,Lr:1.00E-04
Epoch: 2,Train_acc:67.3%,Train_loss:0.616,Test_acc:68.1%,Test_loss:0.624,Lr:1.00E-04
Epoch: 3,Train_acc:69.9%,Train_loss:0.591,Test_acc:68.5%,Test_loss:0.613,Lr:1.00E-04
Epoch: 4,Train_acc:73.3%,Train_loss:0.574,Test_acc:68.1%,Test_loss:0.605,Lr:1.00E-04
Epoch: 5,Train_acc:75.3%,Train_loss:0.556,Test_acc:65.0%,Test_loss:0.651,Lr:1.00E-04
Epoch: 6,Train_acc:74.9%,Train_loss:0.551,Test_acc:73.0%,Test_loss:0.563,Lr:1.00E-04
Epoch: 7,Train_acc:77.4%,Train_loss:0.529,Test_acc:80.4%,Test_loss:0.500,Lr:1.00E-04
Epoch: 8,Train_acc:80.2%,Train_loss:0.503,Test_acc:78.1%,Test_loss:0.518,Lr:1.00E-04
Epoch: 9,Train_acc:79.6%,Train_loss:0.509,Test_acc:77.4%,Test_loss:0.533,Lr:1.00E-04
Epoch:10,Train_acc:80.0%,Train_loss:0.504,Test_acc:77.6%,Test_loss:0.528,Lr:1.00E-04
Epoch:11,Train_acc:85.4%,Train_loss:0.456,Test_acc:86.0%,Test_loss:0.448,Lr:1.00E-04
Epoch:12,Train_acc:84.0%,Train_loss:0.466,Test_acc:84.4%,Test_loss:0.469,Lr:1.00E-04
Epoch:13,Train_acc:84.1%,Train_loss:0.467,Test_acc:80.0%,Test_loss:0.511,Lr:1.00E-04
Epoch:14,Train_acc:85.8%,Train_loss:0.449,Test_acc:77.4%,Test_loss:0.534,Lr:1.00E-04
Epoch:15,Train_acc:85.5%,Train_loss:0.460,Test_acc:83.2%,Test_loss:0.471,Lr:1.00E-04
Epoch:16,Train_acc:84.8%,Train_loss:0.464,Test_acc:81.8%,Test_loss:0.490,Lr:1.00E-04
Epoch:17,Train_acc:84.2%,Train_loss:0.467,Test_acc:76.5%,Test_loss:0.540,Lr:1.00E-04
Epoch:18,Train_acc:83.7%,Train_loss:0.471,Test_acc:84.6%,Test_loss:0.467,Lr:1.00E-04
Epoch:19,Train_acc:87.3%,Train_loss:0.436,Test_acc:88.3%,Test_loss:0.427,Lr:1.00E-04
Epoch:20,Train_acc:86.2%,Train_loss:0.454,Test_acc:82.5%,Test_loss:0.487,Lr:1.00E-04
Epoch:21,Train_acc:86.2%,Train_loss:0.444,Test_acc:87.9%,Test_loss:0.430,Lr:1.00E-04
Epoch:22,Train_acc:88.3%,Train_loss:0.437,Test_acc:87.9%,Test_loss:0.436,Lr:1.00E-04
Epoch:23,Train_acc:89.2%,Train_loss:0.423,Test_acc:86.7%,Test_loss:0.442,Lr:1.00E-04
Epoch:24,Train_acc:89.6%,Train_loss:0.424,Test_acc:89.0%,Test_loss:0.421,Lr:1.00E-04
Epoch:25,Train_acc:90.6%,Train_loss:0.405,Test_acc:91.6%,Test_loss:0.402,Lr:1.00E-04
Epoch:26,Train_acc:90.7%,Train_loss:0.413,Test_acc:89.7%,Test_loss:0.412,Lr:1.00E-04
Epoch:27,Train_acc:90.5%,Train_loss:0.406,Test_acc:87.6%,Test_loss:0.431,Lr:1.00E-04
Epoch:28,Train_acc:87.6%,Train_loss:0.427,Test_acc:86.0%,Test_loss:0.451,Lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.417,Test_acc:89.0%,Test_loss:0.421,Lr:1.00E-04
Epoch:30,Train_acc:91.5%,Train_loss:0.393,Test_acc:90.7%,Test_loss:0.406,Lr:1.00E-04
Epoch:31,Train_acc:92.1%,Train_loss:0.395,Test_acc:88.1%,Test_loss:0.427,Lr:1.00E-04
Epoch:32,Train_acc:93.0%,Train_loss:0.385,Test_acc:88.8%,Test_loss:0.418,Lr:1.00E-04
Epoch:33,Train_acc:91.4%,Train_loss:0.397,Test_acc:91.1%,Test_loss:0.402,Lr:1.00E-04
Epoch:34,Train_acc:92.5%,Train_loss:0.385,Test_acc:88.8%,Test_loss:0.425,Lr:1.00E-04
Epoch:35,Train_acc:91.8%,Train_loss:0.400,Test_acc:92.1%,Test_loss:0.391,Lr:1.00E-04
Epoch:36,Train_acc:91.9%,Train_loss:0.396,Test_acc:92.1%,Test_loss:0.390,Lr:1.00E-04
Epoch:37,Train_acc:90.5%,Train_loss:0.409,Test_acc:90.0%,Test_loss:0.413,Lr:1.00E-04
Epoch:38,Train_acc:93.1%,Train_loss:0.381,Test_acc:86.0%,Test_loss:0.444,Lr:1.00E-04
Epoch:39,Train_acc:93.1%,Train_loss:0.381,Test_acc:93.7%,Test_loss:0.379,Lr:1.00E-04
Epoch:40,Train_acc:93.5%,Train_loss:0.387,Test_acc:93.0%,Test_loss:0.381,Lr:1.00E-04
Epoch:41,Train_acc:94.1%,Train_loss:0.379,Test_acc:91.8%,Test_loss:0.394,Lr:1.00E-04
Epoch:42,Train_acc:93.6%,Train_loss:0.377,Test_acc:93.2%,Test_loss:0.381,Lr:1.00E-04
Epoch:43,Train_acc:93.9%,Train_loss:0.380,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:44,Train_acc:93.9%,Train_loss:0.381,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:45,Train_acc:89.6%,Train_loss:0.413,Test_acc:92.3%,Test_loss:0.388,Lr:1.00E-04
Epoch:46,Train_acc:91.7%,Train_loss:0.395,Test_acc:90.9%,Test_loss:0.401,Lr:1.00E-04
Epoch:47,Train_acc:93.4%,Train_loss:0.387,Test_acc:90.4%,Test_loss:0.407,Lr:1.00E-04
Epoch:48,Train_acc:93.6%,Train_loss:0.375,Test_acc:92.1%,Test_loss:0.388,Lr:1.00E-04
Epoch:49,Train_acc:93.8%,Train_loss:0.375,Test_acc:94.9%,Test_loss:0.367,Lr:1.00E-04
Epoch:50,Train_acc:94.2%,Train_loss:0.379,Test_acc:92.1%,Test_loss:0.390,Lr:1.00E-04
Epoch:51,Train_acc:93.6%,Train_loss:0.385,Test_acc:92.5%,Test_loss:0.383,Lr:1.00E-04
Epoch:52,Train_acc:93.9%,Train_loss:0.380,Test_acc:90.0%,Test_loss:0.414,Lr:1.00E-04
Epoch:53,Train_acc:93.2%,Train_loss:0.378,Test_acc:90.2%,Test_loss:0.412,Lr:1.00E-04
Epoch:54,Train_acc:92.6%,Train_loss:0.393,Test_acc:85.5%,Test_loss:0.454,Lr:1.00E-04
Epoch:55,Train_acc:91.9%,Train_loss:0.397,Test_acc:91.6%,Test_loss:0.398,Lr:1.00E-04
Epoch:56,Train_acc:94.3%,Train_loss:0.368,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:57,Train_acc:95.7%,Train_loss:0.354,Test_acc:93.5%,Test_loss:0.379,Lr:1.00E-04
Epoch:58,Train_acc:94.2%,Train_loss:0.380,Test_acc:93.5%,Test_loss:0.377,Lr:1.00E-04
Epoch:59,Train_acc:95.3%,Train_loss:0.361,Test_acc:93.0%,Test_loss:0.381,Lr:1.00E-04
Epoch:60,Train_acc:92.5%,Train_loss:0.385,Test_acc:90.0%,Test_loss:0.412,Lr:1.00E-04
Epoch:61,Train_acc:94.7%,Train_loss:0.372,Test_acc:95.1%,Test_loss:0.362,Lr:1.00E-04
Epoch:62,Train_acc:95.1%,Train_loss:0.369,Test_acc:90.2%,Test_loss:0.408,Lr:1.00E-04
Epoch:63,Train_acc:94.2%,Train_loss:0.380,Test_acc:92.8%,Test_loss:0.385,Lr:1.00E-04
Epoch:64,Train_acc:94.7%,Train_loss:0.374,Test_acc:91.8%,Test_loss:0.395,Lr:1.00E-04
Epoch:65,Train_acc:96.3%,Train_loss:0.359,Test_acc:94.2%,Test_loss:0.372,Lr:1.00E-04
Epoch:66,Train_acc:95.1%,Train_loss:0.361,Test_acc:92.1%,Test_loss:0.392,Lr:1.00E-04
Epoch:67,Train_acc:95.8%,Train_loss:0.354,Test_acc:92.3%,Test_loss:0.390,Lr:1.00E-04
Epoch:68,Train_acc:95.2%,Train_loss:0.358,Test_acc:93.7%,Test_loss:0.373,Lr:1.00E-04
Epoch:69,Train_acc:95.3%,Train_loss:0.360,Test_acc:93.7%,Test_loss:0.377,Lr:1.00E-04
Epoch:70,Train_acc:95.2%,Train_loss:0.368,Test_acc:93.2%,Test_loss:0.377,Lr:1.00E-04
Epoch:71,Train_acc:94.9%,Train_loss:0.373,Test_acc:91.8%,Test_loss:0.396,Lr:1.00E-04
Epoch:72,Train_acc:94.0%,Train_loss:0.371,Test_acc:94.2%,Test_loss:0.368,Lr:1.00E-04
Epoch:73,Train_acc:95.5%,Train_loss:0.356,Test_acc:93.2%,Test_loss:0.377,Lr:1.00E-04
Epoch:74,Train_acc:96.0%,Train_loss:0.362,Test_acc:93.9%,Test_loss:0.372,Lr:1.00E-04
Epoch:75,Train_acc:95.3%,Train_loss:0.368,Test_acc:93.2%,Test_loss:0.379,Lr:1.00E-04
Epoch:76,Train_acc:94.8%,Train_loss:0.368,Test_acc:91.4%,Test_loss:0.397,Lr:1.00E-04
Epoch:77,Train_acc:91.1%,Train_loss:0.398,Test_acc:90.9%,Test_loss:0.404,Lr:1.00E-04
Epoch:78,Train_acc:92.6%,Train_loss:0.388,Test_acc:90.0%,Test_loss:0.407,Lr:1.00E-04
Epoch:79,Train_acc:94.9%,Train_loss:0.362,Test_acc:94.9%,Test_loss:0.362,Lr:1.00E-04
Epoch:80,Train_acc:93.2%,Train_loss:0.380,Test_acc:91.4%,Test_loss:0.398,Lr:1.00E-04
Epoch:81,Train_acc:94.3%,Train_loss:0.368,Test_acc:93.7%,Test_loss:0.373,Lr:1.00E-04
Epoch:82,Train_acc:94.8%,Train_loss:0.363,Test_acc:93.2%,Test_loss:0.378,Lr:1.00E-04
Epoch:83,Train_acc:95.6%,Train_loss:0.364,Test_acc:88.1%,Test_loss:0.425,Lr:1.00E-04
Epoch:84,Train_acc:92.6%,Train_loss:0.386,Test_acc:92.3%,Test_loss:0.389,Lr:1.00E-04
Epoch:85,Train_acc:94.3%,Train_loss:0.377,Test_acc:91.4%,Test_loss:0.397,Lr:1.00E-04
Epoch:86,Train_acc:96.1%,Train_loss:0.359,Test_acc:96.3%,Test_loss:0.350,Lr:1.00E-04
Epoch:87,Train_acc:95.7%,Train_loss:0.363,Test_acc:94.2%,Test_loss:0.369,Lr:1.00E-04
Epoch:88,Train_acc:96.0%,Train_loss:0.360,Test_acc:93.2%,Test_loss:0.381,Lr:1.00E-04
Epoch:89,Train_acc:95.9%,Train_loss:0.353,Test_acc:93.7%,Test_loss:0.374,Lr:1.00E-04
Epoch:90,Train_acc:95.4%,Train_loss:0.361,Test_acc:93.2%,Test_loss:0.375,Lr:1.00E-04
Epoch:91,Train_acc:96.5%,Train_loss:0.347,Test_acc:95.1%,Test_loss:0.358,Lr:1.00E-04
Epoch:92,Train_acc:97.1%,Train_loss:0.342,Test_acc:94.9%,Test_loss:0.362,Lr:1.00E-04
Epoch:93,Train_acc:95.7%,Train_loss:0.363,Test_acc:93.7%,Test_loss:0.371,Lr:1.00E-04
Epoch:94,Train_acc:94.9%,Train_loss:0.361,Test_acc:95.1%,Test_loss:0.360,Lr:1.00E-04
Epoch:95,Train_acc:96.1%,Train_loss:0.360,Test_acc:95.1%,Test_loss:0.362,Lr:1.00E-04
Epoch:96,Train_acc:96.9%,Train_loss:0.352,Test_acc:93.7%,Test_loss:0.376,Lr:1.00E-04
Epoch:97,Train_acc:95.0%,Train_loss:0.367,Test_acc:93.7%,Test_loss:0.376,Lr:1.00E-04
Epoch:98,Train_acc:95.0%,Train_loss:0.371,Test_acc:95.1%,Test_loss:0.360,Lr:1.00E-04
Epoch:99,Train_acc:96.9%,Train_loss:0.352,Test_acc:95.6%,Test_loss:0.359,Lr:1.00E-04
Epoch:100,Train_acc:94.8%,Train_loss:0.374,Test_acc:93.7%,Test_loss:0.372,Lr:1.00E-04
Done

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")   #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei']   #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False   #正常显示负号
plt.rcParams['figure.dpi']=300   #分辨率epochs_range=range(epochs)
plt.figure(figsize=(12,3))plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果:

2. 指定图片进行预测 

from PIL import Imageclasses=list(total_data.class_to_idx)def predict_one_image(image_path,model,transform,classes):test_img=Image.open(image_path).convert('RGB')plt.imshow(test_img)   #展示预测的图片test_img=transform(test_img)img=test_img.to(device).unsqueeze(0)model.eval()output=model(img)_,pred=torch.max(output,1)pred_class=classes[pred]print(f'预测结果是:{pred_class}')

预测图片:

#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P4-data\Monkeypox\M01_02_00.jpg',model=model,transform=train_transforms,classes=classes)

运行结果:

预测结果是:Monkeypox

3. 模型评估

J8_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J8_model,loss_fn)
epoch_test_acc,epoch_test_loss

运行结果:

(0.9627039627039627, 0.34947266622825907)

五、心得体会

在本次项目训练中,体会了再pytorch环境下搭建Iception V1模型的过程,加深了对该模型的理解。同时在模型运行过程中尝试调整了学习率,从实际结果来看,当学习率为1e-7和1e-5时的结果不如1e-4的结果更好。


http://www.mrgr.cn/news/60899.html

相关文章:

  • 智慧用电监控装置:引领0.4kV安全用电新时代
  • Linux系统解压分卷压缩文件的解决方案
  • 图解Redis 06 | Hash数据类型的原理及应用场景
  • Java与C++:比较与对比
  • 实验04while(简单循环)---7-7 斐波那契数列第n项
  • spygalss cdc 检测的bug(二)
  • Anki插件Export deck to html的改造
  • 后台管理系统的通用权限解决方案(五)SpringBoot整合hibernate-validator实现表单校验
  • Java | Leetcode Java题解之第517题超级洗衣机
  • 【每日一题】王道 - 求序列公共元素
  • 10 个重要的JavaScript概念
  • Cesium的ComputeCommand及影像投影
  • 工业互联网平台赋能制造业数字化转型方案(55页PPT)
  • 深度学习之网络与计算
  • 晶闸管的选择方法
  • [专有网络VPC]创建和管理流日志
  • 脚本判断Zabbix版本
  • Python | Leetcode Python题解之第518题零钱兑换II
  • jQuery Mobile 表单输入
  • 人工智能技术的应用前景:改变我们的生活和工作方式