import torch.nn as nn
import torch
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
# 定义线性变换矩阵用于生成查询(Q),键(K)和值(V),输入和输出维度分别为d_in和d_out
self.W_Q = nn.Linear(d_in, d_out, bias=qkv_bias) # 查询矩阵
self.W_K = nn.Linear(d_in, d_out, bias=qkv_bias) # 键矩阵
self.W_V = nn.Linear(d_in, d_out, bias=qkv_bias) # 值矩阵
def forward(self, x):
# x的形状:(batch_size, seq_len, d_in)
# 计算键(K)矩阵,形状为 (batch_size, seq_len, d_out)
keys = self.W_K(x)
# 计算值(V)矩阵,形状为 (batch_size, seq_len, d_out)
values = self.W_V(x)
# 计算查询(Q)矩阵,形状为 (batch_size, seq_len, d_out)
queries = self.W_Q(x)
# 计算注意力得分,先对键矩阵进行转置 (seq_len, d_out) -> (d_out, seq_len),
# 然后矩阵乘法,attn_scores 形状为 (batch_size, seq_len, seq_len)
attn_scores = queries @ keys.transpose(-2, -1)
# 使用 softmax 计算注意力权重,并进行缩放,形状为 (batch_size, seq_len, seq_len)
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
# 计算上下文向量,通过将注意力权重乘以值矩阵,形状为 (batch_size, seq_len, d_out)
context_vec = attn_weights @ values
# 返回上下文向量
return context_vec
# 设置随机种子以确保可复现性
torch.manual_seed(123)
# 创建 SelfAttention 实例,设定输入和输出维度
d_in = 10 # 输入维度
d_out = 8 # 输出维度
self_attention = SelfAttention(d_in, d_out)
# 模拟输入数据 (batch_size, seq_len, d_in),这里我们创建一个随机张量
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, d_in)
# 调用 SelfAttention 层
output = self_attention(x)
# 输出上下文向量
print("上下文向量 (context vectors):")
print(output)
手撕注意力机制
最新推荐文章于 2025-05-29 16:44:10 发布