Skip to content

Commit 7f966fd

Browse files
committed
docs: cleanup granite/mamba2 layer notes
1 parent 6063c8a commit 7f966fd

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

notes/granite-model.md

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,9 @@ $103 = {type = GGML_TYPE_Q4_0, buffer = 0x555557a25f10, ne = {1536, 6448, 1, 1},
374374
Now, this is an interesting operation, this is a matrix multiplication that is
375375
expanding the input token embedding into the space required for Mamba2.
376376
So this is one large matrix multiplication for efficiency and then we can create
377-
views into this for each of the parts. Notice that we used cur for with this
378-
matrix multiplication so this is how the input tokens are used in the z, B, C,
379-
and dt operations/values.
377+
views into this for each of the parts. Notice that we used cur with this matrix
378+
multiplication so this is how the input tokens are used in the z, B, C, and dt
379+
operations/values.
380380

381381
This projects each of the 7 tokens from 1536 dimensions to 6448 dimensions.
382382
The 6448 dimensions are made up of 3 parts:
@@ -553,6 +553,10 @@ The shape of xBC after the convolution is:
553553
$23 = {3328, 7, 1, 1}
554554
```
555555

556+
So the convolution is to incorporate some local context tokens and also to
557+
enable B and C to be influenced by the input. If we didn't allow for some local
558+
context the selective scan would only be handling long-range dependencies.
559+
556560
The final operations to happen in the convolution is:
557561
```c++
558562
// bias
@@ -562,7 +566,8 @@ The final operations to happen in the convolution is:
562566
```
563567

564568
Following the convolution we have the selective scan.
565-
First a views are created for x, B, and C from xBC ([3328, 7, 1, 1]):
569+
570+
First views are created for x, B, and C from xBC ([3328, 7, 1, 1]):
566571
```c++
567572
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head,
568573
n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
@@ -579,18 +584,21 @@ are organinzed as 48 heads each of dimension 64:
579584
(gdb) p x->ne
580585
$16 = {64, 48, 7, 1}
581586
```
587+
582588
B is the input-dependent B matrix which controls how the current input affects
583589
the state update:
584590
```console
585591
(gdb) p B->ne
586592
$17 = {128, 1, 7, 1}
587593
```
594+
588595
C is the input-dependent C matrix which controls how the hidden state is read
589596
out to produce the output:
590597
```console
591598
(gdb) p C->ne
592599
$18 = {128, 1, 7, 1}
593600
```
601+
594602
Next we adjust dt with a bias:
595603
```c++
596604
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
@@ -606,11 +614,18 @@ input dependent and determines how much the state evolves at each timestep.
606614

607615

608616
Next we get the A matrix for the selective scan which is the state transition
609-
matrix. This is not input dependant but is per layer. This matrix determines
617+
matrix. This is not input dependent but is per layer. This matrix determines
610618
how the the hidden state evolves over time:
611619
```c++
612620
ggml_tensor * A = model.layers[il].ssm_a;
613621
```
622+
The values in this matrix are actually stored in log form for numerical stability:
623+
```python
624+
if name.endswith(".A_log"):
625+
logger.debug("A_log --> A ==> " + new_name)
626+
data_torch = -torch.exp(data_torch)
627+
```
628+
614629
```console
615630
(gdb) p A->ne
616631
$22 = {1, 48, 1, 1}
@@ -752,8 +767,8 @@ For each timestep t in 0..6:
752767
y[t, h] = C[t] * h_state[t, h]
753768
```
754769
The returned tensor will contain both the updated hidden states (h_state above)
755-
as well as the the output of the SSM (y above) for each of the inputs.
756-
The lambda returns this into output_states (recall that we are in build_rs):
770+
as well as the output of the SSM (y above) for each of the inputs. The lambda
771+
returns this into output_states (recall that we are in build_rs):
757772
```c++
758773
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
759774
ggml_build_forward_expand(gf, output_states);

0 commit comments

Comments
 (0)