diff --git a/fast_llm/config.py b/fast_llm/config.py index 0004501bd..92f47136b 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1099,3 +1099,13 @@ def pop_nested_dict_value[ return d.pop(keys[-1]) else: return d.pop(keys) + + +class DiffusionStyle(str, enum.Enum): + """ + Type of diffusion masking to use. + """ + + masked = "masked" # masked diffusion with shift + ar_masked = "autoregressive_masked" # autoregressive context with masked diffusion and shift + none = None diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..b1ad015aa 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -9,6 +9,7 @@ import torch import torch.utils.data +from fast_llm.config import DiffusionStyle from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig @@ -34,12 +35,181 @@ class GPTBatch: sequence_lengths: list[torch.Tensor] | None = None chosen_spans: list[torch.Tensor] | None = None rejected_spans: list[torch.Tensor] | None = None + mask_indexes: torch.Tensor | None = None + mask_probabilities: torch.Tensor | None = None + masked_token_ids: torch.Tensor | None = None + loss_weights: torch.Tensor | None = None + in_context_length: torch.Tensor | None = None + in_context: torch.Tensor | None = None + + +def _do_mask(x, mask, mask_token_id): + x = x.clone() + x[mask] = mask_token_id + return x + + +def _do_uniform(x, is_uniform, vocab_size): + x = x.clone() + uniform = torch.randint(0, vocab_size, x.size()) + x[is_uniform] = uniform[is_uniform] + return x + + +def prepare_masked_batch( + data_ids: torch.Tensor, + positions: torch.Tensor, + padded: torch.Tensor, + mask_token_id: int, + vocab_size: int, + context_length: torch.Tensor, + p_mask: torch.Tensor, + p_uniform: float = 0.0, + ar_factor: float = 1.0, + un_factor: float = 1.0, + last_factor: float = 0.0, + in_mask: torch.Tensor = None, + in_uniform: torch.Tensor = None, +) -> dict[str, torch.Tensor]: + + B, L = positions.size() + context_length = context_length.unsqueeze(1).expand(B, L) + p_mask = p_mask.unsqueeze(1) + + # Reminder: a context_length of zero still has one in_context token () + in_context = positions <= context_length + if in_mask is None: + in_mask = (~in_context) & (torch.rand(B, L) < p_mask) + + if in_uniform is None: + in_uniform = (~in_context) & (~in_mask) & (torch.rand(B, L) < p_uniform) + in_clean = (~in_context) & (~in_mask) & (~in_uniform) + + loss_weights = (~padded)[:, 1:] * torch.cat( + [ + ar_factor * in_context[:, 1:] + + in_mask[:, 1:] / p_mask + + un_factor * ((1 - p_uniform) * in_uniform[:, 1:] + p_uniform * in_clean[:, 1:]) / (1 - p_mask), + last_factor * torch.ones(B, 1), + ], + dim=1, + ) + + input_ids = _do_uniform(_do_mask(data_ids[:, :-1], in_mask, mask_token_id), in_uniform, vocab_size) + + return { + "in_context": in_context, + "in_mask": in_mask, + "in_uniform": in_uniform, + "in_clean": in_clean, + "input_ids": input_ids, + "loss_weights": loss_weights, + } def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: + stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None + mask_indexes = None + mask_probabilities = None + masked_token_ids = None + + loss_weights = None + in_context_length = None + in_context = None + + token_ids = torch.from_numpy(stacked_ids) + + if sampling_parameters.diffusion.style == DiffusionStyle.masked: + diffusion_config = sampling_parameters.diffusion + + batch_size, seq_len = token_ids.shape + diffusion_config.mask_token_id + positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1) + padded = torch.zeros_like(token_ids, dtype=torch.bool) + + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,)) + + # Compute the mask probabilities for every sequence in the batch + p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon + + # Do we need to clamp at max_mask_prob? + # p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) + + # Input has an additional token for shitting, is [0, 1, 2, 3, 4] -> [1, 2, 3, 4] + + # index [0, 1, 2, 3, 4, 5] -> + # The labels are already left shifted x = [A, B, C, D, E, F] -> + # embd = [A, B, C, D, E] + # label = [B, C, D, E, F] + # Last input token is dropped from the processing + + # TODO: Padding - 1% data to have partial sequences and padding from Llada + + batch_data = prepare_masked_batch( + data_ids=token_ids, + positions=positions, + padded=padded, + mask_token_id=diffusion_config.mask_token_id, + vocab_size=sampling_parameters.vocab_size, + context_length=-torch.ones(batch_size, dtype=torch.int), # no auto-regressive context tokens + p_mask=p_mask, + p_uniform=0.0, # no uniform shuffling of tokens + ar_factor=0.0, + un_factor=0.0, + last_factor=0.0, + ) + + masked_token_ids = batch_data["input_ids"] + mask_indexes = batch_data["in_mask"] + loss_weights = batch_data["loss_weights"] + in_context = batch_data["in_context"] + + elif sampling_parameters.diffusion.style == DiffusionStyle.ar_masked: + diffusion_config = sampling_parameters.diffusion + batch_size, seq_len = token_ids.shape + data_ids = token_ids + padded = torch.zeros_like(data_ids, dtype=torch.bool) + positions = torch.arange(seq_len - 1).unsqueeze(0).expand(batch_size, seq_len - 1) + + # eg: context_sampler=0.1 then 90% of the batch: C = random [0, seq_len // 4], 10%: C = random in [0, seq_len-2) + prob = torch.rand(1) + context_length = torch.where( + prob > diffusion_config.context_sampler, + torch.randint(0, seq_len // 4, (batch_size,), dtype=torch.long), + torch.randint(0, seq_len - 2, (batch_size,), dtype=torch.long), + ) + # Generate a random tensor of batch size to seed masking probabilities + t = torch.rand((batch_size,)) + # Compute the mask probabilities for every sequence in the batch leaving extrams 0 & 1 + p_mask = (1 - (2 * diffusion_config.epsilon)) * t + diffusion_config.epsilon + + batch_data = prepare_masked_batch( + data_ids=data_ids, + positions=positions, + padded=padded, + mask_token_id=diffusion_config.mask_token_id, + vocab_size=sampling_parameters.vocab_size, + context_length=context_length, + p_mask=p_mask, + p_uniform=0.0, # no uniform shuffling of tokens + ar_factor=diffusion_config.ar_factor, + un_factor=0.0, + last_factor=0.0, + ) + + masked_token_ids = batch_data["input_ids"] + mask_indexes = batch_data["in_mask"] + loss_weights = batch_data["loss_weights"] + in_context_length = context_length + in_context = batch_data["in_context"] + + if sampling_parameters.use_loss_masking_spans: + stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + stacked_chosen_spans = None stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: @@ -49,12 +219,19 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), + token_ids=token_ids, loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, + mask_indexes=mask_indexes, + mask_probabilities=mask_probabilities, + masked_token_ids=masked_token_ids, + loss_weights=loss_weights, + in_context_length=in_context_length, + in_context=in_context, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ef2efedc9..aedc14ef2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,7 +8,16 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + DiffusionStyle, + Field, + FieldHint, + FieldUpdate, + check_field, + config_class, + skip_valid_if_none, +) from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( BlendedDatasetConfig, @@ -44,6 +53,32 @@ class ShufflingType(str, enum.Enum): legacy = "legacy" +@config_class(registry=True) +class DiffusionMaskingConfig(Config): + """Configuration for diffusion-based masking during data preparation.""" + + style: DiffusionStyle = Field( + default=DiffusionStyle.none, desc="Whether to use masked diffusion during training", hint=FieldHint.feature + ) + + epsilon: float = Field( + default=1e-3, desc="Minimum masking probability", hint=FieldHint.performance, valid=check_field(Assert.gt, 0) + ) + + mask_token_id: int = Field(default=103, desc="Token ID to use for masking", hint=FieldHint.optional) + ar_factor: float = Field( + default=1.0, + desc="Factor for the AR weigting on overal loss.", + hint=FieldHint.optional, + ) + context_sampler: float = Field( + default=1.0, desc="Context lenght C sampled in under 25% sequence length vs all", hint=FieldHint.optional + ) + + def _validate(self) -> None: + super()._validate() + + @config_class() class GPTSamplingConfig(SamplingConfig): """ @@ -62,6 +97,10 @@ class GPTSamplingConfig(SamplingConfig): desc="Shuffling strategy.", hint=FieldHint.feature, ) + diffusion: DiffusionMaskingConfig = Field( + desc="Configuration for diffusion-based masking during data preparation.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -79,6 +118,7 @@ class GPTSamplingParameters(SamplingParameters): # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 + diffusion: DiffusionMaskingConfig @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index df603a910..51252f036 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -137,9 +137,8 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # The name (dict key) is used to insert the weight in the kwargs of the forward pass. return {} - @property @abc.abstractmethod - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: pass def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 78aad230f..96e71e947 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -119,7 +119,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.get_loss_defs() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..f2b302def 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -93,7 +93,7 @@ def __init__( self._stages: list[Stage] = self._multi_stage.stages self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) - self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.get_loss_defs()} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 766398d01..4a6ae1612 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -152,7 +152,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.loss_defs + self._loss_defs = self._multi_stage.base_model.get_loss_defs() if not self._is_evaluation_only: steps_per_split = { diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec7..d4f18d79c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -15,6 +15,7 @@ def _torch_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -22,6 +23,8 @@ def _torch_cross_entropy_forward_backward( and separate forward and backward kernels lead to poor performance. TODO: loss masking only works for with labels format and if the masking index is set to -100. """ + assert loss_weight is None, "Loss weight not supported in torch cross-entropy implementation." + # Torch compile doesn't understand this. with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) @@ -82,6 +85,7 @@ def _fused_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None, group: ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -144,14 +148,23 @@ def _fused_cross_entropy_forward_backward( predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) per_sample_loss = sum_exp_logits.log() - predicted_logits - if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() - if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + if loss_weight is None: + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask - return loss, grad + loss = per_sample_loss.mean() + if target_format != TargetFormat.labels and group is not None: + all_reduce(loss, op=ReduceOp.MEAN, group=group) + return loss, grad + else: + # Weight every token loss by the loss weight. Before averaging. + print(f"Loss weight: {loss_weight}, per_sample_loss: {per_sample_loss}") + per_sample_loss = per_sample_loss * loss_weight.view(-1, 1) + grad = grad * loss_weight.view(-1, 1) if grad is not None else None + denom = torch.clamp((loss_weight != 0).sum(), min=1) + + return (per_sample_loss.sum() / denom), grad _CROSS_ENTROPY_IMPLEMENTATIONS = { @@ -170,6 +183,7 @@ def cross_entropy_forward_backward( implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + loss_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -177,6 +191,7 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) @@ -193,5 +208,5 @@ def cross_entropy_forward_backward( ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, loss_weight=loss_weight ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 8cb59c85c..c257edbb2 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -125,6 +125,7 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, + loss_weight: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -133,6 +134,8 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED + assert loss_weight is None, "Loss weight not supported in triton cross-entropy implementation." + # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c4776abe9..e0d22e54f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -22,6 +22,7 @@ class LanguageModelDimNames: class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" + mlm_loss = "masked_language_model_loss" @staticmethod def multi_token_prediction_loss(index: int) -> str: @@ -38,7 +39,11 @@ class LanguageModelKwargs: chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" loss_mask = "loss_mask" + mask_indexes = "mask_indexes" + mask_probabilities = "mask_probabilities" mask_inputs = "mask_inputs" + loss_weights = "loss_weights" + in_context = "in_context" @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 88b0612bd..85549dcec 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -4,7 +4,7 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable +from fast_llm.config import Configurable, DiffusionStyle from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace @@ -363,3 +363,81 @@ def _logits_cross_entropy_forward_backward( # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) if self.training else None + + +class MLMHead(LanguageModelHead): + """ + A masked language model head for diffusion-based training.` + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + prediction_distance: int, + ): + super().__init__(config, tensor_space, prediction_distance) + if config.transformer.diffusion == DiffusionStyle.masked: + self._loss_name = LanguageModelLossNames.mlm_loss + + def _logits_cross_entropy_forward_backward( + self, + input_: torch.Tensor, + target: torch.Tensor | None, + loss_mask: torch.Tensor | None, + weight: torch.Tensor, + grad_output: float, + kwargs: dict, + losses: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + assert target is not None, "MLM head requires target labels" + assert loss_mask is None, "MLM head does not support loss mask" + + logits, context = output_parallel_linear_forward( + input_=input_, + weight=weight, + bias=None, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + ) + + if self.config.transformer.diffusion is not None: + if self.config.transformer.diffusion == DiffusionStyle.masked: + loss_weights = kwargs[LanguageModelKwargs.loss_weights] + loss, grad = cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, + grad_output=grad_output, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weights, + ) + + elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked: + + loss_weights = kwargs[LanguageModelKwargs.loss_weights] + context_index = kwargs[LanguageModelKwargs.in_context] + masked_index = kwargs[LanguageModelKwargs.mask_indexes] + B = loss_weights.shape[0] + masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1) + + # TODO: Need to update cross-entropy implementation to support per-token loss weights. + loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward( + logits.flatten(0, -2), + target=target, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + loss_weight=loss_weights, + ) + # Add these before weighting to display them separately + losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean()) + losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean()) + + del logits + return loss, output_parallel_linear_backward(grad, context) if self.training else None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..c6f43a861 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -371,7 +371,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=kwargs.get(TransformerKwargs.causal, True), softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -381,10 +381,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=kwargs.get(TransformerKwargs.causal, True), softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) + else: # TODO: Avoid the flattens. input_ = self._attn_fused( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f6eaf5890..7e0155f61 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -6,7 +6,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import DiffusionStyle, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -82,6 +82,7 @@ class TransformerKwargs: sequence_length = "sequence_length" # TODO: Move grad_output = "grad_output" + causal = "causal" class TransformerLossNames: @@ -485,6 +486,11 @@ class TransformerConfig(LLMBlockConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + diffusion: DiffusionStyle = Field( + default=DiffusionStyle.none, + desc="Use masked-diffusion for training.", + hint=FieldHint.feature, + ) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..db7c1897c 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -8,6 +8,9 @@ from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.tensor import TensorMeta +# Import the new masked-bidirectional preprocessor for export +from .preprocessing_masked_bidirectional import MaskedBidirectionalAttentionPreprocessor + logger = logging.getLogger(__name__) @@ -160,3 +163,44 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ) kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + + +class MaskedBidirectionalAttentionPreprocessor(BackupAttentionPreprocessor): + """ + Preprocessor for masked-bidirectional attention, as used in masked diffusion mode. + Sets up a bidirectional attention mask for the transformer. + """ + + def __init__(self, config: TransformerConfig, tensor_space): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + + def preprocess(self, batch, kwargs: dict[str, object]) -> None: + sequence_length = kwargs[TransformerKwargs.sequence_length] + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + device = self._tensor_space.distributed.device + dtype = self._tensor_space.distributed_config.training_dtype.torch + + # TODO: Attention masks are created but not used in the current implementation. Only flash attention is used for masked diffusion. + attention_mask = torch.ones( + (sequence_length, sequence_length), + dtype=torch.bool, + device=device, + ) + # Following BackupAttentionPreprocessor + # k and q are same so can use sequence_length + kwargs[TransformerKwargs.attention_mask] = attention_mask[ + None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + ] + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(dtype).min, + dtype=dtype, + device=device, + ) + + # Set causal to False, for flash attention function + kwargs[TransformerKwargs.causal] = False diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index ce7af88d5..8f2c9ab81 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -112,7 +112,7 @@ class YarnRotaryConfig(DefaultRotaryConfig): # TODO: Add descriptions. scale_factor: float = Field(default=8.0, hint=FieldHint.feature) - attention_factor: None | float = Field( + attention_factor: float | None = Field( default=None, hint=FieldHint.feature, ) @@ -127,9 +127,9 @@ class YarnRotaryConfig(DefaultRotaryConfig): original_context_length: int = Field(default=8192, hint=FieldHint.feature) def _validate(self) -> None: - if self.attention_factor is None: - with self._set_implicit_default(): - self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 + # if self.attention_factor is None: + # # with self._set_implicit_default(): + # self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..867e1e33d 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -181,7 +181,10 @@ class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): """ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: - return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor + attention_factor = self._config.attention_factor + if attention_factor is None: + attention_factor = 0.1 * math.log(self._config.scale_factor) + 1.0 + return super()._get_frequencies(sequence_length, kv_channels, device) * attention_factor def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..8b1f16350 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -59,10 +59,9 @@ def preprocess( # TODO: Adjust or reimplement. return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) - @property - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: # TODO: Adjust or reimplement. - return super().loss_defs + return super().get_loss_defs() class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786d..dbd74e37a 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -404,6 +404,7 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing elif type(rotary_config) is YarnRotaryConfig: rotary_scaling = { "rope_type": "yarn", + "factor": rotary_config.scale_factor, "attention_factor": rotary_config.attention_factor, "beta_fast": rotary_config.beta_fast, "beta_slow": rotary_config.beta_slow, @@ -435,6 +436,7 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A elif rotary_type == "yarn": rotary_config.update( { + "scale_factor": rope_scaling.get("factor", DEFAULT), "attention_factor": rope_scaling.get("attention_factor", DEFAULT), "beta_fast": rope_scaling.get("beta_fast", DEFAULT), "beta_slow": rope_scaling.get("beta_slow", DEFAULT), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 00b4ee277..d832eb8ae 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,6 +3,7 @@ import torch +from fast_llm.config import DiffusionStyle from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor @@ -12,7 +13,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead, MLMHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, @@ -20,7 +21,11 @@ TransformerKwargs, TransformerLossNames, ) -from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor +from fast_llm.layers.transformer.preprocessing import ( + BackupAttentionPreprocessor, + FlashAttnVarlenPreprocessor, + MaskedBidirectionalAttentionPreprocessor, +) from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -55,13 +60,21 @@ def __init__( # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) - if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + + # --- Add new preprocessor for masked-bidirectional attention --- + if self._config.transformer.diffusion is not None: + if self._config.transformer.diffusion == DiffusionStyle.masked: + self._preprocessors.append( + MaskedBidirectionalAttentionPreprocessor(self._config.transformer, self._tensor_space) + ) else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + if self._use_flash_attention: + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + else: + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + if self._config.enable_dpo: # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] @@ -78,13 +91,22 @@ def get_output_layers(self) -> list[Layer]: return_input=i < self._config.prediction_heads - 1, ) ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, + if self._config.transformer.diffusion: + layers.append( + MLMHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) + ) + else: + layers.append( + LanguageModelHead( + self._config, + self._tensor_space, + prediction_distance=i, + ) ) - ) return layers def get_layers(self) -> list[Layer]: @@ -323,6 +345,120 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels + + if self._config.transformer.diffusion is not None: + + assert batch.loss_weights is not None, "masked-diffusion mode needs to set loss_weights" + if self._config.transformer.diffusion == DiffusionStyle.masked: + + if not self._use_flash_attention: + raise ValueError( + f"Diffusion style '{DiffusionStyle.masked}' is only implemented with flash-attention." + ) + + kwargs[LanguageModelKwargs.loss_weights] = batch.loss_weights.to( + device=self._tensor_space.distributed.device, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + ) + + # set token ids to masked tokens + batch.token_ids = batch.masked_token_ids.to( + device=self._tensor_space.distributed.device, + dtype=torch.int64, + non_blocking=True, + ) + # IMPORTANT: Need to set both variables + tokens = batch.token_ids + + # TODO: Bi-direction attention with fused-attention (_attn_fused) function needs to be correctly imnplemented. + elif self._config.transformer.diffusion == DiffusionStyle.ar_masked: + + if self._use_flash_attention: + raise ValueError( + f"Diffusion style '{DiffusionStyle.ar_masked}' dose not use flash attention." + ) + + # We are in masked-diffusion mode, so we need to add the mask indexes and probabilities to kwargs + kwargs[LanguageModelKwargs.mask_indexes] = batch.mask_indexes.to( + device=self._tensor_space.distributed.device + ) + + kwargs[LanguageModelKwargs.loss_weights] = batch.loss_weights.to( + device=self._tensor_space.distributed.device + ) + + kwargs[LanguageModelKwargs.in_context] = batch.in_context.to( + device=self._tensor_space.distributed.device + ) + + # Setup bidirection attention for diffusion should we set this in a preprocessor? BackupAttentionPreprocessor? + batch_size, seq_len = batch.token_ids.shape + # seq_len -= 1 # last token is drop from the input + # # Compute attention mask for diffusion + C = batch.in_context_length.to(device=self._tensor_space.distributed.device) + row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view( + 1, seq_len, 1 + ) + col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view( + 1, 1, seq_len + ) + C_exp = C.view(batch_size, 1, 1) + + causal_mask = col_idx <= row_idx + row_idx < C_exp + col_idx < C_exp + + attn_mask = torch.zeros( + batch_size, + seq_len, + seq_len, + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + + for b in range(batch_size): + C_val = C[b].item() + + if C_val > 0: + context_causal = causal_mask[0, :C_val, :C_val] + attn_mask[b, :C_val, :C_val] = context_causal + + if C_val > 0 and C_val < seq_len: + attn_mask[b, C_val:, :C_val] = True + + if C_val < seq_len: + attn_mask[b, C_val:, C_val:] = True + + # Handle padding if needed + if batch.sequence_lengths is not None: + padded = torch.zeros( + batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device + ) + for b in range(batch_size): + padded[b, batch.sequence_lengths[b] :] = True + not_padded = ~padded[:, 1:] + attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2) + + # Reshape to match expected attention mask format + attention_mask = attn_mask.unsqueeze(1).unsqueeze(1) # Add additional dimension + # print(f"attention_mask shape: {attention_mask.shape}\n{attention_mask}") + kwargs[TransformerKwargs.attention_mask] = attention_mask + kwargs[TransformerKwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._tensor_space.distributed_config.training_dtype.torch).min, + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + + # set token ids to masked tokens + batch.token_ids = batch.masked_token_ids.to( + device=self._tensor_space.distributed.device, + dtype=torch.int64, + non_blocking=True, + ) + # IMPORTANT: Need to set both variables + tokens = batch.token_ids + kwargs.update(reference_logits[i]) for preprocessor in self._preprocessors: @@ -365,8 +501,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: else: return {} - @property - def loss_defs(self) -> list[LossDef]: + def get_loss_defs(self) -> list[LossDef]: loss_defs = [] if ( self._config.transformer.num_experts > 1 @@ -390,6 +525,10 @@ def loss_defs(self) -> list[LossDef]: if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) + if self._config.transformer.diffusion: + # Masked LM Loss for masked-diffusion training + loss_defs.append(LossDef(name=LanguageModelLossNames.mlm_loss, formatted_name="MLM Loss", count=1)) + for i in range(self._config.prediction_heads): loss_defs.append( LossDef( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e1..9003b7234 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -31,6 +31,7 @@ def _get_sampling_parameters( "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.prediction_heads, + "diffusion": self._config.data.sampling.diffusion, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/tests/data/test_prepare_masked_batch.py b/tests/data/test_prepare_masked_batch.py new file mode 100644 index 000000000..20c9e48db --- /dev/null +++ b/tests/data/test_prepare_masked_batch.py @@ -0,0 +1,113 @@ +import pytest +import torch + +from fast_llm.data.data.gpt.data import prepare_masked_batch + + +@pytest.mark.parametrize( + "data_ids, positions, padded, mask_token_id, vocab_size, context_length, p_mask, p_uniform, ar_factor, un_factor, last_factor, expected", + [ + # Shift + Masked diffissuion test case + ( + torch.tensor([[42, 67, 76, 14, 26]]), + torch.tensor([[0, 1, 2, 3]]), + torch.tensor([[False, False, False, False, False]]), + 100, + 100, + -torch.ones(1, dtype=torch.int), + torch.Tensor([0.51]), + 0.0, + 0.0, + 0.0, + 0.0, + { + "in_context": torch.tensor([[False, False, False, False]]), + "in_mask": torch.tensor([[False, False, True, False]]), + "in_uniform": torch.tensor([[False, False, False, False]]), + "in_clean": torch.tensor([[True, True, False, True]]), + "input_ids": torch.tensor([[42, 67, 100, 14]]), + "loss_weights": torch.tensor([[0.0, 1.9608, 0.0, 0.0]]), + }, + ), + # Shift + AR context + Masked diffusion test case + ( + torch.tensor([[42, 67, 76, 14, 26, 42, 67, 76, 26]]), + torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), + torch.tensor([[False, False, False, False, False, False, False, False, False]]), + 100, + 100, + torch.ones(1, dtype=torch.int), + torch.Tensor([0.51]), + 0.0, + 1.0, + 0.0, + 0.0, + { + "in_context": torch.tensor([[True, True, False, False, False, False, False, False]]), + "in_mask": torch.tensor([[False, False, True, False, True, False, True, False]]), + "in_uniform": torch.tensor([[False, False, False, False, False, False, False, False]]), + "in_clean": torch.tensor([[False, False, False, True, False, True, False, True]]), + "input_ids": torch.tensor([[42, 67, 100, 14, 100, 42, 100, 76]]), + "loss_weights": torch.tensor([[1.0000, 1.9608, 0.0000, 1.9608, 0.0000, 1.9608, 0.0000, 0.0000]]), + }, + ), + # Shift + AR context + Masked diffusion + Uniform flip test case + ( + torch.tensor([[42, 67, 76, 14, 26, 42, 67, 76, 26]]), + torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), + torch.tensor([[False, False, False, False, False, False, False, False, False]]), + 100, + 100, + torch.ones(1, dtype=torch.int), + torch.Tensor([0.51]), + 1.0, + 1.0, + 1.0, + 0.0, + { + "in_context": torch.tensor([[True, True, False, False, False, False, False, False]]), + "in_mask": torch.tensor([[False, False, True, False, True, False, True, False]]), + "in_uniform": torch.tensor([[False, False, False, True, False, True, False, True]]), + "in_clean": torch.tensor([[False, False, False, False, False, False, False, False]]), + "input_ids": torch.tensor( + [[42, 67, 100, 6, 100, 76, 100, 11]] + ), # new uniformly shuffled tokens 14->6 67->76 26->11 + "loss_weights": torch.tensor([[1.0000, 1.9608, 0.0000, 1.9608, 0.0000, 1.9608, 0.0000, 0.0000]]), + }, + ), + ], +) +def test_prepare_batch_basic( + data_ids, + positions, + padded, + mask_token_id, + vocab_size, + context_length, + p_mask, + p_uniform, + ar_factor, + un_factor, + last_factor, + expected, +): + torch.manual_seed(42) # For reproducibility + batch = prepare_masked_batch( + data_ids, + positions, + padded, + mask_token_id, + vocab_size, + context_length, + p_mask, + p_uniform, + ar_factor, + un_factor, + last_factor, + ) + + for key, value in expected.items(): + if key == "loss_weights": + assert torch.allclose(batch[key], value, atol=1e-4), f"{key} mismatch: {batch[key]} != {value}" + else: + assert torch.equal(batch[key], value), f"{key} mismatch: {batch[key]} != {value}"