【附代码】登顶Nature,准备起飞!KAN-UNet又杀疯了

KAN-UNet遥感应用

Nature大子刊Nature reviews electrical engineering发布了综述,深度学习在遥感树木监测的应用:High-resolution Sensors and Deep Learning Models for Tree Resource Monitoring
用于树木资源监测的高分辨率传感器和深度学习模型

论文已打包好 还有一份研究生及SCI论文攻略包

本文总结了高分辨率卫星和传感器技术的发展,结合人工智能技术(CNN/ViT/UNet)的应用,推动了树木三维结构(如树冠高度和木材体积)的精准监测。

模型其实也不复杂,是基于UNet的,如图所示

那我们能不能基于最近大火的KAN,把最底层的MLP换成KAN,从而设计出更好的模型KAN-UNet,从而也发一篇Nature呢?(开玩笑) 

模型结构 

模型结构也很容易理解,就是把核心层换成KAN了:

  1. 「特征提取阶段(黄色块):」

    • 编码路径使用多层卷积块,每层的分辨率逐渐降低(如 H/2H/2H/2, H/4H/4H/4, H/8H/8H/8 等),通道数量 CiC_iCi逐渐增加,用于提取多层次的特征。

    • 解码路径则逐层上采样,同时融合来自编码路径的特征。

  2. 「Tokenized KAN Block(绿色块):」

    • 中间特征经过「Tokenization」处理后送入 「KAN 层」(知识增强层),通过深度学习网络建模特征之间的复杂关系。

    • KAN 层基于结构化知识进行多阶段特征交互,如图中所示分为三个阶段 Φ1,Φ2,Φ3

    • 该模块还包括 「Depthwise Convolution(深度卷积)」 和 「Layer Normalization(层归一化)」,进一步提升处理能力。

  3. 「时间嵌入(Time Embedding):」

    • 当模型应用于扩散式 U-KAN(Diffusion U-KAN)时,会注入时间嵌入(红色圆圈表示),以实现动态特征表征。

  4. 「模块连接方式:」

    • 整个网络采用经典的 U-Net 跳跃连接设计,将编码路径的中间层直接与解码路径对接,以保持特征细节和语义信息的完整性。

总体来看,U-KAN 模型通过整合 U-Net 的多尺度分割能力与 KAN 的知识增强特性,在处理高复杂度分割任务时具备强大的特征表达能力。

代码

不多说了,直接放代码,感兴趣的同学自己也可以发一篇Nature

import torch
from torch import nn
import torch.nn.functional as F

from KANUmain.src.fastkanconv import FastKANConvLayer

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, device):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device

        self.double_conv = nn.Sequential(
            FastKANConvLayer(self.in_channels, self.out_channels//2, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
            nn.BatchNorm2d(self.out_channels//2),
            nn.ReLU(inplace=True),
            FastKANConvLayer(self.out_channels//2, self.out_channels, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, device='mps'):
        super().__init__()
        self.device = device
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels, device=self.device)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, device='mps'):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, device=device)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = FastKANConvLayer(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class KANU_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True, device='mps'):
        super(KANU_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.device = device

        self.channels = [64, 128, 256, 512, 1024]

        self.inc = (DoubleConv(n_channels, 64, device=self.device))
        
        self.down1 = (Down(self.channels[0], self.channels[1], self.device))
        self.down2 = (Down(self.channels[1], self.channels[2], self.device))
        self.down3 = (Down(self.channels[2], self.channels[3], self.device))
        factor = 2 if bilinear else 1
        self.down4 = (Down(self.channels[3], self.channels[4] // factor, self.device))
        self.up1 = (Up(self.channels[4], self.channels[3] // factor, bilinear, self.device))
        self.up2 = (Up(self.channels[3], self.channels[2] // factor, bilinear, self.device))
        self.up3 = (Up(self.channels[2], self.channels[1] // factor, bilinear, self.device))
        self.up4 = (Up(self.channels[1], self.channels[0], bilinear, self.device))
        self.outc = (OutConv(self.channels[0], n_classes))

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        #Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    # print(device)
    model = KANU_Net(6, 6, 'mps').to(device)
    # print(model)
    x = torch.randn((1, 6, 224, 224)).to(device)
    print(model(x).shape)

有以下论文写作问题的可以扫码详聊

前沿顶会、期刊论文、综述文献浩如烟海,不知道学习路径,无从下手?

没时间读、不敢读、不愿读、读得少、读不懂、读不下去、读不透彻一篇完整的论文?

CVPR、ICCV、ECCV、ICLR、NeurlPS、AAAI……想发表顶会论文,找不到创新点?

读完论文,仍旧无法用代码复现……

然而,导师时常无法抽出时间指导,想写论文却无人指点……

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值