@@ -277,8 +277,8 @@ ggml_tensor * llm_graph_context::build_rs(
277277 get_state_rows);
278278}
279279```
280- So notice that we are passing in the conv_states_all as `s` here which is the
281- which contains state for each sequence (we only have one sequence in this case):
280+ So notice that we are passing in the conv_states_all as `s` here which contains
281+ the state for each sequence (we only have one sequence in this case):
282282```c++
283283ggml_tensor * llm_graph_context::build_rs(
284284 ggml_tensor * s, // conv_states_all
@@ -304,7 +304,7 @@ ggml_tensor * llm_graph_context::build_rs(
304304 // {state_size, rs_size} -> {state_size, n_seqs}
305305 ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
306306 // Uses indices in state_copy_main to select which rows to copy for the
307- // three previous states.
307+ // 3 previous states.
308308 ggml_build_forward_expand(gf, output_states);
309309
310310 // copy extra states which won't be changed further (between n_seqs and n_rs)
@@ -549,6 +549,330 @@ Notice that for each step we take we are mixing in the previous three timesteps!
549549This is how local temporal context is added before the SSM computation.
550550
551551This operation is performed for each of the 3328 dimensions (rows) in parallel.
552+ The shape of xBC after the convolution is:
553+ ``` console
554+ (gdb) p xBC->ne
555+ $23 = {3328, 7, 1, 1}
556+ ```
557+
558+ The final operations to happen in the convolution is:
559+ ``` c++
560+ // bias
561+ xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
562+
563+ xBC = ggml_silu(ctx0, xBC);
564+ ```
565+
566+ Following the convolution we have the selective scan.
567+ First a views are created for x, B, and C from xBC ([ 3328, 7, 1, 1] ):
568+ ``` c++
569+ ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head,
570+ n_seq_tokens, n_seqs, head_dim*xBC->nb[0 ], xBC->nb[1 ], xBC->nb[2 ], 0 );
571+
572+ ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group,
573+ n_seq_tokens, n_seqs, d_state*xBC->nb[0 ], xBC->nb[1 ], xBC->nb[2 ], d_inner*ggml_element_size (xBC));
574+
575+ ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group,
576+ n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
577+ ```
578+ So x is the input to the selective scan and has a total of 3072 elements which
579+ are organinzed as 48 heads each of dimension 64:
580+ ```console
581+ (gdb) p x->ne
582+ $16 = {64, 48, 7, 1}
583+ ```
584+ B is the input-dependent B matrix which controls how the current input affects
585+ the state update:
586+ ``` console
587+ (gdb) p B->ne
588+ $17 = {128, 1, 7, 1}
589+ ```
590+ C is the input-dependent C matrix which controls how the hidden state is read
591+ out to produce the output:
592+ ``` console
593+ (gdb) p C->ne
594+ $18 = {128, 1, 7, 1}
595+ ```
596+ Next we adjust dt with a bias:
597+ ``` c++
598+ dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
599+ ```
600+ ``` console
601+ (gdb) p dt->ne
602+ $21 = {48, 7, 1, 1
603+ (gdb) p model.layers[il].ssm_dt_b->ne
604+ $20 = {48, 1, 1, 1}
605+ ```
606+ The delta controls the descretization of the continous-time SSM which is also
607+ input dependent and determines how much the state evolves at each timestep.
608+
609+
610+ Next we get the A matrix for the selective scan which is the state transition
611+ matrix. This is not input dependant but is per layer. This matrix determines
612+ how the the hidden state evolves over time:
613+ ``` c++
614+ ggml_tensor * A = model.layers[il].ssm_a;
615+ ```
616+ ``` console
617+ (gdb) p A->ne
618+ $22 = {1, 48, 1, 1}
619+ ```
620+ The hidden state evolves over time according to the equations of Mamba2 or the
621+ SSM (not sure what is the most correct term here). And the delta (dt) determines
622+ when we sample the state.
623+
624+ Next we have a lambda function:
625+ ``` c++
626+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
627+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
628+
629+ // TODO: use semistructured matrices to implement state-space duality
630+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
631+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
632+ };
633+ ```
634+ Notice that it takes a states tensor and an ids tensor. This is where the actual
635+ ggml_ssm_scan function is called to build the ssm scan operation.
636+
637+ Next we will see the first usage of ssm_states_all which contains all the hidden
638+ ssm states, and notice that this is also calling build_rs which we also did for
639+ the conv_states_all earlier, but this time the lambda function defined above is
640+ passed in which is what will call ggml_ssm_scan:
641+ ```c++
642+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
643+ ```
644+ ``` console
645+ (gdb) p d_state
646+ $39 = 128
647+
648+ (gdb) p d_inner
649+ $40 = 3072
650+
651+ (gdb) p d_state * d_inner
652+ $41 = 393216
653+
654+ (gdb) p hparams.n_embd_s()
655+ $38 = 393216
656+ ```
657+
658+ Just like the previous build_rs call which we saw loaded the previous conv statas
659+ this time it will load the previous hidden states.
660+ ``` c++
661+ ggml_tensor * llm_graph_context::build_rs (
662+ llm_graph_input_rs * inp,
663+ ggml_tensor * s, // ssm_states_all
664+ int32_t state_size,
665+ int32_t n_seqs,
666+ const llm_graph_get_rows_fn & get_state_rows) const {
667+ const auto * kv_state = inp->mctx;
668+
669+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
670+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
671+ get_state_rows);
672+ }
673+ ```
674+
675+ ```c++
676+ ggml_tensor * llm_graph_context::build_rs(
677+ ggml_tensor * s, // ssm_states_all
678+ ggml_tensor * state_copy_main,
679+ ggml_tensor * state_copy_extra,
680+ int32_t state_size,
681+ int32_t n_seqs,
682+ uint32_t n_rs,
683+ uint32_t rs_head,
684+ uint32_t rs_size,
685+ int32_t rs_zero,
686+ const llm_graph_get_rows_fn & get_state_rows) const {
687+
688+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
689+ ```
690+ This reshapes the ssm_states_all into a 2d tensor where each row is the hidden
691+ state for a sequence in the cache, in our case we only have one sequence but
692+ there could be more:
693+ ``` console
694+ (gdb) p states->ne
695+ $36 = {393216, 1, 1, 1}
696+ ↑
697+ sequences
698+
699+ Row 0: [393216 elements] <-- SSM state for sequence slot 0
552700
701+ If `rs_size` were larger (say 4), it would be:
702+ Row 0: [393216 elements] <-- Slot 0
703+ Row 1: [393216 elements] <-- Slot 1
704+ Row 2: [393216 elements] <-- Slot 2
705+ Row 3: [393216 elements] <-- Slot 3
706+ ```
707+
708+ Next if there are new sequences we need to zero out their states. This is only
709+ done if rs_zero is greater than or equal to 0:
710+ ``` console
711+ (gdb) p rs_zero
712+ $42 = 0
713+ ```
714+ So this is creating a view into the states tensor with the number of elements
715+ specified by state_size * (rs_zero >= 0), and remember that if rs_zero is true
716+ then the expression (rs_zero >= 0) evaluates to 1, so the number of elements
717+ in our case will be be the state size of 393216, and the offset will be zero:
718+ ``` c++
719+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states,
720+ state_size*(rs_zero >= 0 ), rs_zero*states->nb[1 ]*(rs_zero >= 0 ));
721+ ggml_build_forward_expand (gf, ggml_scale_inplace(ctx0, state_zero, 0));
722+ ```
723+ So in our case this will create a view having 393216 elements starting at offset
724+ 0. And notice that the scale_inplace with 0 will zero out this view. So that is
725+ zeroing out the state for the new sequence that will be written to slot 0, which
726+ is the one we are processing currrently.
727+
728+ Next we have the call to the lambda function get_state_rows which will call
729+ ```c++
730+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
731+ ggml_build_forward_expand(gf, output_states);
732+ ```
733+ ``` c++
734+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
735+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
736+ ```
737+ This will reshape the states into a 4d tensor:
738+ ```console
739+ (gdb) p ssm->ne
740+ $52 = {128, 64, 48, 1}
741+ ```
742+ This is d_state = 128, head_dim = 64, n_head = 48, and the last dimension.
743+ Next we call the ggml_ssm_scan funtion which will operate on the ssm states
744+ which can be empty if this is the first timestep for this sequence:
745+ ``` c++
746+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
747+ };
748+ ```
749+ This will perform something like the following:
750+ ``` console
751+ For each timestep t in 0..6:
752+ For each head h:
753+ h_state[t, h] = discretize(A, dt[t,h]) * h_state[t-1, h] + discretize(B[t], dt[t,h]) * x[t, h]
754+ y[t, h] = C[t] * h_state[t, h]
755+ ```
756+ The returned tensor will contain both the updated hidden states (h_state above)
757+ as well as the the output of the SSM (y above) for each of the inputs.
758+ The lambda returns this into output_states (recall that we are in build_rs):
759+ ``` c++
760+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
761+ ggml_build_forward_expand (gf, output_states);
762+ ```
763+ In a batched scenario where n_rs > n_seqs, some sequence slots might be included
764+ in the request but not actively processed. This step preserves those states by
765+ copying them forward in the cache.
766+ ```c++
767+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
768+ ggml_build_forward_expand(gf,
769+ ggml_cpy(ctx0,
770+ states_extra,
771+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
772+
773+ return output_states;
774+ ```
775+ In our case this will create a view with zero elements since n_rs == n_seqs:
776+ ``` console
777+ (gdb) p state_size*(n_rs - n_seqs)
778+ $62 = 0
779+ ```
780+ This will then return us into the build_mamba2_layer function the output_states:
781+ ``` c++
782+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
783+ ```
784+ ``` console
785+ (gdb) p y_ssm->ne
786+ $65 = {414720, 1, 1, 1}
787+ ```
788+ Baked into the following is the copying of the updates ssm states into
789+ ssm_states_all for the next timestep:
790+ ``` c++
791+ ggml_build_forward_expand (gf,
792+ ggml_cpy(ctx0,
793+ ggml_view_1d(ctx0, y_ssm,
794+ d_state* d_inner* n_seqs, // same number of elements
795+ ggml_nelements(x)* x->nb[ 0] ),
796+ ggml_view_1d(ctx0, ssm_states_all,
797+ d_state* d_inner* n_seqs, // same number of elements
798+ kv_head* d_state* d_inner* ggml_element_size(ssm_states_all))));
799+ ```
800+ The source of the copy is 393216 elements starting from offset 86016 in y_ssm:
801+ ```console
802+ (gdb) p d_state * d_inner * n_seqs
803+ $67 = 393216
804+
805+ (gdb) p ggml_nelements(x) * x->nb[0]
806+ $71 = 86016
807+ ```
808+ In our case kv_heads is 0 so the offset into ssm_states_all is zero, so the
809+ elements are copyied into it.
810+
811+ Next a view is created for y which is the output of the SSM for each head:
812+ ``` c++
813+ ggml_tensor * y = ggml_view_4d(ctx0, y_ssm,
814+ head_dim,
815+ n_head,
816+ n_seq_tokens,
817+ n_seqs,
818+ x->nb[1 ],
819+ n_head*x->nb[1 ],
820+ n_seq_tokens*n_head*x->nb[1 ], 0 );
821+ ```
822+ ``` console
823+ (gdb) p y->ne
824+ $76 = {64, 48, 7, 1}
825+ ```
826+ Then the D matrix is applied which is the skip connection/residual connection
827+ which is the ` y = C*h + D*x ` part of the SSM equation:
828+ ``` c++
829+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
830+ cb (y, "mamba2_y_add_d", il);
831+ ```
832+ Then we have the SwiGLU gating using z:
833+ ```c++
834+ y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
835+ ```
836+ After that we have:
837+ ``` c++
838+ if (model.layers[il].ssm_norm) {
839+ y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
840+ y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
841+ }
842+ ```
843+ ```console
844+ (gdb) p model.layers[il].ssm_norm->ne
845+ $81 = {3072, 1, 1, 1}
846+ ```
847+ Then y is reshaped to 3d for the output of the layer:
848+ ``` c++
849+ y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
850+ ```
851+ ``` console
852+ (gdb) p y->ne
853+ $86 = {3072, 7, 1, 1}
854+ ```
855+
856+ And then the output is projected back to the embedding dimension:
857+ ``` c++
858+ cur = build_lora_mm(model.layers[il].ssm_out, y);
859+ ```
860+ ``` console
861+ (gdb) p model.layers[il].ssm_out->ne
862+ $88 = {3072, 1536, 1, 1}
863+
864+ (gdb) p cur->ne
865+ $89 = {1536, 7, 1, 1}
866+ ```
867+ And then reshaped but we only have a single sequence so this does not change
868+ anything:
869+ ``` c++
870+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0 ], n_seq_tokens * n_seqs);
871+ cb (cur, "mamba_out", il);
872+ ```
873+ ```console
874+ (gdb) p cur->ne
875+ $90 = {1536, 7, 1, 1}
876+ ```
553877
554- _ wip _
878+ And that was the complete mamba2 layer!
0 commit comments