当前位置: 首页 > news >正文

TransUNet 学习记录

研究背景

在医学图像分割领域中,CNN是主流,其中UNet十分常用。尽管基于CNN的方法有优秀的表征能力,对局部信息有很好的效果,但是其固有的局限性导致UNet在长距离上下文上的表现不佳

现在出现了使用Transformer进行分割,Transformer在全局上下文方面表现优异,但纯Transformer会丢失许多低级细节(例如形状与边界)

因此,本文提出将Transformer与CNN进行混合,并结合UNet的U型结构。结合二者的优点,让网络既可以获得低级特征,又可以获得全局上下文。

TransUNet具有强大的学习高级语义特征和低级细节的能力

网络

纯Transformer:Transformer Encoder+直接上采样

经过Patch Embedding与Transformer,得到的数据格式为 D ∗ H W / P 2 D*HW/P^2 DHW/P2,要将其变为 n u m _ c l a s s ∗ H ∗ W num\_class*H*W num_classHW,这里直接使用1x1的卷积改变通道数,然后直接使用双线性插值来进行上采样恢复尺寸。

这种做法的效果一般,在实验部分会有讲解。

TransUNet

网络结构如图所示
在这里插入图片描述
这里使用了CNN-Transformer的混合结构作为编码器与级联上采样器(cascaded upsampler ,CUP)来精确定位

CNN-Transformer的混合结构

不再使用纯Transformer作为编码器,而使用CNN-Transformer的混合结构作为编码器,其中CNN作为特征提取器,提取原图中的feature map。Patch Embedding的输入变为了CNN提取的feature map,而不是原图,patch的大小也变为了1x1。

这样设计的原因:1)可以在解码过程中使用CNN中间高分辨率的feature map,也就是可以跳层连接;2)混合结构的性能更好。

在代码中,作者使用了预训练的ResNet50V2作为特征提取器,而且将conv4_x与conv5_x融合为一层相当于conv4_x变为9个叠加,输出通道数为1024,尺寸为14x14。
其中跳层使用的卷积层为conv1/conv2_x/conv3_x这三层。
在这里插入图片描述
在这里插入图片描述

级联上采样器 CUP

模仿UNet的网络结构,进行逐层上采样与跳层连接。其中跳层连接时与编码器的CNN特征提取器进行连接。

在进行上采样之前,还需要对Transformer输出的数据的尺寸进行调整。通过reshape与一个1x1的卷积将尺寸调整为(512,H/16,W/16),方便后面上采样。
在这里插入图片描述
然后级联多个上采样模块来将数据恢复为(H,W),每个上采样模块依次由2x上采样模块、3x3卷积层与ReLU激活函数构成。其中上采样模块不会改变数据的通道数的。

而跳层连接中,使用了ResNet50V2中的conv1/conv2_x/conv3_x这三层输出的特征图来与上采样器中对应的层进行特征聚合,这样就可以包含CNN中间高分辨率的特征图,来实现精确定位,因为CNN会包含低级信息(细节)
在这里插入图片描述
最后使用分割头将通道数变为num_classes,得到结果。

实验

性能对比实验

在这里插入图片描述
DSC:Dice
hausdorff distance(HD):豪斯多夫距离

表格中不同数据之间的对比,证明了不同设计的优越性:

  1. ViT-Non与ViT-CUP对比:证明了CUP设计比直接上采样效果更好
  2. R50-ViT-CUP与ViT-CUP对比:证明CNN-Transformer混合结构比单纯的Transformer效果更好
  3. TransUNet与R50-ViT-CUP对比:证明跳层连接对性能是有提升的

而观察其他数据,也可以证明TransUNet是有改进的:
ViT-CUP相比于ViT-None是有提升,但性能还是不如V-Net与DARR。使用了混合结构后,性能就优于V-Net与DARR,但还是低于完全基于CNN的UNet与AttnUNet。最后加上跳层连接后,TransUNet达到了新的水平,表明TransUNet具有强大的学习高级语义特征和低级细节的能力

消融实验

跳层连接的数量

测试不同数量的跳层连接对性能的影响,个数分别有0/1/3个,其中1个是只使用1/4的跳层连接。

下图是不同数量的跳过连接下DSC的数值,可以看到基本上3个跳过连接是更优秀的。
在这里插入图片描述
论文中还提及了一种改进的方法,就是在跳过连接中加入Transformer,也就是将CNN的数据先经过一个Transformer再继续连接,作者实验在1/8的连接中加入一个轻量的Transformer,结果DSC提升了1.4%。不过作者的代码中没有这部分。
在这里插入图片描述

输入图像的分辨率

默认的输入分辨率为224x224,这里提供了512x512的尺寸。
在使用512x512时,patch大小还是16x16,导致Transformer的序列长度变长5倍。

结果如图所示,增加了图像分辨率后,DSC提高了6.88%。但作者也说了,代价是计算成本大很多,因此实验还是使用224x224的分辨率
在这里插入图片描述

Patch 大小

研究Patch Size对性能的影响,Sequence Length是Transformer序列长度,也就是patch的个数。Patch Size 与 Sequence Length 是成反比的

可以看到Patch Size越小,Sequence Length越大,性能越好,这是因为Transformer为更长的序列构建更加复杂的依赖关系
在这里插入图片描述

模型的规模

也就是Transformer的尺寸,有Base与Large两种:

  • Base:embed dim=768,层数=12,MLP大小=3072,num_head=12
  • Large:embed dim=1024,层数=24,MLP大小=4096,num_head=16

可以看到Large的性能更好。而作者考虑到计算成本,实验还是使用Base
在这里插入图片描述

可视化

在这里插入图片描述
分析图像可知,TransUNet的误检与错检的更少,而且预测的边缘更加准确。

这些观察结果也表明了TransUNet同时具有高级上下文与低级细节的优势。


http://www.mrgr.cn/news/62681.html

相关文章:

  • Jmeter如何进行多服务器远程测试
  • 计算机网络(五)——传输层
  • 大模型agent学习(day1)
  • Android Dex VMP 动态加载加密指令流
  • 网络传输层TCP协议
  • PHP 使用 Redis
  • 淘宝API接口(item_history_price- 淘宝商品历史价格信息查询)
  • idea git 设置Local Changes窗口
  • Python3 No module named ‘pymysql‘
  • SwiftUI(八)- 绑定对象与环境查询
  • vector的模拟实现
  • 【GO学习笔记 go基础】访问控制
  • 局域网实时监控电脑屏幕软件有哪些?8款优秀的局域网监控app!不看巨亏!
  • 使用Kubernetes自动化部署和管理容器化应用
  • 正则表达式(Regular Expressions)
  • zynq PS端跑Linux响应中断
  • 机器学习的模型评估与选择
  • Nodes —— Utility
  • 24下软考初级信息系统运行管理员,提供一条能过的野路子
  • 两数之和笔记
  • 通过js控制修改css变量
  • YOLOV8代码分析———持续更新中
  • LivePortrait代码调试—给图片实现动态表情
  • 2小时,我搭建了一整套车间生产看板
  • 做反向代购,采购订单应该怎么批量管理?
  • 一秒变高手!MODBUSTCP-DEVICENET网关与AB 1791D模块的完美搭档秘诀!