【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!
【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!
【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!
文章目录
- 【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!
- 前言
- 1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块
- `forward` 前向传播函数:
- 2. 块级模块 Block
- `forward` 前向传播函数:
- 3. 融合模块 Fusion
- `forward` 前向传播函数:
欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz
论文地址:https://ieeexplore.ieee.org/document/10247595
前言
该代码实现了一个多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块 Mutilscal_MHSA
、一个块级模块 Block
以及一个融合模块 Fusion
。此代码用于遥感图像语义分割模型 CMTFNet 中,主要通过多尺度卷积、MHSA 和融合机制增强图像特征提取。以下是逐行代码解析:
1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块
class Mutilscal_MHSA(nn.Module):def __init__(self, dim, num_heads, atten_drop = 0., proj_drop = 0., dilation = [3, 5, 7], fc_ratio=4, pool_ratio=16):super(Mutilscal_MHSA, self).__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.atten_drop = nn.Dropout(atten_drop)self.proj_drop = nn.Dropout(proj_drop)self.MSC = MutilScal(dim=dim, fc_ratio=fc_ratio, dilation=dilation, pool_ratio=pool_ratio)self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels=dim, out_channels=dim//fc_ratio, kernel_size=1),nn.ReLU6(),nn.Conv2d(in_channels=dim//fc_ratio, out_channels=dim, kernel_size=1),nn.Sigmoid())self.kv = Conv(dim, 2 * dim, 1)
__init__
构造函数:super(Mutilscal_MHSA, self).__init__()
: 初始化父类nn.Module
。assert dim % num_heads == 0
: 确保特征维度 dim 可被头数 num_heads 整除。self.dim、self.num_heads
: 初始化维度和多头数量。head_dim = dim // num_heads
: 每个头的维度大小。self.scale = head_dim ** -0.5
: 计算缩放因子,用于稳定点积结果。self.atten_drop、self.proj_drop
: 设置注意力和投影的 dropout 层。self.MSC = MutilScal(...)
: 多尺度卷积模块,用于提取多尺度特征。self.avgpool = nn.AdaptiveAvgPool2d(1)
: 全局平均池化,将特征图缩小至 (1,1)。self.fc = nn.Sequential(...)
: 两层全连接网络,用于生成通道注意力权重。self.kv = Conv(dim, 2 * dim, 1)
: 卷积层,将输入特征转换为键值对。
forward
前向传播函数:
def forward(self, x):u = x.clone()B, C, H, W = x.shapekv = self.MSC(x)kv = self.kv(kv)B1, C1, H1, W1 = kv.shapeq = rearrange(x, 'b (h d) (hh) (ww) -> (b) h (hh ww) d', h=self.num_heads,d=C // self.num_heads, hh=H, ww=W)k, v = rearrange(kv, 'b (kv h d) (hh) (ww) -> kv (b) h (hh ww) d', h=self.num_heads,d=C // self.num_heads, hh=H1, ww=W1, kv=2)dots = (q @ k.transpose(-2, -1)) * self.scaleattn = dots.softmax(dim=-1)attn = self.atten_drop(attn)attn = attn @ vattn = rearrange(attn, '(b) h (hh ww) d -> b (h d) (hh) (ww)', h=self.num_heads,d=C // self.num_heads, hh=H, ww=W)c_attn = self.avgpool(x)c_attn = self.fc(c_attn)c_attn = c_attn * ureturn attn + c_attn
u = x.clone()
: 复制输入x
,用于残差连接。B, C, H, W = x.shape
: 获取输入张量的维度信息。kv = self.MSC(x)
: 将输入x
传入多尺度卷积模块以提取键值特征。kv = self.kv(kv)
: 使用kv
卷积层进一步处理特征。B1, C1, H1, W1 = kv.shape
: 获取键值特征的维度信息。q = rearrange(...)
: 重排x
为query
形式,适用于多头自注意力。k, v = rearrange(...)
: 重排kv
为键和值形式,适用于多头自注意力。dots = (q @ k.transpose(-2, -1)) * self.scale
: 计算缩放的查询键点积。attn = dots.softmax(dim=-1)
: 计算点积的 softmax,生成注意力权重。attn = self.atten_drop(attn)
: 应用注意力 dropout。attn = attn @ v
: 将注意力权重和值相乘,得到新的特征表示。attn = rearrange(...)
: 重排 attn 为原始特征形状。c_attn = self.avgpool(x)
: 对 x 进行全局平均池化。c_attn = self.fc(c_attn)
: 通过全连接层生成通道注意力权重。c_attn = c_attn * u
: 将通道注意力权重与输入 u 相乘。return attn + c_attn
: 返回多头自注意力特征和通道注意力特征的和。
2. 块级模块 Block
class Block(nn.Module):def __init__(self, dim=512, num_heads=16, mlp_ratio=4, pool_ratio=16, drop=0., dilation=[3, 5, 7],drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d):super().__init__()self.norm1 = norm_layer(dim)self.attn = Mutilscal_MHSA(dim, num_heads=num_heads, atten_drop=drop, proj_drop=drop, dilation=dilation,pool_ratio=pool_ratio, fc_ratio=mlp_ratio)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()mlp_hidden_dim = int(dim // mlp_ratio)self.mlp = E_FFN(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer,drop=drop)
super().__init__()
: 初始化父类。self.norm1 = norm_layer(dim)
: 归一化层。self.attn = Mutilscal_MHSA(...)
: 多尺度多头自注意力模块。self.drop_path = DropPath(...)
: 随机丢弃路径,用于防止过拟合。mlp_hidden_dim = int(dim // mlp_ratio)
: 计算多层感知机的隐藏层维度。self.mlp = E_FFN(...)
: 全连接前馈网络。
forward
前向传播函数:
def forward(self, x):x = x + self.drop_path(self.norm1(self.attn(x)))x = x + self.drop_path(self.mlp(x))return x
x = x + self.drop_path(self.norm1(self.attn(x)))
: 对注意力模块进行归一化、添加残差连接。x = x + self.drop_path(self.mlp(x))
: 对全连接层输出添加残差连接。return x
: 返回块的输出。
3. 融合模块 Fusion
class Fusion(nn.Module):def __init__(self, dim, eps=1e-8):super(Fusion, self).__init__()self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)self.eps = epsself.post_conv = SeparableConvBNReLU(dim, dim, 5)
super(Fusion, self).__init__()
: 初始化父类。self.weights = nn.Parameter(...)
: 创建两个可训练的权重参数。self.eps = eps
: 用于避免除零的 epsilon。self.post_conv = SeparableConvBNReLU(...)
: 可分离卷积层,融合后的卷积处理。
forward
前向传播函数:
def forward(self, x, res):x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)weights = nn.ReLU6()(self.weights)fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)x = fuse_weights[0] * res + fuse_weights[1] * xx = self.post_conv(x)return x
x = F.interpolate(...)
: 上采样x
。weights = nn.ReLU6()(self.weights)
: 对权重参数应用 ReLU6 激活。fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
: 归一化权重。x = fuse_weights[0] * res + fuse_weights[1] * x
: 加权融合x
和res
。x = self.post_conv(x)
: 通过可分离卷积进一步处理。return x
: 返回融合后的特征。
这些模块配合在一起实现了多尺度、多头自注意力机制以及融合处理,有效提升遥感图像语义分割性能。
欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz