当前位置: 首页 > news >正文

SCANet代码解读

论文链接:[2406.07189] RGB-Sonar Tracking Benchmark and Spatial Cross-Attention Transformer Tracker

代码链接:

GitHub - LiYunfengLYF/SCANet


 通过代码需要了解到的事情:

1. 两种模态的数据集是怎么传入模型的?
2. 模型的输入输出是什么,模型结构代码是如何搭建?


train.py

路径:tracking/train.py

该文件主要用于是参数设置。

  • --script:训练脚本的名称,用于指定要运行的训练逻辑,比如模型结构或任务。
  • --config:训练配置文件名,默认是 baseline(通常是 YAML 文件),用于指定训练的超参数、模型配置等。
  • --save_dir:指定保存训练结果(如模型权重、日志、TensorBoard 日志)的目录路径。
  • --mode:训练模式,支持以下几种:
    • single:单卡训练(单 GPU)。
    • multiple:多卡训练(多 GPU,单节点)。
    • multi_node:多节点训练(分布式训练,跨多台机器)。
  • --nproc_per_node:每个节点上使用的 GPU 数量(仅在 multiple 或 multi_node 模式下需要)。
  • --use_lmdb:是否使用 LMDB 格式的数据集(0 或 1)。
  • --use_wandb:是否使用 Weights & Biases(WandB) 工具进行训练监控(0 或 1)。
  • --env_num:指定环境编号,支持多个环境的开发(0, 1, 2 等)。

通过选择训练模式(单卡)后,os.system() 会创建一个新的子进程来执行 train_cmd。在子进程中,系统调用相应的 Python 解释器来执行指定的脚本( lib/train/run_training.py)。


run_training.py

 路径:lib/train/run_training.py

这个文件包含训练的入口程序。

功能概述:

  1. 接收用户输入的命令行参数(例如训练脚本名、配置文件名、随机种子等)。
  2. 初始化训练环境,包括设置随机种子、CUDNN 加速、分布式通信等。
  3. 调用具体的训练逻辑,基于用户指定的脚本和配置文件启动训练任务。
  4. 支持扩展功能:
    • 知识蒸馏(使用教师模型指导学生模型训练)。
    • 多环境开发(env_num 参数)。
    • 数据存储格式(如 LMDB)。
    • 集成工具(如 WandB,用于训练过程的可视化)。

 初始化随机种子 (init_seeds)

作用:设置随机种子以确保训练的可复现性。

细节

  • 使用 Python 的 random 和 NumPy 的 np.random.seed 设置全局随机数种子。
  • 使用 PyTorch 的 torch.manual_seed 和 torch.cuda.manual_seed 设置 GPU 和 CPU 的随机数种子。
  • 如果使用 CUDNN 后端(GPU 加速),设置 torch.backends.cudnn.benchmark 为 True,以加速特定的卷积操作。

 初始化训练环境

ws_settings.Settings(env_num):加载训练环境的配置,例如设备编号、超参数等。

路径设置

settings.project_path:项目路径,通常表示当前训练任务的文件夹。

  • settings.cfg_file:训练配置文件的完整路径(experiments/<script_name>/<config_name>.yaml)。
  • settings.save_dir:将用户指定的保存路径转为绝对路径。

 知识蒸馏支持

如果启用了知识蒸馏(distill=1),会加载教师模型的脚本和配置文件。

细节

  • 教师模型和学生模型的训练逻辑分别存放在不同的模块中:
    • train_script_distill:蒸馏模式。
    • train_script普通训练模式(使用)
  • 动态加载模块:使用 importlib.import_module 动态导入模块,通过反射调用相应的训练逻辑。

 执行训练

训练选择普通模式(不使用知识蒸馏)。

  • getattr(expr_module, 'run'):从动态加载的模块中获取 run 函数。
  • expr_func(settings):将训练的配置(settings)传入,并启动训练。

进入lib.train.train_script.run()执行训练。

train_script.py 

路径:lib/train/train_script.py

 加载配置文件

初始化随机种子

日志目录初始化

数据加载器构建

build_dataloaders:加载训练和验证数据集,返回两个数据加载器loader_train 和 loader_val 。

  • 训练加载器(loader_train):用于训练数据的迭代读取。
  • 验证加载器(loader_val):用于在训练期间验证模型性能。

 模型构建

根据配置文件指定的模型名称,动态从 TRACKER_REGISTRY 注册表中找到模型类

TRACKER_REGISTRY 是一个注册表,类似于 Python 的字典,存储了模型名称和模型类之间的映射关系。

在代码中,TRACKER_REGISTRY.get(cfg.MODEL.NETWORK) 用于通过模型名称(cfg.MODEL.NETWORK)从注册表中查找对应的模型类。

在代码的其他部分,模型类会通过装饰器注册到 TRACKER_REGISTRY,如下:

分布式训练包裹

损失函数和 Actor

损失函数

  • 根据配置文件中指定的 IoU 类型(例如 giou 或 wiou),选择相应的 IoU 损失函数。
  • 其他损失函数包括:
    • l1_loss:用于回归目标框。
    • focal_loss:用于处理类别不平衡。
    • BCEWithLogitsLoss:用于分类精度。
  • loss_weight:定义每个损失函数的权重,用于加权损失求和。

Actor

  • 从 ACTOR_Registry 中获取 Actor 实例(类似于控制训练逻辑的管理者)。
  • Actor 将模型(net)、损失函数(objective)和训练设置(settingscfg)结合在一起。

 优化器和学习率调度器

Trainer 实例化

LTRTrainer:训练流程管理器,负责实际的训练逻辑。

参数

  • actor:负责模型前向传播、损失计算和后向传播的组件。
  • 数据加载器 [loader_train, loader_val]:训练和验证数据源。
  • optimizer 和 lr_scheduler:优化器和学习率调度器。
  • use_amp:是否使用混合
    精度训练(AMP)。
  • rgb_mode:是否启用 RGB 模式(配置文件中定义)

启动训练过程

trainer.train:启动训练流程。

  • 参数
    • cfg.TRAIN.LEARN.EPOCH:训练的总 epoch 数。
    • load_latest=True:是否加载最近保存的检查点(如果存在)。
    • fail_safe=True:是否启用故障恢复机制(例如断点续训)。

 base_function.py

路径:lib/train/base_functions.py


build_dataloaders

在上述的train_script.py文件中,数据集传入build_dataloaders进行数据加载

 build_dataloaders函数在base_functions.py文件中


names2datasets

names2datasets 函数根据配置中的数据集名称列表构建相应的数据集实例。对于每个数据集,代码会检查是否使用 LMDB 格式,并创建相应的数据集对象。

只有LASOT 和 GOT10K 两个数据集使用了rgbs_mode,其中使用了
cv2.saliency.StaticSaliencySpectralResidual_create()进行了视觉显著性的图像转换,猜测作者是将其模拟声呐图做训练。

因为作者公开了声光融合数据集,但在代码中,其公开的RGBS50数据集只作为test数据集在使用。


 trackerModel.py——重点

路径:lib/models/scanet/trackerModel.py

在train_script.py中,通过TRACKER_REGISTRY.get() 动态查找模型类。lib/models/scanet/trackerModel.py中的SCANet_network类被装饰器注册到TRACKER_REGISTRY中。

 模型基于 OSTrack 构建,目标是实现RGB-T(可见光 & 声呐)目标跟踪任务。


 SCANet_network类

构造函数

  • self.backbone:构建骨干网络,从输入的输入数据中提取特征,参数cfg.MODEL.BACKBONE.LOAD_MODE指定了加载模式

  • 加载预训练权重
  • 调用骨干网络的微调函数
  • 构建RGB和Sonar的预测头模块,头部模块的类型和参数由配置文件指定(cfg.MODEL.HEAD 和 cfg.MODEL.RGBS_HEAD

forward前向传播

  • 输入:模板帧(template frame)搜索帧(search frame)。模板帧用于定义要跟踪的目标,而搜索帧用于定位目标。
  • 骨干网络特征提取
  • 调用 forward_head 方法,将骨干网络的输出传递给头部模块,生成 RGB 和声呐的预测结果。

 forward_head预测头模块函数

  • 输入:骨干网络的输出 cat_feature,以及可选参数 gt_score_map
  • 特征切片
  • 特征变换
  • RGB 和声呐头部的前向传播

    • 将 RGB 和声呐特征分别传递给对应的头部模块 box_head 和 sonar_head
    • 生成的输出包含预测的边界框(pred_boxes)和其他特征图(如 score_mapsize_map)。
    • 上述的具体参数是来自配置文件cfg的,具体是通过找到 settings.cfg_file 中指定的 YAML 文件路径,查看 MODEL.HEADMODEL.RGBS_HEAD 的具体配置

ltr_trainer.py

路径:lib/train/trainers/ltr_trainer.py

在train_script.py文件中,数据集、模型等相关信息传入LTRTrainer类开始训练:

LTRTrainer类是继承BaseTrainer类的

在 LTRTrainer 类中,实际是将数据传入actor中执行训练过程中所需的前向传播、损失计算和数据处理等功能。

 


 rgbser_actor.py

路径:lib/train/actors/rgbser_actor.py

在上述ltr_trainer.py中,数据实际传入acotr类中实现训练过程

从数据加载器中读取到的输入传入actor,然后传入data2temp_search 就可以分出两个模态的template 和 search 的图片了,即template_img_1, template_img_2, search_img_1, search_img_2。

  • data: 一个字典,包含了图像数据,格式如下:
    • data['template_images']: 包含模板图像的张量,形状为 (batch_size, N_t, 3, H, W),其中 N_t 是模板图像的数量(在这个函数中假设为 2)。
    • data['search_images']: 包含搜索图像的张量,形状为 (batch_size, N_s, 3, H, W),其中 N_s 是搜索图像的数量(在这个函数中假设为 2)。

 cycle_dataset

数据加载和模型训练的主要功能在 cycle_dataset方法中实现。

输入是loader(数据加载器),负责提供训练或验证数据。


baseline.yaml

路径:experiments/scanet/baseline.yaml

这份 YAML 配置文件包含了模型架构、数据集、训练、验证和测试的全流程参数:

  1. 数据配置(DATA):定义了输入数据的模态、归一化参数,以及训练/验证数据集。
  2. 模型配置(MODEL):包括骨干网络、头部模块等配置,支持多模态(RGB + 声呐)。
  3. 训练配置(TRAIN):详细定义了优化器、损失函数、混合精度等训练参数。
  4. 测试配置(TEST):与训练类似,定义了搜索和模板区域的参数。

数据配置DATA

模态设置

  • RGB_MODE 和 RGBS_MODE:指示模型是否同时处理 RGB 和 RGBS(声呐)数据模态。
  • RGBS_ROTATE:是否对 RGBS 数据增加旋转增强。
  • MAX_SAMPLE_INTERVAL:最大采样间隔,用于决定从视频中采样帧的时间跨度。

数据归一化参数

  • MEAN 和 STD:输入图像的归一化参数,通常是 RGB 图像的标准均值和标准差,用于将像素值归一化到标准分布。

搜索区域 (SEARCH) 和模板区域 (TEMPLATE) 配置

  • 搜索区域(SEARCH)
    • 定义了模型在当前帧中搜索目标的参数。
    • 例如,FACTOR=4.0 表示搜索区域的宽度/高度是目标框的 4 倍。
    • SIZE=256 表示将搜索区域缩放到 256x256 的固定尺寸。
  • 模板区域(TEMPLATE)
    • 定义了模板帧(通常是目标的初始位置)的参数。
    • SIZE=128 表示将模板区域缩放到 128x128。
    • 通常模板帧会被固定且不抖动(CENTER_JITTER=0 和 SCALE_JITTER=0)。

数据集


模型配置 MODEL

网络整体信息 

  • NETWORK:模型的主网络名称,这里是 SCANet_network
  • RETURN_STAGES:骨干网络中需要返回的特征层索引(例如第 2、5、8 和 11 层)。

 骨干网络 (BACKBONE)


  • TYPE:骨干网络类型,这里采用 vit_base_patch16_224_midlayer(Vision Transformer)。
  • PARAMS
    • ffm: SCAM:特征融合方式为 SCAM。
    • rgbs_loc: [3, 6, 9]:指定在第 3、6、9 层使用 RGB-S 融合。

 头部模块

  • HEAD(RGB 头) 和 RGBS_HEAD(声呐头)
    • 类型为 center_head,参数为:
      • inplanes:输入通道数。
      • channel:中间特征通道数。
      • feat_sz:特征图大小。
      • stride:卷积步幅。


 训练配置 TRAIN

  • 优化目标
    • 使用 giou(广义 IoU)作为 IoU 损失。
    • GIOU_WEIGHT 和 L1_WEIGHT:分别为 GIoU 和 L1 损失的权重。
  • 优化器参数
    • LR:初始学习率为 0.00001
    • WEIGHT_DECAY:权重衰减值为 0.0001

 优化器和学习率调度器

  • 优化器:使用 ADAMW 优化器。
  • 学习率调度器
    • 使用 step 调度器,每隔 30 个 epoch 降低学习率。

AMP混合精度训练 

  • 是否启用 AMP(混合精度训练):USED: False
  • 梯度裁剪的最大范数:GRAD_CLIP_NORM: 0.1

测试配置 TEST

 

  • 测试的搜索区域和模板区域与训练保持一致,分别为 256x256 和 128x128

 vit_rgbs.py

路径:lib/models/scanet/vit_rgbs.py

在trackerModel.py中

 在YAML文件中,backbone的配置如下:

 backbone的type为vit_base_patch16_224_midlayer,在vit_rgbs.py文件中定义


vit_rgbs.py文件代码结构概览

  1. Attention 和 Block 模块

    • 实现了 Transformer 中的注意力机制(Attention)及基本的 Transformer 块(Block)。
  2. VisionTransformer_midlayer

    • 核心 Transformer 模型,支持 ViT 结构,包含可见光(RGB)和 T 模态的融合能力。
    • 支持冻结部分 Transformer 层(freeze_layer)和插入中间融合层(rgbs_layers)。
  3. 辅助函数

    • 包括权重初始化、加载预训练权重、调整位置嵌入(pos_embed)等。
  4. 模型注册

    • 定义了 vit_basevit_small, 和 vit_tiny 三种 ViT 模型,并通过注册表(MODEL_REGISTRY)动态加载。

 vit_base_patch16_224_midlayer --------> _create_vision_transformer

_create_vision_transformer ---------> VisionTransformer_midlayer


VisionTransformer_midlayer——重点

这部分是这个py文件的核心,它实现了标准的 Vision Transformer,同时支持红外(T)与 RGB 的融合。是BaseBackbone的子类。


模型结构

整体模型架构可以拆解为以下几个主要部分:

  1. 输入处理模块

    • 图像被切分为 Patch,并通过 PatchEmbed 层进行线性嵌入,生成低维特征。
    • 位置编码(pos_embed)被添加到特征中,使得 Transformer 感知输入的空间位置。
  2. Transformer 主干网络

    • 若干个堆叠的 Transformer Block(这里的block就是ViT中的基础block),每个 Block 包含:
      • 多头注意力机制(Attention)。
      • 前馈网络(MLP)。
    • 这些 Block 用于提取输入的高层次特征。
  3. 融合模块rgbs_layers):

    • 在指定的 Transformer Block 层(通过 rgbs_loc 参数定义)插入特征融合模块(ffm-SCAM)。
    • 融合模块用于交互和融合 RGB 和 T 模态的特征。
  4. 输出层

    • Transformer 的输出经过归一化(norm)后,分别恢复为 RGB 和 T 模态的各自特征。
    • 最终将两种模态的特征拼接,作为模型的输出。
a. Patch 嵌入

  • 将输入图像切分成固定大小的 Patch,并使用线性投影将每个 Patch 映射到 embed_dim 维度。
 b. 位置嵌入

  • 提供位置编码,使得 Transformer 能够感知输入的空间位置信息。
c. Transformer Block块

  • 核心组件是多个堆叠的 Block 模块,数量由 depth 参数指定。
 d. 融合层——SCAM的插入

  • 插入融合层( ffm——SCAM)以实现 RGB 和 T 特征的融合。
  • 通过 rgbs_loc 指定融合层插入的位置。

前向传播

a. 输入处理

可见光(x[0], z[0])和声纳(x[1], z[1])数据分别通过独立的Patch Embedding层。

详看base_backbone.py节解读。

b. 特征融合

  • 逐层通过 Transformer 块。
  • 在指定ViT层(如第3、6、9层)插入SCAM模块,实现特征对齐与融合。
 c. 模态恢复

  • 从融合后的特征中恢复 RGB 和 T 的独立特征。
  • 最后将两种模态的特征拼接在一起。


Block

这里的block如图所示,就是ViT的基础block。


 base_backbone.py

路径:lib/models/scanet/base_backbone.py

 上述表示模型结构的类  VisionTransformer_midlayer  继承了   BaseBackbone

BaseBackbone是基本ViT骨干网的联合特征提取和关系建模。他的输入可以理解为单模态的模板(template)----z,和搜索区域(search region)-----x;

VisionTransformer_midlayer继承后,输入的其实也是单序列,但可能在输入之前把两种模态拼接了,所以在传入函数后,又分别提取了x_v,x_s以及z_v,z_s两种模态的template和search:

原文:


scam.py

路径:lib/models/scanet/mid_rgbs_layer/scam.py

在vit_rgbs.py文件中,搭建了backbone结构,其中融合层ffm 在YAML文件中被定义为SCAM,其在scam.py文件中被定义。

在vit_rgbs.py文件中,对整个网络进行了搭建。

网络由block搭建,层数为depth,然后在指定的rgbs_loc的位置添加ffm——SCAM融合块

SCAM模块的作用是接收两个输入特征(x1x2),通过交叉注意力机制进行交互,然后在特征中加入残差连接和前馈网络(FFN)。

这里输入的两个特征分别是两个模态经过ViT Block后的两个输出。

SCAM模块的输出任然是两个模态的特征。

原论文对SCAM的解释:

SCAM旨在实现空间未对齐的RGB特征与声呐特征的有效跨模态交互。

SCAM由一个空间交叉注意力层和两个独立的全局整合模块(GIM)组成。我们的SCAM模块的操作如下所述:

  • SCAM的输入是Hr和Hs。
  • 首先,计算两个模态的QKV矩阵。RGB模态的查询矩阵表示为Qr = Hr ∈ RN ×C,其中N = Nz + Nx。
  • 键矩阵和值矩阵通过Kr, Vr = Split(Linear(Hr))获得,其中Linear表示没有偏置的线性层,它将特征从RC映射到R2C。
  • Split将特征从R2C映射到两个RC。声呐的Qs, Ks, Vs以相同的方式获得。


 center_head.py

路径:lib/models/scanet/head/center_head.py

在YAML配置文件中,Head类型为center_head,它的定义在center_head.py文件中。

  • conv 函数: 构建一个卷积层、BatchNorm 和 ReLU 激活函数的组合。
  • center_head 类
    • 继承自 torch.nn.Module,这是 PyTorch 中用于构建神经网络的基类。
    • 包含:
      1. 构造函数 (__init__):定义网络结构。
      2. forward 函数:定义前向传播逻辑。
      3. 分支模块
        • 中心点分支(预测目标中心点概率图)。
        • 偏移分支(预测目标中心点的偏移量)。
        • 尺寸分支(预测目标的宽度和高度)。
      4. 辅助函数
        • cal_bbox:根据预测结果计算目标边界框。
        • get_score_map:计算三个分支的输出(中心点、尺寸和偏移量)。

原文解释:

 


总结 

整体网络搭建是SCANet_network类,forward流程:template和search的输入----->backbone(输出一个拼接后的特征)------>Head

  • template 是指在给定的视频序列中,最初用于表示目标的图像。这通常是在第一帧或某一帧中提取的目标图像或区域,包含了目标的外观特征。---------用来做目标检测
  • search 是指在后续帧中寻找目标的图像。这通常是一个较大的区域,包含了潜在的目标位置,模型将在这个区域内进行搜索。-------用来做目标跟踪

表示template的输入是什么?两个模态的图像吗? 

答:是两个模态的图像,但是是被concat成了一个序列,x的第0维表示一个模态,x的第1维表示另一个模态。

  • 整个网络 输入是template(两个模态拼接)和search(两个模态拼接);
  • 进入网络后将两个模态提取出来,针对各自模态,将template和search进行拼接;即x_v表示光学template+search的内容;x_s表示声呐template+search内容;
  • 将两个模态(x_v和x_s)分别输入ViT Blocks,输出认识该模态的特征;
  • 将经过ViT Blocks的两个模态特征输入SCAM模块,融合后输出任然是两个模态(掺杂的另一个模态信息);
  • 经过多个blocks  和 SCAM 后,分别输入各自对应Head。


http://www.mrgr.cn/news/91691.html

相关文章:

  • 【Elasticsearch】`nested`和`flattened`字段在索引时有显著的区别
  • 【DeepSeek系列】04 DeepSeek-R1:带有冷启动的强化学习
  • TCP和Http协议
  • PyTorch 源码学习:阅读经验 代码结构
  • 嵌入式音视频开发(三)直播协议及编码器
  • 【Java】泛型与集合篇 —— Set 接口
  • 前端常见面试题-2025
  • C语言——时基
  • Linux-----进程(多任务)
  • C#发送邮件
  • 基于正则化密集连接金字塔网络的显著实例分割
  • mysql总结
  • Day6 25/2/19 WED
  • Windows 启动 SSH 服务报错 1067
  • Compose 常用UI组件
  • PVE使用一个物理网卡采用VLAN为管理IP和VM分配网络的问题
  • springboot-ffmpeg-m3u8-convertor nplayer视频播放弹幕 artplayer视频弹幕
  • 【SQL】多表查询案例
  • OpenResty
  • [数据结构]顺序表详解