Skip to content

[WIP] Multimodal SSM + TP #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 44 commits into
base: raymond/debug_mm
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a4509d7
fix dim name (#331)
RaymondLi0 Jul 18, 2025
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
4e67fbf
Adds lm_eval to evaluations (#282)
bigximik Jul 24, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
60093cd
Merge branch 'tp_mamba' into raymond/debug_mm_tp_mamba
RaymondLi0 Jul 25, 2025
f88fb2f
rename vit layer to block
RaymondLi0 Jul 25, 2025
22296b3
block_index
RaymondLi0 Jul 25, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
fa21174
flexible import
RaymondLi0 Jul 25, 2025
a3e5bde
Merge branch 'tp_mamba' into raymond/debug_mm_tp_mamba
RaymondLi0 Jul 25, 2025
d3cc158
update import
RaymondLi0 Jul 28, 2025
6d245c0
fix automodel export
RaymondLi0 Jul 28, 2025
61ecb5d
try: remove assert for TP and distillation
RaymondLi0 Jul 28, 2025
2565dac
more verbose config
RaymondLi0 Jul 28, 2025
7a7f12c
use local token_ids instead of modifying batch
RaymondLi0 Jul 29, 2025
743b42c
fix allreduce
RaymondLi0 Jul 29, 2025
c7247dc
fix
RaymondLi0 Jul 29, 2025
3074ec9
revert images_sizes conversion to np array
RaymondLi0 Jul 29, 2025
c4cdd86
debug logs
RaymondLi0 Jul 30, 2025
24d7a05
rm debug logs
RaymondLi0 Aug 5, 2025
37ddef4
changes for stp reverse-kl
RaymondLi0 Aug 5, 2025
a0d7a09
reverse kl: add clamping
RaymondLi0 Aug 6, 2025
72945bc
add loss mask for vision. should also handle padded sequences
RaymondLi0 Aug 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Run tests
run: pytest -v -ra .

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV]" triton==3.1.0
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
2 changes: 1 addition & 1 deletion Megatron-LM
134 changes: 134 additions & 0 deletions docs/user_guide/evaluators.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Evaluations

Fast-LLM allows you to perform various evaluations during training or as a separate evaluation step. In both cases, you need to use your training config with `training.evaluators` specified.

For evaluators used during training, both `interval` and `offset` must be specified. Then, start training as usual with:

`fast-llm train gpt --config path/to/training/config.yaml`

To perform evaluation as a separate step, use the same training config. Depending on the training progress, either the start model or the latest checkpoint will be loaded, and `interval` and `offset` will be ignored. To start evaluation:

`fast-llm evaluate gpt --config path/to/training/config.yaml`

## Currently Supported Evaluators

- `loss`
- `lm_eval`

## Loss Evaluator

To set up loss evaluation, specify a dataset to be used in the `data.datasets` section of the config. You must also define the loss evaluator in the `training.evaluators` config section. See example below.

```yaml
training:
evaluations:
stack_3b:
interval: 10
evaluator:
type: loss
iterations: 10
dataset_name: stack_3b
fineweb:
evaluator:
type: loss
iterations: 10
dataset_name: stack_3b
interval: 10
data:
datasets:
stack_3b:
type: memmap
path: path/to/memmap/dataset
fineweb:
type: memmap
path: path/to/memmap/dataset1
```

## Evaluation Harness (`lm_eval`) Evaluator

**Note:** Only data parallelism is currently supported for the `lm_eval` evaluator.

To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks.

The following environment variables may need to be set:

- `HF_HOME`: Path for Hugging Face data caching
- `WANDB_API_KEY_PATH`: Path to a file containing your Weights & Biases API key (if logging to W&B)
- `HUGGINGFACE_API_KEY_PATH`: Path to a file containing your Hugging Face hub token
- `NLTK_DATA`: Path to a directory that will contain downloaded NLTK packages (needed for some tasks)
- `HF_ALLOW_CODE_EVAL=1`: Required for some evaluation tasks

You may need to specify additional environment variables depending on the `lm_eval` tasks you want to run.

To specify an `lm_eval` task, the evaluator config includes the following fields:

### Model Config

The model instantiated for training is reused for evaluation, so you don't need to specify it separately. However, there are some parameters specific to `lm_eval`. See `fast_llm/engine/evaluation/config.EvaluatorLmEvalConfig` for details.

### CLI Parameters for `lm_eval`

All other parameters are specified as if you were calling the `lm_eval` CLI, using a list of strings. Some CLI parameters are ignored or restrictedβ€”specifically those related to model loading, W&B, batch sizes, and device setup, as these are managed by the rest of the Fast-LLM configuration.

Also, the tokenizer must be specified in `data.tokenizer`. If the tokenizer does not have a `bos_token`, it must be specified explicitly in `data.tokenizer.bos_token`. Although `lm_eval` does not use the `bos_token` directly, it is still required because the same tokenizer is used by other Fast-LLM components.

Below is an example of the config:

```yaml
training:
evaluations:
lm_eval_tasks1:
interval: 10
evaluator:
type: lm_eval
cli_args:
- --tasks
- gsm8k,xnli_en,wikitext,ifeval
- --output_path
- /path/to/lm_eval/output
data:
tokenizer:
path: path/to/the/tokenizer
```

It is also possible to run different tasks with different intervals and offsetsβ€”for example, to run slower or more comprehensive tasks less frequently.:

```yaml
training:
evaluations:
gsm8k:
interval: 20
evaluator:
type: lm_eval
cli_args:
- --tasks
- gsm8k
- --output_path
- /path/to/lm_eval/output
- --limit
- "64"
ifeval:
offset: 10
interval: 40
evaluator:
type: lm_eval
cli_args:
- --tasks
- ifeval
- --output_path
- /path/to/lm_eval/output
- --limit
- "32"
faster_tasks:
interval: 10
evaluator:
type: lm_eval
cli_args:
- --tasks
- xnli_en,wikitext
- --output_path
- /path/to/lm_eval/output
data:
tokenizer:
path: path/to/the/tokenizer
```
3 changes: 3 additions & 0 deletions fast_llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fast_llm.engine.config_utils.logging import configure_logging
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.utils import set_global_variables

# Import these submodules to ensure classes are added to the dynamic class registry.
import fast_llm.data.auto # isort: skip
Expand All @@ -20,6 +21,8 @@
def fast_llm_main_wrapper():
# (Pre-)configure logging
configure_logging()
# Set global and environment variables before third-party imports.
set_global_variables()
try:
yield
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def _get_class_name(cls) -> str:
@classmethod
def from_dict(
cls,
default: "Config| dict[str, typing.Any]]",
default: "Config| dict[str, typing.Any]",
*updates: "Config| dict[str | tuple[str, ...], typing.Any]",
strict: bool = True,
update_type: UpdateType = UpdateType.override,
Expand Down
121 changes: 121 additions & 0 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

import contextlib
import datetime
import io
import logging
import pickle
import typing

import torch
import torch.monitor
from torch._C._distributed_c10d import Work
from torch.distributed import ( # noqa
ProcessGroup,
Expand Down Expand Up @@ -46,6 +49,7 @@ def broadcast(
return work
else:
work.wait()
return None


def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None:
Expand Down Expand Up @@ -110,6 +114,7 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta
return work
else:
work.wait()
return None


def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
Expand All @@ -119,6 +124,7 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta
return work
else:
work.wait()
return None


@contextlib.contextmanager
Expand All @@ -133,3 +139,118 @@ def set_generator(generator: torch.Generator) -> typing.Generator[None, None, No
finally:
generator.set_state(default_generator.get_state())
default_generator.set_state(old_state)


def gather(
tensor: torch.Tensor,
gather_list: list[torch.Tensor] | None = None,
group: ProcessGroup | None = None,
async_op: bool = False,
dst: int = 0,
):
assert group is not None
opts = torch.distributed.GatherOptions()
opts.rootRank = dst
work = group.gather([gather_list] if dst == group.rank() else [], [tensor], opts)

if async_op:
return work
elif work is not None:
work.wait()
return None


def scatter(
tensor: torch.Tensor,
scatter_list: list[torch.Tensor] | None = None,
group: ProcessGroup | None = None,
async_op: bool = False,
src: int = 0,
):
assert group is not None
opts = torch.distributed.ScatterOptions()
opts.rootRank = src
opts.asyncOp = async_op
work = group.scatter(
[tensor if not tensor.is_complex() else torch.view_as_real(tensor)],
[[t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list]] if src == group.rank() else [],
opts,
)
if async_op:
return work
elif work is not None:
work.wait()
return None


def _object_to_tensor(obj: typing.Any) -> torch.Tensor:
f = io.BytesIO()
pickle.Pickler(f).dump(obj)
return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8)


def _tensor_to_object(tensor: torch.Tensor) -> typing.Any:
return pickle.Unpickler(io.BytesIO(tensor.numpy(force=True).tobytes())).load()


def gather_object(
obj: typing.Any,
group: ProcessGroup | None = None,
dst: int = 0,
) -> list[typing.Any] | None:
assert group is not None
group_rank = group.rank()
group_size = group.size()
device = torch.cuda.current_device()

obj_tensor = _object_to_tensor(None if group_rank == dst else obj)
sizes = torch.full([group.size()], len(obj_tensor), dtype=torch.int64, device=device)
all_gather_into_tensor(sizes, sizes[group.rank()], group=group)
sizes = sizes.tolist()
max_size = max(sizes)

input_tensor = torch.empty(max_size, dtype=torch.uint8, device=device)

if group_rank == dst:
output_tensors = list(torch.empty(max_size * group_size, dtype=torch.uint8, device=device).chunk(group_size))
gather(input_tensor, output_tensors, dst=dst, group=group)
return [
obj if rank_ == dst else _tensor_to_object(tensor[:size])
for rank_, (tensor, size) in enumerate(zip(output_tensors, sizes, strict=True))
]
else:
input_tensor[: obj_tensor.numel()].copy_(obj_tensor)
gather(input_tensor, None, dst=dst, group=group)
return None


def scatter_object(
scatter_object_input_list: typing.Optional[list[typing.Any]] = None,
group: ProcessGroup | None = None,
src: int = 0,
) -> typing.Any:
assert group is not None
group_rank = group.rank()
group_size = group.size()
device = torch.cuda.current_device()

if group_rank == src:
tensor_list = [
_object_to_tensor(None if rank_ == src else obj) for rank_, obj in enumerate(scatter_object_input_list)
]
sizes = [tensor.numel() for tensor in tensor_list]
max_size = max(sizes)
size_tensor = torch.tensor([[size, max_size] for size in sizes], dtype=torch.int64, device=device)
scatter(size_tensor[group_rank], list(size_tensor.unbind()), src=src, group=group)
scatter_list = list(torch.empty(max_size * group_size, dtype=torch.uint8, device=device).chunk(group_size))
for scatter_tensor, tensor, size in zip(scatter_list, tensor_list, sizes, strict=True):
scatter_tensor[:size].copy_(tensor)
scatter(scatter_list[src], scatter_list, src=src, group=group)
return scatter_object_input_list[src]
else:
size_tensor = torch.empty(2, dtype=torch.int64, device=device)
scatter(size_tensor, None, src=src, group=group)
size, max_size = size_tensor.tolist()
output_tensor = torch.empty(max_size, dtype=torch.uint8, device=device)
scatter(output_tensor, None, src=src, group=group)
return _tensor_to_object(output_tensor[:size])
1 change: 0 additions & 1 deletion fast_llm/data/dataset/gpt/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def _init(
assert self._num_pixels == num_pixels
if num_tokens is not None:
assert self._num_tokens == num_tokens
self._image_sizes = np.array(self._image_sizes, dtype=np.int32)

def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]:
return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _sample(self) -> None:
# Get the document sizes, the main information needed for sampling.
document_sizes, image_sizes = self._indexed_dataset.get_document_sizes()
document_sizes = torch.from_numpy(document_sizes).to(self._device)
if image_sizes.any():
if image_sizes:
image_token_sizes = []
for i, sizes in enumerate(image_sizes):
image_token_sizes.append(
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _split_and_blend_dataset_configs(
text_sizes, image_sizes = dataset.get_document_sizes()
tokens_cumsum = text_sizes.cumsum()
Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens)
if image_sizes.any():
if image_sizes:
num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes])
# We use the patch sizes only for the purposes of even splitting and blending weights.
# We can always use a different patch size for training without any significant impact
Expand Down
Loading