RT-DETR网络结构
时间: 2025-02-02 16:05:33 浏览: 236
### RT-DETR 网络架构详解
#### 1. 基础组件概述
RT-DETR 是一种基于 Transformer 架构的目标检测模型,其核心在于利用自注意力机制来捕捉图像特征间的全局依赖关系。这种设计区别于传统的卷积神经网络(CNN),后者主要依靠局部感受野提取特征[^1]。
#### 2. 输入表示
对于输入图像,RT-DETR 首先将其划分为固定大小的 patches,并通过线性投影转换成一系列 tokens。这些 tokens 被送入后续的编码器模块之前还会加上位置嵌入(position embedding),以便保留空间信息。
#### 3. 编码器部分
编码器由多个堆叠的标准 Transformer 层组成,每层内部包含了多头自注意(multi-head self-attention, MHSA)子层以及前馈神经网络(feed-forward network, FFN)子层。MHSA 子层负责计算各个 token 对其他所有 token 的关注度;FFN 则用于进一步变换特征向量。为了加速收敛并稳定训练过程,在这两个子层之后都加入了残差连接(residual connection) 和层归一化(layer normalization)。
#### 4. 解码器部分
解码器同样采用了类似的 Transformer 结构,不过这里引入了一个额外的概念——查询(query)。初始状态下,一组可学习的位置无关 queries 将被创建出来作为预测框(anchor box)候选者的基础。随着迭代次数增加,queries 不断更新直至最终形成精确的对象边界框估计。值得注意的是,与 DETR 中固定的 object query 数目不同,RT-DETR 支持动态调整数量以适应不同规模的任务需求。
#### 5. 头部设计
在最后一轮解码完成后,得到的结果会被传递给两个平行分支:一个是用来分类物体类别的类别头部(classification head),另一个则是回归具体坐标的坐标头部(regression head)。两者均采用全连接层实现,并且会附加 Softmax 或 Sigmoid 函数完成概率分布输出或数值范围约束操作。
```python
import torch.nn as nn
class ClassificationHead(nn.Module):
def __init__(self, input_dim, num_classes):
super(ClassificationHead, self).__init__()
self.fc = nn.Linear(input_dim, num_classes)
def forward(self, x):
return torch.softmax(self.fc(x), dim=-1)
class RegressionHead(nn.Module):
def __init__(self, input_dim, output_dim=4): # 默认为bbox参数个数
super(RegressionHead, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
```
阅读全文
相关推荐


















