Look into Hugging Face MambaModel

Library Link: https://huggingface.co/docs/transformers/main/en/model_doc/mamba#transformers.MambaModel

By looking into Hugging Face’s Transformers.MambaModel library, we learn the structure and implementation of Mamba.

MambaModel

Despite the part relating to different control flow and customer configures, the main part in MambaModel.forward is as follows:

		hidden_states = inputs_embeds
        all_hidden_states = () if output_hidden_states else None
        for mixer_block in self.layers:
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
                )
            else:
                hidden_states = mixer_block(
                    hidden_states,
                    cache_params=cache_params,
                    cache_position=cache_position,
                    attention_mask=attention_mask,
                )

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

        hidden_states = self.norm_f(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)

        return MambaOutput(
            last_hidden_state=hidden_states,
            cache_params=cache_params if use_cache else None,
            hidden_states=all_hidden_states,
        )

The main computation is a for loop, which runs some mixer_blocks to process the hidden_states. The self.layers is defined as

self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])

which is a list of MambaBlock. The structure of MambaBlock will be introduced later.

MambaRMSNorm

After the iteration, we can see that the hidden_states is passed through a norm layer, defined as

self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

RMSNorm is a widely used normalization in LLMs, because of its simplicity. The norm layer is defined as

class MambaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"

The corresponding formular is
y i = x i 1 n ∑ i = 1 n x i 2 + ϵ × γ , y_i = \frac{x_i}{\sqrt{\frac{1}{n}\sum_{i=1}^n x_i^2+\epsilon}} \times \gamma, yi=n1i=1nxi2+ϵ xi×γ,
where γ \gamma γ is a learnable parameter. Compared with the classic layer normalization y i = x i − E ( x ) V a r ( x ) + ϵ × γ + β y_i = \frac{x_i - E(x)}{\sqrt{Var(x)+\epsilon}}\times \gamma + \beta yi=Var(x)+ϵ xiE(x)×γ+β, it simplified the computation and removes trainable parameter β \beta β but maintains comparable performance. Therefore, RMSNorm becomes popular in nowadays LLMs.

MambaBlock

Now, we look into the fundamental part of Mamba, the MambaBlock. It is simple and straightforward. The Mamba block is just a Mamba mixer with a residual connection.

class MambaBlock(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.residual_in_fp32 = config.residual_in_fp32
        self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
        self.mixer = MambaMixer(config, layer_idx=layer_idx)

    def forward(
        self,
        hidden_states,
        cache_params: Optional[MambaCache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
    ):
        residual = hidden_states
        hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
        if self.residual_in_fp32:
            residual = residual.to(torch.float32)

        hidden_states = self.mixer(
            hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
        )
        hidden_states = residual + hidden_states
        return hidden_states

MambaMixer

Here comes the most important core of Mamba——MambaMixer. We will go through it step by step.

To begin with, let’s set the shape of input_states to ( B , L , D ) (B, L, D) (B,L,D), where B B B represents the batch size, L L L represents the sequence padding length and D D D represents the hidden embedding size. In Mamba, there is also a special intermediate size N N N, which is used to enhance the representative ability of model. By default, N = 2 D N=2D N=2D, which can be seen in the source code of MambaConfig.

Input projection

self.in_proj is used to project the input states into intermediate states hidden_states and also gate signal gate, so the output dimension of self.in_proj is 2 N 2N 2N. gate will be used as the fate signal to control, or say, select the output.

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2)  # (B,L,D)->(B,L,2N)->(B,2N,L)
hidden_states, gate = projected_states.chunk(2, dim=1)  # (B,2N,L)->(B,N,L),(B,N,L)

Obviously, the shape of hidden_states and gate are all ( B , N , L ) (B, N, L) (B,N,L).

Convolution sequence transformation

self.conv1d = nn.Conv1d(
            in_channels=self.intermediate_size,
            out_channels=self.intermediate_size,
            bias=config.use_conv_bias,
            kernel_size=config.conv_kernel,
            groups=self.intermediate_size,
            padding=config.conv_kernel - 1,
        )

The convolution layer is defined as shown above. It sets a 1D depth-wise convolution, whose out_channels and groups are all set to N N N. So, basically there are N N N convolutional cores for N N N channels in the embedding respectively. The default kernel_size is 4, noted as K K K, so each token will aggregate information from its prior 3 tokens. The padding is set to K − 1 K-1 K1 to make sure each token only aggregates information from tokens before itself.

hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])  # (B,N,L)

Activate the output of convolution layer, the default option is SiLU. Note that the padding argument will add padding on both sides of the input, so hidden_states will be padded to ( B , N , L + 2 ( K − 1 ) ) (B, N, L+2(K-1)) (B,N,L+2(K1)) first, and then be convoluted to ( B , N , L + K − 1 ) (B, N, L+K-1) (B,N,L+K1), where the last K − 1 K-1 K1 position carries redundant information from the initially last K − 1 K-1 K1 tokens, so we need to remove them by [..., :seq_len].

Compute parameters of SSM

# selective projection used to make dt, B and C input dependent
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)

By default, shape of the state space latents ssm_state_size, noted as S S S, is set to 16 and rank of the discretization projection matrix time_step_rank is set to ⌊ D 16 ⌋ \lfloor\frac{D}{16}\rfloor 16D.

# 3.a. Selection:  [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))  # (B,L,N)->(B,L,D/16+S+S)
time_step, B, C = torch.split(
    ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)  # (B,L,N/16+S+S)->(B,L,D/16),(B,L,S),(B,L,S)
discrete_time_step = self.dt_proj(time_step)                                    # (B,L,D/16)->(B,L,N)
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # (B,L,N)->(B,N,L)

In 3.a, the model computes the Δ , B \Delta, B Δ,B and C C C in Mamba formula, where B , C B, C B,C are derived from hidden_states via a linear layer, Δ \Delta Δ is firstly projected to a given rank and then projected back to N N N and finally passed through an activation layer.

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]  # (S,)->(1,S)
A = A.expand(self.intermediate_size, -1).contiguous()  # (1,S)->(N,S)

self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))  # (N,)

In contrast, A A A, and D D D are trainable parameters of each Mixer. They are discretized in 3.b.
You might also notice that A A A is actually stored as A_log, but why? This is to guarantee that discrete A A A, or say A ‾ \overline{A} A, is within ( 0 , 1 ) (0,1) (0,1). Since the value domain of A_log is R \mathbb{R} R, the A A A’s value domain will become ( − ∞ , 0 ) (-\infty,0) (,0). Because Δ > 0 \Delta>0 Δ>0, as it is activated by softplus, then A ‾ = e Δ A \overline{A}=e^{\Delta A} A=eΔA is with ( 0 , 1 ) (0,1) (0,1).

# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float())                                              # (N,S)
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # (N,S)->(1,N,1,S)*(B,N,L,1)->(B,N,L,S)
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()       # (B,N,L,1)*(B,1,L,S)->(B,N,L,S)
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()                    # (B,N,L,S)*(B,N,L,1)->(B,N,L,S)

In 3.b, A A A and B B B are both discretized by Δ \Delta Δ, where A ‾ = e Δ A , B ‾ = Δ B \overline{A}=e^{\Delta A}, \overline{B}=\Delta B A=eΔA,B=ΔB, and deltaB_u is B ‾ \overline B B multiply input hidden_states X X X, which represents the information coming from each input token.

Now, we are going to recurrent in 3.c. Keep the formula in mind:
s i = A ‾ i s i − 1 + B ‾ i x i o i = C i s i s_i = \overline{A}_is_{i-1}+\overline{B}_ix_i\\ o_i = C_i s_i si=Aisi1+Bixioi=Cisi

ssm_state is initialized as 0 \textbf{0} 0 or restored from cache in Step 2. Convolution sequence transformation.

 # 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
    ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]      # (B,N,S)*(B,N,S)
    scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # (B,N,S)*(B,1,S,1)->(B,N,1)
    scan_outputs.append(scan_output[:, :, 0])  # (B,N)

General speaking, in the for loop, the SSM iterates from left to right, aggregates the information from current and previous tokens, just like an RNN. Following the formula above, A ‾ i \overline{A}_i Ai controls how much information is kept from SSM previous states and B ‾ i \overline{B}_i Bi controls the information flow from current token x i x_i xi. Then C i C_i Ci produces the output based on current SSM states.

However, if look in detail, we can find there are a lot of tricks of the shapes of these tensors, which represents the mathematical mechanism of Mamba. The recurrency is actually doing SSM for each intermedia dimension N N N, we can see that A ‾ , B ‾ \overline{A}, \overline{B} A,B are both N N N-dependent, but C C C is not, which is quite interesting.

After collecting the token-by-token outputs, we stack them together and add residual connection controlled by D D D to it. We can also notice that D D D is L L L-independent, similar to C C C. The parameters that control the output are all seemed to be simplified.

Then, gate is activated and multiplied on the scan_output, acting like another selective process.

scan_output = torch.stack(scan_outputs, dim=-1)                                # (B,N,L)
scan_output = scan_output + (hidden_states * self.D[None, :, None])  # (B,N,L)+(B,N,L)*(1,N,1)->(B,N,L)
scan_output = (scan_output * self.act(gate))  # (B,N,L)*(B,N,L)

p.s. I found there is a typo in the comment of the shape of tensor, so I fixed it and submit a PR. Let’s see if I can make a contribution to the great Hugging Face Transformers!

Final projection

self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2))  # (B,N,L)->(B,L,N)->(B,L,D)
return contextualized_states

Finally, the scan_output is passed to a linear projection to project it from N N N back to D D D.

Parameter illustration

Obviously, there are a lot of parameters in Mamba Mixer, and it’s often confusing to figure out their relationships, hope the following illustration can help, where all trainable parameters are marked as red.

在这里插入图片描述

We can see that A_log, D, convolution layer and all projection layers are trainable parameters of Mamba Mixer. Δ , B , C \Delta, B,C Δ,B,C is derived from input_states. A ‾ , B ‾ \overline A, \overline B A,B is discretized by Δ \Delta Δ to adjust the memory rate according to input_states. The larger Δ \Delta Δ is, the closer A ‾ \overline A A to 0, the larger the B ‾ \overline B B, which means the ssm_state will remember more information of current token and forget more information from previous tokens.

Additionally, we can see that most of the operations here are pointwise multiplication. The only matrix multiplication is ( S , ) @ ( S , 1 ) (S,)@(S,1) (S,)@(S,1), where S S S is a relatively small hyper-parameter. The overall computation complexity is linear with length of sequence.

All the code snippets are from slow_forward function. Maybe we will look into its accelerated version cuda_kernels_forward someday.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ShadyPi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值