当前位置: 首页 > news >正文

yolov8的标签匹配解析

本文内容简介:

        1.yolov8标签匹配过程理论详解

        2.yolov8标签匹配过程代码详解

        3. 图文结合帮助快速理解(附带参考资料)

        4.重要函数提要(TaskAlignedAssigner、bbox_decode、make_anchors、get_pos_mask、select_highest_overlaps)

本文参考链接:

yolov8源码:ultralytics/ultralytics: Ultralytics YOLO11 🚀

TOOD标签匹配论文:https://arxiv.org/pdf/2108.07755.pdf

yolov8结构图:

总览

1.如何解析box?

首先针对输入图片维度为[4,3,640,640],通过骨干网络可以得到三个feature,也就是1/8,1/16以及1/32的特征输出,对应的维度分别为[4,65,80,80]、[4,65,40,40]、[4,65,20,20]

anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

然后通过上图的make_anchors方法去得到不同尺度特征图的锚点,feats就是输出的三个特征图,stride是对应的三个尺度,为[8,16,32]

def make_anchors(feats, strides, grid_cell_offset=0.5):"""Generate anchors from features."""anchor_points, stride_tensor = [], []assert feats is not Nonedtype, device = feats[0].dtype, feats[0].devicefor i, stride in enumerate(strides):_, _, h, w = feats[i].shape  # 4,3,80,80/4,3,40,40/4,3,20,20sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift xsy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift ysy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))return torch.cat(anchor_points), torch.cat(stride_tensor)

以尺度为20的特征图为例,torch.arange生成一个包含20个从0到19的tensor

 sy同理,再通过torch.meshgrid函数得到一个二维类似于20尺度特征图的二维tensor,对torch.meshgrid和torch.stack操作不懂得同学可以看看下边一幅图

                         

最后得到的结果就是包含三个尺度的二维特征图表示anchor_points(维度为[8400,2]),以及包含stride信息的stride_tensor(维度为[8400,1])

                                 

2.如何通过锚点和模型预测解析得到box信息?

pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)

这里通过上个函数的得到的anchor_points维度为[8400,2])和预测信息pred_distri维度为[4,8400,64])来解析box,具体过程如下:

def bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape  # batch, anchors, channelspred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):"""Transform distance(ltrb) to box(xywh or xyxy)."""lt, rb = distance.chunk(2, dim)x1y1 = anchor_points - ltx2y2 = anchor_points + rbif xywh:c_xy = (x1y1 + x2y2) / 2wh = x2y2 - x1y1return torch.cat((c_xy, wh), dim)  # xywh bboxreturn torch.cat((x1y1, x2y2), dim)  # xyxy bbox

这里注意一下:yolov8是预测上下左右到锚点的偏移

(1)在bbox_decode中,训练时use_dfl为true,模型输出pred_dist维度为[4,8400,64],这里的64是针对锚点的上下左右偏移预测了16次,所以维度变换为[4,8400,4,16],因为这16个预测是通过求积分的方式来得到一个真正的预测,所以先对最后一维求softmax,再乘以对应的序号,相加就是积分结果,做完torch.matmul乘法后,pred_dist维度为[4,8400,4]

(2)在dist2bbox中,distance就是预测输出,维度为[4,8400,4],使用torch.chunk将其划分为左边上边的偏移lt(维度为[4,8400,2])和右边和下边的偏移rb(维度为[4,8400,2]),最后通过锚点与偏移值得到最后的box信息,也就是8400个预测box

3.如何进行标签匹配?

首先简单介绍一下标签匹配策略,在V5中仅仅考虑了box的iou来进行匹配,v8中将类别信息和位置信息同时考虑了进来。

                                                  

这里评价一个gt和预测是否匹配定义了一个指标t,这个指标s指的是box,u指的是class,分别使用两个超参数来调节,然后选择topk作为返回。标签匹配策略定义如下:

self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)


http://www.mrgr.cn/news/81032.html

相关文章:

  • ChatGPT接口测试用例生成的流程
  • C vs C++: 一场编程语言的演变与对比
  • vue CSS 自定义宽高 翻页 剥离 效果
  • 使用 Docker 打包和运行 Vue 应用
  • Vue 浏览器录音、播放、上传服务端(PCM 8000采样率 16位)
  • postman读取文件执行
  • 39.在 Vue3 中使用 OpenLayers 导出 GeoJSON 文件及详解 GEOJSON 格式
  • 多个Echart遍历生成 / 词图云
  • [Java]合理封装第三方工具包(附视频)
  • 数据仓库工具箱—读书笔记02(Kimball维度建模技术概述03、维度表技术基础)
  • 海格通信嵌入式面试题及参考答案
  • draw.io 导出svg图片插入word后模糊(不清晰 )的解决办法
  • Restaurants WebAPI(四)——Identity
  • nodejs利用子进程child_process执行命令及child.stdout输出数据
  • LLMs之rStar:《Mutual Reasoning Makes Smaller LLMs Stronger Problem-Solvers》翻译与解读
  • 开源知识库open source knowledge base
  • 计算机毕业设计hadoop+spark知网文献论文推荐系统 知识图谱 知网爬虫 知网数据分析 知网大数据 知网可视化 预测系统 大数据毕业设计 机器学习
  • 5G -- 网络安全
  • 【测试】APP测试
  • Go by Example学习
  • LeetCode 刷题笔记
  • qemu源码解析【06】qemu启动初始化流程
  • Ubuntu 22.04,Rime / luna_pinyin.schema 输入法:外挂词库,自定义词库 (****) OK
  • Docker 入门:如何使用 Docker 容器化 AI 项目(一)
  • ubuntu 安装更新 ollama新版本
  • CAD xy坐标标注(跟随鼠标位置实时移动)——C#插件实现