当前位置: 首页 > news >正文

快速学会一个算法:Faster R-CNN进行目标检测!

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于YOLOv8深度学习的行人跌倒检测系统】
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】
11.【基于YOLOv8深度学习的安全帽目标检测系统】12.【基于YOLOv8深度学习的120种犬类检测与识别系统】
13.【基于YOLOv8深度学习的路面坑洞检测系统】14.【基于YOLOv8深度学习的火焰烟雾检测系统】
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】16.【基于YOLOv8深度学习的舰船目标分类检测系统】
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】18.【基于YOLOv8深度学习的血细胞检测与计数系统】
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】22.【基于YOLOv8深度学习的路面标志线检测与识别系统】
23.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】24.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】
25.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】26.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】
27.【基于YOLOv8深度学习的人脸面部表情识别系统】28.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】
29.【基于YOLOv8深度学习的智能肺炎诊断系统】30.【基于YOLOv8深度学习的葡萄簇目标检测系统】
31.【基于YOLOv8深度学习的100种中草药智能识别系统】32.【基于YOLOv8深度学习的102种花卉智能识别系统】
33.【基于YOLOv8深度学习的100种蝴蝶智能识别系统】34.【基于YOLOv8深度学习的水稻叶片病害智能诊断系统】
35.【基于YOLOv8与ByteTrack的车辆行人多目标检测与追踪系统】36.【基于YOLOv8深度学习的智能草莓病害检测与分割系统】
37.【基于YOLOv8深度学习的复杂场景下船舶目标检测系统】38.【基于YOLOv8深度学习的农作物幼苗与杂草检测系统】
39.【基于YOLOv8深度学习的智能道路裂缝检测与分析系统】40.【基于YOLOv8深度学习的葡萄病害智能诊断与防治系统】
41.【基于YOLOv8深度学习的遥感地理空间物体检测系统】42.【基于YOLOv8深度学习的无人机视角地面物体检测系统】
43.【基于YOLOv8深度学习的木薯病害智能诊断与防治系统】44.【基于YOLOv8深度学习的野外火焰烟雾检测系统】
45.【基于YOLOv8深度学习的脑肿瘤智能检测系统】46.【基于YOLOv8深度学习的玉米叶片病害智能诊断与防治系统】
47.【基于YOLOv8深度学习的橙子病害智能诊断与防治系统】48.【基于深度学习的车辆检测追踪与流量计数系统】
49.【基于深度学习的行人检测追踪与双向流量计数系统】50.【基于深度学习的反光衣检测与预警系统】
51.【基于深度学习的危险区域人员闯入检测与报警系统】52.【基于深度学习的高密度人脸智能检测与统计系统】
53.【基于深度学习的CT扫描图像肾结石智能检测系统】54.【基于深度学习的水果智能检测系统】
55.【基于深度学习的水果质量好坏智能检测系统】56.【基于深度学习的蔬菜目标检测与识别系统】
57.【基于深度学习的非机动车驾驶员头盔检测系统】58.【太基于深度学习的阳能电池板检测与分析系统】
59.【基于深度学习的工业螺栓螺母检测】60.【基于深度学习的金属焊缝缺陷检测系统】
61.【基于深度学习的链条缺陷检测与识别系统】62.【基于深度学习的交通信号灯检测识别】
63.【基于深度学习的草莓成熟度检测与识别系统】64.【基于深度学习的水下海生物检测识别系统】
65.【基于深度学习的道路交通事故检测识别系统】66.【基于深度学习的安检X光危险品检测与识别系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

  • 引言
  • Faster-R-CNN概述
  • 实现Faster R-CNN用于对象检测
  • 创建辅助函数
  • 准备模型
  • 模型检测
  • 结论

引言

在这里插入图片描述

计算机视觉领域的一个案例研究是创建一个解决方案,使系统能够像人类一样“看到”和“理解”图像中的对象。这个任务被称为目标检测。在分类任务中,模型识别图像中对象的类型(例如,它是猫还是狗的图片)。然而,在对象检测中,目标不仅是识别对象,而且还要通过在其周围绘制边界框来确定其在图像中的位置。

物体检测在日常生活中起着至关重要的作用,从安全方面到交通控制,甚至是自动驾驶汽车技术。这项任务的复杂性主要体现在两个方面:首先,如何检测图像中不同大小、形状和方向的物体;其次,如何快速准确地执行这种检测,即使在经常有噪声和复杂背景的现实世界中也是如此。

Faster-R-CNN概述

最流行的对象检测方法之一是Faster-R-CNN(基于区域的卷积神经网络)。Faster-R-CNN是对之前开发的R-CNN和Fast R-CNN方法的改进,在检测速度和准确性方面都有显着提高。虽然现在有许多更新和更快的方法与Faster R-CNN相比,但仍然值得理解和尝试这种方法作为对象检测技术的介绍。

Faster R-CNN将区域建议网络(RPN)集成到对象检测架构中。借助RPN,Faster R-CNN可以使用卷积网络提取的特征更有效地生成区域建议(可能包含对象的区域)。这允许在单个网络中进行端到端的对象检测,消除了以前依赖于外部过程(如选择性搜索)的方法中存在的瓶颈。为了更深入地了解Faster R-CNN的工作原理,您可以参考本文。

实现Faster R-CNN用于对象检测

在这里,我们不会从头开始训练模型,而是使用torchvision中提供的预训练模型。本文中使用的图像可以在Kaggle上找到。首先,我们需要为这个项目导入一些库:

import cv2
import torch
import requests
import numpy as np
import torchvision
from PIL import Image
from torch import no_grad
import matplotlib.pyplot as plt
from torchvision import transforms

我们导入的库与计算机视觉有关。例如,OpenCV(cv2)用于图像处理,NumPy(np)用于计算,PIL用于加载图像,Matplotlib(plt)用于可视化,Requests用于从Web下载图像。

创建辅助函数

# Function to get predictions with optional filtering by object and threshold
def get_predictions(pred, threshold=0.8, objects=None):"""Assign a string name to predicted classes and filter out predictions below a given threshold.Args:pred: List containing tuples with class labels, probabilities, and bounding boxes.threshold: Minimum probability required to consider a prediction valid.objects: Optional list of object names to filter predictions.Returns:List of tuples containing class name, probability, and bounding box for each valid prediction."""predicted_classes = [(COCO_INSTANCE_CATEGORY_NAMES[i], p, [(box[0], box[1]), (box[2], box[3])])for i, p, box in zip(list(pred[0]['labels'].numpy()),pred[0]['scores'].detach().numpy(),list(pred[0]['boxes'].detach().numpy()))]predicted_classes = [stuff for stuff in predicted_classes if stuff[1] > threshold]if objects and predicted_classes:predicted_classes = [(name, p, box) for name, p, box in predicted_classes if name in objects]return predicted_classes

get_predictions函数用于处理和过滤对象检测模型做出的预测。它接受三个参数:pred作为模型的原始输出,阈值用于过滤低置信度的预测,对象作为可选列表用于过滤基于特定类的预测。首先,该函数通过创建包含类名、检测概率和边界框坐标的元组列表,将原始预测转换为更可读的格式。然后过滤掉不符合概率阈值的预测,如果提供了objects参数,则只保留与指定对象名称匹配的预测。结果是一个经过过滤的元组列表,可供进一步分析或可视化。

# Function to draw bounding boxes around detected objects
def draw_box(predicted_classes, image, rect_th=1, text_size=1, text_th=1):"""Draw bounding boxes and labels around detected objects in an image.Args:predicted_classes: List of tuples containing class name, probability, and bounding box.image: Image tensor on which boxes and labels will be drawn.rect_th: Thickness of the rectangle.text_size: Font size of the label text.text_th: Thickness of the label text."""img = (np.clip(cv2.cvtColor(np.clip(image.numpy().transpose((1, 2, 0)), 0, 1), cv2.COLOR_RGB2BGR), 0, 1) * 255).astype(np.uint8).copy()for predicted_class in predicted_classes:label, probability, box = predicted_classt, l = box[0]r, b = box[1]t, l, r, b = [round(item) for item in [t, l, r, b]]cv2.rectangle(img, (t, l), (r, b), (0, 255, 0), rect_th)  # Draw Rectanglecv2.putText(img, f"{label}: {str(round(probability, 2))}", (t, l), cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 255, 0), thickness=text_th)plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))plt.show()

draw_box函数用于通过在检测到的对象周围绘制边界框和标签来可视化对象检测模型的结果。此函数还接受预测类列表、要处理的图像、边界框的厚度以及标签文本的大小和厚度的参数。该函数将图像从张量格式转换为NumPy数组。然后,它迭代预测类的列表,提取类信息,概率和边界框坐标,使用OpenCV在图像上绘制。之后,包含类名和概率的标签被添加到边界框的左上角。最后,图像被转换回RGB格式,并使用Matplotlib显示,使我们能够看到检测到的对象沿着及其边界框和标签。

# Function to clear GPU memory and delete images to free up RAM
def save_RAM(image_=False):"""Clear GPU memory and delete image variables to free up RAM.Args:image_: Boolean flag to indicate if the image object should be deleted."""torch.cuda.empty_cache()global image, img, preddel img, predif image_:image.close()del image

save_RAM函数用于管理和释放内存,特别是在GPU内存有限且需要仔细管理的环境中,例如在深度学习模型推理期间。该函数主要用于清除GPU内存,并可选地从RAM中删除图像变量。这是通过调用torch.cuda.empty_cache()来清除未使用的GPU内存缓存,使用del从内存中删除imgpred变量,如果image_parameter设置为True,则可以选择删除image变量来完成的。

准备模型

# Load Pre-Trained Faster RCNN Model
model_ = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model_.eval()  # Set the model to evaluation mode# Disable gradient computation for all parameters
for name, param in model_.named_parameters():param.requires_grad = False
print("Model loaded successfully.")

接下来,我们加载Faster R-CNN(基于区域的卷积神经网络)模型,该模型具有ResNet-50骨干和特征金字塔网络(FPN),该网络已针对对象检测任务进行了预训练。使用带有pretrained=True参数的 torchvision.models.detection.fasterrcnn_resnet50_fpn() 函数加载模型,这意味着模型的权重使用预训练版本初始化。然后使用model_.eval()将模型设置为评估模式,以确保层在推理过程中的行为是确定的。代码还通过为每个参数设置requires_grad= False冻结模型的权重,从而防止在推理期间更新。

# Function to get predictions from the model
def model(x):with no_grad():yhat = model_(x)return yhat

上面的模型函数用于从预训练的对象检测模型生成预测。

# COCO class names
COCO_INSTANCE_CATEGORY_NAMES = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat','traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag','tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove','skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon','bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake','chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop','mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book','clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

上面的列表COCO_CATEGORY_NAMES包含COCO(上下文中的公共对象)数据集中使用的类名。

模型检测

一旦模型准备就绪,我们将通过预测各种图像中的对象来测试Faster R-CNN模型。下面是我们将执行对象检测的一些图像示例:

  1. 在包含一个人的图像上进行检测。
  2. 在包含多个人的图像上进行检测。
  3. 检测一只猫和一只狗。
  4. 检测一辆汽车和一架飞机。
  5. 从互联网上下载的图像上的检测。
# 1. Predicting a Person
img_path = '/kaggle/input/sample-images-for-object-detection/ronaldo.jpg'
image = Image.open(img_path)
image = image.resize([int(0.5 * s) for s in image.size])
plt.imshow(image)
plt.show()transform = transforms.Compose([transforms.ToTensor()])
img = transform(image)
pred = model([img])
pred_class = get_predictions(pred, objects=["person"])
draw_box(pred_class, img)
save_RAM(image_=True)

在这里插入图片描述

img

# 2. Predicting People
img_path = '/kaggle/input/sample-images-for-object-detection/people.jpg'
image = Image.open(img_path)
image = image.resize([int(0.5 * s) for s in image.size])
plt.imshow(image)
plt.show()img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.9, objects=["person"])
draw_box(pred_thresh, img, rect_th=1, text_size=1, text_th=1)
save_RAM(image_=True)

img

在这里插入图片描述

# 3. Predicting Cat and Dog
img_path = '/kaggle/input/sample-images-for-object-detection/catanddog.jpg'
image = Image.open(img_path)
image = image.resize([int(0.5 * s) for s in image.size])
plt.imshow(image)
plt.show()img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.8)
draw_box(pred_thresh, img, rect_th=10, text_size=10, text_th=10)
save_RAM(image_=True)

在这里插入图片描述

在这里插入图片描述

# 4. Predicting a Car and a Plane
img_path = '/kaggle/input/sample-images-for-object-detection/carandplane.jpg'
image = Image.open(img_path)
image is resized to [int(0.5 * s) for s in image size]
plt.imshow(image)
plt.show()img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.9)
draw_box(pred_thresh, img)
save_RAM(image_=True)

在这里插入图片描述

在这里插入图片描述

# 5. Predicting on an Uploaded Image
url = 'https://www.plastform.ca/wp-content/themes/plastform/images/slider-image-2.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
plt.imshow(image)
plt.show()img = transform(image)
pred = model([img])
pred_thresh = get_predictions(pred, threshold=0.95)
draw_box(pred_thresh, img, rect_th=2, text_size=1.5, text_th=2)
save_RAM(image_=True)

img

在这里插入图片描述

结论

在本文中,我们探讨了如何使用Faster R-CNN模型进行对象检测。我们在这里所做的只是一小部分,还有很多进一步的探索,你可以自己尝试。这可能包括在不同的数据上尝试,调整阈值,或微调模型以检测更多的对象。希望这篇文章对大家有帮助,谢谢大家!


关注文末名片G-Z-H:【阿旭算法与机器学习】,发送【开源】可获取更多学习资源

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!


http://www.mrgr.cn/news/54900.html

相关文章:

  • qtcreator 仿制vscode黑色背景主题monokai
  • SpringBoot+MyBatis+MySQL项目基础搭建
  • 是时候和传统源代码保密方案说拜拜了
  • Android视频编解码 MediaCodec使用(2)
  • MongoDB常用语句
  • LCWLAN设备的实际使用案例
  • leetcode day1
  • resnetv1骨干
  • 轮班管理新策略,提高效率与降低员工抱怨
  • Vue3中使用自定义指令实现后台管理系统中对于按钮权限的控制
  • 五年三次冲刺IPO失败,企业业绩成长性恐不足,三年分红约1.5亿元
  • 对比迁移项目的改动
  • 值得收藏学习的人工智能学习框架!
  • 【重学 MySQL】七十三、灵活操控视图数据,轻松掌握视图删除技巧
  • DFF对比
  • SpringBoot运维
  • FHQtreap新模板
  • 诊断知识:NRC78(Response Pending)的回复时刻
  • @RequestMapping(“/api/users“)详细解释一下这行代码
  • 【云从】八、HTTPS流程与建站
  • Redux (八) 路由React-router、嵌套路由、路由传参、路由懒加载
  • 【Android】浅析OkHttp(1)
  • MySQL-29.事务-四大特性
  • web3学习-区块链基础知识
  • 图文深入介绍oracle资源管理(续)
  • 10.20学习