YOLO即插即用模块--PPA
HCF-Net: Hierarchical Context Fusion Network for Infrared Small Object Detection
论文地址:2403.10778https://arxiv.org/pdf/2403.10778
问题: 红外小目标检测由于目标尺寸小、背景复杂且对比度低,容易丢失目标信息和受到背景干扰,导致检测难度大。P1
方法:
-
将红外小目标检测建模为语义分割问题, 提出HCF-Net,一种分层上下文融合网络,从零开始训练。
-
提出三种实用模块:
-
并行分块感知注意力模块(PPA): 使用多分支特征提取策略,捕获不同尺度和层次的特征信息,并通过注意力机制增强小目标的特征表示,确保信息在多次下采样过程中得以保留。
-
维度感知选择性集成模块(DASI): 增强U-Net的跳跃连接,关注自适应通道选择和精细融合高低维特征,增强小目标的显著性。
-
多稀释通道细化模块(MDCR): 通过多个深度可分离卷积层捕获不同感受野范围的空间特征,更精细地建模目标和背景之间的差异,增强定位小目标的能力。
-
采用深度监督策略, 使用多尺度损失函数,进一步解决小目标在下采样过程中丢失的问题。
-
结果:
-
在公开的红外单帧图像数据集SRIST上进行的实验表明,HCF-Net的性能优于其他传统和深度学习方法。
-
消融实验验证了各个模块的有效性。
-
可视化结果展示了HCF-Net在准确性和细节描述方面的优势。
总结: HCF-Net通过分层上下文融合和深度监督策略,有效地解决了红外小目标检测中的挑战,实现了更高的检测精度和鲁棒性。
代码:
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.sigmoid(self.conv2d(out))return out * xclass PPA(nn.Module):def __init__(self, in_features, filters) -> None:super().__init__()self.skip = conv_block(in_features=in_features,out_features=filters,kernel_size=(1, 1),padding=(0, 0),norm_type='bn',activation=False)self.c1 = conv_block(in_features=in_features,out_features=filters,kernel_size=(3, 3),padding=(1, 1),norm_type='bn',activation=True)self.c2 = conv_block(in_features=filters,out_features=filters,kernel_size=(3, 3),padding=(1, 1),norm_type='bn',activation=True)self.c3 = conv_block(in_features=filters,out_features=filters,kernel_size=(3, 3),padding=(1, 1),norm_type='bn',activation=True)self.sa = SpatialAttentionModule()self.cn = ECA(filters)self.lga2 = LocalGlobalAttention(filters, 2)self.lga4 = LocalGlobalAttention(filters, 4)self.bn1 = nn.BatchNorm2d(filters)self.drop = nn.Dropout2d(0.1)self.relu = nn.ReLU()self.gelu = nn.GELU()def forward(self, x):x_skip = self.skip(x)x_lga2 = self.lga2(x_skip)x_lga4 = self.lga4(x_skip)x1 = self.c1(x)x2 = self.c2(x1)x3 = self.c3(x2)x = x1 + x2 + x3 + x_skip + x_lga2 + x_lga4x = self.cn(x)x = self.sa(x)x = self.drop(x)x = self.bn1(x)x = self.relu(x)return xclass LocalGlobalAttention(nn.Module):def __init__(self, output_dim, patch_size):super().__init__()self.output_dim = output_dimself.patch_size = patch_sizeself.mlp1 = nn.Linear(patch_size * patch_size, output_dim // 2)self.norm = nn.LayerNorm(output_dim // 2)self.mlp2 = nn.Linear(output_dim // 2, output_dim)self.conv = nn.Conv2d(output_dim, output_dim, kernel_size=1)self.prompt = torch.nn.parameter.Parameter(torch.randn(output_dim, requires_grad=True))self.top_down_transform = torch.nn.parameter.Parameter(torch.eye(output_dim), requires_grad=True)def forward(self, x):x = x.permute(0, 2, 3, 1)B, H, W, C = x.shapeP = self.patch_size# Local branchlocal_patches = x.unfold(1, P, P).unfold(2, P, P) # (B, H/P, W/P, P, P, C)local_patches = local_patches.reshape(B, -1, P * P, C) # (B, H/P*W/P, P*P, C)local_patches = local_patches.mean(dim=-1) # (B, H/P*W/P, P*P)local_patches = self.mlp1(local_patches) # (B, H/P*W/P, input_dim // 2)local_patches = self.norm(local_patches) # (B, H/P*W/P, input_dim // 2)local_patches = self.mlp2(local_patches) # (B, H/P*W/P, output_dim)local_attention = F.softmax(local_patches, dim=-1) # (B, H/P*W/P, output_dim)local_out = local_patches * local_attention # (B, H/P*W/P, output_dim)cos_sim = F.normalize(local_out, dim=-1) @ F.normalize(self.prompt[None, ..., None], dim=1) # B, N, 1mask = cos_sim.clamp(0, 1)local_out = local_out * masklocal_out = local_out @ self.top_down_transform# Restore shapeslocal_out = local_out.reshape(B, H // P, W // P, self.output_dim) # (B, H/P, W/P, output_dim)local_out = local_out.permute(0, 3, 1, 2)local_out = F.interpolate(local_out, size=(H, W), mode='bilinear', align_corners=False)output = self.conv(local_out)return outputclass ECA(nn.Module):def __init__(self, in_channel, gamma=2, b=1):super(ECA, self).__init__()k = int(abs((math.log(in_channel, 2) + b) / gamma))kernel_size = k if k % 2 else k + 1padding = kernel_size // 2self.pool = nn.AdaptiveAvgPool2d(output_size=1)self.conv = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False),nn.Sigmoid())def forward(self, x):out = self.pool(x)out = out.view(x.size(0), 1, x.size(1))out = self.conv(out)out = out.view(x.size(0), x.size(1), 1, 1)return out * xclass conv_block(nn.Module):def __init__(self,in_features,out_features,kernel_size=(3, 3),stride=(1, 1),padding=(1, 1),dilation=(1, 1),norm_type='bn',activation=True,use_bias=True,groups=1):super().__init__()self.conv = nn.Conv2d(in_channels=in_features,out_channels=out_features,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,bias=use_bias,groups=groups)self.norm_type = norm_typeself.act = activationif self.norm_type == 'gn':self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)if self.norm_type == 'bn':self.norm = nn.BatchNorm2d(out_features)if self.act:# self.relu = nn.GELU()self.relu = nn.ReLU(inplace=False)def forward(self, x):x = self.conv(x)if self.norm_type is not None:x = self.norm(x)if self.act:x = self.relu(x)return xif __name__ == '__main__':block = PPA(in_features=4, filters=64) # 输入通道数,输出通道数input = torch.rand(1, 4, 128, 128) # 输入 B C H Woutput = block(input)print(input.size())print(output.size())
YOLO小伙伴可进交流群: