当前位置: 首页 > news >正文

视觉 注意力机制——通道注意力、空间注意力、自注意力、交叉注意力

计算机视觉——探索视觉注意力机制:通道、空间、自注意力及交叉注意力

在计算机视觉领域,注意力机制已经成为了提升模型性能的关键技术之一。通过模拟人类视觉注意力,模型能够更加高效地处理图像数据,关注重要的特征并忽略无关信息。本文将详细介绍几种主要的视觉注意力机制,包括通道注意力、空间注意力、自注意力和交叉注意力,并提供相应的代码示例。

通道注意力机制(Channel Attention)

通道注意力机制,如SENet中的Squeeze-and-Excitation (SE) 模块,通过强调重要的通道特征并抑制不重要的通道特征来增强模型的特征表达能力。

代码示例

import torch
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)# 假设输入特征图
input_feature = torch.randn(1, 64, 56, 56)
se_layer = SELayer(channel=64)
output_feature = se_layer(input_feature)

空间注意力机制(Spatial Attention)

空间注意力机制关注于图像中的重要空间位置,通常通过学习图像中每个位置的重要性权重来实现。

代码示例

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size % 2 == 1, "Kernel size must be odd."self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)attention = self.sigmoid(x)return x * attention# 假设输入特征图
input_feature = torch.randn(1, 64, 56, 56)
spatial_attention = SpatialAttention()
output_feature = spatial_attention(input_feature)

自注意力机制(Self-Attention)

自注意力机制,如Transformer模型中的机制,允许模型在处理序列数据时考虑序列内部的长距离依赖关系。

代码示例

class SelfAttention(nn.Module):def __init__(self, in_dim):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1)self.value_conv = nn.Conv2d(in_dim, in_dim, 1)self.softmax = nn.Softmax(dim=-1)def forward(self, x):batch_size, C, width, height = x.size()query = self.query_conv(x).view(batch_size, -1, width*height).permute(0, 2, 1)key = self.key_conv(x).view(batch_size, -1, width*height)energy = torch.bmm(query, key)attention = self.softmax(energy)value = self.value_conv(x).view(batch_size, -1, width*height)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, C, width, height)return out# 假设输入特征图
input_feature = torch.randn(1, 64, 56, 56)
self_attention = SelfAttention(in_dim=64)
output_feature = self_attention(input_feature)

交叉注意力机制(Cross-Attention)

交叉注意力机制通常用于序列到序列的任务中,如机器翻译,它允许模型在生成输出序列时考虑输入序列的信息。

代码示例

class CrossAttention(nn.Module):def __init__(self, query_dim, key_dim):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(query_dim, query_dim // 8, 1)self.key_conv = nn.Conv2d(key_dim, key_dim // 8, 1)self.value_conv = nn.Conv2d(key_dim, query_dim, 1)self.softmax = nn.Softmax(dim=-1)def forward(self, query, key, value):batch_size, C_q, width, height = query.size()batch_size_k, C_k, width_k, height_k = key.size()query = self.query_conv(query).view(batch_size, -1, width*height).permute(0, 2, 1)key = self.key_conv(key).view(batch_size_k, -1, width_k*height_k)energy = torch.bmm(query, key)attention = self.softmax(energy)value = self.value_conv(value).view(batch_size_k, -1, width_k*height_k)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, -1, width, height)return out# 假设输入特征图
query = torch.randn(1, 64, 56, 56)
key = torch.randn(1, 64, 56, 56)
value = torch.randn(1, 64, 56, 56)
cross_attention = CrossAttention(query_dim=64, key_dim=64)
output_feature = cross_attention(query, key, value)

结论

注意力机制在计算机视觉中的应用极大地提高了模型对图像特征的处理能力。通过通道注意力、空间注意力、自注意力和交叉注意力等不同的机制,模型能够更加关注于图像中的关键信息,从而提升识别、分类和分割等任务的性能。随着研究的深入,注意力机制将继续在计算机视觉领域发挥重要作用,并推动相关技术的发展。

✅作者简介:热爱科研的人工智能开发者,修心和技术同步精进

❤欢迎关注我的知乎:对error视而不见

代码获取、问题探讨及文章转载可私信。

☁ 愿你的生命中有够多的云翳,来造就一个美丽的黄昏。

🍎获取更多人工智能资料可点击链接进群领取,谢谢支持!👇

点击领取更多详细资料


http://www.mrgr.cn/news/30661.html

相关文章:

  • 工作和学习遇到的技术问题
  • GIT:如何查找已删除的文件的历史记录
  • 【ubuntu】Geogebra
  • 【时间之外】IT人求职和创业应知【31】
  • 设计模式之责任链模式(Chain Of Responsibility)
  • Elasticsearch中什么是倒排索引?
  • C# 访问Access存取图片
  • 软件安全最佳实践:首先关注的地方
  • 【macOS】【Python】安装Python到虚拟环境的命令
  • 版本控制之Git
  • 电力施工作业安全行为检测图像数据集
  • 算法打卡 Day41(动态规划)-理论基础 + 斐波那契数 + 爬楼梯 + 使用最小花费爬楼梯
  • MATLAB矩阵下标引用
  • 图数据库之HugeGraph
  • 深度学习笔记(8)预训练模型
  • Linux文件系统
  • 8.1差分边缘检测
  • 介绍几个AI生成视频的工具
  • 新发布的OpenAI o1生成式AI模型在强化学习方面迈出了重要的一步
  • iptables 基础示例
  • 电脑维修的基本原则
  • AI助力智慧农田作物病虫害监测,基于YOLOv8全系列【n/s/m/l/x】参数模型开发构建花田作物种植场景下棉花作物常见病虫害检测识别系统
  • 【ShuQiHere】 从逻辑门到组合电路:构建数字系统的核心
  • Python习题 192:编写一个猜单词游戏
  • 算法打卡 Day34(贪心算法)-分发饼干 + 摆动序列 + 最大子序和
  • 链式栈讲解