Library Link: https://huggingface.co/docs/transformers/main/en/model_doc/mamba#transformers.MambaModel
By looking into Hugging Face’s Transformers.MambaModel
library, we learn the structure and implementation of Mamba.
MambaModel
Despite the part relating to different control flow and customer configures, the main part in MambaModel.forward
is as follows:
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return MambaOutput(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
The main computation is a for
loop, which runs some mixer_block
s to process the hidden_states
. The self.layers
is defined as
self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
which is a list of MambaBlock
. The structure of MambaBlock
will be introduced later.
MambaRMSNorm
After the iteration, we can see that the hidden_states
is passed through a norm layer, defined as
self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
RMSNorm is a widely used normalization in LLMs, because of its simplicity. The norm layer is defined as
class MambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
The corresponding formular is
y
i
=
x
i
1
n
∑
i
=
1
n
x
i
2
+
ϵ
×
γ
,
y_i = \frac{x_i}{\sqrt{\frac{1}{n}\sum_{i=1}^n x_i^2+\epsilon}} \times \gamma,
yi=n1∑i=1nxi2+ϵxi×γ,
where
γ
\gamma
γ is a learnable parameter. Compared with the classic layer normalization
y
i
=
x
i
−
E
(
x
)
V
a
r
(
x
)
+
ϵ
×
γ
+
β
y_i = \frac{x_i - E(x)}{\sqrt{Var(x)+\epsilon}}\times \gamma + \beta
yi=Var(x)+ϵxi−E(x)×γ+β, it simplified the computation and removes trainable parameter
β
\beta
β but maintains comparable performance. Therefore, RMSNorm becomes popular in nowadays LLMs.
MambaBlock
Now, we look into the fundamental part of Mamba, the MambaBlock
. It is simple and straightforward. The Mamba block is just a Mamba mixer with a residual connection.
class MambaBlock(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = MambaMixer(config, layer_idx=layer_idx)
def forward(
self,
hidden_states,
cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states
return hidden_states
MambaMixer
Here comes the most important core of Mamba——MambaMixer
. We will go through it step by step.
To begin with, let’s set the shape of input_states
to
(
B
,
L
,
D
)
(B, L, D)
(B,L,D), where
B
B
B represents the batch size,
L
L
L represents the sequence padding length and
D
D
D represents the hidden embedding size. In Mamba, there is also a special intermediate size
N
N
N, which is used to enhance the representative ability of model. By default,
N
=
2
D
N=2D
N=2D, which can be seen in the source code of MambaConfig
.
Input projection
self.in_proj
is used to project the input states into intermediate states hidden_states
and also gate signal gate
, so the output dimension of self.in_proj
is
2
N
2N
2N. gate
will be used as the fate signal to control, or say, select the output.
# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # (B,L,D)->(B,L,2N)->(B,2N,L)
hidden_states, gate = projected_states.chunk(2, dim=1) # (B,2N,L)->(B,N,L),(B,N,L)
Obviously, the shape of hidden_states
and gate
are all
(
B
,
N
,
L
)
(B, N, L)
(B,N,L).
Convolution sequence transformation
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.intermediate_size,
padding=config.conv_kernel - 1,
)
The convolution layer is defined as shown above. It sets a 1D depth-wise convolution, whose out_channels
and groups
are all set to
N
N
N. So, basically there are
N
N
N convolutional cores for
N
N
N channels in the embedding respectively. The default kernel_size
is 4, noted as
K
K
K, so each token will aggregate information from its prior 3 tokens. The padding is set to
K
−
1
K-1
K−1 to make sure each token only aggregates information from tokens before itself.
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (B,N,L)
Activate the output of convolution layer, the default option is SiLU. Note that the padding
argument will add padding on both sides of the input, so hidden_states
will be padded to
(
B
,
N
,
L
+
2
(
K
−
1
)
)
(B, N, L+2(K-1))
(B,N,L+2(K−1)) first, and then be convoluted to
(
B
,
N
,
L
+
K
−
1
)
(B, N, L+K-1)
(B,N,L+K−1), where the last
K
−
1
K-1
K−1 position carries redundant information from the initially last
K
−
1
K-1
K−1 tokens, so we need to remove them by [..., :seq_len]
.
Compute parameters of SSM
# selective projection used to make dt, B and C input dependent
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
By default, shape of the state space latents ssm_state_size
, noted as
S
S
S, is set to 16 and rank of the discretization projection matrix time_step_rank
is set to
⌊
D
16
⌋
\lfloor\frac{D}{16}\rfloor
⌊16D⌋.
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) # (B,L,N)->(B,L,D/16+S+S)
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
) # (B,L,N/16+S+S)->(B,L,D/16),(B,L,S),(B,L,S)
discrete_time_step = self.dt_proj(time_step) # (B,L,D/16)->(B,L,N)
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # (B,L,N)->(B,N,L)
In 3.a, the model computes the
Δ
,
B
\Delta, B
Δ,B and
C
C
C in Mamba formula, where
B
,
C
B, C
B,C are derived from hidden_states
via a linear layer,
Δ
\Delta
Δ is firstly projected to a given rank and then projected back to
N
N
N and finally passed through an activation layer.
# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] # (S,)->(1,S)
A = A.expand(self.intermediate_size, -1).contiguous() # (1,S)->(N,S)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size)) # (N,)
In contrast,
A
A
A, and
D
D
D are trainable parameters of each Mixer. They are discretized in 3.b.
You might also notice that
A
A
A is actually stored as A_log
, but why? This is to guarantee that discrete
A
A
A, or say
A
‾
\overline{A}
A, is within
(
0
,
1
)
(0,1)
(0,1). Since the value domain of A_log
is
R
\mathbb{R}
R, the
A
A
A’s value domain will become
(
−
∞
,
0
)
(-\infty,0)
(−∞,0). Because
Δ
>
0
\Delta>0
Δ>0, as it is activated by softplus
, then
A
‾
=
e
Δ
A
\overline{A}=e^{\Delta A}
A=eΔA is with
(
0
,
1
)
(0,1)
(0,1).
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # (N,S)
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # (N,S)->(1,N,1,S)*(B,N,L,1)->(B,N,L,S)
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # (B,N,L,1)*(B,1,L,S)->(B,N,L,S)
deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # (B,N,L,S)*(B,N,L,1)->(B,N,L,S)
In 3.b,
A
A
A and
B
B
B are both discretized by
Δ
\Delta
Δ, where
A
‾
=
e
Δ
A
,
B
‾
=
Δ
B
\overline{A}=e^{\Delta A}, \overline{B}=\Delta B
A=eΔA,B=ΔB, and deltaB_u
is
B
‾
\overline B
B multiply input hidden_states
X
X
X, which represents the information coming from each input token.
Now, we are going to recurrent in 3.c. Keep the formula in mind:
s
i
=
A
‾
i
s
i
−
1
+
B
‾
i
x
i
o
i
=
C
i
s
i
s_i = \overline{A}_is_{i-1}+\overline{B}_ix_i\\ o_i = C_i s_i
si=Aisi−1+Bixioi=Cisi
ssm_state
is initialized as
0
\textbf{0}
0 or restored from cache in Step 2. Convolution sequence transformation.
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # (B,N,S)*(B,N,S)
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # (B,N,S)*(B,1,S,1)->(B,N,1)
scan_outputs.append(scan_output[:, :, 0]) # (B,N)
General speaking, in the for
loop, the SSM iterates from left to right, aggregates the information from current and previous tokens, just like an RNN. Following the formula above,
A
‾
i
\overline{A}_i
Ai controls how much information is kept from SSM previous states and
B
‾
i
\overline{B}_i
Bi controls the information flow from current token
x
i
x_i
xi. Then
C
i
C_i
Ci produces the output based on current SSM states.
However, if look in detail, we can find there are a lot of tricks of the shapes of these tensors, which represents the mathematical mechanism of Mamba. The recurrency is actually doing SSM for each intermedia dimension N N N, we can see that A ‾ , B ‾ \overline{A}, \overline{B} A,B are both N N N-dependent, but C C C is not, which is quite interesting.
After collecting the token-by-token outputs, we stack them together and add residual connection controlled by D D D to it. We can also notice that D D D is L L L-independent, similar to C C C. The parameters that control the output are all seemed to be simplified.
Then, gate
is activated and multiplied on the scan_output
, acting like another selective process.
scan_output = torch.stack(scan_outputs, dim=-1) # (B,N,L)
scan_output = scan_output + (hidden_states * self.D[None, :, None]) # (B,N,L)+(B,N,L)*(1,N,1)->(B,N,L)
scan_output = (scan_output * self.act(gate)) # (B,N,L)*(B,N,L)
p.s. I found there is a typo in the comment of the shape of tensor, so I fixed it and submit a PR. Let’s see if I can make a contribution to the great Hugging Face Transformers!
Final projection
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # (B,N,L)->(B,L,N)->(B,L,D)
return contextualized_states
Finally, the scan_output
is passed to a linear projection to project it from
N
N
N back to
D
D
D.
Parameter illustration
Obviously, there are a lot of parameters in Mamba Mixer, and it’s often confusing to figure out their relationships, hope the following illustration can help, where all trainable parameters are marked as red.
We can see that A_log
, D
, convolution layer and all projection layers are trainable parameters of Mamba Mixer.
Δ
,
B
,
C
\Delta, B,C
Δ,B,C is derived from input_states
.
A
‾
,
B
‾
\overline A, \overline B
A,B is discretized by
Δ
\Delta
Δ to adjust the memory rate according to input_states
. The larger
Δ
\Delta
Δ is, the closer
A
‾
\overline A
A to 0, the larger the
B
‾
\overline B
B, which means the ssm_state
will remember more information of current token and forget more information from previous tokens.
Additionally, we can see that most of the operations here are pointwise multiplication. The only matrix multiplication is ( S , ) @ ( S , 1 ) (S,)@(S,1) (S,)@(S,1), where S S S is a relatively small hyper-parameter. The overall computation complexity is linear with length of sequence.
All the code snippets are from slow_forward
function. Maybe we will look into its accelerated version cuda_kernels_forward
someday.