当前位置: 首页 > news >正文

YOLO即插即用模块--PPA

HCF-Net: Hierarchical Context Fusion Network for Infrared Small Object Detection

论文地址:2403.10778icon-default.png?t=O83Ahttps://arxiv.org/pdf/2403.10778

问题: 红外小目标检测由于目标尺寸小、背景复杂且对比度低,容易丢失目标信息和受到背景干扰,导致检测难度大。P1

方法

  • 将红外小目标检测建模为语义分割问题, 提出HCF-Net,一种分层上下文融合网络,从零开始训练。

  • 提出三种实用模块

    • 并行分块感知注意力模块(PPA): 使用多分支特征提取策略,捕获不同尺度和层次的特征信息,并通过注意力机制增强小目标的特征表示,确保信息在多次下采样过程中得以保留。

    • 维度感知选择性集成模块(DASI): 增强U-Net的跳跃连接,关注自适应通道选择和精细融合高低维特征,增强小目标的显著性。

    • 多稀释通道细化模块(MDCR): 通过多个深度可分离卷积层捕获不同感受野范围的空间特征,更精细地建模目标和背景之间的差异,增强定位小目标的能力。

    • 采用深度监督策略, 使用多尺度损失函数,进一步解决小目标在下采样过程中丢失的问题。

结果

  • 在公开的红外单帧图像数据集SRIST上进行的实验表明,HCF-Net的性能优于其他传统和深度学习方法。

  • 消融实验验证了各个模块的有效性。

  • 可视化结果展示了HCF-Net在准确性和细节描述方面的优势。

总结: HCF-Net通过分层上下文融合和深度监督策略,有效地解决了红外小目标检测中的挑战,实现了更高的检测精度和鲁棒性。

 代码:

import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.sigmoid(self.conv2d(out))return out * xclass PPA(nn.Module):def __init__(self, in_features, filters) -> None:super().__init__()self.skip = conv_block(in_features=in_features,out_features=filters,kernel_size=(1, 1),padding=(0, 0),norm_type='bn',activation=False)self.c1 = conv_block(in_features=in_features,out_features=filters,kernel_size=(3, 3),padding=(1, 1),norm_type='bn',activation=True)self.c2 = conv_block(in_features=filters,out_features=filters,kernel_size=(3, 3),padding=(1, 1),norm_type='bn',activation=True)self.c3 = conv_block(in_features=filters,out_features=filters,kernel_size=(3, 3),padding=(1, 1),norm_type='bn',activation=True)self.sa = SpatialAttentionModule()self.cn = ECA(filters)self.lga2 = LocalGlobalAttention(filters, 2)self.lga4 = LocalGlobalAttention(filters, 4)self.bn1 = nn.BatchNorm2d(filters)self.drop = nn.Dropout2d(0.1)self.relu = nn.ReLU()self.gelu = nn.GELU()def forward(self, x):x_skip = self.skip(x)x_lga2 = self.lga2(x_skip)x_lga4 = self.lga4(x_skip)x1 = self.c1(x)x2 = self.c2(x1)x3 = self.c3(x2)x = x1 + x2 + x3 + x_skip + x_lga2 + x_lga4x = self.cn(x)x = self.sa(x)x = self.drop(x)x = self.bn1(x)x = self.relu(x)return xclass LocalGlobalAttention(nn.Module):def __init__(self, output_dim, patch_size):super().__init__()self.output_dim = output_dimself.patch_size = patch_sizeself.mlp1 = nn.Linear(patch_size * patch_size, output_dim // 2)self.norm = nn.LayerNorm(output_dim // 2)self.mlp2 = nn.Linear(output_dim // 2, output_dim)self.conv = nn.Conv2d(output_dim, output_dim, kernel_size=1)self.prompt = torch.nn.parameter.Parameter(torch.randn(output_dim, requires_grad=True))self.top_down_transform = torch.nn.parameter.Parameter(torch.eye(output_dim), requires_grad=True)def forward(self, x):x = x.permute(0, 2, 3, 1)B, H, W, C = x.shapeP = self.patch_size# Local branchlocal_patches = x.unfold(1, P, P).unfold(2, P, P)  # (B, H/P, W/P, P, P, C)local_patches = local_patches.reshape(B, -1, P * P, C)  # (B, H/P*W/P, P*P, C)local_patches = local_patches.mean(dim=-1)  # (B, H/P*W/P, P*P)local_patches = self.mlp1(local_patches)  # (B, H/P*W/P, input_dim // 2)local_patches = self.norm(local_patches)  # (B, H/P*W/P, input_dim // 2)local_patches = self.mlp2(local_patches)  # (B, H/P*W/P, output_dim)local_attention = F.softmax(local_patches, dim=-1)  # (B, H/P*W/P, output_dim)local_out = local_patches * local_attention  # (B, H/P*W/P, output_dim)cos_sim = F.normalize(local_out, dim=-1) @ F.normalize(self.prompt[None, ..., None], dim=1)  # B, N, 1mask = cos_sim.clamp(0, 1)local_out = local_out * masklocal_out = local_out @ self.top_down_transform# Restore shapeslocal_out = local_out.reshape(B, H // P, W // P, self.output_dim)  # (B, H/P, W/P, output_dim)local_out = local_out.permute(0, 3, 1, 2)local_out = F.interpolate(local_out, size=(H, W), mode='bilinear', align_corners=False)output = self.conv(local_out)return outputclass ECA(nn.Module):def __init__(self, in_channel, gamma=2, b=1):super(ECA, self).__init__()k = int(abs((math.log(in_channel, 2) + b) / gamma))kernel_size = k if k % 2 else k + 1padding = kernel_size // 2self.pool = nn.AdaptiveAvgPool2d(output_size=1)self.conv = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False),nn.Sigmoid())def forward(self, x):out = self.pool(x)out = out.view(x.size(0), 1, x.size(1))out = self.conv(out)out = out.view(x.size(0), x.size(1), 1, 1)return out * xclass conv_block(nn.Module):def __init__(self,in_features,out_features,kernel_size=(3, 3),stride=(1, 1),padding=(1, 1),dilation=(1, 1),norm_type='bn',activation=True,use_bias=True,groups=1):super().__init__()self.conv = nn.Conv2d(in_channels=in_features,out_channels=out_features,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,bias=use_bias,groups=groups)self.norm_type = norm_typeself.act = activationif self.norm_type == 'gn':self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)if self.norm_type == 'bn':self.norm = nn.BatchNorm2d(out_features)if self.act:# self.relu = nn.GELU()self.relu = nn.ReLU(inplace=False)def forward(self, x):x = self.conv(x)if self.norm_type is not None:x = self.norm(x)if self.act:x = self.relu(x)return xif __name__ == '__main__':block = PPA(in_features=4, filters=64)  # 输入通道数,输出通道数input = torch.rand(1, 4, 128, 128)  # 输入 B C H Woutput = block(input)print(input.size())print(output.size())

YOLO小伙伴可进交流群:


http://www.mrgr.cn/news/61917.html

相关文章:

  • 【C++】string 类深度解析:探秘字符串操作的核心
  • Python世界:自动化办公Word之批量替换文本生成副本
  • 【后台管理系统】
  • C++研发笔记9——C语言程序设计初阶学习笔记7
  • Python——发送HTTP请求
  • 【HTML5移动端】手势解锁
  • 【LeetCode】两数之和、大数相加
  • Brainpy的jit编译环境基础
  • Linux_02 Linux常用软件——vi、vim
  • 【算法】(Python)回溯算法
  • Spring Cloud Ribbon:负载均衡的服务调用
  • Java 泛型和反射(15/30)
  • 软件工程经验详细总结
  • 进程线程、同步异步、并发并行
  • 小游戏发展迅速,游戏平台如何从技术方向加速业务转化?
  • 如何进行Java的时间序列分析与算法优化,应该从何入手?
  • 大模型:索引构建、预检索与检索阶段、检索后与生成阶段
  • 自动批量生成图片代码
  • Apache Hive 通过Docker快速入门
  • 深入解析Sysmon日志:增强网络安全与威胁应对的关键一环
  • Leetcode—3216. 交换后字典序最小的字符串【简单】
  • 先验概率、似然概率、后验概率
  • Qt5 读写共享内存,已验证,支持汉字的正确写入和读取
  • Java 中 InputStream 的使用:try-with-resources 与传统方式的比较
  • 解密自闭症全托寄宿肇庆:专业照顾与培养一站式服务
  • node学习记录-os