手撕交叉注意力机制

import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out 必须能够被 num_heads 整除"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # 每个头的维度

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # 查询的线性变换
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)    # 键的线性变换
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # 值的线性变换
        self.out_proj = nn.Linear(d_out, d_out)               # 输出的线性变换
        self.dropout = nn.Dropout(dropout)                    # Dropout层

    def forward(self, queries, keys, values, mask=None):
        """
        queries: (batch_size, seq_len_q, d_in)
        keys: (batch_size, seq_len_kv, d_in)
        values: (batch_size, seq_len_kv, d_in)
        mask: (batch_size, num_heads, seq_len_q, seq_len_kv) 或 None
        """
        b, seq_len_q, d_in = queries.shape
        _, seq_len_kv, _ = keys.shape

        # 线性变换
        queries = self.W_query(queries)  # (b, seq_len_q, d_out)
        keys = self.W_key(keys)          # (b, seq_len_kv, d_out)
        values = self.W_value(values)    # (b, seq_len_kv, d_out)

        # 重塑为多头形式
        queries = queries.view(b, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)  # (b, num_heads, seq_len_q, head_dim)
        keys = keys.view(b, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)      # (b, num_heads, seq_len_kv, head_dim)
        values = values.view(b, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)  # (b, num_heads, seq_len_kv, head_dim)

        # 计算注意力得分
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1))  # (b, num_heads, seq_len_q, seq_len_kv)
        attn_scores = attn_scores / (self.head_dim ** 0.5)            # 缩放

        # 应用mask(如果提供)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        # 计算注意力权重
        attn_weights = torch.softmax(attn_scores, dim=-1)            # (b, num_heads, seq_len_q, seq_len_kv)
        attn_weights = self.dropout(attn_weights)                   # Dropout

        # 计算上下文向量
        context_vec = torch.matmul(attn_weights, values)            # (b, num_heads, seq_len_q, head_dim)
        context_vec = context_vec.transpose(1, 2).contiguous()      # (b, seq_len_q, num_heads, head_dim)
        context_vec = context_vec.view(b, seq_len_q, self.d_out)    # (b, seq_len_q, d_out)

        # 输出投影
        context_vec = self.out_proj(context_vec)                    # (b, seq_len_q, d_out)

        return context_vec

# 设置随机种子以确保结果可复现
torch.manual_seed(123)

# 模拟查询和键值序列 (batch_size, seq_len_q/kv, d_in)
batch_size = 3
seq_len_q = 5  # 查询序列长度
seq_len_kv = 7  # 键和值的序列长度
d_in = 10      # 输入维度
d_out = 8      # 输出维度
num_heads = 2  # 多头数量
dropout = 0.1  # dropout 概率

# 创建 CrossAttention 实例
cross_attention = CrossAttention(d_in, d_out, dropout, num_heads)

# 输入的随机张量
queries = torch.randn(batch_size, seq_len_q, d_in)
keys = torch.randn(batch_size, seq_len_kv, d_in)
values = torch.randn(batch_size, seq_len_kv, d_in)

# 如果需要应用mask,例如因果遮蔽,可以创建mask
# 这里我们创建一个全1的mask,不进行遮蔽
mask = torch.ones(batch_size, num_heads, seq_len_q, seq_len_kv)

# 调用 forward 函数,输出上下文向量
output = cross_attention(queries, keys, values, mask=mask)

# 输出上下文向量的形状和内容
print("上下文向量的形状:", output.shape)
print("上下文向量 (context vectors):")
print(output)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值