当前位置: 首页 > news >正文

使用Flask部署自己的pytorch模型(猫品种分类模型)

使用Flask部署自己的pytorch模型(猫品种分类模型)

全部代码开源在YTY666ZSY/Flask_Cat_7classify — yty666zsy/Flask_Cat_7classify (github.com)

一、数据集准备

来自大佬的文章调用pytorch的resnet,训练出准确率高达96%的猫12类分类模型。 - 知乎 (zhihu.com),在其基础上进行修改的。

在视觉中国中使用爬虫来进行猫咪品种的爬取,爬取后的图片需要自己去检查有没有错误,清洗图片数据。

如下代码所示,需要修改file_path,指定保存地址,修改base_url,例如"buoumao"为布偶猫的拼音,如果想搜索其他品种的猫,直接更改拼音就可以。

import asyncio
import re  import aiohttp
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.edge.options import Optionsdef ks_download_uel(image_urls):async def download_images(url_list):async with aiohttp.ClientSession() as session:global kfor url in url_list:try:async with session.get("https:" + url) as response:  # "https:" + url 进行网址拼接response.raise_for_status()file_path = fr"F:\project\猫12分类\data\孟买猫\{k}.jpg"  # 指定保存地址with open(file_path, 'wb') as file:while True:chunk = await response.content.read(8192)if not chunk:breakfile.write(chunk)print(f"已经完成 {k} 张")except Exception as e:print(f"下载第 {k} 张出现错误 :{str(e)}")k += 1  # 为下一张做标记# 创建事件循环对象loop = asyncio.get_event_loop()# 调用异步函数loop.run_until_complete(download_images(image_urls))if __name__ == '__main__':base_url = 'https://www.vcg.com/creative-image/mengmaimao/?page={page}'  # "buoumao"为布偶猫的拼音,如果想搜索其他品种的猫,直接更改拼音就可以edge_options = Options()edge_options.add_argument("--headless")  # 不显示浏览器敞口, 加快爬取速度。edge_options.add_argument("--no-sandbox")  # 防止启动失败driver = webdriver.Edge(options=edge_options)k = 1  # 为保存的每一种图片做标记for page in range(1, 5):  # 每一页150张,十页就够了。if page == 1:  # 目的是就打开一个网特,减少内存开销driver.get(base_url.format(page=page))  # 开始访问第page页elements = driver.find_elements(By.XPATH,'//*[@id="imageContent"]/section[1]')  # 将返回 //*[@id="imageContent"]/section/div 下的所有子标签元素urls_ls = []  # 所要的图片下载地址。for element in elements:html_source = element.get_attribute('outerHTML')urls_ls = re.findall('data-src="(.*?)"', str(html_source))  # 这里用了正则匹配,可以加快执行速度#  下面给大家推荐一个异步快速下载图片的方法, 建议这时候尽量多提供一下cpu和内存为程序ks_download_uel(urls_ls)driver.execute_script(f"window.open('{base_url.format(page=page)}', '_self')")  # 在当前窗口打开新网页,减少内存使用driver.quit()  # 在所有网页访问完成后退出 WebDriver

爬取后的图片保存在指定的位置

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

二、训练模型

如下面代码所示,我们使用res50的预训练模型,但是需要注意的是最后的线性层model.fc需要修改为自己需要的分类种类,train_data_path修改为自己data所在位置,需要说的是我们并不需要主动去划分测试集和训练集,我们只需要进行数据分类,在代码中会自动分类。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt  # 用于绘制图形class Trainer:def __init__(self, model, device, train_loader, valid_loader, lr=0.0001):self.model = model.to(device)  # 将模型转移到设备上self.device = deviceself.train_loader = train_loaderself.valid_loader = valid_loaderself.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)def train_one_epoch(self, epoch):self.model.train()correct_predictions = 0total_samples = 0epoch_loss = 0  # 记录当前轮次的损失for inputs, targets in tqdm(self.train_loader):inputs, targets = inputs.to(self.device), targets.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs)predictions = outputs.argmax(dim=1)correct_predictions += (predictions == targets).sum().item()total_samples += targets.size(0)loss = nn.CrossEntropyLoss()(outputs, targets)epoch_loss += loss.item()  # 累加损失loss.backward()self.optimizer.step()accuracy = 100. * correct_predictions / total_samplesaverage_loss = epoch_loss / len(self.train_loader)  # 计算平均损失print(f"Epoch {epoch}: Train Accuracy: {accuracy:.2f}%, Loss: {average_loss:.4f}")return average_loss  # 返回当前轮次的平均损失def validate(self):self.model.eval()correct_predictions = 0total_samples = 0total_loss = 0.0with torch.no_grad():for inputs, targets in self.valid_loader:inputs, targets = inputs.to(self.device), targets.to(self.device)outputs = self.model(inputs)loss = nn.CrossEntropyLoss()(outputs, targets)total_loss += loss.item()predictions = outputs.argmax(dim=1)correct_predictions += (predictions == targets).sum().item()total_samples += targets.size(0)accuracy = 100. * correct_predictions / total_samplesaverage_loss = total_loss / len(self.valid_loader)  # 计算平均损失print(f"Validation Accuracy: {accuracy:.2f}%, Loss: {average_loss:.4f}")return accuracy, average_loss  # 返回准确率和损失def create_data_loaders(train_root, batch_size):transform = transforms.Compose([transforms.Resize(256),transforms.RandomResizedCrop(244, scale=(0.6, 1.0), ratio=(0.8, 1.0)),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.5),transforms.ToTensor(),transforms.Normalize(mean=[0.4848, 0.4435, 0.4023], std=[0.2744, 0.2688, 0.2757])])dataset = torchvision.datasets.ImageFolder(root=train_root, transform=transform)class_counts = Counter(dataset.targets)weights = [1.0 / class_counts[idx] for idx in dataset.targets]sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)train_subset = Subset(dataset, list(sampler))valid_indices = [idx for idx in range(len(dataset)) if idx not in list(sampler)]valid_subset = Subset(dataset, valid_indices)train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False)return train_loader, valid_loaderif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')learning_rate = 0.0001epochs = 30batch_size = 32train_data_path = r"E:\github\my_github\Flask_cat_7classify\data"train_loader, valid_loader = create_data_loaders(train_data_path, batch_size)model = torchvision.models.resnet50(weights='ResNet50_Weights.DEFAULT')model.fc = nn.Linear(2048, 7)  # 调整输出层以适应7个类别trainer = Trainer(model, device, train_loader, valid_loader)best_accuracy = 0.0best_model_state = Nonetrain_losses = []  # 记录每轮训练损失valid_losses = []  # 记录每轮验证损失for epoch in range(1, epochs + 1):train_loss = trainer.train_one_epoch(epoch)train_losses.append(train_loss)  # 存储训练损失accuracy, valid_loss = trainer.validate()valid_losses.append(valid_loss)  # 存储验证损失if accuracy > best_accuracy:best_accuracy = accuracybest_model_state = model.state_dict()print(f"Best Validation Accuracy: {best_accuracy:.2f}%")torch.save(best_model_state, fr"E:\github\my_github\Flask_cat_7classify\best_model_train{best_accuracy:.2f}.pth")# 绘制损失变化图plt.figure(figsize=(10, 5))plt.plot(range(1, epochs + 1), train_losses, label='Train Loss', color='blue')plt.plot(range(1, epochs + 1), valid_losses, label='Validation Loss', color='orange')plt.title('Loss Change Over Epochs')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid()plt.savefig('loss_plot.png')  # 保存图形为 PNG 文件plt.show()  # 显示图形

三、使用flask部署模型

如下面代码所示,我们需要修改模型导入的位置,然后修改线性层类别,特别需要注意的一点是在定义类别上categories,我们需要按照文件夹的顺序来填写。

from flask import Flask, request, jsonify, send_from_directory  
import torchvision
import torch
import torchvision.transforms as transforms
from PIL import Image
import io# 初始化 Flask 应用
app = Flask(__name__)# 设置设备为 GPU,如果不可用则使用 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义图像转换管道
transform = transforms.Compose([transforms.Resize(256),  # 将图像调整为 256x256transforms.CenterCrop(244),  # 裁剪中心的 244x244 区域transforms.ToTensor(),  # 将图像转换为 PyTorch 张量transforms.Normalize(mean=[0.4848, 0.4435, 0.4023], std=[0.2744, 0.2688, 0.2757])  # 归一化
])# 加载 ResNet50 模型并修改为 7 个分类
model = torchvision.models.resnet50(weights=None)
model.fc = torch.nn.Linear(2048, 7)  # 设置输出层为 7 个类
model.load_state_dict(torch.load(r"E:\github\my_github\Flask_cat_7classify\best_model_train92.81.pth", map_location=device))
model.to(device)  # 将模型移动到指定设备
model.eval()  # 设置模型为评估模式# 定义类别
categories = ['俄罗斯蓝猫', '孟买猫', '布偶猫', '暹罗猫', '波斯猫', '缅因猫', '英国短毛猫']# 预测图像类别的函数
def predict_image(image_bytes):image = Image.open(io.BytesIO(image_bytes))  # 从字节加载图像image = transform(image).unsqueeze(0).to(device)  # 转换并添加批次维度with torch.no_grad():  # 禁用梯度计算output = model(image)  # 获取模型预测_, predicted = torch.max(output, 1)  # 获取预测的类别索引return predicted.item()  # 返回预测的类别索引# 定义预测路由
@app.route('/predict', methods=['POST'])
def predict():if 'file' not in request.files:return jsonify({'error': '没有文件部分'}), 400  # 如果没有文件部分,返回错误file = request.files['file']  # 获取上传的文件if file.filename == '':return jsonify({'error': '没有选择文件'}), 400  # 如果没有选择文件,返回错误if file:img_bytes = file.read()  # 读取文件字节try:prediction_index = predict_image(img_bytes)  # 获取预测索引prediction_label = categories[prediction_index]  # 将索引映射到标签return jsonify({'prediction': prediction_label})  # 返回预测结果except Exception as e:return jsonify({'error': str(e)}), 500  # 处理预测过程中出现的错误# 用于服务静态文件的路由
@app.route('/<path:path>')
def send_static(path):return send_from_directory('templates', path)  # 从 templates 目录提供静态文件# 运行 Flask 应用
if __name__ == '__main__':app.run(debug=True, host='0.0.0.0', port=5000)

另外在templates中定义了简单的静态html文件,如下图所示外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传最后运行app.py,然后在浏览器打开http://127.0.0.1:5000/index.html就可以成功实现
在这里插入图片描述

这样我们就实现了使用falsk部署深度学习模型的简易实现,在大佬文章中使用pyqt和gradio,感兴趣的可以实现一下。


http://www.mrgr.cn/news/72332.html

相关文章:

  • 记录一次Android Studio的下载、安装、配置
  • Jmeter进行http接口并发测试
  • 机器学习与人工智能的关系
  • ffmpeg常用命令及介绍
  • 不同路径 不同路径 II 整数拆分
  • Hive4.0.1集群安装部署(Hadoop版本为3.3.6)(详细教程)
  • 举例说明自然语言处理(NLP)技术。
  • 丹摩征文活动|CogVideoX-2b:从0到1,轻松完成安装与部署!
  • 功能性材料立式粉碎机、立式破碎机、立式超细磨、立式磨粉机
  • vxe-table 实现全部单元格都能编辑的方法
  • GPS L1信号捕获跟踪MATLAB仿真(终极版)
  • ubuntu20.04_从零LOD-3DGS的复现
  • 服务器数据恢复——Ext4文件系统使用fsck后mount不上的数据恢复案例
  • netmap.js:基于浏览器的网络发现工具
  • PET-文件包含-FINISHED
  • ManageOne_SC里业务员账号user01发布ECS
  • LeetCode【0024】两两交换链表中的节点
  • (11)(2.1.7) FETtec OneWire ESCs(二)
  • Sigrity SPEED2000 Power Ground Noise Simulation模式如何进行串扰分析操作指导-trace耦合
  • 遗传算法与深度学习实战(23)——利用遗传算法优化深度学习模型
  • Mysql详细知识点(建议收藏)
  • JUC-locks锁
  • Java基础-组件及事件处理(上)
  • AIDL HAL简介
  • Ajax 与 Vue 框架应用点——随笔谈
  • UI自动化测试|XPath元素定位实践