1、图像分割(Image Segmentation)
1)定义:图像中的每个对象创建一个像素级掩码。目的是通过对每个像素进行分类来识别图像中不同对象的位置和形状。
2)度量标准:
(1)Binary cross-entropy:二分类的交叉熵损失。
- 在图像分割中,每个像素点都被视为一个独立的二分类问题(即目标类或背景类)。
- BCE计算预测值与真实标签之间的差异,并试图最小化这个差异。
(2)Dice coefficient:是预测值与之间之间的重叠的通用度量标准。
- 计算公式:2 * 预测值和真实值之间重叠面积 / 预测的和真实的相加的总面积。
- 此指标的范围是0到1,其中1表示完美和完全重叠。
(3)Intersection over Union:一个简单的度量标准,
-
计算公式:预测的和真实值之间的重叠区域 / 并集的区域。
-
此度量范围为0到1,其中0表示无重叠,而1表示预测的和地面真实情况之间的完全重叠。
2、UNet网络架构实现
U-Net网络常用于自然图像的分割。
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self,x):
return self.conv(x)
class UNet(nn.Module):
# 设置了in_channels=3, out_channels=1默认参数值
# 如果实例化有参数传进来则会覆盖默认参数值
def __init__(self, in_channels=3, out_channels=1,features=[64,128,256,512]):
super(UNet,self).__init__()
# 往上走
self.ups = nn.ModuleList()
# 往下走
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
# 从上往下走
for feature in features:
# 3-64 64-128 128-256 256-512
self.downs.append(DoubleConv(in_channels,feature))
in_channels = feature
# Down to Up part
for feature in reversed(features):
# 1024-512 拼接 1024-512 最后是64
self.ups.append(nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2))
self.ups.append(DoubleConv(feature*2,feature))
# 最下面的512-1024
self.bottleneck = DoubleConv(features[-1],features[-1]*2)
# 输出的64-1
self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size=1)
def forward(self,x):
skip_connections = []
# 从上往下走
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
# 最底部的层
x = self.bottleneck(x)
# 将skip_connections列表中的元素顺序反转
# 开始 结束 步长-1表示从后往前取元素
skip_connections = skip_connections[::-1]
# 从下往上走
for idx in range(0,len(self.ups),2):
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
# [2:]表示从索引2开始到序列末尾的所有元素
# 形状为[batch_size, channels, height, width]
# resize支持多种插值方法,默认情况下使用双线性插值
x = TF.resize(x,size=skip_connection.shape[2:])
concat_skip = torch.cat((skip_connection,x),dim=1)
x = self.ups[idx+1](concat_skip)
return self.final_conv(x)
def test():
x = torch.randn((3,1,160,160))
model = UNet(in_channels=1,out_channels=1)
preds = model(x)
print(preds.shape)
print(x.shape)
assert preds.shape == x.shape
if __name__ == "__main__":
test()
3、AttentionUNet网络架构实现
通俗易懂来讲就是:在原网络的基础上,在原特征图和上采样特征图进行拼接前,对下采样特征图和上采样特征图增加了一个Attention Gate网络结构的处理,处理完成之后,再进行拼接操作,其他结构是一样的。
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class AttentionGate(nn.Module):
"""
F_g: 上采样路径特征图的通道数
F_l: 跳跃连接特征图的通道数
F_int: 注意力机制内部特征图的通道数
"""
def __init__(self, F_g, F_l, F_int):
super(AttentionGate, self).__init__()
# W_g是用于处理来自上采样路径的特征图g的1x1卷积层
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
# W_x是用于处理来自跳跃连接的特征图x的1x1卷积层
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
# psi是用于生成注意力权重的1x1卷积层,后面跟着BatchNorm和Sigmoid激活函数
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
# ReLU 激活函数用于计算W_g(g) + W_x(x)之后的非线性变换
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
"""
g: 上采样路径中的特征图 ( gating signal )
x: 来自跳跃连接的特征图 ( skip connection )
return: 应用注意力机制后的加权特征图
"""
# 对gating signal应用1x1卷积和batchnorm
g1 = self.W_g(g)
# 对skip connection特征图应用1x1卷积和batchnorm
x1 = self.W_x(x)
# 如果两个特征图的空间尺寸不同,则调整g1到与x1相同的空间尺寸
# antialias=False 参数指定了是否应用抗锯齿处理
# 当antialias设置为 False 时,表示不应用抗锯齿处理
# [2:]这个切片操作提取了形状元组的后两个元素,即高度和宽度 (H, W)
if g1.shape[2:] != x1.shape[2:]:
g1 = TF.resize(g1, size=x1.shape[2:], antialias=False)
# 将两个特征图相加,并通过ReLU激活
psi = self.relu(g1 + x1)
# 生成注意力权重并调整尺寸以匹配x的尺寸
psi = self.psi(psi)
# 如果psi的空间尺寸不等于x的空间尺寸,则调整psi
if psi.shape[2:] != x.shape[2:]:
psi = TF.resize(psi, size=x.shape[2:], antialias=False)
# 返回经过注意力权重加权后的特征图
# *运算符进行的是逐元素乘法
return x * psi
class AttentionUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(AttentionUNet, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.attentions = nn.ModuleList()
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
self.ups.append(DoubleConv(feature * 2, feature))
# 注意力门,应用于跳跃连接前的特征图,所以输入和输出都是 feature 个通道
self.attentions.append(AttentionGate(feature, feature, feature // 2))
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections = skip_connections[::-1]
for idx in range(0, len(self.ups), 2):
x = self.ups[idx](x)
# idx // 2 将两个数相除后,返回结果的整数部分,舍弃小数部分
skip_connection = skip_connections[idx // 2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
# Apply attention gate before concatenation
# AttentionGate的forward方法,并传入了两个参数
skip_connection = self.attentions[idx // 2](g=x, x=skip_connection)
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.ups[idx + 1](concat_skip)
return self.final_conv(x)
if __name__ == "__main__":
x = torch.randn((3,1,160,160))
model = AttentionUNet(in_channels=1,out_channels=1)
preds = model(x)
print(preds.shape)
print(x.shape)
assert preds.shape == x.shape
# model = AttentionUNet(in_channels=3, out_channels=1)
# print(model)