当前位置: 首页 > news >正文

J2:ResNet50v2算法实战与解析

J2周:ResNet50V2算法实战与解析

      • 论文解读
        • 1、ResNetV2结构与ResNet结构对比☕
        • 2、关于残差结构的不同尝试☕
        • 3、关于激活的尝试☕
      • Pytorch实现ResNet50V2算法
        • 1、导入库并设置GPU
        • 2、导入和检查数据
        • 3、划分数据集
        • 4、搭建ResNet-50V2模型
          • Residual Block
          • Stack(堆叠上述的Residual Blocks)
          • ResNet50V2搭建
        • 5、编写训练和测试函数
        • 6、正式训练
        • 7、模型评估
        • 8、预测
      • 总结

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

🍺本周任务

  • 根据本文Tensorflow代码,编写出相应的Pytorch代码
  • 了解ResNetV2与ResNetV的区别
  • 改进思路是否可以迁移到其他地方呢(自由探索)

⛽ 我的环境

  • 语言环境:Python3.10.12
  • 编译器:Google Colab
  • 深度学习环境:
    • torch==2.4.1+cu121
    • torchvision==0.19.1+cu121

⛵参考博客/文章:

  • 论文原文《Identity Mappings in Deep Residual Networks》

论文解读

1、ResNetV2结构与ResNet结构对比☕

实线表示测试误差(右边的y轴),虚线表示训练损失(左边的y轴),Iterations 表示迭代次数标题

实线表示测试误差(右边的y轴),虚线表示训练损失(左边的y轴),Iterations 表示迭代次数标题

  • 改进点:(a)original 表示原始的 ResNet 的残差结构,(b)proposed 表示新的 ResNet 的残差结构。主要差别就是(a)结构先卷积后进行 BN 和激活函数计算,最后执行 addition 后再进行ReLU 计算; (b)结构先进行 BN 和激活函数计算后卷积,把 addition 后的 ReLU 计算放到了残差结构内部。

  • 改进结果:作者使用这两种不同的结构在 CIFAR-10 数据集上做测试,模型用的是 1001层的 ResNet 模型。从图中结果我们可以看出,(b)proposed 的测试集错误率明显更低一些,达到了 4.92%的错误率,(a)original 的测试集错误率是 7.61%。

2、关于残差结构的不同尝试☕

在这里插入图片描述

(b-f)中的快捷连接被不同的组件阻碍。为了简化插图,我们不显示BN层,这里所有单位均采用权值层之后的BN层。图中(a-f)都是作者对残差结构的 shortcut部分进行的不同尝试,作者对不同 shortcut结构的尝试结果如下表所示。
在这里插入图片描述

使用ResNet-110在CIFAR-10测试集上的分类错误,对所有残差单元应用了不同类型的shortcut connections。当测试误差高于20%时,标注为“fail”。


用不同 shortcut 结构的 ResNet-110 在CIFAR-10 数据集上做测试,发现最原始的(a)original 结构是最好的,也就是identity mapping 恒等映射是最好的。

3、关于激活的尝试☕

在这里插入图片描述
在这里插入图片描述

上图显示在CIFAR-10上测试最好的结果是(e) full pre-activation,其次是(a) original;所有尝试所包含的组件是一样的但是顺序不同。

Pytorch实现ResNet50V2算法

1、导入库并设置GPU
#### 导入库和需要的包
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import torch.nn.functional as F
import copy, random, pathlib
import matplotlib.pyplot as plt
from PIL import Image
from torchsummary import summary
import numpy as np
import warnings
warnings.filterwarnings("ignore")### 设置GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
2、导入和检查数据
from google.colab import drive
drive.mount("/content/drive/")
%cd "/content/drive/My Drive/Colab Notebooks/jupyter notebook/data/J1"
Mounted at /content/drive/
/content/drive/My Drive/Colab Notebooks/jupyter notebook/data/J1
data_dir = "./bird_photos/"
data_dir = pathlib.Path(data_dir)
data_paths = list(data_dir.glob("*"))classnames = [str(path).split("/")[1] for path in data_paths]
print(classnames)
num_classes = len(classnames)
print(num_classes)
['Cockatoo', 'Black Skimmer', 'Bananaquit', 'Black Throated Bushtiti']
4
3、划分数据集
'''图像数据变换'''
train_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],)
])total_data = datasets.ImageFolder(data_dir,transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)
Dataset ImageFolderNumber of datapoints: 565Root location: bird_photosStandardTransform
Transform: Compose(ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
{'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}
'''划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_ds, test_ds = random_split(total_data,[train_size,test_size])
print("train_dataset={}\ntest_dataset={}".format(train_ds, test_ds))
print("train_size={}\ntest_size={}".format(train_size, test_size))
train_dataset=<torch.utils.data.dataset.Subset object at 0x7fb3244300d0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x7fb324430040>
train_size=452
test_size=113
'''加载数据'''
batch_size=8
train_dl = DataLoader(train_ds,batch_size=batch_size,shuffle = True,num_workers=1)
test_dl = DataLoader(test_ds,batch_size=batch_size,shuffle = True,num_workers=1)
for X,y in test_dl:print("Shape of X [N, C, H, W]:",X.shape)print("Shape of y:",y.shape,y.dtype)break
Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
4、搭建ResNet-50V2模型

J2-6

如上图(简化出的模型结构,更详细的完整模型结构见末尾)ResNet50V2在ResNet-50基础上改变了BN和ReLU激活层的顺序,由于shortcut的设计存在3个不同的残差块(因此需要在构建时加以区分)。整体网络是通过不同数量的残差块组堆叠而成。

Residual Block
'''residual block'''
"""
残差块:Arguments:in_channels:输入通道数(输入张量)filters: filters of the bottleneck layerkernel_size: 默认是3,kernels_size of the bottleneck layerstride:default=1, stride of the first layerconv_shortcut:default False, use convolution shortcut if True, otherwise identity shortcutreturns: output tensor for the residual block.
"""
class Block2(nn.Module):def __init__(self, in_channel, filters, kernel_size=3, stride=1, conv_shortcut=False):super(Block2, self).__init__()self.preact = nn.Sequential(nn.BatchNorm2d(in_channel),nn.ReLU(True))self.shortcut = conv_shortcutif self.shortcut:self.short = nn.Conv2d(in_channel, 4*filters, 1, stride=stride, padding=0, bias=False)elif stride>1:self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)else:self.short = nn.Identity()self.conv1 = nn.Sequential(nn.Conv2d(in_channel, filters, 1, stride=1, bias=False),nn.BatchNorm2d(filters),nn.ReLU(True))self.conv2 = nn.Sequential(nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1, bias=False),nn.BatchNorm2d(filters),nn.ReLU(True))self.conv3 = nn.Conv2d(filters, 4*filters, 1, stride=1, bias=False)def forward(self, x):x1 = self.preact(x)if self.shortcut:x2 = self.short(x1)else:x2 = self.short(x)x1 = self.conv1(x1)x1 = self.conv2(x1)x1 = self.conv3(x1)x = x1 + x2return x
Stack(堆叠上述的Residual Blocks)

注意辨析清楚输入通道数和输出通道数在整个残差块组中的变化规律

class Stack(nn.Module):def __init__(self, in_channels, filters, blocks, stride = 2):super(Stack, self).__init__()self.conv = nn.Sequential()#每个残差组的第一个残差块均含conv shortcutself.conv.add_module(str(0), Block2(in_channels, filters, conv_shortcut=True))#堆不同数量的identity块【不同组的实际区别就是这个数量】for i in range(1, blocks-1):self.conv.add_module(str(i), Block2(4*filters, filters))#每个残差组的最后一个残差块的shortcut均含maxpool,stride=2self.conv.add_module(str(blocks-1), Block2(4*filters, filters, stride=stride))def forward(self,x):x = self.conv(x)return x
ResNet50V2搭建
''' 构建ResNet50v2 '''
class ResNet50v2(nn.Module):def __init__(self,include_top=True,  # 是否包含位于网络顶部的全链接层preact=True,  # 是否使用预激活use_bias=True,  # 是否对卷积层使用偏置input_shape=[224, 224, 3],classes=1000,pooling=None):super(ResNet50v2, self).__init__()self.conv1 = nn.Sequential()self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))if not preact:self.conv1.add_module('bn', nn.BatchNorm2d(64))self.conv1.add_module('relu', nn.ReLU())self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.conv2 = Stack(64, 64, 3)self.conv3 = Stack(256, 128, 4)self.conv4 = Stack(512, 256, 6)self.conv5 = Stack(1024, 512, 3, stride=1)self.post = nn.Sequential()if preact:self.post.add_module('bn', nn.BatchNorm2d(2048))self.post.add_module('relu', nn.ReLU())if include_top:self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))self.post.add_module('flatten', nn.Flatten())self.post.add_module('fc', nn.Linear(2048, classes))else:if pooling=='avg':self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))elif pooling=='max':self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)x = self.post(x)return xmodel = ResNet50v2().to(device)summary(model, (3, 224, 224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,472MaxPool2d-2           [-1, 64, 56, 56]               0BatchNorm2d-3           [-1, 64, 56, 56]             128ReLU-4           [-1, 64, 56, 56]               0Conv2d-5          [-1, 256, 56, 56]          16,384Conv2d-6           [-1, 64, 56, 56]           4,096BatchNorm2d-7           [-1, 64, 56, 56]             128ReLU-8           [-1, 64, 56, 56]               0Conv2d-9           [-1, 64, 56, 56]          36,864BatchNorm2d-10           [-1, 64, 56, 56]             128ReLU-11           [-1, 64, 56, 56]               0Conv2d-12          [-1, 256, 56, 56]          16,384Block2-13          [-1, 256, 56, 56]               0BatchNorm2d-14          [-1, 256, 56, 56]             512ReLU-15          [-1, 256, 56, 56]               0Identity-16          [-1, 256, 56, 56]               0Conv2d-17           [-1, 64, 56, 56]          16,384BatchNorm2d-18           [-1, 64, 56, 56]             128ReLU-19           [-1, 64, 56, 56]               0Conv2d-20           [-1, 64, 56, 56]          36,864BatchNorm2d-21           [-1, 64, 56, 56]             128ReLU-22           [-1, 64, 56, 56]               0Conv2d-23          [-1, 256, 56, 56]          16,384Block2-24          [-1, 256, 56, 56]               0BatchNorm2d-25          [-1, 256, 56, 56]             512ReLU-26          [-1, 256, 56, 56]               0MaxPool2d-27          [-1, 256, 28, 28]               0Conv2d-28           [-1, 64, 56, 56]          16,384BatchNorm2d-29           [-1, 64, 56, 56]             128ReLU-30           [-1, 64, 56, 56]               0Conv2d-31           [-1, 64, 28, 28]          36,864BatchNorm2d-32           [-1, 64, 28, 28]             128ReLU-33           [-1, 64, 28, 28]               0Conv2d-34          [-1, 256, 28, 28]          16,384Block2-35          [-1, 256, 28, 28]               0Stack-36          [-1, 256, 28, 28]               0BatchNorm2d-37          [-1, 256, 28, 28]             512ReLU-38          [-1, 256, 28, 28]               0Conv2d-39          [-1, 512, 28, 28]         131,072Conv2d-40          [-1, 128, 28, 28]          32,768BatchNorm2d-41          [-1, 128, 28, 28]             256ReLU-42          [-1, 128, 28, 28]               0Conv2d-43          [-1, 128, 28, 28]         147,456BatchNorm2d-44          [-1, 128, 28, 28]             256ReLU-45          [-1, 128, 28, 28]               0Conv2d-46          [-1, 512, 28, 28]          65,536Block2-47          [-1, 512, 28, 28]               0BatchNorm2d-48          [-1, 512, 28, 28]           1,024ReLU-49          [-1, 512, 28, 28]               0Identity-50          [-1, 512, 28, 28]               0Conv2d-51          [-1, 128, 28, 28]          65,536BatchNorm2d-52          [-1, 128, 28, 28]             256ReLU-53          [-1, 128, 28, 28]               0Conv2d-54          [-1, 128, 28, 28]         147,456BatchNorm2d-55          [-1, 128, 28, 28]             256ReLU-56          [-1, 128, 28, 28]               0Conv2d-57          [-1, 512, 28, 28]          65,536Block2-58          [-1, 512, 28, 28]               0BatchNorm2d-59          [-1, 512, 28, 28]           1,024ReLU-60          [-1, 512, 28, 28]               0Identity-61          [-1, 512, 28, 28]               0Conv2d-62          [-1, 128, 28, 28]          65,536BatchNorm2d-63          [-1, 128, 28, 28]             256ReLU-64          [-1, 128, 28, 28]               0Conv2d-65          [-1, 128, 28, 28]         147,456BatchNorm2d-66          [-1, 128, 28, 28]             256ReLU-67          [-1, 128, 28, 28]               0Conv2d-68          [-1, 512, 28, 28]          65,536Block2-69          [-1, 512, 28, 28]               0BatchNorm2d-70          [-1, 512, 28, 28]           1,024ReLU-71          [-1, 512, 28, 28]               0MaxPool2d-72          [-1, 512, 14, 14]               0Conv2d-73          [-1, 128, 28, 28]          65,536BatchNorm2d-74          [-1, 128, 28, 28]             256ReLU-75          [-1, 128, 28, 28]               0Conv2d-76          [-1, 128, 14, 14]         147,456BatchNorm2d-77          [-1, 128, 14, 14]             256ReLU-78          [-1, 128, 14, 14]               0Conv2d-79          [-1, 512, 14, 14]          65,536Block2-80          [-1, 512, 14, 14]               0Stack-81          [-1, 512, 14, 14]               0BatchNorm2d-82          [-1, 512, 14, 14]           1,024ReLU-83          [-1, 512, 14, 14]               0Conv2d-84         [-1, 1024, 14, 14]         524,288Conv2d-85          [-1, 256, 14, 14]         131,072BatchNorm2d-86          [-1, 256, 14, 14]             512ReLU-87          [-1, 256, 14, 14]               0Conv2d-88          [-1, 256, 14, 14]         589,824BatchNorm2d-89          [-1, 256, 14, 14]             512ReLU-90          [-1, 256, 14, 14]               0Conv2d-91         [-1, 1024, 14, 14]         262,144Block2-92         [-1, 1024, 14, 14]               0BatchNorm2d-93         [-1, 1024, 14, 14]           2,048ReLU-94         [-1, 1024, 14, 14]               0Identity-95         [-1, 1024, 14, 14]               0Conv2d-96          [-1, 256, 14, 14]         262,144BatchNorm2d-97          [-1, 256, 14, 14]             512ReLU-98          [-1, 256, 14, 14]               0Conv2d-99          [-1, 256, 14, 14]         589,824BatchNorm2d-100          [-1, 256, 14, 14]             512ReLU-101          [-1, 256, 14, 14]               0Conv2d-102         [-1, 1024, 14, 14]         262,144Block2-103         [-1, 1024, 14, 14]               0BatchNorm2d-104         [-1, 1024, 14, 14]           2,048ReLU-105         [-1, 1024, 14, 14]               0Identity-106         [-1, 1024, 14, 14]               0Conv2d-107          [-1, 256, 14, 14]         262,144BatchNorm2d-108          [-1, 256, 14, 14]             512ReLU-109          [-1, 256, 14, 14]               0Conv2d-110          [-1, 256, 14, 14]         589,824BatchNorm2d-111          [-1, 256, 14, 14]             512ReLU-112          [-1, 256, 14, 14]               0Conv2d-113         [-1, 1024, 14, 14]         262,144Block2-114         [-1, 1024, 14, 14]               0BatchNorm2d-115         [-1, 1024, 14, 14]           2,048ReLU-116         [-1, 1024, 14, 14]               0Identity-117         [-1, 1024, 14, 14]               0Conv2d-118          [-1, 256, 14, 14]         262,144BatchNorm2d-119          [-1, 256, 14, 14]             512ReLU-120          [-1, 256, 14, 14]               0Conv2d-121          [-1, 256, 14, 14]         589,824BatchNorm2d-122          [-1, 256, 14, 14]             512ReLU-123          [-1, 256, 14, 14]               0Conv2d-124         [-1, 1024, 14, 14]         262,144Block2-125         [-1, 1024, 14, 14]               0BatchNorm2d-126         [-1, 1024, 14, 14]           2,048ReLU-127         [-1, 1024, 14, 14]               0Identity-128         [-1, 1024, 14, 14]               0Conv2d-129          [-1, 256, 14, 14]         262,144BatchNorm2d-130          [-1, 256, 14, 14]             512ReLU-131          [-1, 256, 14, 14]               0Conv2d-132          [-1, 256, 14, 14]         589,824BatchNorm2d-133          [-1, 256, 14, 14]             512ReLU-134          [-1, 256, 14, 14]               0Conv2d-135         [-1, 1024, 14, 14]         262,144Block2-136         [-1, 1024, 14, 14]               0BatchNorm2d-137         [-1, 1024, 14, 14]           2,048ReLU-138         [-1, 1024, 14, 14]               0MaxPool2d-139           [-1, 1024, 7, 7]               0Conv2d-140          [-1, 256, 14, 14]         262,144BatchNorm2d-141          [-1, 256, 14, 14]             512ReLU-142          [-1, 256, 14, 14]               0Conv2d-143            [-1, 256, 7, 7]         589,824BatchNorm2d-144            [-1, 256, 7, 7]             512ReLU-145            [-1, 256, 7, 7]               0Conv2d-146           [-1, 1024, 7, 7]         262,144Block2-147           [-1, 1024, 7, 7]               0Stack-148           [-1, 1024, 7, 7]               0BatchNorm2d-149           [-1, 1024, 7, 7]           2,048ReLU-150           [-1, 1024, 7, 7]               0Conv2d-151           [-1, 2048, 7, 7]       2,097,152Conv2d-152            [-1, 512, 7, 7]         524,288BatchNorm2d-153            [-1, 512, 7, 7]           1,024ReLU-154            [-1, 512, 7, 7]               0Conv2d-155            [-1, 512, 7, 7]       2,359,296BatchNorm2d-156            [-1, 512, 7, 7]           1,024ReLU-157            [-1, 512, 7, 7]               0Conv2d-158           [-1, 2048, 7, 7]       1,048,576Block2-159           [-1, 2048, 7, 7]               0BatchNorm2d-160           [-1, 2048, 7, 7]           4,096ReLU-161           [-1, 2048, 7, 7]               0Identity-162           [-1, 2048, 7, 7]               0Conv2d-163            [-1, 512, 7, 7]       1,048,576BatchNorm2d-164            [-1, 512, 7, 7]           1,024ReLU-165            [-1, 512, 7, 7]               0Conv2d-166            [-1, 512, 7, 7]       2,359,296BatchNorm2d-167            [-1, 512, 7, 7]           1,024ReLU-168            [-1, 512, 7, 7]               0Conv2d-169           [-1, 2048, 7, 7]       1,048,576Block2-170           [-1, 2048, 7, 7]               0BatchNorm2d-171           [-1, 2048, 7, 7]           4,096ReLU-172           [-1, 2048, 7, 7]               0Identity-173           [-1, 2048, 7, 7]               0Conv2d-174            [-1, 512, 7, 7]       1,048,576BatchNorm2d-175            [-1, 512, 7, 7]           1,024ReLU-176            [-1, 512, 7, 7]               0Conv2d-177            [-1, 512, 7, 7]       2,359,296BatchNorm2d-178            [-1, 512, 7, 7]           1,024ReLU-179            [-1, 512, 7, 7]               0Conv2d-180           [-1, 2048, 7, 7]       1,048,576Block2-181           [-1, 2048, 7, 7]               0Stack-182           [-1, 2048, 7, 7]               0BatchNorm2d-183           [-1, 2048, 7, 7]           4,096ReLU-184           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-185           [-1, 2048, 1, 1]               0Flatten-186                 [-1, 2048]               0Linear-187                 [-1, 1000]       2,049,000
================================================================
Total params: 25,549,416
Trainable params: 25,549,416
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 241.69
Params size (MB): 97.46
Estimated Total Size (MB): 339.73
----------------------------------------------------------------
5、编写训练和测试函数
'''
编写训练函数
'''
def train(dataloader, model, optimizer, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)train_acc, train_loss = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_acc, train_loss
'''
编写测试函数
'''
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
6、正式训练
'''
设置损失函数和学习率
'''
loss_fn = nn.CrossEntropyLoss()   #交叉熵函数
learn_rate = 1e-6
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.96)
'''
正式训练
'''
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0# 开始训练
for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, optimizer, loss_fn)scheduler.step()model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)if epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))print('Done')
Epoch: 1, Train_acc:0.0%, Train_loss:6.872, Test_acc:0.0%, Test_loss:6.949, Lr:1.00E-06
Epoch: 2, Train_acc:0.9%, Train_loss:6.775, Test_acc:2.7%, Test_loss:6.842, Lr:1.00E-06
Epoch: 3, Train_acc:8.4%, Train_loss:6.677, Test_acc:5.3%, Test_loss:6.793, Lr:1.00E-06
Epoch: 4, Train_acc:14.2%, Train_loss:6.585, Test_acc:8.0%, Test_loss:6.666, Lr:1.00E-06
Epoch: 5, Train_acc:20.1%, Train_loss:6.485, Test_acc:15.9%, Test_loss:6.550, Lr:9.60E-07
Epoch: 6, Train_acc:25.0%, Train_loss:6.387, Test_acc:17.7%, Test_loss:6.481, Lr:9.60E-07
Epoch: 7, Train_acc:26.8%, Train_loss:6.300, Test_acc:19.5%, Test_loss:6.383, Lr:9.60E-07
Epoch: 8, Train_acc:28.1%, Train_loss:6.202, Test_acc:19.5%, Test_loss:6.304, Lr:9.60E-07
Epoch: 9, Train_acc:27.4%, Train_loss:6.111, Test_acc:23.0%, Test_loss:6.223, Lr:9.60E-07
Epoch:10, Train_acc:28.3%, Train_loss:6.013, Test_acc:23.9%, Test_loss:6.093, Lr:9.22E-07
Epoch:11, Train_acc:29.2%, Train_loss:5.922, Test_acc:23.9%, Test_loss:6.020, Lr:9.22E-07
Epoch:12, Train_acc:29.6%, Train_loss:5.837, Test_acc:29.2%, Test_loss:5.942, Lr:9.22E-07
Epoch:13, Train_acc:33.0%, Train_loss:5.732, Test_acc:26.5%, Test_loss:5.806, Lr:9.22E-07
Epoch:14, Train_acc:31.0%, Train_loss:5.636, Test_acc:33.6%, Test_loss:5.742, Lr:9.22E-07
Epoch:15, Train_acc:34.1%, Train_loss:5.529, Test_acc:30.1%, Test_loss:5.631, Lr:8.85E-07
Epoch:16, Train_acc:36.1%, Train_loss:5.430, Test_acc:38.9%, Test_loss:5.448, Lr:8.85E-07
Epoch:17, Train_acc:38.9%, Train_loss:5.304, Test_acc:40.7%, Test_loss:5.409, Lr:8.85E-07
Epoch:18, Train_acc:39.8%, Train_loss:5.247, Test_acc:42.5%, Test_loss:5.279, Lr:8.85E-07
Epoch:19, Train_acc:44.5%, Train_loss:5.109, Test_acc:45.1%, Test_loss:5.176, Lr:8.85E-07
Epoch:20, Train_acc:48.5%, Train_loss:4.988, Test_acc:46.9%, Test_loss:5.111, Lr:8.49E-07
Epoch:21, Train_acc:51.1%, Train_loss:4.848, Test_acc:51.3%, Test_loss:4.819, Lr:8.49E-07
Epoch:22, Train_acc:49.1%, Train_loss:4.719, Test_acc:52.2%, Test_loss:4.778, Lr:8.49E-07
Epoch:23, Train_acc:52.4%, Train_loss:4.575, Test_acc:52.2%, Test_loss:4.576, Lr:8.49E-07
Epoch:24, Train_acc:57.7%, Train_loss:4.408, Test_acc:51.3%, Test_loss:4.453, Lr:8.49E-07
Epoch:25, Train_acc:60.8%, Train_loss:4.223, Test_acc:52.2%, Test_loss:4.001, Lr:8.15E-07
Epoch:26, Train_acc:55.1%, Train_loss:4.158, Test_acc:60.2%, Test_loss:3.995, Lr:8.15E-07
Epoch:27, Train_acc:58.0%, Train_loss:3.888, Test_acc:57.5%, Test_loss:3.795, Lr:8.15E-07
Epoch:28, Train_acc:58.2%, Train_loss:3.781, Test_acc:57.5%, Test_loss:3.810, Lr:8.15E-07
Epoch:29, Train_acc:61.1%, Train_loss:3.554, Test_acc:58.4%, Test_loss:3.364, Lr:8.15E-07
Epoch:30, Train_acc:63.1%, Train_loss:3.471, Test_acc:63.7%, Test_loss:3.322, Lr:7.83E-07
Epoch:31, Train_acc:65.5%, Train_loss:3.286, Test_acc:63.7%, Test_loss:3.151, Lr:7.83E-07
Epoch:32, Train_acc:62.6%, Train_loss:3.189, Test_acc:62.8%, Test_loss:2.993, Lr:7.83E-07
Epoch:33, Train_acc:63.9%, Train_loss:3.064, Test_acc:60.2%, Test_loss:2.774, Lr:7.83E-07
Epoch:34, Train_acc:65.5%, Train_loss:2.920, Test_acc:69.0%, Test_loss:2.916, Lr:7.83E-07
Epoch:35, Train_acc:68.4%, Train_loss:2.843, Test_acc:61.9%, Test_loss:2.643, Lr:7.51E-07
Epoch:36, Train_acc:69.7%, Train_loss:2.696, Test_acc:69.9%, Test_loss:2.502, Lr:7.51E-07
Epoch:37, Train_acc:68.1%, Train_loss:2.659, Test_acc:66.4%, Test_loss:2.307, Lr:7.51E-07
Epoch:38, Train_acc:61.9%, Train_loss:2.634, Test_acc:64.6%, Test_loss:2.103, Lr:7.51E-07
Epoch:39, Train_acc:70.1%, Train_loss:2.432, Test_acc:69.0%, Test_loss:2.125, Lr:7.51E-07
Epoch:40, Train_acc:65.5%, Train_loss:2.396, Test_acc:72.6%, Test_loss:2.127, Lr:7.21E-07
Epoch:41, Train_acc:71.0%, Train_loss:2.269, Test_acc:69.0%, Test_loss:2.067, Lr:7.21E-07
Epoch:42, Train_acc:68.1%, Train_loss:2.266, Test_acc:70.8%, Test_loss:2.067, Lr:7.21E-07
Epoch:43, Train_acc:66.8%, Train_loss:2.186, Test_acc:73.5%, Test_loss:1.887, Lr:7.21E-07
Epoch:44, Train_acc:67.9%, Train_loss:2.128, Test_acc:76.1%, Test_loss:1.818, Lr:7.21E-07
Epoch:45, Train_acc:69.5%, Train_loss:2.102, Test_acc:69.9%, Test_loss:1.955, Lr:6.93E-07
Epoch:46, Train_acc:70.6%, Train_loss:1.962, Test_acc:77.0%, Test_loss:1.844, Lr:6.93E-07
Epoch:47, Train_acc:71.0%, Train_loss:1.967, Test_acc:78.8%, Test_loss:1.697, Lr:6.93E-07
Epoch:48, Train_acc:71.9%, Train_loss:1.910, Test_acc:77.9%, Test_loss:1.860, Lr:6.93E-07
Epoch:49, Train_acc:75.4%, Train_loss:1.809, Test_acc:76.1%, Test_loss:1.550, Lr:6.93E-07
Epoch:50, Train_acc:71.5%, Train_loss:1.786, Test_acc:73.5%, Test_loss:1.565, Lr:6.65E-07
Done
7、模型评估
# 结果可视化
import warnings
warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['axes.unicode_minus'] = False    #用来正常显示负号
plt.rcParams['figure.dpi'] = 100              #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
8、预测
classes = list(total_data.class_to_idx)
# 定义反归一化操作
inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],std=[1/0.229, 1/0.224, 1/0.225]
)for images, _ in test_dl:break  # 获取一个batch的图片# 进行预测
with torch.no_grad():outputs = best_model(images.to(device))_, preds = torch.max(outputs, 1)# 可视化
plt.figure(figsize=(10, 4))
for i in range(8):plt.subplot(2, 4, i + 1)img = images[i].cpu()img = inv_normalize(img)  # 应用反归一化img = img.permute(1, 2, 0)  # 转为 [H, W, C] 格式plt.imshow(img)plt.title(classes[preds[i].item()])plt.axis('off')plt.show()

总结

本周以上周的训练为基础,整体较为简单,对残差模型的本质和特点有了更深的理解,但是第一次50epoch结果正确率不是很理想,需要改进。


http://www.mrgr.cn/news/63785.html

相关文章:

  • dolphinscheduler2.0.9升级3.1.9版本问题记录
  • 【Spring】注入方式
  • Dexcap复现代码数据预处理全流程(四)——demo_clipping_3d.py
  • 实战篇: BiLSTM+CRF实现中文分词
  • ubuntu22.04 编译安装libvirt 10.x
  • 面向对象分析和设计OOA/D,UML,GRASP
  • CTF顶级工具与资源
  • 市面上12款能帮忙微信记录的数据恢复软件神器!!!
  • Python For循环
  • 再探“构造函数”
  • 备考最后一周调整
  • shodan用法(完)
  • 在 Vue 3 中实现流畅的 Swiper 滑动效果
  • HJ36 字符串加密
  • c++仿函数--通俗易懂
  • 【p2p、分布式,区块链笔记 Torrent】WebTorrent 的lt_donthave插件
  • LeetCode总结-链表
  • 使用TensorFlow进行图像分类
  • 某小型CMS漏洞复现审计
  • Ceisum无人机巡检视频投放
  • NET Core的AOP实施方法1 DispatchProxy
  • 【Linux】基础指令
  • ERROR: Failed cleaning build dir for numpy Failed > to build numpy ERROR
  • 一键切换暗黑模式,这些代码片段你不可错过
  • 直流电机在液压泵领域的应用
  • ubuntu运行gazebo导致内存越来越少