SCANet代码解读
论文链接:[2406.07189] RGB-Sonar Tracking Benchmark and Spatial Cross-Attention Transformer Tracker
代码链接:
GitHub - LiYunfengLYF/SCANet
通过代码需要了解到的事情:
1. 两种模态的数据集是怎么传入模型的?
2. 模型的输入输出是什么,模型结构代码是如何搭建?
train.py
路径:tracking/train.py
该文件主要用于是参数设置。
--script
:训练脚本的名称,用于指定要运行的训练逻辑,比如模型结构或任务。--config
:训练配置文件名,默认是baseline
(通常是 YAML 文件),用于指定训练的超参数、模型配置等。--save_dir
:指定保存训练结果(如模型权重、日志、TensorBoard 日志)的目录路径。--mode
:训练模式,支持以下几种:
single
:单卡训练(单 GPU)。multiple
:多卡训练(多 GPU,单节点)。multi_node
:多节点训练(分布式训练,跨多台机器)。--nproc_per_node
:每个节点上使用的 GPU 数量(仅在multiple
或multi_node
模式下需要)。--use_lmdb
:是否使用 LMDB 格式的数据集(0 或 1)。--use_wandb
:是否使用 Weights & Biases(WandB) 工具进行训练监控(0 或 1)。--env_num
:指定环境编号,支持多个环境的开发(0, 1, 2 等)。
通过选择训练模式(单卡)后,os.system()
会创建一个新的子进程来执行 train_cmd
。在子进程中,系统调用相应的 Python 解释器来执行指定的脚本( lib/train/run_training.py
)。
run_training.py
路径:lib/train/run_training.py
这个文件包含训练的入口程序。
功能概述:
- 接收用户输入的命令行参数(例如训练脚本名、配置文件名、随机种子等)。
- 初始化训练环境,包括设置随机种子、CUDNN 加速、分布式通信等。
- 调用具体的训练逻辑,基于用户指定的脚本和配置文件启动训练任务。
- 支持扩展功能:
- 知识蒸馏(使用教师模型指导学生模型训练)。
- 多环境开发(
env_num
参数)。- 数据存储格式(如 LMDB)。
- 集成工具(如 WandB,用于训练过程的可视化)。
初始化随机种子 (init_seeds
)
作用:设置随机种子以确保训练的可复现性。
细节:
- 使用 Python 的
random
和 NumPy 的np.random.seed
设置全局随机数种子。- 使用 PyTorch 的
torch.manual_seed
和torch.cuda.manual_seed
设置 GPU 和 CPU 的随机数种子。- 如果使用 CUDNN 后端(GPU 加速),设置
torch.backends.cudnn.benchmark
为True
,以加速特定的卷积操作。
初始化训练环境
ws_settings.Settings(env_num)
:加载训练环境的配置,例如设备编号、超参数等。路径设置:
settings.project_path
:项目路径,通常表示当前训练任务的文件夹。
settings.cfg_file
:训练配置文件的完整路径(experiments/<script_name>/<config_name>.yaml
)。settings.save_dir
:将用户指定的保存路径转为绝对路径。
知识蒸馏支持
如果启用了知识蒸馏(
distill=1
),会加载教师模型的脚本和配置文件。细节:
- 教师模型和学生模型的训练逻辑分别存放在不同的模块中:
train_script_distill
:蒸馏模式。train_script
:普通训练模式(使用)。- 动态加载模块:使用
importlib.import_module
动态导入模块,通过反射调用相应的训练逻辑。
执行训练
训练选择普通模式(不使用知识蒸馏)。
getattr(expr_module, 'run')
:从动态加载的模块中获取run
函数。expr_func(settings)
:将训练的配置(settings
)传入,并启动训练。进入lib.train.train_script.run()执行训练。
train_script.py
路径:lib/train/train_script.py
加载配置文件
初始化随机种子
日志目录初始化
数据加载器构建
build_dataloaders
:加载训练和验证数据集,返回两个数据加载器loader_train
和loader_val
。
- 训练加载器(
loader_train
):用于训练数据的迭代读取。- 验证加载器(
loader_val
):用于在训练期间验证模型性能。
模型构建
根据配置文件指定的模型名称,动态从
TRACKER_REGISTRY
注册表中找到模型类。
TRACKER_REGISTRY
是一个注册表,类似于 Python 的字典,存储了模型名称和模型类之间的映射关系。在代码中,
TRACKER_REGISTRY.get(cfg.MODEL.NETWORK)
用于通过模型名称(cfg.MODEL.NETWORK
)从注册表中查找对应的模型类。在代码的其他部分,模型类会通过装饰器注册到
TRACKER_REGISTRY,如下:
分布式训练包裹
损失函数和 Actor
损失函数:
- 根据配置文件中指定的 IoU 类型(例如
giou
或wiou
),选择相应的 IoU 损失函数。- 其他损失函数包括:
l1_loss
:用于回归目标框。focal_loss
:用于处理类别不平衡。BCEWithLogitsLoss
:用于分类精度。loss_weight
:定义每个损失函数的权重,用于加权损失求和。Actor:
- 从
ACTOR_Registry
中获取 Actor 实例(类似于控制训练逻辑的管理者)。- Actor 将模型(
net
)、损失函数(objective
)和训练设置(settings
、cfg
)结合在一起。
优化器和学习率调度器
Trainer 实例化
LTRTrainer
:训练流程管理器,负责实际的训练逻辑。参数:
actor
:负责模型前向传播、损失计算和后向传播的组件。- 数据加载器
[loader_train, loader_val]
:训练和验证数据源。optimizer
和lr_scheduler
:优化器和学习率调度器。use_amp
:是否使用混合
精度训练(AMP)。rgb_mode
:是否启用 RGB 模式(配置文件中定义)
启动训练过程
trainer.train
:启动训练流程。
- 参数:
cfg.TRAIN.LEARN.EPOCH
:训练的总 epoch 数。load_latest=True
:是否加载最近保存的检查点(如果存在)。fail_safe=True
:是否启用故障恢复机制(例如断点续训)。
base_function.py
路径:lib/train/base_functions.py
build_dataloaders
在上述的train_script.py文件中,数据集传入build_dataloaders进行数据加载
build_dataloaders函数在base_functions.py文件中
names2datasets
names2datasets
函数根据配置中的数据集名称列表构建相应的数据集实例。对于每个数据集,代码会检查是否使用 LMDB 格式,并创建相应的数据集对象。
只有LASOT 和 GOT10K 两个数据集使用了rgbs_mode,其中使用了
cv2.saliency.StaticSaliencySpectralResidual_create()进行了视觉显著性的图像转换,猜测作者是将其模拟声呐图做训练。
因为作者公开了声光融合数据集,但在代码中,其公开的RGBS50数据集只作为test数据集在使用。
trackerModel.py——重点
路径:lib/models/scanet/trackerModel.py
在train_script.py中,通过TRACKER_REGISTRY.get()
动态查找模型类。lib/models/scanet/trackerModel.py中的SCANet_network类被装饰器注册到TRACKER_REGISTRY
中。
模型基于
OSTrack
构建,目标是实现RGB-T(可见光 & 声呐)目标跟踪任务。
SCANet_network类
构造函数
self.backbone:构建骨干网络,从输入的输入数据中提取特征,参数cfg.MODEL.BACKBONE.LOAD_MODE指定了加载模式
- 加载预训练权重
- 调用骨干网络的微调函数
- 构建RGB和Sonar的预测头模块,头部模块的类型和参数由配置文件指定(
cfg.MODEL.HEAD
和cfg.MODEL.RGBS_HEAD)
forward前向传播
- 输入:模板帧(template frame)和搜索帧(search frame)。模板帧用于定义要跟踪的目标,而搜索帧用于定位目标。
- 骨干网络特征提取
- 调用
forward_head
方法,将骨干网络的输出传递给头部模块,生成 RGB 和声呐的预测结果。
forward_head预测头模块函数
- 输入:骨干网络的输出 cat_feature,以及可选参数
gt_score_map
- 特征切片
- 特征变换
RGB 和声呐头部的前向传播
- 将 RGB 和声呐特征分别传递给对应的头部模块
box_head
和sonar_head
。- 生成的输出包含预测的边界框(
pred_boxes
)和其他特征图(如score_map
、size_map
)。- 上述的具体参数是来自配置文件cfg的,具体是通过找到
settings.cfg_file
中指定的 YAML 文件路径,查看MODEL.HEAD
和MODEL.RGBS_HEAD
的具体配置
ltr_trainer.py
路径:lib/train/trainers/ltr_trainer.py
在train_script.py文件中,数据集、模型等相关信息传入LTRTrainer类开始训练:
LTRTrainer类是继承BaseTrainer类的
在 LTRTrainer 类中,实际是将数据传入actor中执行训练过程中所需的前向传播、损失计算和数据处理等功能。
![]()
rgbser_actor.py
路径:lib/train/actors/rgbser_actor.py
在上述ltr_trainer.py中,数据实际传入acotr类中实现训练过程
从数据加载器中读取到的输入传入actor,然后传入data2temp_search 就可以分出两个模态的template 和 search 的图片了,即template_img_1, template_img_2, search_img_1, search_img_2。
data
: 一个字典,包含了图像数据,格式如下:
data['template_images']
: 包含模板图像的张量,形状为(batch_size, N_t, 3, H, W)
,其中N_t
是模板图像的数量(在这个函数中假设为 2)。data['search_images']
: 包含搜索图像的张量,形状为(batch_size, N_s, 3, H, W)
,其中N_s
是搜索图像的数量(在这个函数中假设为 2)。
cycle_dataset
数据加载和模型训练的主要功能在 cycle_dataset方法中实现。
输入是loader(数据加载器),负责提供训练或验证数据。
baseline.yaml
路径:experiments/scanet/baseline.yaml
这份 YAML 配置文件包含了模型架构、数据集、训练、验证和测试的全流程参数:
- 数据配置(DATA):定义了输入数据的模态、归一化参数,以及训练/验证数据集。
- 模型配置(MODEL):包括骨干网络、头部模块等配置,支持多模态(RGB + 声呐)。
- 训练配置(TRAIN):详细定义了优化器、损失函数、混合精度等训练参数。
- 测试配置(TEST):与训练类似,定义了搜索和模板区域的参数。
数据配置DATA
模态设置
RGB_MODE
和RGBS_MODE
:指示模型是否同时处理 RGB 和 RGBS(声呐)数据模态。RGBS_ROTATE
:是否对 RGBS 数据增加旋转增强。MAX_SAMPLE_INTERVAL
:最大采样间隔,用于决定从视频中采样帧的时间跨度。
数据归一化参数
MEAN
和STD
:输入图像的归一化参数,通常是 RGB 图像的标准均值和标准差,用于将像素值归一化到标准分布。
搜索区域 (SEARCH) 和模板区域 (TEMPLATE) 配置
- 搜索区域(SEARCH):
- 定义了模型在当前帧中搜索目标的参数。
- 例如,
FACTOR=4.0
表示搜索区域的宽度/高度是目标框的 4 倍。SIZE=256
表示将搜索区域缩放到 256x256 的固定尺寸。- 模板区域(TEMPLATE):
- 定义了模板帧(通常是目标的初始位置)的参数。
SIZE=128
表示将模板区域缩放到 128x128。- 通常模板帧会被固定且不抖动(
CENTER_JITTER=0
和SCALE_JITTER=0
)。
数据集
模型配置 MODEL
网络整体信息
NETWORK
:模型的主网络名称,这里是SCANet_network
。RETURN_STAGES
:骨干网络中需要返回的特征层索引(例如第 2、5、8 和 11 层)。
骨干网络 (BACKBONE)
![]()
TYPE
:骨干网络类型,这里采用vit_base_patch16_224_midlayer
(Vision Transformer)。PARAMS
:
ffm: SCAM
:特征融合方式为 SCAM。rgbs_loc: [3, 6, 9]
:指定在第 3、6、9 层使用 RGB-S 融合。
头部模块
HEAD
(RGB 头) 和RGBS_HEAD
(声呐头):
- 类型为
center_head
,参数为:
inplanes
:输入通道数。channel
:中间特征通道数。feat_sz
:特征图大小。stride
:卷积步幅。
训练配置 TRAIN
- 优化目标:
- 使用
giou
(广义 IoU)作为 IoU 损失。GIOU_WEIGHT
和L1_WEIGHT
:分别为 GIoU 和 L1 损失的权重。- 优化器参数:
LR
:初始学习率为0.00001
。WEIGHT_DECAY
:权重衰减值为0.0001
。
优化器和学习率调度器
- 优化器:使用
ADAMW
优化器。- 学习率调度器:
- 使用
step
调度器,每隔 30 个 epoch 降低学习率。
AMP混合精度训练
- 是否启用 AMP(混合精度训练):
USED: False
。- 梯度裁剪的最大范数:
GRAD_CLIP_NORM: 0.1
。
测试配置 TEST
- 测试的搜索区域和模板区域与训练保持一致,分别为
256x256
和128x128
。
vit_rgbs.py
路径:lib/models/scanet/vit_rgbs.py
在trackerModel.py中
在YAML文件中,backbone的配置如下:
backbone的type为vit_base_patch16_224_midlayer,在vit_rgbs.py文件中定义
vit_rgbs.py文件代码结构概览
Attention 和 Block 模块:
- 实现了 Transformer 中的注意力机制(
Attention
)及基本的 Transformer 块(Block
)。VisionTransformer_midlayer:
- 核心 Transformer 模型,支持 ViT 结构,包含可见光(RGB)和 T 模态的融合能力。
- 支持冻结部分 Transformer 层(
freeze_layer
)和插入中间融合层(rgbs_layers
)。辅助函数:
- 包括权重初始化、加载预训练权重、调整位置嵌入(
pos_embed
)等。模型注册:
- 定义了
vit_base
,vit_small
, 和vit_tiny
三种 ViT 模型,并通过注册表(MODEL_REGISTRY
)动态加载。
vit_base_patch16_224_midlayer --------> _create_vision_transformer
_create_vision_transformer ---------> VisionTransformer_midlayer
VisionTransformer_midlayer——重点
这部分是这个py文件的核心,它实现了标准的 Vision Transformer,同时支持红外(T)与 RGB 的融合。是BaseBackbone的子类。
模型结构
整体模型架构可以拆解为以下几个主要部分:
输入处理模块:
- 图像被切分为 Patch,并通过
PatchEmbed
层进行线性嵌入,生成低维特征。- 位置编码(
pos_embed
)被添加到特征中,使得 Transformer 感知输入的空间位置。Transformer 主干网络:
- 若干个堆叠的 Transformer Block(这里的block就是ViT中的基础block),每个 Block 包含:
- 多头注意力机制(
Attention
)。- 前馈网络(MLP)。
- 这些 Block 用于提取输入的高层次特征。
融合模块(
rgbs_layers
):
- 在指定的 Transformer Block 层(通过
rgbs_loc
参数定义)插入特征融合模块(ffm-SCAM
)。- 融合模块用于交互和融合 RGB 和 T 模态的特征。
输出层:
- Transformer 的输出经过归一化(
norm
)后,分别恢复为 RGB 和 T 模态的各自特征。- 最终将两种模态的特征拼接,作为模型的输出。
a. Patch 嵌入
- 将输入图像切分成固定大小的 Patch,并使用线性投影将每个 Patch 映射到
embed_dim
维度。
b. 位置嵌入
- 提供位置编码,使得 Transformer 能够感知输入的空间位置信息。
c. Transformer Block块
- 核心组件是多个堆叠的
Block
模块,数量由depth
参数指定。
d. 融合层——SCAM的插入
- 插入融合层(
ffm——SCAM
)以实现 RGB 和 T 特征的融合。- 通过
rgbs_loc
指定融合层插入的位置。
前向传播
a. 输入处理
可见光(
x[0], z[0]
)和声纳(x[1], z[1]
)数据分别通过独立的Patch Embedding层。详看base_backbone.py节解读。
b. 特征融合
- 逐层通过 Transformer 块。
- 在指定ViT层(如第3、6、9层)插入SCAM模块,实现特征对齐与融合。
c. 模态恢复
- 从融合后的特征中恢复 RGB 和 T 的独立特征。
- 最后将两种模态的特征拼接在一起。
Block
这里的block如图所示,就是ViT的基础block。
base_backbone.py
路径:lib/models/scanet/base_backbone.py
上述表示模型结构的类 VisionTransformer_midlayer 继承了 BaseBackbone
BaseBackbone是基本ViT骨干网的联合特征提取和关系建模。他的输入可以理解为单模态的模板(template)----z,和搜索区域(search region)-----x;
VisionTransformer_midlayer继承后,输入的其实也是单序列,但可能在输入之前把两种模态拼接了,所以在传入函数后,又分别提取了x_v,x_s以及z_v,z_s两种模态的template和search:
原文:
scam.py
路径:lib/models/scanet/mid_rgbs_layer/scam.py
在vit_rgbs.py文件中,搭建了backbone结构,其中融合层ffm 在YAML文件中被定义为SCAM,其在scam.py文件中被定义。
在vit_rgbs.py文件中,对整个网络进行了搭建。
网络由block搭建,层数为depth,然后在指定的rgbs_loc的位置添加ffm——SCAM融合块
SCAM模块的作用是接收两个输入特征(
x1
和x2
),通过交叉注意力机制进行交互,然后在特征中加入残差连接和前馈网络(FFN)。这里输入的两个特征分别是两个模态经过ViT Block后的两个输出。
SCAM模块的输出任然是两个模态的特征。
原论文对SCAM的解释:
SCAM旨在实现空间未对齐的RGB特征与声呐特征的有效跨模态交互。
SCAM由一个空间交叉注意力层和两个独立的全局整合模块(GIM)组成。我们的SCAM模块的操作如下所述:
- SCAM的输入是Hr和Hs。
- 首先,计算两个模态的QKV矩阵。RGB模态的查询矩阵表示为Qr = Hr ∈ RN ×C,其中N = Nz + Nx。
- 键矩阵和值矩阵通过Kr, Vr = Split(Linear(Hr))获得,其中Linear表示没有偏置的线性层,它将特征从RC映射到R2C。
- Split将特征从R2C映射到两个RC。声呐的Qs, Ks, Vs以相同的方式获得。
center_head.py
路径:lib/models/scanet/head/center_head.py
在YAML配置文件中,Head类型为center_head,它的定义在center_head.py文件中。
conv
函数: 构建一个卷积层、BatchNorm 和 ReLU 激活函数的组合。center_head
类:
- 继承自
torch.nn.Module
,这是 PyTorch 中用于构建神经网络的基类。- 包含:
- 构造函数 (
__init__
):定义网络结构。forward
函数:定义前向传播逻辑。- 分支模块:
- 中心点分支(预测目标中心点概率图)。
- 偏移分支(预测目标中心点的偏移量)。
- 尺寸分支(预测目标的宽度和高度)。
- 辅助函数:
cal_bbox
:根据预测结果计算目标边界框。get_score_map
:计算三个分支的输出(中心点、尺寸和偏移量)。
原文解释:
总结
整体网络搭建是SCANet_network类,forward流程:template和search的输入----->backbone(输出一个拼接后的特征)------>Head
template
是指在给定的视频序列中,最初用于表示目标的图像。这通常是在第一帧或某一帧中提取的目标图像或区域,包含了目标的外观特征。---------用来做目标检测search
是指在后续帧中寻找目标的图像。这通常是一个较大的区域,包含了潜在的目标位置,模型将在这个区域内进行搜索。-------用来做目标跟踪
表示template的输入是什么?两个模态的图像吗?
答:是两个模态的图像,但是是被concat成了一个序列,x的第0维表示一个模态,x的第1维表示另一个模态。
- 整个网络 输入是template(两个模态拼接)和search(两个模态拼接);
- 进入网络后将两个模态提取出来,针对各自模态,将template和search进行拼接;即x_v表示光学template+search的内容;x_s表示声呐template+search内容;
- 将两个模态(x_v和x_s)分别输入ViT Blocks,输出认识该模态的特征;
- 将经过ViT Blocks的两个模态特征输入SCAM模块,融合后输出任然是两个模态(掺杂的另一个模态信息);
- 经过多个blocks 和 SCAM 后,分别输入各自对应Head。