TransUNet 学习记录
研究背景
在医学图像分割领域中,CNN是主流,其中UNet十分常用。尽管基于CNN的方法有优秀的表征能力,对局部信息有很好的效果,但是其固有的局限性导致UNet在长距离上下文上的表现不佳。
现在出现了使用Transformer进行分割,Transformer在全局上下文方面表现优异,但纯Transformer会丢失许多低级细节(例如形状与边界)
因此,本文提出将Transformer与CNN进行混合,并结合UNet的U型结构。结合二者的优点,让网络既可以获得低级特征,又可以获得全局上下文。
TransUNet具有强大的学习高级语义特征和低级细节的能力
网络
纯Transformer:Transformer Encoder+直接上采样
经过Patch Embedding与Transformer,得到的数据格式为 D ∗ H W / P 2 D*HW/P^2 D∗HW/P2,要将其变为 n u m _ c l a s s ∗ H ∗ W num\_class*H*W num_class∗H∗W,这里直接使用1x1的卷积改变通道数,然后直接使用双线性插值来进行上采样恢复尺寸。
这种做法的效果一般,在实验部分会有讲解。
TransUNet
网络结构如图所示
这里使用了CNN-Transformer的混合结构作为编码器与级联上采样器(cascaded upsampler ,CUP)来精确定位
CNN-Transformer的混合结构
不再使用纯Transformer作为编码器,而使用CNN-Transformer的混合结构作为编码器,其中CNN作为特征提取器,提取原图中的feature map。Patch Embedding的输入变为了CNN提取的feature map,而不是原图,patch的大小也变为了1x1。
这样设计的原因:1)可以在解码过程中使用CNN中间高分辨率的feature map,也就是可以跳层连接;2)混合结构的性能更好。
在代码中,作者使用了预训练的ResNet50V2作为特征提取器,而且将conv4_x与conv5_x融合为一层,相当于conv4_x变为9个叠加,输出通道数为1024,尺寸为14x14。
其中跳层使用的卷积层为conv1/conv2_x/conv3_x这三层。
级联上采样器 CUP
模仿UNet的网络结构,进行逐层上采样与跳层连接。其中跳层连接时与编码器的CNN特征提取器进行连接。
在进行上采样之前,还需要对Transformer输出的数据的尺寸进行调整。通过reshape与一个1x1的卷积将尺寸调整为(512,H/16,W/16),方便后面上采样。
然后级联多个上采样模块来将数据恢复为(H,W),每个上采样模块依次由2x上采样模块、3x3卷积层与ReLU激活函数构成。其中上采样模块不会改变数据的通道数的。
而跳层连接中,使用了ResNet50V2中的conv1/conv2_x/conv3_x这三层输出的特征图来与上采样器中对应的层进行特征聚合,这样就可以包含CNN中间高分辨率的特征图,来实现精确定位,因为CNN会包含低级信息(细节)
最后使用分割头将通道数变为num_classes,得到结果。
实验
性能对比实验
DSC:Dice
hausdorff distance(HD):豪斯多夫距离
表格中不同数据之间的对比,证明了不同设计的优越性:
- ViT-Non与ViT-CUP对比:证明了CUP设计比直接上采样效果更好
- R50-ViT-CUP与ViT-CUP对比:证明CNN-Transformer混合结构比单纯的Transformer效果更好
- TransUNet与R50-ViT-CUP对比:证明跳层连接对性能是有提升的
而观察其他数据,也可以证明TransUNet是有改进的:
ViT-CUP相比于ViT-None是有提升,但性能还是不如V-Net与DARR。使用了混合结构后,性能就优于V-Net与DARR,但还是低于完全基于CNN的UNet与AttnUNet。最后加上跳层连接后,TransUNet达到了新的水平,表明TransUNet具有强大的学习高级语义特征和低级细节的能力
消融实验
跳层连接的数量
测试不同数量的跳层连接对性能的影响,个数分别有0/1/3个,其中1个是只使用1/4的跳层连接。
下图是不同数量的跳过连接下DSC的数值,可以看到基本上3个跳过连接是更优秀的。
论文中还提及了一种改进的方法,就是在跳过连接中加入Transformer,也就是将CNN的数据先经过一个Transformer再继续连接,作者实验在1/8的连接中加入一个轻量的Transformer,结果DSC提升了1.4%。不过作者的代码中没有这部分。
输入图像的分辨率
默认的输入分辨率为224x224,这里提供了512x512的尺寸。
在使用512x512时,patch大小还是16x16,导致Transformer的序列长度变长5倍。
结果如图所示,增加了图像分辨率后,DSC提高了6.88%。但作者也说了,代价是计算成本大很多,因此实验还是使用224x224的分辨率
Patch 大小
研究Patch Size对性能的影响,Sequence Length是Transformer序列长度,也就是patch的个数。Patch Size 与 Sequence Length 是成反比的。
可以看到Patch Size越小,Sequence Length越大,性能越好,这是因为Transformer为更长的序列构建更加复杂的依赖关系
模型的规模
也就是Transformer的尺寸,有Base与Large两种:
- Base:embed dim=768,层数=12,MLP大小=3072,num_head=12
- Large:embed dim=1024,层数=24,MLP大小=4096,num_head=16
可以看到Large的性能更好。而作者考虑到计算成本,实验还是使用Base
可视化
分析图像可知,TransUNet的误检与错检的更少,而且预测的边缘更加准确。
这些观察结果也表明了TransUNet同时具有高级上下文与低级细节的优势。