yolov8的标签匹配解析
本文内容简介:
1.yolov8标签匹配过程理论详解
2.yolov8标签匹配过程代码详解
3. 图文结合帮助快速理解(附带参考资料)
4.重要函数提要(TaskAlignedAssigner、bbox_decode、make_anchors、get_pos_mask、select_highest_overlaps)
本文参考链接:
yolov8源码:ultralytics/ultralytics: Ultralytics YOLO11 🚀
TOOD标签匹配论文:https://arxiv.org/pdf/2108.07755.pdf
yolov8结构图:
总览
1.如何解析box?
首先针对输入图片维度为[4,3,640,640],通过骨干网络可以得到三个feature,也就是1/8,1/16以及1/32的特征输出,对应的维度分别为[4,65,80,80]、[4,65,40,40]、[4,65,20,20]
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
然后通过上图的make_anchors方法去得到不同尺度特征图的锚点,feats就是输出的三个特征图,stride是对应的三个尺度,为[8,16,32]
def make_anchors(feats, strides, grid_cell_offset=0.5):"""Generate anchors from features."""anchor_points, stride_tensor = [], []assert feats is not Nonedtype, device = feats[0].dtype, feats[0].devicefor i, stride in enumerate(strides):_, _, h, w = feats[i].shape # 4,3,80,80/4,3,40,40/4,3,20,20sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift xsy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift ysy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))return torch.cat(anchor_points), torch.cat(stride_tensor)
以尺度为20的特征图为例,torch.arange生成一个包含20个从0到19的tensor
sy同理,再通过torch.meshgrid函数得到一个二维类似于20尺度特征图的二维tensor,对torch.meshgrid和torch.stack操作不懂得同学可以看看下边一幅图
最后得到的结果就是包含三个尺度的二维特征图表示anchor_points(维度为[8400,2]),以及包含stride信息的stride_tensor(维度为[8400,1])
2.如何通过锚点和模型预测解析得到box信息?
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
这里通过上个函数的得到的anchor_points维度为[8400,2])和预测信息pred_distri维度为[4,8400,64])来解析box,具体过程如下:
def bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape # batch, anchors, channelspred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):"""Transform distance(ltrb) to box(xywh or xyxy)."""lt, rb = distance.chunk(2, dim)x1y1 = anchor_points - ltx2y2 = anchor_points + rbif xywh:c_xy = (x1y1 + x2y2) / 2wh = x2y2 - x1y1return torch.cat((c_xy, wh), dim) # xywh bboxreturn torch.cat((x1y1, x2y2), dim) # xyxy bbox
这里注意一下:yolov8是预测上下左右到锚点的偏移,
(1)在bbox_decode中,训练时use_dfl为true,模型输出pred_dist维度为[4,8400,64],这里的64是针对锚点的上下左右偏移预测了16次,所以维度变换为[4,8400,4,16],因为这16个预测是通过求积分的方式来得到一个真正的预测,所以先对最后一维求softmax,再乘以对应的序号,相加就是积分结果,做完torch.matmul乘法后,pred_dist维度为[4,8400,4]
(2)在dist2bbox中,distance就是预测输出,维度为[4,8400,4],使用torch.chunk将其划分为左边上边的偏移lt(维度为[4,8400,2])和右边和下边的偏移rb(维度为[4,8400,2]),最后通过锚点与偏移值得到最后的box信息,也就是8400个预测box
3.如何进行标签匹配?
首先简单介绍一下标签匹配策略,在V5中仅仅考虑了box的iou来进行匹配,v8中将类别信息和位置信息同时考虑了进来。
这里评价一个gt和预测是否匹配定义了一个指标t,这个指标s指的是box,u指的是class,分别使用两个超参数来调节,然后选择topk作为返回。标签匹配策略定义如下:
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)