Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
176 commits
Select commit Hold shift + click to select a range
7709e65
WIP: multimodal support
sohamparikh Apr 8, 2025
0db2bd2
rough idea for memmap
sohamparikh Apr 9, 2025
0d89f68
faster image size reading
sohamparikh Apr 15, 2025
3866a53
solidify prepare
sohamparikh Apr 21, 2025
8413983
wip
sohamparikh Apr 24, 2025
6521e41
vision model
sohamparikh Apr 24, 2025
daf586f
wip
sohamparikh Apr 24, 2025
ef4488d
wip
sohamparikh Apr 25, 2025
6d9d595
missing files
sohamparikh Apr 28, 2025
6cb8f5d
make it work, barely
sohamparikh Apr 30, 2025
5761a2d
fix
sohamparikh Apr 30, 2025
d45d600
fixes
sohamparikh May 1, 2025
74a99b8
changes
sohamparikh May 6, 2025
99ad5d9
patches and fixes
sohamparikh May 7, 2025
bcb557a
fix dependency
sohamparikh May 7, 2025
a6f5364
remove for testing
sohamparikh May 7, 2025
73b431b
mising
sohamparikh May 7, 2025
6d65676
fix
sohamparikh May 8, 2025
46aefc1
Merge branch 'main' into soham/pixtral-support
sohamparikh May 9, 2025
66e7081
fixes
sohamparikh May 9, 2025
7f86a7f
fix
sohamparikh May 12, 2025
3a8a99d
more fixes after merge
sohamparikh May 12, 2025
d16284e
conv cleanup
sohamparikh May 12, 2025
b3134aa
more conv cleanup
sohamparikh May 12, 2025
c8aa66e
images + loss-masks
sohamparikh May 13, 2025
0baae59
minor fixes
sohamparikh May 13, 2025
48855be
cleanup
sohamparikh May 13, 2025
f35e003
cleanup
sohamparikh May 13, 2025
4eb34cb
cleanup
sohamparikh May 13, 2025
ebb9e27
cleanup
sohamparikh May 13, 2025
51098ef
fix
sohamparikh May 13, 2025
60b87fa
prepare cleanup
sohamparikh May 13, 2025
f8a5532
slightly better conversion
sohamparikh May 13, 2025
490651e
cleanup, sequence parallelism
sohamparikh May 14, 2025
24e1b83
fix conv
sohamparikh May 14, 2025
0f1612a
wip fixes
sohamparikh May 14, 2025
2e48c5f
fix
sohamparikh May 14, 2025
d529d37
fix image position
sohamparikh May 17, 2025
3c22dda
cleanup
sohamparikh May 17, 2025
f0c8d83
cleanup
sohamparikh May 20, 2025
ca33ee8
cleaner, extensible multimodal config
sohamparikh May 21, 2025
f3a4a74
cleanup
sohamparikh May 21, 2025
3b955b1
fixes for pixtral
sohamparikh May 21, 2025
49daf58
model fixes
sohamparikh May 21, 2025
b5ed9f4
more cleanup
sohamparikh May 22, 2025
dc888c8
image break token in sampling
sohamparikh May 22, 2025
af3e2db
minor fixes
sohamparikh May 23, 2025
6d56be0
fix img break
sohamparikh May 24, 2025
ce91646
fixes
sohamparikh May 27, 2025
204b3e9
fix image embeddings offset
sohamparikh May 28, 2025
fd08eac
heterogeneous data fixes
sohamparikh May 29, 2025
1e3652a
convert to rgb
sohamparikh May 29, 2025
2aabf35
fix sequence parallel image patches
sohamparikh May 30, 2025
b6d4858
fixes
sohamparikh May 31, 2025
25a650b
no compile for embeddings
sohamparikh May 31, 2025
c904da5
fix sampling
sohamparikh Jun 1, 2025
7a4701c
sampling and preprocessing bugs
sohamparikh Jun 2, 2025
067f901
speed up sampling
sohamparikh Jun 2, 2025
f24325e
cap image size reduction
sohamparikh Jun 2, 2025
0f37664
fix span offset with images
sohamparikh Jun 2, 2025
ff8fecc
fix span offset with images
sohamparikh Jun 2, 2025
c663cbb
move image logic to sampled
sohamparikh Jun 3, 2025
f52f02b
cleanup
sohamparikh Jun 3, 2025
5436357
merge main
sohamparikh Jun 4, 2025
02f6d8f
cleanup
sohamparikh Jun 5, 2025
6843129
jpeg dependency
sohamparikh Jun 5, 2025
b94b1ee
install libjpeg-dev in gh actions
sohamparikh Jun 5, 2025
9e4f14f
fix sampling test
sohamparikh Jun 5, 2025
d1c804f
fix
sohamparikh Jun 6, 2025
75d64a6
fix data cache reloading
sohamparikh Jun 9, 2025
cba6986
fix tokenization
sohamparikh Jun 9, 2025
275fefa
pixtral SFT (#296)
shruthan Jun 11, 2025
605cc7f
review comments
sohamparikh Jun 11, 2025
06aa740
simplified tokenization with spans
sohamparikh Jun 12, 2025
30e3d34
Update fast_llm/data/preparator/gpt_memmap/prepare.py
sohamparikh Jun 12, 2025
c1aa709
rename
sohamparikh Jun 12, 2025
0ada42b
Merge branch 'soham/pixtral-support' of github.com:ServiceNow/Fast-LL…
sohamparikh Jun 12, 2025
4e7afd8
merge main
sohamparikh Jun 12, 2025
8e106f7
fix conversion
sohamparikh Jun 12, 2025
080dcb5
fix sequence lengths, parallel conv
sohamparikh Jun 16, 2025
f186868
minor
sohamparikh Jun 16, 2025
6b9ea2e
fix image at beginning
sohamparikh Jun 16, 2025
ad18ea1
pixtral fix conversion (#315)
RaymondLi0 Jun 20, 2025
29e66d9
handle no image samples
sohamparikh Jun 25, 2025
06a0910
mask special image tokens
sohamparikh Jun 26, 2025
bbd71df
avoid multiple labels cloning
sohamparikh Jun 27, 2025
bdc138c
merge main
sohamparikh Jul 1, 2025
96a5fd8
fix training
sohamparikh Jul 8, 2025
8f93a27
fix prepare config
sohamparikh Jul 8, 2025
c3eda1c
fix imports
sohamparikh Jul 8, 2025
1cf0ea0
fix tests
sohamparikh Jul 8, 2025
77d294c
fix tests
sohamparikh Jul 9, 2025
8434b20
cleanup
sohamparikh Jul 9, 2025
a0b6e45
add assert
RaymondLi0 Jul 17, 2025
c03faa5
resolve merge conflicts
tscholak Jul 17, 2025
d35c685
move check to config validation
RaymondLi0 Jul 17, 2025
ef982c9
fix torchvision import
sohamparikh Jul 17, 2025
903e270
Merge branch 'soham/pixtral-support' into raymond/debug_mm
RaymondLi0 Jul 17, 2025
3345ab1
debug log
RaymondLi0 Jul 17, 2025
b0b52fa
fix device
RaymondLi0 Jul 17, 2025
956a8dd
Merge branch 'main' into soham/pixtral-support
RaymondLi0 Jul 17, 2025
2be288c
remove log
RaymondLi0 Jul 17, 2025
9b180b4
Merge branch 'soham/pixtral-support' into raymond/debug_mm
RaymondLi0 Jul 17, 2025
d1caa98
fix name
RaymondLi0 Jul 17, 2025
0fbe881
fix hybrid get_layers
RaymondLi0 Jul 18, 2025
854f305
debug
RaymondLi0 Jul 18, 2025
b7b8193
add llava hybrid format
RaymondLi0 Jul 18, 2025
a1589da
workaround init
RaymondLi0 Jul 18, 2025
a202b4c
update
RaymondLi0 Jul 18, 2025
b675ec2
update
RaymondLi0 Jul 18, 2025
a055d2a
update
RaymondLi0 Jul 18, 2025
8e4ef5d
refactoring attempt
RaymondLi0 Jul 18, 2025
4bfce67
update ssm conversion: use hf_prefix/offset
RaymondLi0 Jul 21, 2025
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
f93c51f
draft llava hybrid
RaymondLi0 Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
d10eaad
fix and add test config
RaymondLi0 Jul 23, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
7c8de47
conversion fixes and tests
RaymondLi0 Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
752274b
fix conversion
RaymondLi0 Jul 24, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
8c03b54
use hybrid cache, update test
RaymondLi0 Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
d51f817
finish ssm-hybrid conversion
RaymondLi0 Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
f898ff2
fix architecture classvar
RaymondLi0 Jul 24, 2025
c447fc3
add llava test and m2 conversion test
RaymondLi0 Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
60093cd
Merge branch 'tp_mamba' into raymond/debug_mm_tp_mamba
RaymondLi0 Jul 25, 2025
f88fb2f
rename vit layer to block
RaymondLi0 Jul 25, 2025
22296b3
block_index
RaymondLi0 Jul 25, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
fa21174
flexible import
RaymondLi0 Jul 25, 2025
a3e5bde
Merge branch 'tp_mamba' into raymond/debug_mm_tp_mamba
RaymondLi0 Jul 25, 2025
d3cc158
update import
RaymondLi0 Jul 28, 2025
6d245c0
fix automodel export
RaymondLi0 Jul 28, 2025
61ecb5d
try: remove assert for TP and distillation
RaymondLi0 Jul 28, 2025
2565dac
more verbose config
RaymondLi0 Jul 28, 2025
7a7f12c
use local token_ids instead of modifying batch
RaymondLi0 Jul 29, 2025
743b42c
fix allreduce
RaymondLi0 Jul 29, 2025
c7247dc
fix
RaymondLi0 Jul 29, 2025
3074ec9
revert images_sizes conversion to np array
RaymondLi0 Jul 29, 2025
c4cdd86
debug logs
RaymondLi0 Jul 30, 2025
24d7a05
rm debug logs
RaymondLi0 Aug 5, 2025
37ddef4
changes for stp reverse-kl
RaymondLi0 Aug 5, 2025
a0d7a09
reverse kl: add clamping
RaymondLi0 Aug 6, 2025
72945bc
add loss mask for vision. should also handle padded sequences
RaymondLi0 Aug 6, 2025
26541c6
simplify loss-masking
RaymondLi0 Aug 6, 2025
d66942b
data time debug
RaymondLi0 Aug 6, 2025
f00509d
fix mm embedding indices
RaymondLi0 Aug 6, 2025
8a5e8f0
cleanup
RaymondLi0 Aug 7, 2025
01783de
Mamba2 to be merged (#349)
oleksost Aug 7, 2025
ad6c0c0
checkpoint format for llava
oleksost Aug 7, 2025
cbc94e0
fixes
RaymondLi0 Aug 8, 2025
5263996
[Dev Hybrid] Distill with loss_mask (SFT dataset) and sequence-TP (#350)
oleksost Aug 13, 2025
eb4ad2f
varlen maba (#352)
oleksost Aug 19, 2025
085991a
lr scale
oleksost Aug 20, 2025
d3e3aec
fixes
RaymondLi0 Aug 20, 2025
518ae8d
hybrid checkpoint creation script
oleksost Aug 27, 2025
4e860cc
make hybrid checkpoint script
oleksost Aug 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ jobs:

- name: Install dependencies
run: |
sudo apt install libjpeg-dev
pip install "torch>=2.7.0"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Run tests
run: pytest -v -ra .

Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ jobs:
restore-keys: |
mkdocs-material-
- run: |
sudo apt install libjpeg-dev
pip install "torch>=2.7.0"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand All @@ -56,6 +57,7 @@ jobs:
restore-keys: |
mkdocs-material-
- run: |
sudo apt install libjpeg-dev
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
Expand Down
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ ENV PIP_CONSTRAINT=""
# There is no pre-build mamba image for pytorch 2.8, we build it before the rest to avoid rebuilds.
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
# Using varlen_mamba for variable length sequence support
RUN MAX_JOBS=2 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
RUN MAX_JOBS=2 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/jxiw/varlen_mamba@varlen_mamba"
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
2 changes: 1 addition & 1 deletion Megatron-LM
18 changes: 18 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
images: list[torch.Tensor] | None = None
image_positions: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None

Expand All @@ -49,12 +51,28 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
has_images = False
batch_images = []
for sample in batch:
if sample.images is not None:
batch_images.append([torch.from_numpy(image) for image in sample.images])
has_images = True
else:
batch_images.append([])
batch_image_positions = []
for sample in batch:
if sample.image_positions is not None:
batch_image_positions.append(torch.from_numpy(sample.image_positions))
else:
batch_image_positions.append([])
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
images=batch_images if has_images else None,
image_positions=batch_image_positions if has_images else None,
)


Expand Down
13 changes: 12 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class GPTSamplingParameters(SamplingParameters):
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
truncate_documents: bool = True
patch_size: int | None = None
max_image_size: int | None = None
image_break_token: int | None = None
image_end_token: int | None = None
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand Down Expand Up @@ -142,11 +146,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="Expected number of tokens in the dataset.",
hint=FieldHint.optional,
)
num_pixels: int | None = Field(
default=None,
desc="Expected number of pixels in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
return GPTMemmapDataset(
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels
)


@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"})
Expand Down
6 changes: 3 additions & 3 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def _fim_permute_sequence(
middle = contents[boundaries[0] : boundaries[1]]
suffix = contents[boundaries[1] :]

prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64)
middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64)
suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64)
prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64)
middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64)
suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64)

# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
Expand Down
26 changes: 24 additions & 2 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset":
else GPTSampledIndexedDataset(self, sampling)
)

@property
@abc.abstractmethod
def has_images(self) -> bool:
"""
Whether the dataset contains images.
This is used to determine whether to use image-related fields in the sampled data.
"""


class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset):
"""
Expand All @@ -44,11 +52,16 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]
doc_sizes, im_sizes = self._dataset.get_document_sizes()
return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([])

def get_document_size(self, index: int) -> int:
return self._dataset.get_document_size(self._begin + index)

@property
def has_images(self) -> bool:
return self._dataset.has_images


class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
Expand All @@ -57,8 +70,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
# return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
sizes = [dataset.get_document_sizes() for dataset in self._datasets]
return (
np.concatenate([size[0] for size in sizes]),
np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]),
)

def get_document_size(self, index: int) -> int:
dataset = np.searchsorted(self._dataset_splits[1:], index, side="right")
return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item())

@property
def has_images(self) -> bool:
return any(dataset.has_images for dataset in self._datasets)
Loading