Open
Description
We're sharing a roadmap of ideas for improving and speeding up Gemma. If you'd like to join in and help us get there faster, please reach out so we can coordinate :)
Pending
MatMul2
- (jan-wassenberg) Bind tensors to NUMA node
- De-templatize matmul; wrap combinations like bf16, bf16, float into normal functions to speed up compile time and enable per-tensor type dispatch
- Support bf16 output from matmul
- Add bf16 support to AddFrom and MulByConstAndAdd
- .. then change most activations to bf16: C1, C2, ffw_out, maybe att_sums and att_out
- For F32 input, use F32 mul instead of forcing conversion to bf16
- (jan-wassenberg) Batched GEMM interface
Faster startup
- IOBatch interface calling preadv on Linux
Infra improvements/simplification
- KVCache use MatPtr, Row() instead of CachePosSize
- in ops.h, pass RowVectorBatch to functions rather than pointers
- Replace RowVectorBatch with MatStorageT
- Replace MMAutoTune with AutoTune from Highway, update Highway version
- rename compression/shared.h -> types.h
Optimizations
- Replace attention matVec with matmul - requires reshaping a matrix
- Use MatMul in EmbedImagePatches
- Fuse softmax and sampling
- Vectorize RoPE
- Unroll WeightedSumV (4 V rows at a time)
- Flash attention
- Improved KV and
att
layout - SFP embedding instead of bf16 - convert at load-time until the exporter is updated
- Vectorize RMSNorm
- Smaller KVCache: bf16, possibly reorder for better locality
Usability
- warn if unknown arguments given. std::map of known arg names?
- multiple .cc files to speed up builds
- move eval/test files to tests/
- Ctrl+C signal handler to ensure profiler results are printed without requiring %q input
- add --prompt flag to run.cc
- random prompt generation for debug_prompt.cc
- gemma_test: ensure deterministic output (same output given two of the same prompts)
Threading
- (jan-wassenberg) detect: total #logical, per-logical: package, chiplet, core, smt
- (jan-wassenberg) detect: CPU name, L2D/L3 size
- (Z.A.) CCX-aware pinning - ready, awaiting Highway 1.2 release
- (jan-wassenberg) more efficient ThreadPool (collaborative work requesting, not stealing)
- command line arg to disable pinning
- detect NUMA
Done
[x] Compression
- (pculliton, A.R.) Eval infrastructure
- (A.R.) Arbiter model for eval
- (Ray) add metadata to tensors, remove RawWeights
- add TOC to BlobStore
[x] File format
- store ModelInfo in weights BlobStore
- store tensor info in BlobStore
- store tokenizer in BlobStore
[x] New models
- (Daniel) Support PaliGemma
- Split Model into ModelFamily and ModelSize
- (jan-wassenberg) Land single-file format
[x] General infra
- (pculliton) Python wrapper
- (pculliton, ...) Improved CI: run on Kaggle infra
- AuxOut to hold timing info instead of printing in GenerateImpl.
- Sampling struct holds rng and temperature, to reduce length of args
- (P. C.) use new HWY_EXPORT_T to simplify dispatch - ready, awaiting Highway 1.2 release
[x] Dot product
- Add
_mm*_dpbf16_ps
toHWY_AVX3_SPR
andHWY_AVX3_ZEN4
targets, plus defineHWY_NATIVE_DOT_BF16
inset_macros-inl.h
- Faster SFP decode via table lookup
- Add new
NEON_*
target that usesvbfdot
forReorderWidenMulAccumulate
- If
!defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
, decompress bf16->f32 to temp array before MatVec (idea by Samuel, thank you!) - in Factor out deinterleaving of bf16 vectors for MatVecs. #166 - Apply even/odd trick to SFP
[x] Matmul
- (pculliton) implement basic matmul and test. Not using BLAS because we want to fuse matmul and decompression.
- (pculliton) 4x4 unrolled and vectorized matmul
- (szabadka, B.B.) Update Prefill to use matmul (activation @ weights) instead of MatVec. Almost there.
- Fused decompression inside matmul
- Support offsets within the matrix, required by some call sites
- (jan-wassenberg) Decompress weights to bf16 when native
- (jan-wassenberg) Cache-aware tiling/packing
- (jan-wassenberg) NUMA aware
- (jan-wassenberg) 64-bit precision
- (B.B.) Larger batch size
- (A.V.) Avoid allocations for decompression