1. 网络概述
GCTx-UNET是基于传统UNet架构的改进版本,主要引入了以下几个关键创新:
-
GCT模块:全局上下文变换器(Global Context Transformer)模块
-
多尺度特征融合:增强的特征提取能力
-
改进的跳跃连接:优化编码器与解码器之间的信息传递
1. 网络架构组成
1.1 编码器部分(下采样路径)
编码器由多个阶段组成,每个阶段包含:
-
卷积块:通常由2-3个卷积层组成,每个卷积后接批量归一化和ReLU激活
-
GCT模块:全局上下文特征提取
-
下采样层:通常使用最大池化或步长卷积
与传统UNet相比,GCTx-UNet的编码器在每个阶段后加入了GCT模块,增强了全局上下文信息的捕获能力。
1.2 解码器部分(上采样路径)
解码器同样由多个阶段组成,每个阶段包含:
-
上采样层:通常使用转置卷积或插值方法
-
特征拼接:与对应编码器阶段的特征图拼接
-
卷积块:与编码器类似的卷积结构
-
GCT模块:增强解码过程中的全局信息利用
1.3 瓶颈层
位于编码器和解码器之间的过渡层,包含:
-
深度卷积层
-
GCT模块
-
常规卷积层
2. GCT模块详解
GCT(Global Context Transformer)模块是GCTx-UNet的核心创新,其结构如下:
2.1 主要组件
-
全局平均池化(GAP):将空间维度压缩为1×1,获取通道级全局信息
-
通道注意力机制:学习各通道的重要性权重
-
空间注意力机制:捕捉空间位置间的长程依赖关系
-
特征变换:通过1×1卷积调整特征维度
2.2 数学表达
给定输入特征图X∈R^(H×W×C),GCT模块的处理流程:
-
全局上下文提取:
z = GAP(X) ∈ R^(1×1×C)
-
通道注意力计算:
a_c = σ(W_c z + b_c) ∈ R^C
-
空间注意力计算:
a_s = σ(conv(W_s * X)) ∈ R^(H×W×1)
-
特征融合:
Y = X ⊗ a_c ⊗ a_s + X
其中σ表示sigmoid函数,⊗表示逐元素乘法。
3. 改进的跳跃连接
GCTx-UNet对传统UNet的跳跃连接进行了优化:
-
特征对齐:使用1×1卷积调整通道数
-
GCT增强:在拼接前对编码器特征应用GCT模块
-
注意力门控:可选地加入注意力机制,抑制不相关区域
GCTx-UNet通过巧妙结合CNN的局部特征提取能力和Transformer的全局建模优势,在医学图像分割领域取得了优异的性能表现。
2. OCT 图像中的黄斑水肿和裂孔分割
数据集总共有2500张图片和mask标签,按照7:3划分了训练集和验证集
数据集自动预处理结果:
3. 训练参数
参数如下:
- ct 如果数据集是医学影像的ct格式,那么会自动进行windowing增强,效果会更好
- model 可以选择unet、unet++、unet3+三种模型
- base size是预处理的图像大小,也就是输入model的图像尺寸
- batch size 批大小,epoch 训练轮次,lr 是cos衰减的起始值,lrf是衰减倍率,这里lr最终大小是0.01 * 0.001
- results 是训练生成的主目录
parser.add_argument("--ct", default=False,type=bool,help='is CT?') # Ct --> True
parser.add_argument("--model", default='gctx_unet',help='gctx_unet')
parser.add_argument("--base-size",default=(224,224),type=tuple) # 根据图像大小更改
parser.add_argument("--batch-size", default=16, type=int)
parser.add_argument("--epochs", default=300, type=int)
parser.add_argument('--lr', default=0.002, type=float)
parser.add_argument('--lrf',default=0.002,type=float) # 最终学习率 = lr * lrf
parser.add_argument("--img_f", default='.jpg', type=str) # 数据图像的后缀
parser.add_argument("--mask_f", default='.png', type=str) # mask图像的后缀
parser.add_argument("--results", default='runs', type=str) # 保存目录
parser.add_argument("--data-train", default='./data/train/images', type=str)
parser.add_argument("--data-val", default='./data/val/images', type=str)
4. 训练结果
如下:这里以unet为例
这里新添了网络参数个数、flops
"train parameters": {
"model": "gctx_unet",
"input size": [
224,
224
],
"batch size": 16,
"lr": 0.002,
"lrf": 0.002,
"ct": false,
"epochs": 300,
"num classes": 3
},
"model": {
"total parameters": 14048512.0,
"flops": 4427196736.0
曲线如下:
最后一个epoch的指标:
"epoch:299": {
"train log:": {
"info": {
"pixel accuracy": [
0.9974969029426575
],
"Precision": [
"0.9147",
"0.9549"
],
"Recall": [
"0.8954",
"0.9517"
],
"F1 score": [
"0.9050",
"0.9533"
],
"Dice": [
"0.9050",
"0.9533"
],
"IoU": [
"0.8264",
"0.9107"
],
"mean precision": 0.9347942471504211,
"mean recall": 0.9235433340072632,
"mean f1 score": 0.9291158318519592,
"mean dice": 0.929115891456604,
"mean iou": 0.868566632270813
}
},
"val log:": {
"info": {
"pixel accuracy": [
0.9966462850570679
],
"Precision": [
"0.8974",
"0.9246"
],
"Recall": [
"0.8812",
"0.9088"
],
"F1 score": [
"0.8893",
"0.9166"
],
"Dice": [
"0.8893",
"0.9166"
],
"IoU": [
"0.8006",
"0.8461"
],
"mean precision": 0.9109943509101868,
"mean recall": 0.8950419425964355,
"mean f1 score": 0.9029476642608643,
"mean dice": 0.9029476642608643,
"mean iou": 0.8233512043952942
}
5. 推理
把想要推理的图片直接放在inference/img下,运行predict脚本即可
6. 下载
下载地址:基于GlobalContextVisionTransformers-UNet(GCViT)模型实现的oct黄斑水肿和裂孔完整分割项目、代码、数据、和训练结果、推理结果等资源-CSDN文库
关于图像分类、分割网络的改进:图像分类网络改进_听风吹等浪起的博客-CSDN博客