Transformer中的三种注意力机制

原本想学习一下交叉注意力,结果发现交叉注意力并不是一种特殊的注意力模型,只是Transformer中的一种注意力,因此这里都整理一下,方便以后查阅。

Self-Attention、Cross-Attention 和 Causal-Attention 是深度学习中注意力机制的三种重要变体,广泛应用于 Transformer 模型等架构中。它们在功能、输入来源和应用场景上各有不同。

在这里插入图片描述

Transformer中的三种注意力机制

1. Self-Attention(自注意力机制)

1.1 定义

Self-Attention 是一种注意力机制,允许模型在处理一个输入序列时,关注序列内部的每个元素之间的关系。每个元素既作为查询(Query),又作为键(Key)和值(Value),通过计算自身与其他元素的相关性来更新表示

Scaled Dot-Product Attention(缩放点积注意力)

  • 输入:单一序列 X X X,形状为 ( n , d ) (n, d) (n,d),其中 n n n 是序列长度, d d d 是嵌入维度。
  • 计算
    • 生成查询、键、值: Q = X W q Q = X W_q Q=XWq, K = X W k K = X W_k K=XWk, V = X W v V = X W_v V=XWv
    • 计算注意力分数: Score = Q K T d k \text{Score} = \frac{Q K^T}{\sqrt{d_k}} Score=dk QKT
    • 应用 softmax: Attention Weights = softmax ( Score ) \text{Attention Weights} = \text{softmax}(\text{Score}) Attention Weights=softmax(Score)
    • 加权求和: Output = Attention Weights ⋅ V \text{Output} = \text{Attention Weights} \cdot V Output=Attention WeightsV

体现如何计算注意力分数,关注Q、K、V计算公式。

在这里插入图片描述

Scaled Dot-Product Attention(缩放点积注意力)

Self Attention(自注意力)

对同一个序列,通过缩放点积注意力计算注意力分数,最终对值向量进行加权求和,从而得到输入序列中每个位置的加权表示。

表达的是一种注意力机制,如何使用缩放点积注意力对同一个序列计算注意力分数,从而得到同一序列中每个位置的注意力权重。

在这里插入图片描述

Self Attention(自注意力)

Multi-Head Self Attention(多头自注意力)

多个注意力头并行运行,每个头都会独立地计算注意力权重和输出,然后将所有头的输出拼接起来得到最终的输出。

强调的是一种实操方法,实际操作中我们并不会使用单个维度来执行单一的注意力函数,而是通过h=8个头分别计算,然后加权平均。这样为了避免单个计算的误差。

在这里插入图片描述

Multi-Head Self Attention(多头自注意力)

1.2 应用场景

  • 自然语言处理:如 BERT 的编码器,用于理解句子中词与词之间的上下文关系(例如,捕捉“bank”在“river bank”和“bank account”中的不同含义)。
  • 计算机视觉:如 Vision Transformer (ViT),用于建模图像中不同区域之间的关系。
  • 序列建模:适用于需要全局上下文的任务,如文本分类、语义表示学习。

1.3 优点 and 缺点

优点

  • 捕捉序列内部的长距离依赖关系。
  • 并行计算,相比 RNN 更高效。
  • 灵活性强,适用于多种任务。

缺点

  • 计算复杂度为 O ( n 2 ) O(n^2) O(n2),对长序列计算成本高。
  • 缺乏时间顺序约束,可能不适合生成任务。

2. Cross-Attention(交叉注意力机制)

2.1 定义

Cross-Attention 用于建模两个不同序列之间的关系。一个序列提供查询(Query),另一个序列提供键(Key)和值(Value)。它通常用于需要融合来自不同数据源或模态的信息的任务。

在这里插入图片描述

Cross Attention(交叉注意力)

2.2 工作原理

  • 输入
    • 查询序列 X q X_q Xq,形状为 ( n , d q ) (n, d_q) (n,dq),通常来自目标序列。
    • 键-值序列 X k v X_{kv} Xkv, 形状为 ( m , d k v ) (m, d_{kv}) (m,dkv),通常来自源序列。
  • 计算
    • 生成查询、键、值: Q = X q W q Q = X_q W_q Q=XqWq, K = X k v W k K = X_{kv} W_k K=XkvWk, V = X k v W v V = X_{kv} W_v V=XkvWv
    • 计算注意力分数: Score = Q K T d k \text{Score} = \frac{Q K^T}{\sqrt{d_k}} Score=dk QKT
    • 应用 softmax: Attention Weights = softmax ( Score ) \text{Attention Weights} = \text{softmax}(\text{Score}) Attention Weights=softmax(Score)
    • 加权求和: Output = Attention Weights ⋅ V \text{Output} = \text{Attention Weights} \cdot V Output=Attention WeightsV
  • 特点:查询和键-值来自不同序列,输出反映了查询序列对源序列的关注。
  • 多头机制:同样支持多头注意力以增强表达能力。

2.3 应用场景

  • 机器翻译:在 Transformer 解码器中,查询来自目标语言序列,键-值来自源语言序列(如将“Je t’aime”翻译为“I love you”时对齐“aime”和“love”)。
  • 视觉-语言模型:如 CLIP(对齐图像和文本特征)或 DALL·E(将文本描述融入图像生成)。
  • 问答系统:从文档(键-值)中提取与问题(查询)相关的答案。
  • 跨模态任务:如图像-文本检索、视频-文本对齐。

2.4 优点 and 缺点

优点

  • 有效融合来自不同序列或模态的信息。
  • 适合跨模态或跨语言任务。
  • 支持并行计算,效率高。

缺点

  • 计算复杂度为 O ( n ⋅ m ) O(n \cdot m) O(nm),长序列时仍可能成本高。
  • 依赖输入序列的质量,噪声可能影响对齐效果。

3. Causal-Attention(因果注意力机制)

3.1 定义

Causal-Attention(也称为 Masked Self-Attention)是自注意力的一种变体,通过引入掩码(Mask)限制模型只关注序列中当前及之前的元素,防止“看到未来”的信息。它通常用于自回归生成任务,确保生成过程符合时间顺序。

Predict The Next Word(预测下一个词)

模型通常需要基于已经生成的词来预测下一个词。这种特性要求模型在预测时不能“看到”未来的信息,以避免预测受到未来信息的影响。

请添加图片描述

Predict The Next Word(预测下一个词)

Masked Language Model(掩码语言模型)

遮盖一些词语来让模型学习预测被遮盖的词语,从而帮助模型学习语言规律。

在这里插入图片描述

Masked Language Model(掩码语言模型)

Autoregressive(自回归)

在生成序列的某个词时,解码器会考虑已经生成的所有词,包括当前正在生成的这个词本身。为了保持自回归属性,即模型在生成序列时只能基于已经生成的信息进行预测,需要防止解码器中的信息向左流动。换句话说,当解码器在生成第t个词时,它不应该看到未来(即第t+1, t+2,…等位置)的信息。

在这里插入图片描述

Autoregressive(自回归)

Causal Attention(因果注意力)

为了确保模型在生成序列时,只依赖于之前的输入信息,而不会受到未来信息的影响。Causal Attention通过掩盖(mask)未来的位置来实现这一点,使得模型在预测某个位置的输出时,只能看到该位置及其之前的输入。

在这里插入图片描述

Causal Attention(因果注意力)

3.2 工作原理

  • 输入:单一序列 X X X,形状为 ( n , d ) (n, d) (n,d)
  • 计算
    • 生成查询、键、值: Q = X W q Q = X W_q Q=XWq, K = X W k K = X W_k K=XWk, V = X W v V = X W_v V=XWv
    • 计算注意力分数: Score = Q K T d k \text{Score} = \frac{Q K^T}{\sqrt{d_k}} Score=dk QKT(加掩码)
    • 掩码:在 softmax 之前,对分数矩阵的上三角(未来位置)施加负无穷大(或 0),确保每个位置只关注自身及之前的位置。
    • 应用 softmax: Attention Weights = softmax ( Score with Mask ) \text{Attention Weights} = \text{softmax}(\text{Score with Mask}) Attention Weights=softmax(Score with Mask)
    • 加权求和: Output = Attention Weights ⋅ V \text{Output} = \text{Attention Weights} \cdot V Output=Attention WeightsV.
  • 特点:通过掩码实现因果约束,适合自回归生成任务。
  • 多头机制:同样支持多头注意力。

3.3 应用场景

  • 语言生成:如 GPT 系列模型,用于生成连贯的文本(例如,生成下一个词时只依赖之前的词)。
  • 机器翻译:在 Transformer 解码器中,用于确保生成目标序列时遵循时间顺序。
  • 语音生成:如 WaveNet,用于生成音频序列。

3.4 优点 and 缺点

优点

  • 适合自回归任务,保证生成过程的时间顺序。
  • 保留自注意力的并行计算优势。
  • 能捕捉序列中之前的上下文信息。

缺点

  • 计算复杂度仍为 O ( n 2 ) O(n^2) O(n2)
  • 无法利用序列中的未来信息,限制了某些任务的性能。

4. 对比总结

特性Self-AttentionCross-AttentionCausal-Attention
输入来源单一序列(Q, K, V 均来自同一序列)两个序列(Q 来自目标,K, V 来自源)单一序列(带因果掩码)
注意力范围全局(所有位置)目标序列关注源序列当前及之前位置
计算复杂度 O ( n 2 ) O(n^2) O(n2), n n n 为序列长度 O ( n ⋅ m ) O(n \cdot m) O(nm), n , m n, m n,m为序列长度 O ( n 2 ) O(n^2) O(n2), n n n 为序列长度
主要应用语义表示(如 BERT)、图像分类翻译、跨模态任务(如 CLIP)语言生成(如 GPT)、自回归任务
时间顺序约束有(通过掩码实现)
典型模型BERT, ViT RosformerTransformer (解码器), CLIP, DALL·EGPT, Transformer (解码器部分)

5. 代码示例(PyTorch)

以下是一个简单的 PyTorch 实现,展示三种注意力机制的区别:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, dim, num_heads):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out = nn.Linear(dim, dim)

    def forward(self, query, key, value, mask=None, causal=False):
        batch_size = query.size(0)
        seq_len_q, seq_len_k = query.size(1), key.size(1)

        # 线性变换
        Q = self.query(query)  # (batch_size, seq_len_q, dim)
        K = self.key(key)      # (batch_size, seq_len_k, dim)
        V = self.value(value)  # (batch_size, seq_len_k, dim)

        # 分割多头
        Q = Q.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len_k, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len_k, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # 因果掩码(仅 Causal-Attention 使用)
        if causal:
            mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()
            scores = scores.masked_fill(mask[None, None, :, :], float('-inf'))

        # 普通掩码(可选,用于 padding 或其他场景)
        if mask is not None:
            scores = scores.masked_fill(mask[None, None, :, :], float('-inf'))

        # Softmax
        attn_weights = F.softmax(scores, dim=-1)

        # 加权求和
        out = torch.matmul(attn_weights, V)

        # 合并多头
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1)
        out = self.out(out)

        return out

# 示例用法
dim = 512
num_heads = 8
batch_size = 32
seq_len_q, seq_len_k = 10, 20

# 准备输入
query = torch.randn(batch_size, seq_len_q, dim)
key = torch.randn(batch_size, seq_len_k, dim)
value = torch.randn(batch_size, seq_len_k, dim)

# 初始化注意力模块
attn = Attention(dim, num_heads)

# Self-Attention
self_attn_output = attn(query, query, query)
print("Self-Attention Output Shape:", self_attn_output.shape)  # (32, 10, 512)

# Cross-Attention
cross_attn_output = attn(query, key, value)
print("Cross-Attention Output Shape:", cross_attn_output.shape)  # (32, 10, 512)

# Causal-Attention
causal_attn_output = attn(query, query, query, causal=True)
print("Causal-Attention Output Shape:", causal_attn_output.shape)  # (32, 10, 512)

6. 总结

  • Self-Attention 适合需要捕捉序列内部全局关系的任务,如语义表示学习。
  • Cross-Attention 专为跨序列或跨模态任务设计,如翻译和多模态融合。
  • Causal-Attention 适用于自回归生成任务,确保时间顺序约束。

选择哪种机制取决于任务需求:

  • 如果需要全局上下文,使用 Self-Attention。
  • 如果涉及两个序列的交互,使用 Cross-Attention。
  • 如果需要生成序列并遵循时间顺序,使用 Causal-Attention。

参考资料

神经网络算法 - 一文搞懂Transformer中的三种注意力机制

第四篇:一文搞懂Transformer架构的三种注意力机制

三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值