手撕注意力机制

import torch.nn as nn
import torch


class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        # 定义线性变换矩阵用于生成查询(Q),键(K)和值(V),输入和输出维度分别为d_in和d_out
        self.W_Q = nn.Linear(d_in, d_out, bias=qkv_bias)  # 查询矩阵
        self.W_K = nn.Linear(d_in, d_out, bias=qkv_bias)  # 键矩阵
        self.W_V = nn.Linear(d_in, d_out, bias=qkv_bias)  # 值矩阵

    def forward(self, x):
        # x的形状:(batch_size, seq_len, d_in)

        # 计算键(K)矩阵,形状为 (batch_size, seq_len, d_out)
        keys = self.W_K(x)

        # 计算值(V)矩阵,形状为 (batch_size, seq_len, d_out)
        values = self.W_V(x)

        # 计算查询(Q)矩阵,形状为 (batch_size, seq_len, d_out)
        queries = self.W_Q(x)

        # 计算注意力得分,先对键矩阵进行转置 (seq_len, d_out) -> (d_out, seq_len),
        # 然后矩阵乘法,attn_scores 形状为 (batch_size, seq_len, seq_len)
        attn_scores = queries @ keys.transpose(-2, -1)

        # 使用 softmax 计算注意力权重,并进行缩放,形状为 (batch_size, seq_len, seq_len)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)

        # 计算上下文向量,通过将注意力权重乘以值矩阵,形状为 (batch_size, seq_len, d_out)
        context_vec = attn_weights @ values

        # 返回上下文向量
        return context_vec


# 设置随机种子以确保可复现性
torch.manual_seed(123)

# 创建 SelfAttention 实例,设定输入和输出维度
d_in = 10  # 输入维度
d_out = 8  # 输出维度
self_attention = SelfAttention(d_in, d_out)

# 模拟输入数据 (batch_size, seq_len, d_in),这里我们创建一个随机张量
batch_size = 3
seq_len = 5
x = torch.randn(batch_size, seq_len, d_in)

# 调用 SelfAttention 层
output = self_attention(x)

# 输出上下文向量
print("上下文向量 (context vectors):")
print(output)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值