kv_cache实现

kv cache实现

以MHA为例

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads,attention_size):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model 必须可以被 num_heads 整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.attention_size = attention_size

        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.linear = nn.Linear(d_model, d_model)
        
        # kv缓存
        self.k_cache = None
        self.v_cache = None

    def forward(self, x, past_k=None, past_v=None, mask=None):
        
        Q = self.WQ(x)   # [batch,frame,d]
        K = self.WK(x)   # [batch,frame,d]
        V = self.WV(x)   # [batch,frame,d]
        
        batch_size = x.size(0)
        seq_len = x.size(1)
        
        # Initialize past_k and past_v if not provided
        if past_k is None:
            past_k = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
            past_v = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
        
        # Concatenate past K, V with current K, V
        K = torch.cat([past_k, K], dim=1)  # (N, seq_len + T, model_dim)
        V = torch.cat([past_v, V], dim=1)  # (N, seq_len + T, model_dim)
        
        # Trim cache to the attention size 采用滑动窗口的方法来实现注意力
        if K.size(1) > self.attention_size:
            K = K[:, -self.attention_size:]  # (N, attention_size, model_dim)
            V = V[:, -self.attention_size:]  # (N, attention_size, model_dim)
        
        # Update caches
        self.k_cache = K
        self.v_cache = V
        
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)   # [batch,frame,d]->[batch,head_num,frame,d//head_num]
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)   # [batch,frame,d]->[batch,head_num,attention_size,d//head_num]
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)   # [batch,frame,d]->[batch,head_num,attention_size,d//head_num]
        
        d_k = K.size(-1)
        scores = Q.matmul(K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))   #[batch,head_num,frame,attention_size]

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention = F.softmax(scores, dim=-1)
        output = attention.matmul(V)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)   # [batch,head_num,frame,d/head_num]->[batch,frame,d]
        output = self.linear(output)
        return output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值