Skip to content

Near-term roadmap #164

Open
Open
@jan-wassenberg

Description

@jan-wassenberg

We're sharing a roadmap of ideas for improving and speeding up Gemma. If you'd like to join in and help us get there faster, please reach out so we can coordinate :)

Pending

MatMul2

  • (jan-wassenberg) Bind tensors to NUMA node
  • De-templatize matmul; wrap combinations like bf16, bf16, float into normal functions to speed up compile time and enable per-tensor type dispatch
  • Support bf16 output from matmul
  • Add bf16 support to AddFrom and MulByConstAndAdd
  • .. then change most activations to bf16: C1, C2, ffw_out, maybe att_sums and att_out
  • For F32 input, use F32 mul instead of forcing conversion to bf16
  • (jan-wassenberg) Batched GEMM interface

Faster startup

  • IOBatch interface calling preadv on Linux

Infra improvements/simplification

  • KVCache use MatPtr, Row() instead of CachePosSize
  • in ops.h, pass RowVectorBatch to functions rather than pointers
  • Replace RowVectorBatch with MatStorageT
  • Replace MMAutoTune with AutoTune from Highway, update Highway version
  • rename compression/shared.h -> types.h

Optimizations

  • Replace attention matVec with matmul - requires reshaping a matrix
  • Use MatMul in EmbedImagePatches
  • Fuse softmax and sampling
  • Vectorize RoPE
  • Unroll WeightedSumV (4 V rows at a time)
  • Flash attention
  • Improved KV and att layout
  • SFP embedding instead of bf16 - convert at load-time until the exporter is updated
  • Vectorize RMSNorm
  • Smaller KVCache: bf16, possibly reorder for better locality

Usability

  • warn if unknown arguments given. std::map of known arg names?
  • multiple .cc files to speed up builds
  • move eval/test files to tests/
  • Ctrl+C signal handler to ensure profiler results are printed without requiring %q input
  • add --prompt flag to run.cc
  • random prompt generation for debug_prompt.cc
  • gemma_test: ensure deterministic output (same output given two of the same prompts)

Threading

  • (jan-wassenberg) detect: total #logical, per-logical: package, chiplet, core, smt
  • (jan-wassenberg) detect: CPU name, L2D/L3 size
  • (Z.A.) CCX-aware pinning - ready, awaiting Highway 1.2 release
  • (jan-wassenberg) more efficient ThreadPool (collaborative work requesting, not stealing)
  • command line arg to disable pinning
  • detect NUMA

Done

[x] Compression

  • (pculliton, A.R.) Eval infrastructure
  • (A.R.) Arbiter model for eval
  • (Ray) add metadata to tensors, remove RawWeights
  • add TOC to BlobStore

[x] File format

  • store ModelInfo in weights BlobStore
  • store tensor info in BlobStore
  • store tokenizer in BlobStore

[x] New models

  • (Daniel) Support PaliGemma
  • Split Model into ModelFamily and ModelSize
  • (jan-wassenberg) Land single-file format

[x] General infra

  • (pculliton) Python wrapper
  • (pculliton, ...) Improved CI: run on Kaggle infra
  • AuxOut to hold timing info instead of printing in GenerateImpl.
  • Sampling struct holds rng and temperature, to reduce length of args
  • (P. C.) use new HWY_EXPORT_T to simplify dispatch - ready, awaiting Highway 1.2 release

[x] Dot product

  • Add _mm*_dpbf16_ps to HWY_AVX3_SPR and HWY_AVX3_ZEN4 targets, plus define HWY_NATIVE_DOT_BF16 in set_macros-inl.h
  • Faster SFP decode via table lookup
  • Add new NEON_* target that uses vbfdot for ReorderWidenMulAccumulate
  • If !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16, decompress bf16->f32 to temp array before MatVec (idea by Samuel, thank you!) - in Factor out deinterleaving of bf16 vectors for MatVecs. #166
  • Apply even/odd trick to SFP

[x] Matmul

  • (pculliton) implement basic matmul and test. Not using BLAS because we want to fuse matmul and decompression.
  • (pculliton) 4x4 unrolled and vectorized matmul
  • (szabadka, B.B.) Update Prefill to use matmul (activation @ weights) instead of MatVec. Almost there.
  • Fused decompression inside matmul
  • Support offsets within the matrix, required by some call sites
  • (jan-wassenberg) Decompress weights to bf16 when native
  • (jan-wassenberg) Cache-aware tiling/packing
  • (jan-wassenberg) NUMA aware
  • (jan-wassenberg) 64-bit precision
  • (B.B.) Larger batch size
  • (A.V.) Avoid allocations for decompression

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions