【MHA】之 Attention Mask (with back & forward trace) / Causal Mask (with back trace)

本文介绍如何在TensorFlow中实现不同的Attention Mask,包括因果Mask、带N帧回溯的因果Mask及自定义Mask等,适用于多头注意力机制。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

multihead attention 中可添加attention mask,对输入进行范围限定,如

  • 因果mask (causal mask):即可限定只看当前点前面的数据,不可看该点之后的数据。从矩阵上看,causal mask类似一个倒三角,下半部分为1,上半部分为0;
  • 因果mask带n_backtrace:即可限定每一点尽可最多向前看n_backtrace帧。从矩阵上看,即在上面的倒三角中,再在最左侧截去一部分,使得其为宽度为n_backtrace的斜带1;
  • 前后向N帧:即在上述带有n_backtrace的causal mask上,再以同样方式,向前即向右扩展一个宽度为n_backtrace的斜带1;
  • 类似的,可根据自定义需求,自行设定mask

ref:
MHA TFA 的 实现: https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/layers/multihead_attention.py#L23-L298

1. Attention Mask or Causal Mask

可指定causal参数,来生成普通的attention mask 还是causal mask:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from tensorflow.keras.layers import Layer, Masking
import tensorflow as tf

class AttentionMask(Layer):
    """
	Computes attention mask.
	"""

    def __init__(self, causal, mask_value=-1e9):
        """
        Argument/s:
			causal - causal attention mask flag.
			mask_value - value used to mask components that aren't to be attended
				to (typically -1e9).
        """
        super(AttentionMask, self).__init__()
        self.causal = causal
        self.mask_value = mask_value
        if not isinstance(mask_value, float): raise ValueError("Mask value must be a float.")

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        causal_mask = self.lower_triangular_mask([1, max_seq_len, max_seq_len]) if self.causal else None
        ################
        logical_mask = self.merge_masks(causal_mask, seq_mask)
        unmasked = tf.zeros([batch_size, max_seq_len, max_seq_len])
        masked = tf.fill([batch_size, max_seq_len, max_seq_len], self.mask_value)
        att_mask = tf.where(logical_mask, unmasked, masked)
        seq_mask = tf.cast(seq_mask, tf.float32)
        return att_mask, seq_mask

    def lower_triangular_mask(self, shape):
        """
		Creates a lower-triangular boolean mask over the last 2 dimensions.

		Argument/s:
			shape - shape of mask.

		Returns:
			causal mask.
		"""
        row_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)
        col_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)
        return tf.math.greater_equal(row_index, col_index)

    def merge_masks(self, x, y):
        """
		Merges a sequence mask and a causal mask to make an attantion mask.

		Argument/s:
			x - mask.
			y - mask.

		Returns:
			Attention mask.
		"""
        if x is None: return y
        if y is None: return x
        return tf.math.logical_and(x, y)

测试:

if __name__ == '__main__':
    input = tf.ones([64, 526, 40])

    attention_mask = AttentionMask(causal=0)(input)
    causal_mask = AttentionMask(causal=1)(input)
    print('done')

实验结果为:

在这里插入图片描述

其中attention mask为:
在这里插入图片描述

causal mask为:
在这里插入图片描述

2. Causal Mask (with n_backtrce)

即带有n_backtrce的因果mask,继承上面的AttentionMask:

from tensorflow.keras.layers import Masking
import tensorflow as tf

from AttentionMask import AttentionMask

class AttentionMask_Causal_Backtrace(AttentionMask):
    """
	Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention.
	"""

    def __init__(self, causal, n_backtrace=None):
        """
		Argument/s:
			causal - causal attention mask flag.
			n_backtrace - (int) number of backtrace
		"""
        super().__init__(causal)
        self.causal = causal
        self.n_backtrace = n_backtrace

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        causal_mask = self.lower_triangular_mask([batch_size, max_seq_len, max_seq_len]) if self.causal else None
        bt_mask = self.backtrace_mask([1, max_seq_len, max_seq_len]) \
            if self.causal and self.n_backtrace else None
        ################
        logical_mask = self.merge_masks(causal_mask, seq_mask)
        logical_mask = self.merge_masks(logical_mask, bt_mask)
        att_mask = tf.cast(logical_mask, tf.float32)
        att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])
        return att_mask

    def backtrace_mask(self, shape):
        """
		Creates a lower-triangular boolean mask over the last 2 dimensions.

		Argument/s:
			shape - shape of mask.

		Returns:
			causal mask.
		"""
        row_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)
        col_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)
        return tf.math.less_equal(row_index, col_index + self.n_backtrace)

测试:

if __name__ == '__main__':
    input = tf.ones([64, 526, 40])

    causal_mask = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=None)(input)
    causal_mask_backtrace = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=50)(input)
    print('done')

实验结果:
在这里插入图片描述

其中causal_mask为:
在这里插入图片描述

在这里插入图片描述

causal_mask_backtrace为:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

测试样例2:

causal_mask_backtrace = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=5)(input)

在这里插入图片描述

3. Attention Mask with backstrace and forwardtrace

from tensorflow.keras.layers import Masking
import tensorflow as tf

from AttentionMask import AttentionMask

class AttentionMask_Backtrace_Forwardtrace(AttentionMask):
    """
	Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention.
	"""

    def __init__(self, causal, n_backtrace=None, n_forwardtrace=None):
        """
		Argument/s:
			causal - causal attention mask flag.
			n_backtrace - (int) number of backtrace
		"""
        super().__init__(causal)
        self.causal = causal
        self.n_backtrace = n_backtrace
        self.n_forwardtrace = n_forwardtrace

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        bt_ft_mask = self.backtrace_forwardtrace_mask([1, max_seq_len, max_seq_len]) \
            if self.n_backtrace and self.n_forwardtrace else None
        ################
        logical_mask = self.merge_masks(bt_ft_mask, seq_mask)
        att_mask = tf.cast(logical_mask, tf.float32)
        att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])
        return att_mask

    def backtrace_forwardtrace_mask(self, shape):
        """
		Creates a lower-triangular boolean mask over the last 2 dimensions.

		Argument/s:
			shape - shape of mask.

		Returns:
			causal mask.
		"""
        row_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)
        col_index = tf.math.cumsum(
            tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)
        bt_mask = tf.math.less_equal(row_index, col_index + self.n_backtrace)
        ft_mask = tf.math.greater_equal(row_index + self.n_forwardtrace, col_index)
        bt_ft_mask = self.merge_masks(bt_mask, ft_mask)
        return bt_ft_mask

测试:

if __name__ == '__main__':
    input = tf.ones([64, 526, 40])

    bt_ft_mask = AttentionMask_Backtrace_Forwardtrace(causal=0, n_backtrace=2, n_forwardtrace=5)(input)
    print('done')

实验结果:

在这里插入图片描述

bt_ft_mask为:

在这里插入图片描述

4. Customized Mask

class AttentionMask_Customization(AttentionMask):
    """
	Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention.
	"""

    def __init__(self, causal, trace=None):
        """
		Argument/s:
			causal - causal attention mask flag.
			n_backtrace - (int) number of backtrace
		"""
        super().__init__(causal)
        self.causal = causal
        self.trace = trace

    def call(self, inp):
        """
		Compute attention mask.

		Argument/s:
			inp - used to compute sequence mask.

		Returns:
			Attention mask.
		"""
        batch_size = tf.shape(inp)[0]
        max_seq_len = tf.shape(inp)[1]
        flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)
        seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))
        ### HERE !!! ###
        customized_mask = self.customized_mask(batch_size, max_seq_len, self.trace)
        ################
        logical_mask = self.merge_masks(customized_mask, seq_mask)
        att_mask = tf.cast(logical_mask, tf.float32)
        att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])
        return att_mask

    @tf.function
    def customized_mask(self, batchsize, max_length, trace):
        mask = tf.ones(shape=[batchsize, trace, trace], dtype=tf.int32, name="row")
        shape_pad = int(max_length - trace)
        mask = tf.pad(mask, paddings=[[0, 0], [shape_pad, 0], [shape_pad, 0]])
        mask = tf.cast(mask, dtype=bool)
        return mask

测试:

if __name__ == '__main__':
    input = tf.ones([64, 526, 40])

    customized_mask = AttentionMask_Customization(causal=1, trace=5)(input)
    print('done')

实验结果:
在这里插入图片描述

### MHA Multi-Head Attention Mask Implementation Code Example In the context of implementing masking within a multi-head attention mechanism (MHA), it is essential to understand how masks are applied during the computation process. Masks play an important role by allowing certain positions not to contribute to the output, which can be particularly useful in tasks like language modeling where future tokens should not influence current predictions. Below is a Python code snippet demonstrating how one might implement masked multi-head self-attention using PyTorch: ```python import torch import torch.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_k = d_model // num_heads self.num_heads = num_heads self.q_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.out = nn.Linear(d_model, d_model) def forward(self, q, k, v, mask=None): batch_size = q.size(0) # Linear transformations k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # Calculate compatibility scores and apply scaling factor scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # Apply mask if provided if mask is not None: mask = mask.unsqueeze(1) # For broadcasting across heads dimension scores = scores.masked_fill(mask == 0, float('-inf')) # Softmax along last dimension before applying dropout attn_weights = torch.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) # Concatenate outputs from different heads together again concat_output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) # Final linear transformation after concatenation step final_output = self.out(concat_output) return final_output, attn_weights ``` This implementation includes support for optional masking through `mask` parameter passed into the `forward()` method[^1]. When present, this mask will prevent specific elements from contributing towards computing attention weights via setting their corresponding score values very low (`float('-inf')`) so they effectively get ignored when calculating softmax probabilities later on. --related questions-- 1. How does masking work inside transformer models? 2. What modifications would need to occur in order to adapt this code sample specifically for GQA or MQA implementations instead of standard MHA? 3. Can you explain why we use `-inf` value while applying mask in attention mechanisms? 4. Is there any difference between padding masks versus subsequent masks used within transformers architecture? If yes, what distinguishes them apart? 5. Could you provide more details about ziya_finetune project mentioned earlier regarding inference quantization examples?
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值