第18章:基于Global Context Vision Transformers(GCTx_unet)网络实现的oct图像中的黄斑水肿和裂孔分割

1. 网络概述

GCTx-UNET是基于传统UNet架构的改进版本,主要引入了以下几个关键创新:

  • GCT模块:全局上下文变换器(Global Context Transformer)模块

  • 多尺度特征融合:增强的特征提取能力

  • 改进的跳跃连接:优化编码器与解码器之间的信息传递

1. 网络架构组成

1.1 编码器部分(下采样路径)

编码器由多个阶段组成,每个阶段包含:

  1. 卷积块:通常由2-3个卷积层组成,每个卷积后接批量归一化和ReLU激活

  2. GCT模块:全局上下文特征提取

  3. 下采样层:通常使用最大池化或步长卷积

与传统UNet相比,GCTx-UNet的编码器在每个阶段后加入了GCT模块,增强了全局上下文信息的捕获能力。

1.2 解码器部分(上采样路径)

解码器同样由多个阶段组成,每个阶段包含:

  1. 上采样层:通常使用转置卷积或插值方法

  2. 特征拼接:与对应编码器阶段的特征图拼接

  3. 卷积块:与编码器类似的卷积结构

  4. GCT模块:增强解码过程中的全局信息利用

1.3 瓶颈层

位于编码器和解码器之间的过渡层,包含:

  • 深度卷积层

  • GCT模块

  • 常规卷积层

2. GCT模块详解

GCT(Global Context Transformer)模块是GCTx-UNet的核心创新,其结构如下:

2.1 主要组件

  1. 全局平均池化(GAP):将空间维度压缩为1×1,获取通道级全局信息

  2. 通道注意力机制:学习各通道的重要性权重

  3. 空间注意力机制:捕捉空间位置间的长程依赖关系

  4. 特征变换:通过1×1卷积调整特征维度

2.2 数学表达

给定输入特征图X∈R^(H×W×C),GCT模块的处理流程:

  1. 全局上下文提取:

    z = GAP(X) ∈ R^(1×1×C)
  2. 通道注意力计算:

    a_c = σ(W_c z + b_c) ∈ R^C
  3. 空间注意力计算:

    a_s = σ(conv(W_s * X)) ∈ R^(H×W×1)
  4. 特征融合:

    Y = X ⊗ a_c ⊗ a_s + X

其中σ表示sigmoid函数,⊗表示逐元素乘法。

3. 改进的跳跃连接

GCTx-UNet对传统UNet的跳跃连接进行了优化:

  1. 特征对齐:使用1×1卷积调整通道数

  2. GCT增强:在拼接前对编码器特征应用GCT模块

  3. 注意力门控:可选地加入注意力机制,抑制不相关区域

GCTx-UNet通过巧妙结合CNN的局部特征提取能力和Transformer的全局建模优势,在医学图像分割领域取得了优异的性能表现。

2. OCT 图像中的黄斑水肿和裂孔分割

数据集总共有2500张图片和mask标签,按照7:3划分了训练集和验证集

数据集自动预处理结果:

3. 训练参数

参数如下:

  • ct 如果数据集是医学影像的ct格式,那么会自动进行windowing增强,效果会更好
  • model 可以选择unet、unet++、unet3+三种模型
  • base size是预处理的图像大小,也就是输入model的图像尺寸
  • batch size 批大小,epoch 训练轮次,lr 是cos衰减的起始值,lrf是衰减倍率,这里lr最终大小是0.01 * 0.001 
  • results 是训练生成的主目录
    parser.add_argument("--ct", default=False,type=bool,help='is CT?')    # Ct --> True
    parser.add_argument("--model", default='gctx_unet',help='gctx_unet')

    parser.add_argument("--base-size",default=(224,224),type=tuple)         # 根据图像大小更改

    parser.add_argument("--batch-size", default=16, type=int)
    parser.add_argument("--epochs", default=300, type=int)

    parser.add_argument('--lr', default=0.002, type=float)
    parser.add_argument('--lrf',default=0.002,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument("--img_f", default='.jpg', type=str)               # 数据图像的后缀
    parser.add_argument("--mask_f", default='.png', type=str)              # mask图像的后缀

    parser.add_argument("--results", default='runs', type=str)                  # 保存目录
    parser.add_argument("--data-train", default='./data/train/images', type=str)
    parser.add_argument("--data-val", default='./data/val/images', type=str)

4. 训练结果

如下:这里以unet为例
这里新添了网络参数个数、flops

    "train parameters": {
        "model": "gctx_unet",
        "input size": [
            224,
            224
        ],
        "batch size": 16,
        "lr": 0.002,
        "lrf": 0.002,
        "ct": false,
        "epochs": 300,
        "num classes": 3
    },
    "model": {
        "total parameters": 14048512.0,
        "flops": 4427196736.0

曲线如下:

最后一个epoch的指标:

    "epoch:299": {
        "train log:": {
            "info": {
                "pixel accuracy": [
                    0.9974969029426575
                ],
                "Precision": [
                    "0.9147",
                    "0.9549"
                ],
                "Recall": [
                    "0.8954",
                    "0.9517"
                ],
                "F1 score": [
                    "0.9050",
                    "0.9533"
                ],
                "Dice": [
                    "0.9050",
                    "0.9533"
                ],
                "IoU": [
                    "0.8264",
                    "0.9107"
                ],
                "mean precision": 0.9347942471504211,
                "mean recall": 0.9235433340072632,
                "mean f1 score": 0.9291158318519592,
                "mean dice": 0.929115891456604,
                "mean iou": 0.868566632270813
            }
        },
        "val log:": {
            "info": {
                "pixel accuracy": [
                    0.9966462850570679
                ],
                "Precision": [
                    "0.8974",
                    "0.9246"
                ],
                "Recall": [
                    "0.8812",
                    "0.9088"
                ],
                "F1 score": [
                    "0.8893",
                    "0.9166"
                ],
                "Dice": [
                    "0.8893",
                    "0.9166"
                ],
                "IoU": [
                    "0.8006",
                    "0.8461"
                ],
                "mean precision": 0.9109943509101868,
                "mean recall": 0.8950419425964355,
                "mean f1 score": 0.9029476642608643,
                "mean dice": 0.9029476642608643,
                "mean iou": 0.8233512043952942
            }

5. 推理

把想要推理的图片直接放在inference/img下,运行predict脚本即可

6. 下载

下载地址:基于GlobalContextVisionTransformers-UNet(GCViT)模型实现的oct黄斑水肿和裂孔完整分割项目、代码、数据、和训练结果、推理结果等资源-CSDN文库

关于图像分类、分割网络的改进:图像分类网络改进_听风吹等浪起的博客-CSDN博客 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

听风吹等浪起

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值