kv cache实现
以MHA为例
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads,attention_size):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model 必须可以被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.attention_size = attention_size
self.WQ = nn.Linear(d_model, d_model)
self.WK = nn.Linear(d_model, d_model)
self.WV = nn.Linear(d_model, d_model)
self.linear = nn.Linear(d_model, d_model)
# kv缓存
self.k_cache = None
self.v_cache = None
def forward(self, x, past_k=None, past_v=None, mask=None):
Q = self.WQ(x) # [batch,frame,d]
K = self.WK(x) # [batch,frame,d]
V = self.WV(x) # [batch,frame,d]
batch_size = x.size(0)
seq_len = x.size(1)
# Initialize past_k and past_v if not provided
if past_k is None:
past_k = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
past_v = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
# Concatenate past K, V with current K, V
K = torch.cat([past_k, K], dim=1) # (N, seq_len + T, model_dim)
V = torch.cat([past_v, V], dim=1) # (N, seq_len + T, model_dim)
# Trim cache to the attention size 采用滑动窗口的方法来实现注意力
if K.size(1) > self.attention_size:
K = K[:, -self.attention_size:] # (N, attention_size, model_dim)
V = V[:, -self.attention_size:] # (N, attention_size, model_dim)
# Update caches
self.k_cache = K
self.v_cache = V
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [batch,frame,d]->[batch,head_num,frame,d//head_num]
K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [batch,frame,d]->[batch,head_num,attention_size,d//head_num]
V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [batch,frame,d]->[batch,head_num,attention_size,d//head_num]
d_k = K.size(-1)
scores = Q.matmul(K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) #[batch,head_num,frame,attention_size]
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = F.softmax(scores, dim=-1)
output = attention.matmul(V)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # [batch,head_num,frame,d/head_num]->[batch,frame,d]
output = self.linear(output)
return output