MedCLIP-SAMv2 实验计划
1. 模型搭建
1.1 下游SAM模型架构
SAM模型将接收从BiomedCLIP生成的显著性图作为输入,通过点提示(Point Prompts)和框提示(Box Prompts)生成精确的分割掩码。需要完成以下工作:
-
BiomedCLIP模型接口
- 确保微调后的BiomedCLIP模型能够正确输出显著性图
- 实现有效的模型检查点加载机制
-
SAM模型配置
- 使用预训练的SAM模型(ViT-H)
- 实现自定义的Prompt生成策略
- 修改SAM预测器以适应医学图像特点
-
后处理流程
- 实现多种后处理算法,包括K-means、CRF和形态学操作
- 设计投票机制整合多次预测结果
1.2 代码实现
创建一个整合脚本,将微调的BiomedCLIP、后处理和SAM分割连接起来:
# integration.py
import torch
import cv2
import numpy as np
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from segment_anything import sam_model_registry, SamPredictor
# 1. 加载已微调的BiomedCLIP模型
model = AutoModel.from_pretrained("./model", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("chuhac/BiomedCLIP-vit-bert-hf", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("chuhac/BiomedCLIP-vit-bert-hf", trust_remote_code=True)
# 2. 加载SAM模型
sam = sam_model_registry["vit_h"](checkpoint="segment-anything/sam_checkpoints/sam_vit_h_4b8939.pth")
sam.to(device)
predictor = SamPredictor(sam)
# 3. 自定义模型推理流程
def segment_with_text(image_path, text_prompt, post_process="kmeans"):
# 图像预处理
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 文本处理
text_ids = torch.tensor([tokenizer.encode(text_prompt, add_special_tokens=True)]).to(device)
image_feat = processor(images=image, return_tensors="pt")['pixel_values'].to(device)
# 生成显著性图
vmap = vision_heatmap_iba(text_ids, image_feat, model, vlayer=9, vbeta=1.0, vvar=1.0)
# 应用后处理
if post_process == "kmeans":
processed_map = apply_kmeans(vmap)
elif post_process == "crf":
processed_map = apply_crf(vmap, image)
else:
processed_map = vmap > 0.5
# 获取SAM提示
points, point_labels, boxes = get_prompts(processed_map)
# 使用SAM生成分割掩码
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=points,
point_labels=point_labels,
box=boxes,
multimask_output=False
)
return masks[0]
2. 比较实验
2.1 基线模型选择
与以下模型进行比较:
- 原始SAM(没有BiomedCLIP的引导)
- 使用原始CLIP (OpenAI)的MedCLIP-SAM
- 未微调的BiomedCLIP-SAM
- 最新的医学图像分割模型(如nnU-Net)
2.2 数据集选择
在以下医学影像数据集上进行评估:
- 乳腺超声(BUSI数据集)
- 脑肿瘤MRI(Brain Tumor数据集)
- 肺部X光(COVID-QU-Ex数据集)
- 肺部CT(Lung CT数据集)
2.3 评估指标
使用以下指标进行评估:
- Dice系数(DSC)- 评估区域重叠
- 归一化表面距离(NSD)- 评估边界准确性
- 准确率、召回率、精确率 - 评估分割质量
- 可视化结果对比
2.4 实验脚本
创建一个比较实验脚本:
#!/bin/bash
# compare_models.sh
# 数据集路径
DATASETS=("data/breast_tumors" "data/brain_tumors" "data/lung_xray" "data/lung_ct")
MODELS=("original_sam" "original_clip_sam" "biomedclip_sam_not_finetuned" "biomedclip_sam_finetuned" "nnunet")
for DATASET in "${DATASETS[@]}"; do
for MODEL in "${MODELS[@]}"; do
echo "Running model $MODEL on dataset $DATASET"
# 根据模型类型选择不同的命令
if [ "$MODEL" == "original_sam" ]; then
python segment-anything/segment_image.py --input ${DATASET}/images --output results/${MODEL}/${DATASET}
elif [ "$MODEL" == "original_clip_sam" ]; then
python saliency_maps/generate_saliency_maps.py --model-name CLIP --input-path ${DATASET}/images --output-path saliency_outputs/${MODEL}/${DATASET}
python postprocessing/postprocess_saliency_maps.py --input-path ${DATASET}/images --output-path coarse_outputs/${MODEL}/${DATASET} --sal-path saliency_outputs/${MODEL}/${DATASET} --postprocess kmeans
python segment-anything/prompt_sam.py --input ${DATASET}/images --mask-input coarse_outputs/${MODEL}/${DATASET} --output results/${MODEL}/${DATASET} --model-type vit_h --checkpoint segment-anything/sam_checkpoints/sam_vit_h_4b8939.pth --prompts boxes
elif [ "$MODEL" == "biomedclip_sam_not_finetuned" ]; then
python saliency_maps/generate_saliency_maps.py --model-name BiomedCLIP --finetuned false --input-path ${DATASET}/images --output-path saliency_outputs/${MODEL}/${DATASET}
python postprocessing/postprocess_saliency_maps.py --input-path ${DATASET}/images --output-path coarse_outputs/${MODEL}/${DATASET} --sal-path saliency_outputs/${MODEL}/${DATASET} --postprocess kmeans
python segment-anything/prompt_sam.py --input ${DATASET}/images --mask-input coarse_outputs/${MODEL}/${DATASET} --output results/${MODEL}/${DATASET} --model-type vit_h --checkpoint segment-anything/sam_checkpoints/sam_vit_h_4b8939.pth --prompts boxes
elif [ "$MODEL" == "biomedclip_sam_finetuned" ]; then
python saliency_maps/generate_saliency_maps.py --model-name BiomedCLIP --finetuned true --input-path ${DATASET}/images --output-path saliency_outputs/${MODEL}/${DATASET}
python postprocessing/postprocess_saliency_maps.py --input-path ${DATASET}/images --output-path coarse_outputs/${MODEL}/${DATASET} --sal-path saliency_outputs/${MODEL}/${DATASET} --postprocess kmeans
python segment-anything/prompt_sam.py --input ${DATASET}/images --mask-input coarse_outputs/${MODEL}/${DATASET} --output results/${MODEL}/${DATASET} --model-type vit_h --checkpoint segment-anything/sam_checkpoints/sam_vit_h_4b8939.pth --prompts boxes
elif [ "$MODEL" == "nnunet" ]; then
cd weak_segmentation
python -m nnunetv2.inference.predict_from_raw_data -i ${DATASET}/images -o results/${MODEL}/${DATASET} -d DATASET_ID -c 2d
cd ..
fi
# 评估结果
python evaluation/eval.py --gt_path ${DATASET}/test_masks --seg_path results/${MODEL}/${DATASET}
done
done
3. 消融实验
3.1 实验设计
消融实验将验证各组件对整体性能的贡献:
-
文本提示变体
- 测试不同的提示模板对分割性能的影响
- 简单vs复杂提示
- 一般vs特定疾病提示
-
BiomedCLIP层选择
- 测试不同的中间层作为特征提取源
- 测试不同超参数(vbeta, vvar)的影响
-
后处理方法
- 比较不同后处理算法的效果:Kmeans vs CRF vs 阈值法
- 测试多次后处理的组合效果
-
SAM提示类型
- 点提示vs框提示vs两者结合
- 测试点的数量对结果的影响
- 测试正负点提示的影响
3.2 实验脚本
创建消融实验脚本:
#!/bin/bash
# ablation_study.sh
# 测试文本提示变体
echo "Testing different text prompts"
PROMPTS=("breast_tumor_P2_prompts" "benign_breast_tumor_P3_prompts" "malignant_breast_tumor_P3_prompts")
for PROMPT in "${PROMPTS[@]}"; do
python integration.py --dataset data/breast_tumors --text_prompt_set $PROMPT --output ablation/text_prompts/$PROMPT
python evaluation/eval.py --gt_path data/breast_tumors/test_masks --seg_path ablation/text_prompts/$PROMPT
done
# 测试不同的BiomedCLIP层和超参数
echo "Testing different BiomedCLIP layers and hyperparameters"
LAYERS=(7 8 9)
VBETAS=(0.1 1.0 2.0)
VVARS=(0.1 1.0 2.0)
for LAYER in "${LAYERS[@]}"; do
for VBETA in "${VBETAS[@]}"; do
for VVAR in "${VVARS[@]}"; do
OUT_DIR="ablation/clip_params/layer${LAYER}_beta${VBETA}_var${VVAR}"
python integration.py --dataset data/breast_tumors --vlayer $LAYER --vbeta $VBETA --vvar $VVAR --output $OUT_DIR
python evaluation/eval.py --gt_path data/breast_tumors/test_masks --seg_path $OUT_DIR
done
done
done
# 测试不同后处理方法
echo "Testing different postprocessing methods"
METHODS=("kmeans" "crf" "threshold" "morphology")
for METHOD in "${METHODS[@]}"; do
python integration.py --dataset data/breast_tumors --post_process $METHOD --output ablation/postprocess/$METHOD
python evaluation/eval.py --gt_path data/breast_tumors/test_masks --seg_path ablation/postprocess/$METHOD
done
# 测试SAM提示类型和参数
echo "Testing different SAM prompt types"
PROMPT_TYPES=("points" "boxes" "both")
POINT_COUNTS=(5 10 20)
for TYPE in "${PROMPT_TYPES[@]}"; do
for COUNT in "${POINT_COUNTS[@]}"; do
OUT_DIR="ablation/sam_prompts/${TYPE}_${COUNT}"
python integration.py --dataset data/breast_tumors --prompt_type $TYPE --num_points $COUNT --output $OUT_DIR
python evaluation/eval.py --gt_path data/breast_tumors/test_masks --seg_path $OUT_DIR
done
done
4. 结果分析与可视化
4.1 定量分析
- 创建表格和图表比较不同模型和设置的性能
- 进行统计显著性测试,验证改进是否显著
- 分析不同医学模态上的表现差异
4.2 定性分析
- 生成分割结果的可视化对比图
- 显示成功案例和失败案例
- 分析边界准确性和小结构保留情况
4.3 可视化工具
创建可视化脚本:
# visualize_results.py
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import pandas as pd
import seaborn as sns
def plot_segmentation_results(image_path, gt_path, pred_paths, model_names, save_path):
# 加载图像和标签
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
# 设置图表
n_models = len(model_names)
plt.figure(figsize=(15, 8))
# 原始图像
plt.subplot(2, n_models+1, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')
# 真实标签
plt.subplot(2, n_models+1, n_models+2)
plt.imshow(image)
mask = np.ma.masked_where(gt == 0, gt)
plt.imshow(mask, alpha=0.5, cmap='jet')
plt.title('Ground Truth')
plt.axis('off')
# 各模型预测结果
for i, (pred_path, model_name) in enumerate(zip(pred_paths, model_names)):
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
plt.subplot(2, n_models+1, i+2)
plt.imshow(image)
mask = np.ma.masked_where(pred == 0, pred)
plt.imshow(mask, alpha=0.5, cmap='jet')
plt.title(model_name)
plt.axis('off')
# 计算和显示Dice系数
dice = np.sum(2 * (pred & gt)) / (np.sum(pred) + np.sum(gt))
plt.subplot(2, n_models+1, i+n_models+3)
plt.imshow(np.abs(pred.astype(float) - gt.astype(float)), cmap='hot')
plt.title(f'Error Map - Dice: {dice:.4f}')
plt.axis('off')
plt.tight_layout()
plt.savefig(save_path)
plt.close()
# 绘制比较实验结果
def plot_comparison_results(results_csv, save_path):
results = pd.read_csv(results_csv)
plt.figure(figsize=(12, 6))
sns.barplot(x='Model', y='DSC', data=results)
plt.title('Dice Coefficient Comparison Across Models')
plt.ylim(0, 1)
plt.savefig(save_path + '/dsc_comparison.png')
plt.figure(figsize=(12, 6))
sns.barplot(x='Model', y='NSD', data=results)
plt.title('Normalized Surface Distance Comparison Across Models')
plt.ylim(0, 1)
plt.savefig(save_path + '/nsd_comparison.png')
5. 实施时间表
阶段 | 任务 | 预计时间 |
---|---|---|
1 | SAM模型搭建和集成 | 1周 |
2 | 基线模型准备 | 3天 |
3 | 比较实验执行 | 1周 |
4 | 消融实验执行 | 1周 |
5 | 结果分析与可视化 | 3天 |
6 | 报告撰写与总结 | 2天 |
6. 潜在问题与解决方案
-
计算资源限制
- 解决方案:使用较小的SAM模型变体(vit_b),批量处理数据,利用预计算结果
-
标签质量问题
- 解决方案:实施数据清洗步骤,排除低质量样本
-
模型集成问题
- 解决方案:详细记录中间结果,确保每个组件单独工作正常
-
超参数调优
- 解决方案:使用网格搜索或贝叶斯优化自动寻找最佳参数