Skip to content

Masked Diffusion Training with Shift #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
db28a11
changes for basic LLaDA style diffusion masking support
gopeshh Apr 21, 2025
3d44671
tests for masking and MLM loss
gopeshh Apr 22, 2025
46dd535
temp fixes
nitsanluke Jun 4, 2025
aa8ab4d
tmp fix
nitsanluke Jun 4, 2025
9f348e7
including masked diffusion training setup
nitsanluke Jun 7, 2025
cdc9c96
adding weighted loss
nitsanluke Jun 11, 2025
d71e693
clean up
nitsanluke Jun 11, 2025
072e6c4
add loss weight
nitsanluke Jun 13, 2025
6127544
adding updates to p_mask
nitsanluke Jun 13, 2025
1cf15a8
update error mgs
nitsanluke Jun 16, 2025
f7a46d7
add comments and clean up
nitsanluke Jun 16, 2025
b80024e
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jun 16, 2025
01a683b
fx merge errors
nitsanluke Jun 18, 2025
ba913e1
fix merge issues
nitsanluke Jun 18, 2025
6c0c72d
register mask config
nitsanluke Jun 18, 2025
26aa13a
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jun 18, 2025
3245496
fx merge issues
nitsanluke Jun 18, 2025
5198310
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jun 23, 2025
4ad0bc1
fix labels
nitsanluke Jun 23, 2025
acacfe3
drop old tests
nitsanluke Jun 23, 2025
2a06ed4
tmp fix
nitsanluke Jun 24, 2025
dd68d28
fx tests
nitsanluke Jun 24, 2025
e0a7c80
update missing rotery export
nitsanluke Jun 25, 2025
0306e36
reset attention_factor to old behaviour
nitsanluke Jun 25, 2025
6bcb38d
setting attention to _flash_attn_func
nitsanluke Jun 27, 2025
093aa33
debug
nitsanluke Jun 28, 2025
141ed88
avg only non-zero loss
nitsanluke Jun 28, 2025
8bb00ed
debug remove
nitsanluke Jun 28, 2025
38737d4
remove non-zero weight
nitsanluke Jul 4, 2025
b043efe
revert to mean loss on all tokens
nitsanluke Jul 4, 2025
0c221fd
tmp
nitsanluke Jul 4, 2025
d29af35
adding fused attn
nitsanluke Jul 4, 2025
aa0d08c
include ar+masking
nitsanluke Jul 6, 2025
014b92e
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jul 6, 2025
632dc7c
main update cr loss
nitsanluke Jul 6, 2025
0b469fb
include ar+diff option as a seperate style
nitsanluke Jul 8, 2025
068138f
minor
nitsanluke Jul 8, 2025
a573cd7
attn verificiation checks
nitsanluke Jul 8, 2025
fa952fe
tmp updates
nitsanluke Jul 8, 2025
1d687ea
temp
nitsanluke Jul 8, 2025
2eb5e84
adding updates loss avg on none-zero
nitsanluke Jul 11, 2025
f11c07f
avg across all tokens
nitsanluke Jul 11, 2025
2afd95b
update loss
nitsanluke Jul 14, 2025
e9af787
Merge branch 'main' into luke/gopeshh/masked_diffusion
nitsanluke Jul 14, 2025
4fd1ba5
temp fixes
nitsanluke Jul 16, 2025
e92f567
clean up
nitsanluke Jul 17, 2025
a312607
adding tests and cleanup
nitsanluke Jul 24, 2025
e1fcc9f
adding masking test-case and data cleanup
nitsanluke Jul 28, 2025
0e11f8b
cleaup and move attention to preprocessing
nitsanluke Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,13 @@ def pop_nested_dict_value[
return d.pop(keys[-1])
else:
return d.pop(keys)


class DiffusionStyle(str, enum.Enum):
"""
Type of diffusion masking to use.
"""

masked = "masked" # masked diffusion with shift
ar_masked = "autoregressive_masked" # autoregressive context with masked diffusion and shift
none = None
179 changes: 178 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.utils.data

from fast_llm.config import DiffusionStyle
from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
Expand All @@ -34,12 +35,181 @@ class GPTBatch:
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None
mask_indexes: torch.Tensor | None = None
mask_probabilities: torch.Tensor | None = None
masked_token_ids: torch.Tensor | None = None
loss_weights: torch.Tensor | None = None
in_context_length: torch.Tensor | None = None
in_context: torch.Tensor | None = None


def _do_mask(x, mask, mask_token_id):
x = x.clone()
x[mask] = mask_token_id
return x


def _do_uniform(x, is_uniform, vocab_size):
x = x.clone()
uniform = torch.randint(0, vocab_size, x.size())
x[is_uniform] = uniform[is_uniform]
return x


def prepare_masked_batch(
data_ids: torch.Tensor,
positions: torch.Tensor,
padded: torch.Tensor,
mask_token_id: int,
vocab_size: int,
context_length: torch.Tensor,
p_mask: torch.Tensor,
p_uniform: float = 0.0,
ar_factor: float = 1.0,
un_factor: float = 1.0,
last_factor: float = 0.0,
in_mask: torch.Tensor = None,
in_uniform: torch.Tensor = None,
) -> dict[str, torch.Tensor]:

B, L = positions.size()
context_length = context_length.unsqueeze(1).expand(B, L)
p_mask = p_mask.unsqueeze(1)

# Reminder: a context_length of zero still has one in_context token (<BOS>)
in_context = positions <= context_length
if in_mask is None:
in_mask = (~in_context) & (torch.rand(B, L) < p_mask)

if in_uniform is None:
in_uniform = (~in_context) & (~in_mask) & (torch.rand(B, L) < p_uniform)
in_clean = (~in_context) & (~in_mask) & (~in_uniform)

loss_weights = (~padded)[:, 1:] * torch.cat(
[
ar_factor * in_context[:, 1:]
+ in_mask[:, 1:] / p_mask
+ un_factor * ((1 - p_uniform) * in_uniform[:, 1:] + p_uniform * in_clean[:, 1:]) / (1 - p_mask),
last_factor * torch.ones(B, 1),
],
dim=1,
)

input_ids = _do_uniform(_do_mask(data_ids[:, :-1], in_mask, mask_token_id), in_uniform, vocab_size)

return {
"in_context": in_context,
"in_mask": in_mask,
"in_uniform": in_uniform,
"in_clean": in_clean,
"input_ids": input_ids,
"loss_weights": loss_weights,
}


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:

stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
mask_indexes = None
mask_probabilities = None
masked_token_ids = None

loss_weights = None
in_context_length = None
in_context = None

token_ids = torch.from_numpy(stacked_ids)

if sampling_parameters.diffusion.style == DiffusionStyle.masked:
diffusion_config = sampling_parameters.diffusion

batch_size, seq_len = token_ids.shape
diffusion_config.mask_token_id
positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1)
padded = torch.zeros_like(token_ids, dtype=torch.bool)

# Generate a random tensor of batch size to seed masking probabilities
t = torch.rand((batch_size,))

# Compute the mask probabilities for every sequence in the batch
p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon

# Do we need to clamp at max_mask_prob?
# p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob))

# Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4]

# index [0, 1, 2, 3, 4, 5] ->
# The labels are already left shifted x = [A, B, C, D, E, F] ->
# embd = [A, B, C, D, E]
# label = [B, C, D, E, F]
# Last input token is dropped from the processing

# TODO: Padding - 1% data to have partial sequences and padding from Llada

batch_data = prepare_masked_batch(
data_ids=token_ids,
positions=positions,
padded=padded,
mask_token_id=diffusion_config.mask_token_id,
vocab_size=sampling_parameters.vocab_size,
context_length=-torch.ones(batch_size, dtype=torch.int), # no auto-regressive context tokens
p_mask=p_mask,
p_uniform=0.0, # no uniform shuffling of tokens
ar_factor=0.0,
un_factor=0.0,
last_factor=0.0,
)

masked_token_ids = batch_data["input_ids"]
mask_indexes = batch_data["in_mask"]
loss_weights = batch_data["loss_weights"]
in_context = batch_data["in_context"]

elif sampling_parameters.diffusion.style == DiffusionStyle.ar_masked:
diffusion_config = sampling_parameters.diffusion
batch_size, seq_len = token_ids.shape
data_ids = token_ids
padded = torch.zeros_like(data_ids, dtype=torch.bool)
positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1)

# eg: context_sampler=0.1 then 90% of the batch: C = random [0, seq_len // 4], 10%: C = random in [0, seq_len-2)
prob = torch.rand(1)
context_length = torch.where(
prob > diffusion_config.context_sampler,
torch.randint(0, seq_len // 4, (batch_size,), dtype=torch.long),
torch.randint(0, seq_len - 2, (batch_size,), dtype=torch.long),
)
# Generate a random tensor of batch size to seed masking probabilities
t = torch.rand((batch_size,))
# Compute the mask probabilities for every sequence in the batch leaving extrams 0 & 1
p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon

batch_data = prepare_masked_batch(
data_ids=data_ids,
positions=positions,
padded=padded,
mask_token_id=diffusion_config.mask_token_id,
vocab_size=sampling_parameters.vocab_size,
context_length=context_length,
p_mask=p_mask,
p_uniform=0.0, # no uniform shuffling of tokens
ar_factor=diffusion_config.ar_factor,
un_factor=0.0,
last_factor=0.0,
)

masked_token_ids = batch_data["input_ids"]
mask_indexes = batch_data["in_mask"]
loss_weights = batch_data["loss_weights"]
in_context_length = context_length
in_context = batch_data["in_context"]

if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]

stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
Expand All @@ -49,12 +219,19 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]

return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
token_ids=token_ids,
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
mask_indexes=mask_indexes,
mask_probabilities=mask_probabilities,
masked_token_ids=masked_token_ids,
loss_weights=loss_weights,
in_context_length=in_context_length,
in_context=in_context,
)


Expand Down
42 changes: 41 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@

import yaml

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.config import (
Config,
DiffusionStyle,
Field,
FieldHint,
FieldUpdate,
check_field,
config_class,
skip_valid_if_none,
)
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.dataset.config import (
BlendedDatasetConfig,
Expand Down Expand Up @@ -44,6 +53,32 @@ class ShufflingType(str, enum.Enum):
legacy = "legacy"


@config_class(registry=True)
class DiffusionMaskingConfig(Config):
"""Configuration for diffusion-based masking during data preparation."""

style: DiffusionStyle = Field(
default=DiffusionStyle.none, desc="Whether to use masked diffusion during training", hint=FieldHint.feature
)

epsilon: float = Field(
default=1e-3, desc="Minimum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0)
)

mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional)
ar_factor: float = Field(
default=1.0,
desc="Factor for the AR weigting on overal loss.",
hint=FieldHint.optional,
)
context_sampler: float = Field(
default=1.0, desc="Context lenght C sampled in under 25% sequence length vs all", hint=FieldHint.optional
)

def _validate(self) -> None:
super()._validate()


@config_class()
class GPTSamplingConfig(SamplingConfig):
"""
Expand All @@ -62,6 +97,10 @@ class GPTSamplingConfig(SamplingConfig):
desc="Shuffling strategy.",
hint=FieldHint.feature,
)
diffusion: DiffusionMaskingConfig = Field(
desc="Configuration for diffusion-based masking during data preparation.",
hint=FieldHint.feature,
)


@dataclasses.dataclass(kw_only=True)
Expand All @@ -79,6 +118,7 @@ class GPTSamplingParameters(SamplingParameters):
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
diffusion: DiffusionMaskingConfig


@dataclasses.dataclass(kw_only=True)
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
# The name (dict key) is used to insert the weight in the kwargs of the forward pass.
return {}

@property
@abc.abstractmethod
def loss_defs(self) -> list[LossDef]:
def get_loss_defs(self) -> list[LossDef]:
pass

def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def setup(
phase=PhaseType.validation,
)

self._loss_defs = self._multi_stage.base_model.loss_defs
self._loss_defs = self._multi_stage.base_model.get_loss_defs()
self._evaluation_iterator = None
self._is_setup = True

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self._stages: list[Stage] = self._multi_stage.stages
self._tied_parameters = self._multi_stage.tied_parameters
self._num_stages = len(self._stages)
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs}
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()}

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, config: TrainerConfig):
multi_stage=self._multi_stage,
distributed_config=self._config.model.distributed,
)
self._loss_defs = self._multi_stage.base_model.loss_defs
self._loss_defs = self._multi_stage.base_model.get_loss_defs()

if not self._is_evaluation_only:
steps_per_split = {
Expand Down
Loading