当前位置: 首页 > news >正文

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!


文章目录

  • 【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!
  • 前言
    • 1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块
    • `forward` 前向传播函数:
    • 2. 块级模块 Block
    • `forward` 前向传播函数:
    • 3. 融合模块 Fusion
    • `forward` 前向传播函数:


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

论文地址:https://ieeexplore.ieee.org/document/10247595

前言

在这里插入图片描述
该代码实现了一个多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块 Mutilscal_MHSA、一个块级模块 Block 以及一个融合模块 Fusion。此代码用于遥感图像语义分割模型 CMTFNet 中,主要通过多尺度卷积、MHSA 和融合机制增强图像特征提取。以下是逐行代码解析:

在这里插入图片描述

1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块

class Mutilscal_MHSA(nn.Module):def __init__(self, dim, num_heads, atten_drop = 0., proj_drop = 0., dilation = [3, 5, 7], fc_ratio=4, pool_ratio=16):super(Mutilscal_MHSA, self).__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.atten_drop = nn.Dropout(atten_drop)self.proj_drop = nn.Dropout(proj_drop)self.MSC = MutilScal(dim=dim, fc_ratio=fc_ratio, dilation=dilation, pool_ratio=pool_ratio)self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels=dim, out_channels=dim//fc_ratio, kernel_size=1),nn.ReLU6(),nn.Conv2d(in_channels=dim//fc_ratio, out_channels=dim, kernel_size=1),nn.Sigmoid())self.kv = Conv(dim, 2 * dim, 1)
  • __init__ 构造函数:
  • super(Mutilscal_MHSA, self).__init__(): 初始化父类 nn.Module
  • assert dim % num_heads == 0: 确保特征维度 dim 可被头数 num_heads 整除。
  • self.dim、self.num_heads: 初始化维度和多头数量。
  • head_dim = dim // num_heads: 每个头的维度大小。
  • self.scale = head_dim ** -0.5: 计算缩放因子,用于稳定点积结果。
  • self.atten_drop、self.proj_drop: 设置注意力和投影的 dropout 层。
  • self.MSC = MutilScal(...): 多尺度卷积模块,用于提取多尺度特征。
  • self.avgpool = nn.AdaptiveAvgPool2d(1): 全局平均池化,将特征图缩小至 (1,1)。
  • self.fc = nn.Sequential(...): 两层全连接网络,用于生成通道注意力权重。
  • self.kv = Conv(dim, 2 * dim, 1): 卷积层,将输入特征转换为键值对。

forward 前向传播函数:

    def forward(self, x):u = x.clone()B, C, H, W = x.shapekv = self.MSC(x)kv = self.kv(kv)B1, C1, H1, W1 = kv.shapeq = rearrange(x, 'b (h d) (hh) (ww) -> (b) h (hh ww) d', h=self.num_heads,d=C // self.num_heads, hh=H, ww=W)k, v = rearrange(kv, 'b (kv h d) (hh) (ww) -> kv (b) h (hh ww) d', h=self.num_heads,d=C // self.num_heads, hh=H1, ww=W1, kv=2)dots = (q @ k.transpose(-2, -1)) * self.scaleattn = dots.softmax(dim=-1)attn = self.atten_drop(attn)attn = attn @ vattn = rearrange(attn, '(b) h (hh ww) d -> b (h d) (hh) (ww)', h=self.num_heads,d=C // self.num_heads, hh=H, ww=W)c_attn = self.avgpool(x)c_attn = self.fc(c_attn)c_attn = c_attn * ureturn attn + c_attn
  • u = x.clone(): 复制输入 x,用于残差连接。
  • B, C, H, W = x.shape: 获取输入张量的维度信息。
  • kv = self.MSC(x): 将输入 x 传入多尺度卷积模块以提取键值特征。
  • kv = self.kv(kv): 使用 kv 卷积层进一步处理特征。
  • B1, C1, H1, W1 = kv.shape: 获取键值特征的维度信息。
  • q = rearrange(...): 重排 xquery 形式,适用于多头自注意力。
  • k, v = rearrange(...): 重排 kv 为键和值形式,适用于多头自注意力。
  • dots = (q @ k.transpose(-2, -1)) * self.scale: 计算缩放的查询键点积。
  • attn = dots.softmax(dim=-1): 计算点积的 softmax,生成注意力权重。
  • attn = self.atten_drop(attn): 应用注意力 dropout。
  • attn = attn @ v: 将注意力权重和值相乘,得到新的特征表示。
  • attn = rearrange(...): 重排 attn 为原始特征形状。
  • c_attn = self.avgpool(x): 对 x 进行全局平均池化。
  • c_attn = self.fc(c_attn): 通过全连接层生成通道注意力权重。
  • c_attn = c_attn * u: 将通道注意力权重与输入 u 相乘。
  • return attn + c_attn: 返回多头自注意力特征和通道注意力特征的和。

2. 块级模块 Block

class Block(nn.Module):def __init__(self, dim=512, num_heads=16,  mlp_ratio=4, pool_ratio=16, drop=0., dilation=[3, 5, 7],drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d):super().__init__()self.norm1 = norm_layer(dim)self.attn = Mutilscal_MHSA(dim, num_heads=num_heads, atten_drop=drop, proj_drop=drop, dilation=dilation,pool_ratio=pool_ratio, fc_ratio=mlp_ratio)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()mlp_hidden_dim = int(dim // mlp_ratio)self.mlp = E_FFN(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer,drop=drop)
  • super().__init__(): 初始化父类。
  • self.norm1 = norm_layer(dim): 归一化层。
  • self.attn = Mutilscal_MHSA(...): 多尺度多头自注意力模块。
  • self.drop_path = DropPath(...): 随机丢弃路径,用于防止过拟合。
  • mlp_hidden_dim = int(dim // mlp_ratio): 计算多层感知机的隐藏层维度。
  • self.mlp = E_FFN(...): 全连接前馈网络。

forward 前向传播函数:

    def forward(self, x):x = x + self.drop_path(self.norm1(self.attn(x)))x = x + self.drop_path(self.mlp(x))return x
  • x = x + self.drop_path(self.norm1(self.attn(x))): 对注意力模块进行归一化、添加残差连接。
  • x = x + self.drop_path(self.mlp(x)): 对全连接层输出添加残差连接。
  • return x: 返回块的输出。

3. 融合模块 Fusion

class Fusion(nn.Module):def __init__(self, dim, eps=1e-8):super(Fusion, self).__init__()self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)self.eps = epsself.post_conv = SeparableConvBNReLU(dim, dim, 5)
  • super(Fusion, self).__init__(): 初始化父类。
  • self.weights = nn.Parameter(...): 创建两个可训练的权重参数。
  • self.eps = eps: 用于避免除零的 epsilon。
  • self.post_conv = SeparableConvBNReLU(...): 可分离卷积层,融合后的卷积处理。

forward 前向传播函数:

    def forward(self, x, res):x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)weights = nn.ReLU6()(self.weights)fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)x = fuse_weights[0] * res + fuse_weights[1] * xx = self.post_conv(x)return x
  • x = F.interpolate(...): 上采样 x
  • weights = nn.ReLU6()(self.weights): 对权重参数应用 ReLU6 激活。
  • fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps): 归一化权重。
  • x = fuse_weights[0] * res + fuse_weights[1] * x: 加权融合 xres
  • x = self.post_conv(x): 通过可分离卷积进一步处理。
  • return x: 返回融合后的特征。

这些模块配合在一起实现了多尺度、多头自注意力机制以及融合处理,有效提升遥感图像语义分割性能。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz


http://www.mrgr.cn/news/64533.html

相关文章:

  • python中应该使用while 1吗?按位运算符可以代替逻辑运算符使用吗?
  • c++/qt连接阿里云视觉智能开发平台
  • MySQL 的 select * 会用到事务吗?
  • 写给粉丝们的信
  • 电赛入门之软件stm32keil+cubemx
  • 春秋云境CVE-2022-21661,sqlmap+json一把梭哈
  • 非线性数据结构之图
  • Python编程风格:保持逻辑完整性
  • Linux运行Java程序,并按天输出日志
  • 【Orange Pi 5 Linux 5.x 内核编程】-设备驱动中的sysfs
  • 【单片机C51两个按键K1、K2控制8个LED灯,初始值0xFE。摁下一次K1,LED灯左移;摁下一次K2,LED灯右移;】2022-1-5
  • 再学FreeRTOS---(中断管理)
  • 智能指针、移动语义、完美转发、lambda
  • 数字信号处理Python示例(3)生成三相正弦信号
  • 鸿蒙开发案例:分贝仪
  • Android中的Handle底层原理
  • 如何设置和使用低代码平台中的点击事件?
  • redis源码系列--(二)--eventlooop+set流程
  • 常用滤波算法(三)-算术平均滤波法
  • 【51蛋骗鸡单按键控制计数开始暂停复位】
  • 【ChatGPT】通过自定义参数让ChatGPT输出特定格式的文本
  • 同一局域网内A主机连接B主机的虚拟机中的服务
  • C++入门基础知识135—【关于C 库函数 - mktime()】
  • C++学习笔记----10、模块、头文件及各种主题(一)---- 模块(1)
  • 非线性数据结构之树
  • 【Vue3】一文全览基础语法-案例程序及配图版