import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out 必须能够被 num_heads 整除"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # 每个头的维度
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 查询的线性变换
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # 键的线性变换
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 值的线性变换
self.out_proj = nn.Linear(d_out, d_out) # 输出的线性变换
self.dropout = nn.Dropout(dropout) # Dropout层
def forward(self, queries, keys, values, mask=None):
"""
queries: (batch_size, seq_len_q, d_in)
keys: (batch_size, seq_len_kv, d_in)
values: (batch_size, seq_len_kv, d_in)
mask: (batch_size, num_heads, seq_len_q, seq_len_kv) 或 None
"""
b, seq_len_q, d_in = queries.shape
_, seq_len_kv, _ = keys.shape
# 线性变换
queries = self.W_query(queries) # (b, seq_len_q, d_out)
keys = self.W_key(keys) # (b, seq_len_kv, d_out)
values = self.W_value(values) # (b, seq_len_kv, d_out)
# 重塑为多头形式
queries = queries.view(b, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len_q, head_dim)
keys = keys.view(b, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len_kv, head_dim)
values = values.view(b, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len_kv, head_dim)
# 计算注意力得分
attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) # (b, num_heads, seq_len_q, seq_len_kv)
attn_scores = attn_scores / (self.head_dim ** 0.5) # 缩放
# 应用mask(如果提供)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
# 计算注意力权重
attn_weights = torch.softmax(attn_scores, dim=-1) # (b, num_heads, seq_len_q, seq_len_kv)
attn_weights = self.dropout(attn_weights) # Dropout
# 计算上下文向量
context_vec = torch.matmul(attn_weights, values) # (b, num_heads, seq_len_q, head_dim)
context_vec = context_vec.transpose(1, 2).contiguous() # (b, seq_len_q, num_heads, head_dim)
context_vec = context_vec.view(b, seq_len_q, self.d_out) # (b, seq_len_q, d_out)
# 输出投影
context_vec = self.out_proj(context_vec) # (b, seq_len_q, d_out)
return context_vec
# 设置随机种子以确保结果可复现
torch.manual_seed(123)
# 模拟查询和键值序列 (batch_size, seq_len_q/kv, d_in)
batch_size = 3
seq_len_q = 5 # 查询序列长度
seq_len_kv = 7 # 键和值的序列长度
d_in = 10 # 输入维度
d_out = 8 # 输出维度
num_heads = 2 # 多头数量
dropout = 0.1 # dropout 概率
# 创建 CrossAttention 实例
cross_attention = CrossAttention(d_in, d_out, dropout, num_heads)
# 输入的随机张量
queries = torch.randn(batch_size, seq_len_q, d_in)
keys = torch.randn(batch_size, seq_len_kv, d_in)
values = torch.randn(batch_size, seq_len_kv, d_in)
# 如果需要应用mask,例如因果遮蔽,可以创建mask
# 这里我们创建一个全1的mask,不进行遮蔽
mask = torch.ones(batch_size, num_heads, seq_len_q, seq_len_kv)
# 调用 forward 函数,输出上下文向量
output = cross_attention(queries, keys, values, mask=mask)
# 输出上下文向量的形状和内容
print("上下文向量的形状:", output.shape)
print("上下文向量 (context vectors):")
print(output)
手撕交叉注意力机制
于 2024-10-07 20:28:18 首次发布