From ba54bc9f572d7204e88644a22fb211a380755d40 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 05:29:32 -0800 Subject: [PATCH 01/22] add FIM dataset support Signed-off-by: dimapihtar --- megatron/training/arguments.py | 11 + megatron/training/datasets/fim_dataset.py | 289 ++++++++++++++++++++++ pretrain_gpt.py | 58 +++-- 3 files changed, 338 insertions(+), 20 deletions(-) create mode 100644 megatron/training/datasets/fim_dataset.py diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 8c533e36f7..24d52ecf89 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2907,6 +2907,17 @@ def _add_data_args(parser): 'If instead this argument is set, the training flow will treat all tokens ' 'that share the same id as the pad token as true pad tokens, potentially ' 'causing severe training instability.') + group.add_argument('--fim-data', action='store_true', help='Whether to use the FIM dataset.') + group.add_argument('--fim-rate', type=float, default=0.5, + help='Probability to convert a training sample into a FIM format.') + group.add_argument('--fim-spm-rate', type=float, default=0.5, + help='Probability that the a FIM sample uses the SPM format over the PSM format.') + group.add_argument('--fim-split-sample', type=str, default=None, + help='String around which to split the sample for FIM.') + group.add_argument('--fim-fragment-rate', type=float, default=None, + help='Rate of FIM on each fragment when --fim-split-sample is not None.') + group.add_argument('--fim-no-prefix', action='store_true', + help='Do not apply FIM to fragments that start with this prefix') return parser diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py new file mode 100644 index 0000000000..dd7e153e71 --- /dev/null +++ b/megatron/training/datasets/fim_dataset.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from typing import Tuple, Optional +from dataclasses import dataclass + +import numpy as np +import logging +from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig +from megatron.core.datasets.indexed_dataset import IndexedDataset +from megatron.core.datasets.utils import Split + +logger = logging.getLogger(__name__) + + +@dataclass +class GPTFIMDatasetConfig(GPTDatasetConfig): + """Configuration object for Megatron Core GPT FIM datasets""" + + rate: float = 0.5 + """Probability to convert a training sample into a FIM format""" + + spm_rate: float = 0.5 + """Probability that the a FIM sample uses the SPM format over the PSM format""" + + split_sample: Optional[str] = None + """String around which to split the sample for FIM""" + + fragment_rate: Optional[float] = None + """Rate of FIM on each fragment when split_sample is not None""" + + no_prefix: Optional[bool] = None + """Do not apply FIM to fragments that start with this prefix""" + + +class GPTFIMDataset(GPTDataset): + """The base GPT dataset + + Args: + indexed_dataset (IndexedDataset): The IndexedDataset around which to build the + MegatronDataset + + indexed_indices (np.ndarray): The set of the documents indices to expose + + num_samples (int): The number of samples to draw from the indexed dataset + + index_split (Split): The indexed_indices Split + + config (GPTFIMDatasetConfig): The GPT-specific container for all config sourced parameters + """ + + def __init__( + self, + indexed_dataset: IndexedDataset, + dataset_path: str, + indexed_indices: np.ndarray, + num_samples: int, + index_split: Split, + config: GPTFIMDatasetConfig, + ) -> None: + super().__init__(indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config) + + self.indexed_dataset = indexed_dataset + self.np_rng = np.random.RandomState(seed=self.config.random_seed) + logger.info(f"Initialized FIM RNG with seed = {self.config.random_seed}") + # get FIM params + self.fim_rate = self.config.rate + self.fim_spm_rate = self.config.spm_rate + self.fragment_fim_rate = self.config.fragment_rate + split_sample = self.config.split_sample + self.fim_split_sample = self.config.tokenizer.tokens_to_ids(split_sample) if split_sample else None + self.no_fim_prefix = self.config.fim.no_prefix + + # get extra tokens ids + fim_tokens = self.config.fim.extra_tokens + fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] + fim_tokens_ids = self.config.tokenizer.tokens_to_ids(fim_tokens) + ( + self.prefix_tok_id, + self.middle_tok_id, + self.suffix_tok_id, + self.pad_tok_id, + self.eod_tok_id, + ) = fim_tokens_ids + + def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, np.ndarray]: + """Get the text (token ids) and document ids for a given index + + Args: + idx (int): The index into the dataset + + Returns: + Tuple[np.ndarray, np.ndarray]: The text ids and document ids + """ + # Do the shuffle mapping + idx = self.shuffle_index[idx] + + # Get the beginning and end documents and offsets + doc_index_beg, doc_index_beg_offset = self.sample_index[idx] + doc_index_end, doc_index_end_offset = self.sample_index[idx + 1] + + document_ids = [] + sample_parts = [] + + # Sample spans a single document + if doc_index_beg == doc_index_end: + # Add the document id + document_ids.append(self.document_index[doc_index_beg]) + + # Add the entire sample + sample_parts.append( + self.indexed_dataset.get( + self.document_index[doc_index_beg], + offset=doc_index_beg_offset, + length=doc_index_end_offset - doc_index_beg_offset + 1, + ) + ) + + # Sample spans multiple documents + else: + for i in range(doc_index_beg, doc_index_end + 1): + # Add the document id + document_ids.append(self.document_index[i]) + + # Add the sample part + offset = 0 if i > doc_index_beg else doc_index_beg_offset + length = None if i < doc_index_end else doc_index_end_offset + 1 + sample_parts.append(self.indexed_dataset.get(self.document_index[i], offset=offset, length=length)) + + sample = np.concatenate(sample_parts) + + sample_len = sample.shape[0] + segment_breaks = np.argwhere(sample == self.eod_tok_id) + np_rng = self.np_rng + + if segment_breaks.shape != (0, 1): # then there is an EOD token in this example + curr_start_position = 0 + new_samples = [] + for loc in np.nditer(segment_breaks): + # Only permute non-empty segments. + if loc - curr_start_position > 0: + # permute {prefix, suffix, middle} or {suffix, prefix, middle} + permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:loc], np_rng) + new_samples += [permuted, [self.eod_tok_id]] + + curr_start_position = loc + 1 # jump over the EOD token + # Permute the segment after the last EOD + permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) + new_samples.append(permuted) + + sample = np.concatenate(new_samples) + else: + sample = self._fim_split_and_permute_sequence(sample, np_rng) + + diff = sample.shape[0] - sample_len + if diff > 0: # too long + sample = sample[:sample_len] + elif diff < 0: # too short + sample = np.concatenate([sample, np.full((-1 * diff), self.pad_tok_id)]) + + assert sample.shape[0] == sample_len + + return ( + np.array(sample, dtype=np.int64), + np.array(document_ids, dtype=np.int64), + ) + + def _fim_permute_sequence(self, sequence, np_rng, rate): + return self._permute( + sequence, + np_rng, + rate, + self.fim_spm_rate, + self.config.tokenizer, + truncate_or_pad=False, + suffix_tok_id=self.suffix_tok_id, + prefix_tok_id=self.prefix_tok_id, + middle_tok_id=self.middle_tok_id, + pad_tok_id=self.pad_tok_id, + no_fim_prefix=self.no_fim_prefix, + ) + + def _fim_split_and_permute_sequence(self, sequence, np_rng): + """ + If self.fim_split_sample is not None, split the sequence. + Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None. + """ + if self.fim_split_sample is None: + return self._fim_permute_sequence(sequence, np_rng, self.fim_rate) + # fim_split_sample is set: split the sample on this token and permute each fragment separately. + # Typically, if each sample is a repository, then we split again on the file level. + # Each fragment is a file, and we permute the files. + fragment_breaks = np.argwhere(sequence == self.fim_split_sample) + if fragment_breaks.shape == (0, 1): + # no split token in this sample + return self._fim_permute_sequence(sequence, np_rng, self.fim_rate) + if not np_rng.binomial(1, self.fim_rate): + # don't do FIM preproc + return sequence + # Do FIM on each fragment + curr_start_position = 0 + new_samples = [] + for loc in np.nditer(fragment_breaks): + if loc - curr_start_position > 0: + permuted = self._fim_permute_sequence( + sequence[curr_start_position:loc], np_rng, self.fragment_fim_rate + ) + new_samples += [permuted, [self.fim_split_sample]] + curr_start_position = loc + 1 # Jump over the split token + # Permute the segment after the last split token + permuted = self._fim_permute_sequence(sequence[curr_start_position:], np_rng, self.fragment_fim_rate) + new_samples.append(permuted) + + return np.concatenate(new_samples) + + def _permute( + self, + sample, + np_rng, + fim_rate, + fim_spm_rate, + tokenizer, + truncate_or_pad=True, + suffix_tok_id=None, + prefix_tok_id=None, + middle_tok_id=None, + pad_tok_id=None, + no_fim_prefix=None, + ): + """ + Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. + Maintain the same sample length (if transform creates a few extra tokens, drop them). + """ + if np_rng.binomial(1, fim_rate): # sample bernoulli dist + + contents = tokenizer.ids_to_text(sample) + + # Do not apply FIM if the sample starts with no_fim_prefix + if no_fim_prefix is not None and contents.startswith(no_fim_prefix): + return sample + + try: + # A boundary can be =0 (prefix will be empty) + # a boundary can be =len(contents) (suffix will be empty) + # The two boundaries can be equal (middle will be empty) + boundaries = list(np_rng.randint(low=0, high=len(contents) + 1, size=2)) + boundaries.sort() + except ValueError as e: + print(len(contents), contents) + print(e) + raise e + + prefix = contents[: boundaries[0]] + middle = contents[boundaries[0] : boundaries[1]] + suffix = contents[boundaries[1] :] + + prefix = np.array([*tokenizer.text_to_ids(prefix)], dtype=np.int64) + middle = np.array([*tokenizer.text_to_ids(middle)], dtype=np.int64) + suffix = np.array([*tokenizer.text_to_ids(suffix)], dtype=np.int64) + + # here we truncate each given segment to fit the same length as it was before + # A consequence is that we never reach the end of a file? + # we should rather truncate at the context-level + if truncate_or_pad: + # need to make same length as the input. Take the 3 sentinel tokens into account + new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3 + diff = new_length - sample.shape[0] + if diff > 0: # too long + if ( + suffix.shape[0] <= diff + ): # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening + return sample, np_rng + suffix = suffix[: suffix.shape[0] - diff] + elif diff < 0: # too short + suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) + + if np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate([[prefix_tok_id, suffix_tok_id], suffix, [middle_tok_id], prefix, middle]) + else: + # PSM + new_sample = np.concatenate( + [[prefix_tok_id], prefix, [suffix_tok_id], suffix, [middle_tok_id], middle] + ) + + else: + # don't do FIM preproc + new_sample = sample + + return new_sample diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 69f26f3271..d48534e7b5 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -18,6 +18,7 @@ from megatron.core.utils import StragglerDetector, get_attr_wrapped_model from megatron.training import get_args, get_timers, get_tokenizer, inprocess_restart, pretrain, print_rank_0 from megatron.training.datasets.sft_dataset import SFTDataset +from megatron.training.datasets.fim_dataset import GPTFIMDataset, GPTFIMDatasetConfig from megatron.training.utils import ( get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, @@ -172,26 +173,41 @@ def core_gpt_dataset_config_from_args(args): blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] blend, blend_per_split = get_blend_and_blend_per_split(args) - return GPTDatasetConfig( - random_seed=args.seed, - sequence_length=args.seq_length, - blend=blend, - blend_per_split=blend_per_split, - split=args.split, - multiple_validation_sets=args.multiple_validation_sets, - full_validation=args.full_validation, - num_dataset_builder_threads=args.num_dataset_builder_threads, - path_to_cache=args.data_cache_path, - mmap_bin_files=args.mmap_bin_files, - tokenizer=tokenizer, - reset_position_ids=args.reset_position_ids, - reset_attention_mask=args.reset_attention_mask, - eod_mask_loss=args.eod_mask_loss, - create_attention_mask=args.create_attention_mask_in_dataloader, - object_storage_cache_path=args.object_storage_cache_path, - mid_level_dataset_surplus=args.mid_level_dataset_surplus, - allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, - ) + data_args = { + "random_seed": args.seed, + "sequence_length": args.seq_length, + "blend": blend, + "blend_per_split": blend_per_split, + "split": args.split, + "multiple_validation_sets": args.multiple_validation_sets, + "full_validation": args.full_validation, + "num_dataset_builder_threads": args.num_dataset_builder_threads, + "path_to_cache": args.data_cache_path, + "mmap_bin_files": args.mmap_bin_files, + "tokenizer": tokenizer, + "reset_position_ids": args.reset_position_ids, + "reset_attention_mask": args.reset_attention_mask, + "eod_mask_loss": args.eod_mask_loss, + "create_attention_mask": args.create_attention_mask_in_dataloader, + "object_storage_cache_path": args.object_storage_cache_path, + "mid_level_dataset_surplus": args.mid_level_dataset_surplus, + "allow_ambiguous_pad_tokens": args.allow_ambiguous_pad_tokens, + } + + # add FIM args to the config + if args.fim_data: + data_args.update( + { + "rate": args.fim_rate, + "spm_rate": args.fim_spm_rate, + "spli_sample": args.fim_split_sample, + "fragment_rate": args.fragment_rate, + "no_prefix": args.no_prefix, + } + ) + return GPTFIMDatasetConfig(**data_args) + + return GPTDatasetConfig(**data_args) def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None): @@ -209,6 +225,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None else: if args.mock_data: dataset_type = MockGPTDataset + elif args.fim_data: + dataset_type = GPTFIMDataset else: dataset_type = GPTDataset From 1f1cba72ae24f230ef00cc4f510600acf3f3c322 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 06:24:00 -0800 Subject: [PATCH 02/22] fix issues Signed-off-by: dimapihtar --- megatron/training/arguments.py | 12 ++++++++- megatron/training/datasets/fim_dataset.py | 31 +++++++++++++---------- pretrain_gpt.py | 14 +++++++--- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 24d52ecf89..3e8d6fca8f 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2916,8 +2916,18 @@ def _add_data_args(parser): help='String around which to split the sample for FIM.') group.add_argument('--fim-fragment-rate', type=float, default=None, help='Rate of FIM on each fragment when --fim-split-sample is not None.') - group.add_argument('--fim-no-prefix', action='store_true', + group.add_argument('--fim-no-prefix', type=str, default=None, help='Do not apply FIM to fragments that start with this prefix') + group.add_argument('--fim-prefix-token', type=str, default='', + help='FIM prefix token') + group.add_argument('--fim-middle-token', type=str, default='', + help='FIM middle token') + group.add_argument('--fim-suffix-token', type=str, default='', + help='FIM suffix token') + group.add_argument('--fim-pad-token', type=str, default='', + help='FIM PAD token') + group.add_argument('--fim-eod-token', type=str, default='<|endoftext|>', + help='FIM EOD token') return parser diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index dd7e153e71..f4289259f9 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from typing import Tuple, Optional -from dataclasses import dataclass +from typing import Dict, Tuple, Optional +from dataclasses import dataclass, field import numpy as np import logging @@ -16,19 +16,22 @@ class GPTFIMDatasetConfig(GPTDatasetConfig): """Configuration object for Megatron Core GPT FIM datasets""" - rate: float = 0.5 + rate: float = None """Probability to convert a training sample into a FIM format""" - spm_rate: float = 0.5 + spm_rate: float = None """Probability that the a FIM sample uses the SPM format over the PSM format""" + extra_tokens: Dict = None + """FIM extra tokens. Should consist of prefix, middle, suffix, PAD, and EOD tokens.""" + split_sample: Optional[str] = None """String around which to split the sample for FIM""" fragment_rate: Optional[float] = None """Rate of FIM on each fragment when split_sample is not None""" - no_prefix: Optional[bool] = None + no_prefix: Optional[str] = None """Do not apply FIM to fragments that start with this prefix""" @@ -67,13 +70,13 @@ def __init__( self.fim_spm_rate = self.config.spm_rate self.fragment_fim_rate = self.config.fragment_rate split_sample = self.config.split_sample - self.fim_split_sample = self.config.tokenizer.tokens_to_ids(split_sample) if split_sample else None - self.no_fim_prefix = self.config.fim.no_prefix + self.fim_split_sample = self.config.tokenizer._tokenizer.tokens_to_ids(split_sample) if split_sample else None + self.no_fim_prefix = self.config.no_prefix # get extra tokens ids - fim_tokens = self.config.fim.extra_tokens - fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] - fim_tokens_ids = self.config.tokenizer.tokens_to_ids(fim_tokens) + fim_tokens = self.config.extra_tokens + fim_tokens = [fim_tokens["prefix"], fim_tokens["middle"], fim_tokens["suffix"], fim_tokens["pad"], fim_tokens["eod"]] + fim_tokens_ids = self.config.tokenizer._tokenizer.tokens_to_ids(fim_tokens) ( self.prefix_tok_id, self.middle_tok_id, @@ -232,7 +235,7 @@ def _permute( """ if np_rng.binomial(1, fim_rate): # sample bernoulli dist - contents = tokenizer.ids_to_text(sample) + contents = tokenizer._tokenizer.ids_to_text(sample) # Do not apply FIM if the sample starts with no_fim_prefix if no_fim_prefix is not None and contents.startswith(no_fim_prefix): @@ -253,9 +256,9 @@ def _permute( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*tokenizer.text_to_ids(prefix)], dtype=np.int64) - middle = np.array([*tokenizer.text_to_ids(middle)], dtype=np.int64) - suffix = np.array([*tokenizer.text_to_ids(suffix)], dtype=np.int64) + prefix = np.array([*tokenizer._tokenizer.text_to_ids(prefix)], dtype=np.int64) + middle = np.array([*tokenizer._tokenizer.text_to_ids(middle)], dtype=np.int64) + suffix = np.array([*tokenizer._tokenizer.text_to_ids(suffix)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/pretrain_gpt.py b/pretrain_gpt.py index d48534e7b5..2b1a72c042 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -196,13 +196,21 @@ def core_gpt_dataset_config_from_args(args): # add FIM args to the config if args.fim_data: + extra_tokens = { + "prefix": args.fim_prefix_token, + "middle": args.fim_middle_token, + "suffix": args.fim_suffix_token, + "pad": args.fim_pad_token, + "eod": args.fim_eod_token, + } data_args.update( { "rate": args.fim_rate, "spm_rate": args.fim_spm_rate, - "spli_sample": args.fim_split_sample, - "fragment_rate": args.fragment_rate, - "no_prefix": args.no_prefix, + "extra_tokens": extra_tokens, + "split_sample": args.fim_split_sample, + "fragment_rate": args.fim_fragment_rate, + "no_prefix": args.fim_no_prefix, } ) return GPTFIMDatasetConfig(**data_args) From c03fd16b9e196634fae3cd9bf890bcb2bd48c182 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 07:04:09 -0800 Subject: [PATCH 03/22] add assertions Signed-off-by: dimapihtar --- megatron/training/arguments.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3e8d6fca8f..91d40f5c15 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1067,6 +1067,19 @@ def validate_args(args, defaults={}): any([args.train_data_path, args.valid_data_path, args.test_data_path]) \ <= 1, "A single data source must be provided in training mode, else None" + if args.fim_data: + extra_tokens = [ + args.fim_prefix_token, + args.fim_middle_token, + args.fim_suffix_token, + args.fim_pad_token, + args.fim_eod_token, + ] + assert not args.mock_data, "Mock dataset is not supported with FIM dataset." + assert args.fim_rate, "--fim-rate should be specified." + assert args.fim_spm_rate, "--fim-spm-rate should be specified." + assert all(token is not None for token in extra_tokens), "FIM extra tokens should be specified." + # Deterministic mode if args.deterministic_mode: assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode." From efa6ac77250edfbf9fef9a759e5b6482e90f34cb Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 10:58:04 -0800 Subject: [PATCH 04/22] add unit & functional tests Signed-off-by: dimapihtar --- .../text/libraries/null_tokenizer.py | 8 +++ megatron/training/arguments.py | 1 + .../model_config.yaml | 56 ++++++++++++++++ tests/test_utils/recipes/gpt.yaml | 5 ++ tests/unit_tests/data/test_fim_dataset.py | 66 +++++++++++++++++++ 5 files changed, 136 insertions(+) create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml create mode 100644 tests/unit_tests/data/test_fim_dataset.py diff --git a/megatron/core/tokenizers/text/libraries/null_tokenizer.py b/megatron/core/tokenizers/text/libraries/null_tokenizer.py index 13d5643619..4ddf77fc77 100644 --- a/megatron/core/tokenizers/text/libraries/null_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/null_tokenizer.py @@ -25,6 +25,14 @@ def ids_to_text(self, ids): text = [str(x) for x in ids] return ' '.join(text) + def tokens_to_ids(self, tokens): + """Converts tokens to ids.""" + return [int(x) for x in tokens] + + def ids_to_tokens(self, ids): + """Converts ids to tokens.""" + return [str(x) for x in ids] + def offsets(self, ids: list[int], text: str) -> list[int]: """Returns offsets.""" offsets, start_idx = [], 0 diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 91d40f5c15..8689ed3bd7 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1076,6 +1076,7 @@ def validate_args(args, defaults={}): args.fim_eod_token, ] assert not args.mock_data, "Mock dataset is not supported with FIM dataset." + assert not args.legacy_tokenizer, "FIM dataset is not supported with legacy tokenizers." assert args.fim_rate, "--fim-rate should be specified." assert args.fim_spm_rate, "--fim-spm-rate should be specified." assert all(token is not None for token in extra_tokens), "FIM extra tokens should be specified." diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml new file mode 100644 index 0000000000..b0b0807085 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml @@ -0,0 +1,56 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 0 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_SAVE_PATH} + --load: ${CHECKPOINT_LOAD_PATH} + --data-path: ${DATA_PATH}/text/the_pile/shard00/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/text/the_pile/shard00/bpe/vocab.json + --merge-file: ${DATA_PATH}/text/the_pile/shard00/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 1 + --use-distributed-optimizer: true + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch_dist + --dist-ckpt-strictness: log_all # backward compatibility for TE changes + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true + --attention-backend: unfused + --log-memory-to-tensorboard: true + --fim-data: true + --fim-rate: 0.2 + --fim-spm-rate: 0.2 +TEST_TYPE: regular diff --git a/tests/test_utils/recipes/gpt.yaml b/tests/test_utils/recipes/gpt.yaml index 1f40711246..4b0c89e13f 100644 --- a/tests/test_utils/recipes/gpt.yaml +++ b/tests/test_utils/recipes/gpt.yaml @@ -114,6 +114,11 @@ products: platforms: [dgx_h100] - environment: [lts] scope: [nightly] + - test_case: [gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset] + products: + - environment: [dev] + scope: [mr, mr-github] + platforms: [dgx_h100] - test_case: [gpt3_mcore_te_tp1_pp1_resume_torch_dist_dist_optimizer] products: - environment: [dev] diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py new file mode 100644 index 0000000000..8a6363e6cb --- /dev/null +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import torch + +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.training.datasets.fim_dataset import GPTFIMDatasetConfig, GPTFIMDataset +from megatron.core.datasets.utils import compile_helpers +from megatron.core.tokenizers import MegatronTokenizer +from tests.unit_tests.test_utilities import Utils +from megatron.core.datasets.utils import get_blend_from_list + + +def test_fim_gpt_dataset(): + if torch.distributed.is_available(): + Utils.initialize_distributed() + if torch.distributed.get_rank() == 0: + compile_helpers() + torch.distributed.barrier() + else: + compile_helpers() + + tokenizer = MegatronTokenizer.from_pretrained( + metadata_path={"library": "null"}, + vocab_size=131072, + ) + blend = get_blend_from_list(["/opt/data/datasets/train/test_text_document"]) + extra_tokens = { + "prefix": "777", + "middle": "888", + "suffix": "999", + "pad": "666", + "eod": "000", + } + seq_length = 8 + rate = 0.2 + spm_rate = 0.2 + fragment_rate = 0.5 + config = GPTFIMDatasetConfig( + blend=blend, + random_seed=1234, + sequence_length=seq_length, + split="990,9,1", + tokenizer=tokenizer, + reset_position_ids=True, + reset_attention_mask=True, + eod_mask_loss=True, + extra_tokens=extra_tokens, + rate=rate, + spm_rate=spm_rate, + fragment_rate=fragment_rate, + no_prefix="111214", + ) + + datasets = BlendedMegatronDatasetBuilder( + GPTFIMDataset, [10, 10, 10], lambda: True, config + ).build() + + dataset = datasets[0] + assert dataset.fim_rate == rate + assert dataset.fim_spm_rate == spm_rate + assert dataset.fragment_fim_rate == 0.5 + assert dataset[0]["tokens"].tolist() == [343, 54365900, 77, 131072, 111214, 343, 54365900,77] + + +if __name__ == "__main__": + test_fim_gpt_dataset() \ No newline at end of file From 5b8d7eb2b304b932dfe3d5ef916de46bde6c9227 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 11:02:05 -0800 Subject: [PATCH 05/22] fix code style Signed-off-by: dimapihtar --- tests/unit_tests/data/test_fim_dataset.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py index 8a6363e6cb..bbfb0706d2 100644 --- a/tests/unit_tests/data/test_fim_dataset.py +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -20,17 +20,10 @@ def test_fim_gpt_dataset(): compile_helpers() tokenizer = MegatronTokenizer.from_pretrained( - metadata_path={"library": "null"}, - vocab_size=131072, + metadata_path={"library": "null"}, vocab_size=131072 ) blend = get_blend_from_list(["/opt/data/datasets/train/test_text_document"]) - extra_tokens = { - "prefix": "777", - "middle": "888", - "suffix": "999", - "pad": "666", - "eod": "000", - } + extra_tokens = {"prefix": "777", "middle": "888", "suffix": "999", "pad": "666", "eod": "000"} seq_length = 8 rate = 0.2 spm_rate = 0.2 @@ -59,7 +52,7 @@ def test_fim_gpt_dataset(): assert dataset.fim_rate == rate assert dataset.fim_spm_rate == spm_rate assert dataset.fragment_fim_rate == 0.5 - assert dataset[0]["tokens"].tolist() == [343, 54365900, 77, 131072, 111214, 343, 54365900,77] + assert dataset[0]["tokens"].tolist() == [343, 54365900, 77, 131072, 111214, 343, 54365900, 77] if __name__ == "__main__": From 0a2ad73b020b925a742aeeab2b3d7a86c4f7b27b Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 11:04:42 -0800 Subject: [PATCH 06/22] fix code style Signed-off-by: dimapihtar --- tests/unit_tests/data/test_fim_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py index bbfb0706d2..20fae22874 100644 --- a/tests/unit_tests/data/test_fim_dataset.py +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -56,4 +56,4 @@ def test_fim_gpt_dataset(): if __name__ == "__main__": - test_fim_gpt_dataset() \ No newline at end of file + test_fim_gpt_dataset() From 26145dd24db9bd9dd8d458c89f267ab5a37ebe5d Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 11:07:50 -0800 Subject: [PATCH 07/22] fix code style Signed-off-by: dimapihtar --- tests/unit_tests/data/test_fim_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py index 20fae22874..f8ee89f403 100644 --- a/tests/unit_tests/data/test_fim_dataset.py +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -3,11 +3,10 @@ import torch from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.training.datasets.fim_dataset import GPTFIMDatasetConfig, GPTFIMDataset -from megatron.core.datasets.utils import compile_helpers +from megatron.core.datasets.utils import compile_helpers, get_blend_from_list from megatron.core.tokenizers import MegatronTokenizer +from megatron.training.datasets.fim_dataset import GPTFIMDataset, GPTFIMDatasetConfig from tests.unit_tests.test_utilities import Utils -from megatron.core.datasets.utils import get_blend_from_list def test_fim_gpt_dataset(): From 7a793e77cb0aa503da80e1b05a15ef39df072107 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 11:23:03 -0800 Subject: [PATCH 08/22] add readme Signed-off-by: dimapihtar --- megatron/training/datasets/README.md | 41 ++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 megatron/training/datasets/README.md diff --git a/megatron/training/datasets/README.md b/megatron/training/datasets/README.md new file mode 100644 index 0000000000..91eb47440a --- /dev/null +++ b/megatron/training/datasets/README.md @@ -0,0 +1,41 @@ +# Data Pipeline + +## FIM dataset + +`GPTFIMDataset` extends Megatron-Core’s `GPTDataset` to support **Fill-in-the-Middle (FIM)** data augmentation. +It probabilistically converts samples into FIM format using configurable rates, with support for both PSM and SPM patterns, fragment-level splitting, and length-preserving output. + +`GPTFIMDatasetConfig` provides the configuration needed to enable this behavior. +`GPTFIMDatasetConfig` configuration object extending `GPTDatasetConfig` to enable FIM preprocessing. + +**Attributes** + +- `rate`: Probability of converting a sample into a FIM example. + +- `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). + +- `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}. + +- `split_sample`: Optional token around which samples are split before applying FIM. + +- `fragment_rate`: Probability of applying FIM to each fragment when split_sample is used. + +- `no_prefix`: If the decoded sequence starts with this prefix, FIM is skipped. + +`GPTFIMDataset` dataset class that loads token sequences from an `IndexedDataset` and applies FIM transformations before returning each sample. + +**PSM Format** +``` +[prefix_tok] prefix [suffix_tok] suffix [middle_tok] middle +``` + +**SPM Format** +``` +[prefix_tok, suffix_tok] suffix [middle_tok] prefix middle +``` + +**Special cases:** + +- If the sequence starts with no_prefix, FIM is skipped. + +- If FIM is not applied, the sample is returned unchanged. \ No newline at end of file From e17dc5bddea71c26a9fbf9bb8c9c77e0dff08ca7 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Tue, 18 Nov 2025 11:26:11 -0800 Subject: [PATCH 09/22] fix readme Signed-off-by: dimapihtar --- megatron/training/datasets/README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/megatron/training/datasets/README.md b/megatron/training/datasets/README.md index 91eb47440a..7538d6d2be 100644 --- a/megatron/training/datasets/README.md +++ b/megatron/training/datasets/README.md @@ -11,17 +11,11 @@ It probabilistically converts samples into FIM format using configurable rates, **Attributes** - `rate`: Probability of converting a sample into a FIM example. - - `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). - - `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}. - - `split_sample`: Optional token around which samples are split before applying FIM. - - `fragment_rate`: Probability of applying FIM to each fragment when split_sample is used. - - `no_prefix`: If the decoded sequence starts with this prefix, FIM is skipped. - `GPTFIMDataset` dataset class that loads token sequences from an `IndexedDataset` and applies FIM transformations before returning each sample. **PSM Format** @@ -37,5 +31,4 @@ It probabilistically converts samples into FIM format using configurable rates, **Special cases:** - If the sequence starts with no_prefix, FIM is skipped. - - If FIM is not applied, the sample is returned unchanged. \ No newline at end of file From 092ecb95f87191ba109b94236052607be0857ec6 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Wed, 19 Nov 2025 05:33:03 -0800 Subject: [PATCH 10/22] fix np_rng usage Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 32 +++++++++++------------ 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index f4289259f9..15e5d26e5b 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -133,7 +133,6 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, sample_len = sample.shape[0] segment_breaks = np.argwhere(sample == self.eod_tok_id) - np_rng = self.np_rng if segment_breaks.shape != (0, 1): # then there is an EOD token in this example curr_start_position = 0 @@ -142,17 +141,17 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, # Only permute non-empty segments. if loc - curr_start_position > 0: # permute {prefix, suffix, middle} or {suffix, prefix, middle} - permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:loc], np_rng) + permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:loc]) new_samples += [permuted, [self.eod_tok_id]] curr_start_position = loc + 1 # jump over the EOD token # Permute the segment after the last EOD - permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) + permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:]) new_samples.append(permuted) sample = np.concatenate(new_samples) else: - sample = self._fim_split_and_permute_sequence(sample, np_rng) + sample = self._fim_split_and_permute_sequence(sample) diff = sample.shape[0] - sample_len if diff > 0: # too long @@ -167,10 +166,10 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, np.array(document_ids, dtype=np.int64), ) - def _fim_permute_sequence(self, sequence, np_rng, rate): + def _fim_permute_sequence(self, sequence, rate): return self._permute( sequence, - np_rng, + self.np_rng, rate, self.fim_spm_rate, self.config.tokenizer, @@ -182,21 +181,21 @@ def _fim_permute_sequence(self, sequence, np_rng, rate): no_fim_prefix=self.no_fim_prefix, ) - def _fim_split_and_permute_sequence(self, sequence, np_rng): + def _fim_split_and_permute_sequence(self, sequence): """ If self.fim_split_sample is not None, split the sequence. Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None. """ if self.fim_split_sample is None: - return self._fim_permute_sequence(sequence, np_rng, self.fim_rate) + return self._fim_permute_sequence(sequence, self.fim_rate) # fim_split_sample is set: split the sample on this token and permute each fragment separately. # Typically, if each sample is a repository, then we split again on the file level. # Each fragment is a file, and we permute the files. fragment_breaks = np.argwhere(sequence == self.fim_split_sample) if fragment_breaks.shape == (0, 1): # no split token in this sample - return self._fim_permute_sequence(sequence, np_rng, self.fim_rate) - if not np_rng.binomial(1, self.fim_rate): + return self._fim_permute_sequence(sequence, self.fim_rate) + if not self.np_rng.binomial(1, self.fim_rate): # don't do FIM preproc return sequence # Do FIM on each fragment @@ -205,12 +204,12 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng): for loc in np.nditer(fragment_breaks): if loc - curr_start_position > 0: permuted = self._fim_permute_sequence( - sequence[curr_start_position:loc], np_rng, self.fragment_fim_rate + sequence[curr_start_position:loc], self.fragment_fim_rate ) new_samples += [permuted, [self.fim_split_sample]] curr_start_position = loc + 1 # Jump over the split token # Permute the segment after the last split token - permuted = self._fim_permute_sequence(sequence[curr_start_position:], np_rng, self.fragment_fim_rate) + permuted = self._fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate) new_samples.append(permuted) return np.concatenate(new_samples) @@ -218,7 +217,6 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng): def _permute( self, sample, - np_rng, fim_rate, fim_spm_rate, tokenizer, @@ -233,7 +231,7 @@ def _permute( Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. Maintain the same sample length (if transform creates a few extra tokens, drop them). """ - if np_rng.binomial(1, fim_rate): # sample bernoulli dist + if self.np_rng.binomial(1, fim_rate): # sample bernoulli dist contents = tokenizer._tokenizer.ids_to_text(sample) @@ -245,7 +243,7 @@ def _permute( # A boundary can be =0 (prefix will be empty) # a boundary can be =len(contents) (suffix will be empty) # The two boundaries can be equal (middle will be empty) - boundaries = list(np_rng.randint(low=0, high=len(contents) + 1, size=2)) + boundaries = list(self.np_rng.randint(low=0, high=len(contents) + 1, size=2)) boundaries.sort() except ValueError as e: print(len(contents), contents) @@ -271,12 +269,12 @@ def _permute( if ( suffix.shape[0] <= diff ): # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening - return sample, np_rng + return sample, self.np_rng suffix = suffix[: suffix.shape[0] - diff] elif diff < 0: # too short suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) - if np_rng.binomial(1, fim_spm_rate): + if self.np_rng.binomial(1, fim_spm_rate): # SPM (variant 2 from FIM paper) new_sample = np.concatenate([[prefix_tok_id, suffix_tok_id], suffix, [middle_tok_id], prefix, middle]) else: From c7766e907a5dc02494a9b7db0fd99d373f7c9a8e Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Wed, 19 Nov 2025 05:35:41 -0800 Subject: [PATCH 11/22] remove self.indexed_dataset Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index 15e5d26e5b..5162141452 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -62,7 +62,6 @@ def __init__( ) -> None: super().__init__(indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config) - self.indexed_dataset = indexed_dataset self.np_rng = np.random.RandomState(seed=self.config.random_seed) logger.info(f"Initialized FIM RNG with seed = {self.config.random_seed}") # get FIM params @@ -111,7 +110,7 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, # Add the entire sample sample_parts.append( - self.indexed_dataset.get( + self.dataset.get( self.document_index[doc_index_beg], offset=doc_index_beg_offset, length=doc_index_end_offset - doc_index_beg_offset + 1, @@ -127,7 +126,7 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, # Add the sample part offset = 0 if i > doc_index_beg else doc_index_beg_offset length = None if i < doc_index_end else doc_index_end_offset + 1 - sample_parts.append(self.indexed_dataset.get(self.document_index[i], offset=offset, length=length)) + sample_parts.append(self.dataset.get(self.document_index[i], offset=offset, length=length)) sample = np.concatenate(sample_parts) From bd9d68906d7264f81a75ab1ddd9e2d162c115960 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Wed, 19 Nov 2025 08:50:51 -0800 Subject: [PATCH 12/22] minor fix Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index 5162141452..8ad56d2721 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -168,7 +168,6 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, def _fim_permute_sequence(self, sequence, rate): return self._permute( sequence, - self.np_rng, rate, self.fim_spm_rate, self.config.tokenizer, From 4d9ee92bb776d99faa2d378bec4acc22e6d29139 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Wed, 19 Nov 2025 11:28:24 -0800 Subject: [PATCH 13/22] update unit tests Signed-off-by: dimapihtar --- tests/unit_tests/data/test_fim_dataset.py | 51 ++++++++++++++++++----- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py index f8ee89f403..14cfe84e3f 100644 --- a/tests/unit_tests/data/test_fim_dataset.py +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import pytest import torch from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder @@ -9,7 +10,9 @@ from tests.unit_tests.test_utilities import Utils -def test_fim_gpt_dataset(): +@pytest.mark.parametrize("spm_rate", [0.0, 1.0]) +@pytest.mark.parametrize("split_sample", [None, "python"]) +def test_fim_gpt_dataset(spm_rate, split_sample): if torch.distributed.is_available(): Utils.initialize_distributed() if torch.distributed.get_rank() == 0: @@ -19,14 +22,22 @@ def test_fim_gpt_dataset(): compile_helpers() tokenizer = MegatronTokenizer.from_pretrained( - metadata_path={"library": "null"}, vocab_size=131072 + tokenizer_path="/opt/data/tokenizers/huggingface", + metadata_path={"library": "huggingface"}, + additional_special_tokens=["", "", "", "", ""], + include_special_tokens=True, ) - blend = get_blend_from_list(["/opt/data/datasets/train/test_text_document"]) - extra_tokens = {"prefix": "777", "middle": "888", "suffix": "999", "pad": "666", "eod": "000"} - seq_length = 8 - rate = 0.2 - spm_rate = 0.2 - fragment_rate = 0.5 + blend = get_blend_from_list(["/home/data/fim/fim_text_document"]) + extra_tokens = { + "prefix": "", + "middle": "", + "suffix": "", + "pad": "", + "eod": "", + } + seq_length = 32 + rate = 1.0 + fragment_rate = 1.0 config = GPTFIMDatasetConfig( blend=blend, random_seed=1234, @@ -40,18 +51,36 @@ def test_fim_gpt_dataset(): rate=rate, spm_rate=spm_rate, fragment_rate=fragment_rate, - no_prefix="111214", + split_sample=split_sample, ) datasets = BlendedMegatronDatasetBuilder( GPTFIMDataset, [10, 10, 10], lambda: True, config ).build() + prefix_id = tokenizer.tokenize("")[1] + suffix_id = tokenizer.tokenize("")[1] + middle_id = tokenizer.tokenize("")[1] + dataset = datasets[0] assert dataset.fim_rate == rate assert dataset.fim_spm_rate == spm_rate - assert dataset.fragment_fim_rate == 0.5 - assert dataset[0]["tokens"].tolist() == [343, 54365900, 77, 131072, 111214, 343, 54365900, 77] + assert dataset.fragment_fim_rate == fragment_rate + + tokens = dataset[0]["tokens"].tolist() + if split_sample: + split_sample_id = tokenizer.tokenize(split_sample)[1] + split_sample_index = tokens.index(split_sample_id) + assert prefix_id == tokens[split_sample_index + 1] + if spm_rate == 0.0: + assert prefix_id == tokens[0] + assert suffix_id in tokens + assert middle_id in tokens + assert tokens.index(suffix_id) < tokens.index(middle_id) + else: + assert prefix_id == tokens[0] + assert suffix_id == tokens[1] + assert middle_id in tokens if __name__ == "__main__": From 7dba91638a0327a62b5bd74fb77a56ec8ea9afb9 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 04:08:33 -0800 Subject: [PATCH 14/22] add assertion Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 40 +++++++++++++++++------ tests/unit_tests/data/test_fim_dataset.py | 2 +- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index 8ad56d2721..bf7b72f080 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -60,7 +60,9 @@ def __init__( index_split: Split, config: GPTFIMDatasetConfig, ) -> None: - super().__init__(indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config) + super().__init__( + indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config + ) self.np_rng = np.random.RandomState(seed=self.config.random_seed) logger.info(f"Initialized FIM RNG with seed = {self.config.random_seed}") @@ -69,12 +71,27 @@ def __init__( self.fim_spm_rate = self.config.spm_rate self.fragment_fim_rate = self.config.fragment_rate split_sample = self.config.split_sample - self.fim_split_sample = self.config.tokenizer._tokenizer.tokens_to_ids(split_sample) if split_sample else None self.no_fim_prefix = self.config.no_prefix + if split_sample: + fim_split_sample_ids = self.config.tokenizer._tokenizer.tokens_to_ids(split_sample) + assert isinstance(fim_split_sample_ids, int) or len(fim_split_sample_ids) == 1 + self.fim_split_sample = ( + fim_split_sample_ids + if isinstance(fim_split_sample_ids, int) + else fim_split_sample_ids[0] + ) + else: + self.fim_split_sample = None # get extra tokens ids fim_tokens = self.config.extra_tokens - fim_tokens = [fim_tokens["prefix"], fim_tokens["middle"], fim_tokens["suffix"], fim_tokens["pad"], fim_tokens["eod"]] + fim_tokens = [ + fim_tokens["prefix"], + fim_tokens["middle"], + fim_tokens["suffix"], + fim_tokens["pad"], + fim_tokens["eod"], + ] fim_tokens_ids = self.config.tokenizer._tokenizer.tokens_to_ids(fim_tokens) ( self.prefix_tok_id, @@ -126,7 +143,9 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, # Add the sample part offset = 0 if i > doc_index_beg else doc_index_beg_offset length = None if i < doc_index_end else doc_index_end_offset + 1 - sample_parts.append(self.dataset.get(self.document_index[i], offset=offset, length=length)) + sample_parts.append( + self.dataset.get(self.document_index[i], offset=offset, length=length) + ) sample = np.concatenate(sample_parts) @@ -160,10 +179,7 @@ def _query_document_sample_shuffle_indices(self, idx: int) -> Tuple[np.ndarray, assert sample.shape[0] == sample_len - return ( - np.array(sample, dtype=np.int64), - np.array(document_ids, dtype=np.int64), - ) + return (np.array(sample, dtype=np.int64), np.array(document_ids, dtype=np.int64)) def _fim_permute_sequence(self, sequence, rate): return self._permute( @@ -207,7 +223,9 @@ def _fim_split_and_permute_sequence(self, sequence): new_samples += [permuted, [self.fim_split_sample]] curr_start_position = loc + 1 # Jump over the split token # Permute the segment after the last split token - permuted = self._fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate) + permuted = self._fim_permute_sequence( + sequence[curr_start_position:], self.fragment_fim_rate + ) new_samples.append(permuted) return np.concatenate(new_samples) @@ -274,7 +292,9 @@ def _permute( if self.np_rng.binomial(1, fim_spm_rate): # SPM (variant 2 from FIM paper) - new_sample = np.concatenate([[prefix_tok_id, suffix_tok_id], suffix, [middle_tok_id], prefix, middle]) + new_sample = np.concatenate( + [[prefix_tok_id, suffix_tok_id], suffix, [middle_tok_id], prefix, middle] + ) else: # PSM new_sample = np.concatenate( diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py index 14cfe84e3f..8939a5cd3e 100644 --- a/tests/unit_tests/data/test_fim_dataset.py +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -27,7 +27,7 @@ def test_fim_gpt_dataset(spm_rate, split_sample): additional_special_tokens=["", "", "", "", ""], include_special_tokens=True, ) - blend = get_blend_from_list(["/home/data/fim/fim_text_document"]) + blend = get_blend_from_list(["/opt/data/datasets/fim/fim_text_document"]) extra_tokens = { "prefix": "", "middle": "", From ae4e01e9142c1d195214bb15251ef41574c80e7a Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 04:12:33 -0800 Subject: [PATCH 15/22] change fim rate Signed-off-by: dimapihtar --- .../model_config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml index b0b0807085..ddc8286573 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/model_config.yaml @@ -51,6 +51,6 @@ MODEL_ARGS: --attention-backend: unfused --log-memory-to-tensorboard: true --fim-data: true - --fim-rate: 0.2 - --fim-spm-rate: 0.2 + --fim-rate: 0.5 + --fim-spm-rate: 0.5 TEST_TYPE: regular From ec47085fb96fdc47865d70168b5d1a7f209a4f12 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 06:24:21 -0800 Subject: [PATCH 16/22] add golden values Signed-off-by: dimapihtar --- .../golden_values_dev_dgx_h100.json | 287 ++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/golden_values_dev_dgx_h100.json diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/golden_values_dev_dgx_h100.json new file mode 100644 index 0000000000..cd90888e65 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp1_pp1_dist_optimizer_fim_dataset/golden_values_dev_dgx_h100.json @@ -0,0 +1,287 @@ +{ + "lm loss": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 10.89074, + "2": 10.89234, + "3": 10.89032, + "4": 10.89221, + "5": 10.89416, + "6": 10.90226, + "7": 10.8884, + "8": 10.90211, + "9": 10.90202, + "10": 10.88512, + "11": 10.87636, + "12": 10.89499, + "13": 10.89837, + "14": 10.89182, + "15": 10.85125, + "16": 10.8534, + "17": 10.82862, + "18": 10.83653, + "19": 10.82847, + "20": 10.74583, + "21": 10.73117, + "22": 10.61256, + "23": 10.72616, + "24": 10.62932, + "25": 10.59394, + "26": 10.63357, + "27": 10.63137, + "28": 10.58201, + "29": 10.58671, + "30": 10.40936, + "31": 10.15873, + "32": 10.48319, + "33": 10.46977, + "34": 10.23978, + "35": 10.28144, + "36": 10.23894, + "37": 10.35198, + "38": 10.20565, + "39": 10.40496, + "40": 10.09271, + "41": 10.16148, + "42": 10.2231, + "43": 9.84152, + "44": 9.97329, + "45": 9.84544, + "46": 9.82102, + "47": 10.14261, + "48": 9.86553, + "49": 9.54033, + "50": 9.9169 + } + }, + "num-zeros": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 1544.0, + "2": 1729.0, + "3": 1672.0, + "4": 1807.0, + "5": 1942.0, + "6": 1736.0, + "7": 1956.0, + "8": 1716.0, + "9": 2011.0, + "10": 1385.0, + "11": 1864.0, + "12": 1767.0, + "13": 2019.0, + "14": 1787.0, + "15": 1828.0, + "16": 1908.0, + "17": 1718.0, + "18": 1602.0, + "19": 1785.0, + "20": 1679.0, + "21": 1917.0, + "22": 1712.0, + "23": 2034.0, + "24": 1752.0, + "25": 1645.0, + "26": 1820.0, + "27": 1915.0, + "28": 1996.0, + "29": 2051.0, + "30": 1890.0, + "31": 1577.0, + "32": 1886.0, + "33": 2116.0, + "34": 1912.0, + "35": 2037.0, + "36": 1924.0, + "37": 2462.0, + "38": 2241.0, + "39": 2321.0, + "40": 2221.0, + "41": 2345.0, + "42": 2386.0, + "43": 2027.0, + "44": 2211.0, + "45": 2096.0, + "46": 2285.0, + "47": 2536.0, + "48": 2289.0, + "49": 2270.0, + "50": 2421.0 + } + }, + "mem-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 581489664.0, + "2": 581489664.0, + "3": 581489664.0, + "4": 581489664.0, + "5": 581489664.0, + "6": 581489664.0, + "7": 581489664.0, + "8": 581489664.0, + "9": 581489664.0, + "10": 581489664.0, + "11": 581489664.0, + "12": 581489664.0, + "13": 581489664.0, + "14": 581489664.0, + "15": 581489664.0, + "16": 581489664.0, + "17": 581489664.0, + "18": 581489664.0, + "19": 581489664.0, + "20": 581489664.0, + "21": 581489664.0, + "22": 581489664.0, + "23": 581489664.0, + "24": 581489664.0, + "25": 581489664.0, + "26": 581489664.0, + "27": 581489664.0, + "28": 581489664.0, + "29": 581489664.0, + "30": 581489664.0, + "31": 581489664.0, + "32": 581489664.0, + "33": 581489664.0, + "34": 581489664.0, + "35": 581489664.0, + "36": 581489664.0, + "37": 581489664.0, + "38": 581489664.0, + "39": 581489664.0, + "40": 581489664.0, + "41": 581489664.0, + "42": 581489664.0, + "43": 581489664.0, + "44": 581489664.0, + "45": 581489664.0, + "46": 581489664.0, + "47": 581489664.0, + "48": 581489664.0, + "49": 581489664.0, + "50": 581489664.0 + } + }, + "mem-max-allocated-bytes": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 4605814272.0, + "2": 4702430720.0, + "3": 4702430720.0, + "4": 4702430720.0, + "5": 4702430720.0, + "6": 4702430720.0, + "7": 4702430720.0, + "8": 4702430720.0, + "9": 4702430720.0, + "10": 4702430720.0, + "11": 4702430720.0, + "12": 4702430720.0, + "13": 4702430720.0, + "14": 4702430720.0, + "15": 4702430720.0, + "16": 4702430720.0, + "17": 4702430720.0, + "18": 4702430720.0, + "19": 4702430720.0, + "20": 4702430720.0, + "21": 4702430720.0, + "22": 4702430720.0, + "23": 4702430720.0, + "24": 4702430720.0, + "25": 4702430720.0, + "26": 4702430720.0, + "27": 4702430720.0, + "28": 4702430720.0, + "29": 4702430720.0, + "30": 4702430720.0, + "31": 4702430720.0, + "32": 4702430720.0, + "33": 4702430720.0, + "34": 4702430720.0, + "35": 4702430720.0, + "36": 4702430720.0, + "37": 4702430720.0, + "38": 4702430720.0, + "39": 4702430720.0, + "40": 4702430720.0, + "41": 4702430720.0, + "42": 4702430720.0, + "43": 4702430720.0, + "44": 4702430720.0, + "45": 4702430720.0, + "46": 4702430720.0, + "47": 4702430720.0, + "48": 4702430720.0, + "49": 4702430720.0, + "50": 4702430720.0 + } + }, + "iteration-time": { + "start_step": 1, + "end_step": 50, + "step_interval": 1, + "values": { + "1": 6.95394, + "2": 0.0878, + "3": 0.06953, + "4": 0.07916, + "5": 0.06775, + "6": 0.07681, + "7": 0.06695, + "8": 0.0786, + "9": 0.0664, + "10": 0.08059, + "11": 0.06554, + "12": 0.07501, + "13": 0.06663, + "14": 0.06608, + "15": 0.06585, + "16": 0.06738, + "17": 0.067, + "18": 0.06553, + "19": 0.06755, + "20": 0.06723, + "21": 0.06559, + "22": 0.0664, + "23": 0.06722, + "24": 0.06553, + "25": 0.06829, + "26": 0.06873, + "27": 0.06733, + "28": 0.06731, + "29": 0.06824, + "30": 0.06696, + "31": 0.06661, + "32": 0.06587, + "33": 0.06588, + "34": 0.06564, + "35": 0.06761, + "36": 0.06655, + "37": 0.06712, + "38": 0.06601, + "39": 0.06661, + "40": 0.06632, + "41": 0.0691, + "42": 0.06551, + "43": 0.06839, + "44": 0.06528, + "45": 0.06744, + "46": 0.0675, + "47": 0.06698, + "48": 0.0649, + "49": 0.06596, + "50": 0.06581 + } + } +} \ No newline at end of file From 70dda917003bcb475d5e5e4af48e1129dada6633 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 07:14:03 -0800 Subject: [PATCH 17/22] fix typo Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index bf7b72f080..f8b130eb49 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -285,7 +285,7 @@ def _permute( if ( suffix.shape[0] <= diff ): # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening - return sample, self.np_rng + return sample suffix = suffix[: suffix.shape[0] - diff] elif diff < 0: # too short suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) From 0bdec43deb276ec5d71b8f5ba8aa3faa3e8fbb81 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 14:10:27 -0800 Subject: [PATCH 18/22] fix param description Signed-off-by: dimapihtar --- megatron/training/datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/datasets/README.md b/megatron/training/datasets/README.md index 7538d6d2be..55ea34f29b 100644 --- a/megatron/training/datasets/README.md +++ b/megatron/training/datasets/README.md @@ -10,7 +10,7 @@ It probabilistically converts samples into FIM format using configurable rates, **Attributes** -- `rate`: Probability of converting a sample into a FIM example. +- `rate`: Probability of converting a sample into a FIM example. A value of `1.0` means FIM is always applied. a value of `0.0` means FIM is never applied. - `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). - `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}. - `split_sample`: Optional token around which samples are split before applying FIM. From 352593871fe96078285acf0828b5669af2ebba42 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 14:12:13 -0800 Subject: [PATCH 19/22] fix param description Signed-off-by: dimapihtar --- megatron/training/datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/datasets/README.md b/megatron/training/datasets/README.md index 55ea34f29b..523e30e7b9 100644 --- a/megatron/training/datasets/README.md +++ b/megatron/training/datasets/README.md @@ -11,7 +11,7 @@ It probabilistically converts samples into FIM format using configurable rates, **Attributes** - `rate`: Probability of converting a sample into a FIM example. A value of `1.0` means FIM is always applied. a value of `0.0` means FIM is never applied. -- `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). +- `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). The remaining probability (`1 - spm_rate`) selects the PSM (prefix-suffix-middle) pattern instead. For example, if `spm_rate = 0.3`: 30% SPM, 70% PSM. - `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}. - `split_sample`: Optional token around which samples are split before applying FIM. - `fragment_rate`: Probability of applying FIM to each fragment when split_sample is used. From d01590d6f6132075a702c6d22f15df2c42afda26 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 14:16:33 -0800 Subject: [PATCH 20/22] fix param description Signed-off-by: dimapihtar --- megatron/training/datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/datasets/README.md b/megatron/training/datasets/README.md index 523e30e7b9..d5543c3d1b 100644 --- a/megatron/training/datasets/README.md +++ b/megatron/training/datasets/README.md @@ -13,7 +13,7 @@ It probabilistically converts samples into FIM format using configurable rates, - `rate`: Probability of converting a sample into a FIM example. A value of `1.0` means FIM is always applied. a value of `0.0` means FIM is never applied. - `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). The remaining probability (`1 - spm_rate`) selects the PSM (prefix-suffix-middle) pattern instead. For example, if `spm_rate = 0.3`: 30% SPM, 70% PSM. - `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}. -- `split_sample`: Optional token around which samples are split before applying FIM. +- `split_sample`: Optional token around which samples are split before applying FIM. If provided, the input sequence is divided at every occurrence of this token, and FIM is applied independently to each fragment. `A B C D E F G H` -> `FIM(Fragment 1) FIM(Fragment 2) FIM(Fragment 3)`. - `fragment_rate`: Probability of applying FIM to each fragment when split_sample is used. - `no_prefix`: If the decoded sequence starts with this prefix, FIM is skipped. `GPTFIMDataset` dataset class that loads token sequences from an `IndexedDataset` and applies FIM transformations before returning each sample. From c2578e4d4656839f84067bb0610eb35b42058929 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 14:25:43 -0800 Subject: [PATCH 21/22] fix params names Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 26 +++++++++++------------ pretrain_gpt.py | 12 +++++------ tests/unit_tests/data/test_fim_dataset.py | 10 ++++----- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index f8b130eb49..2d686d47ef 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -16,22 +16,22 @@ class GPTFIMDatasetConfig(GPTDatasetConfig): """Configuration object for Megatron Core GPT FIM datasets""" - rate: float = None + fim_rate: float = None """Probability to convert a training sample into a FIM format""" - spm_rate: float = None + fim_spm_rate: float = None """Probability that the a FIM sample uses the SPM format over the PSM format""" - extra_tokens: Dict = None + fim_extra_tokens: Dict = None """FIM extra tokens. Should consist of prefix, middle, suffix, PAD, and EOD tokens.""" - split_sample: Optional[str] = None + fim_split_sample: Optional[str] = None """String around which to split the sample for FIM""" - fragment_rate: Optional[float] = None + fim_fragment_rate: Optional[float] = None """Rate of FIM on each fragment when split_sample is not None""" - no_prefix: Optional[str] = None + fim_no_prefix: Optional[str] = None """Do not apply FIM to fragments that start with this prefix""" @@ -67,13 +67,13 @@ def __init__( self.np_rng = np.random.RandomState(seed=self.config.random_seed) logger.info(f"Initialized FIM RNG with seed = {self.config.random_seed}") # get FIM params - self.fim_rate = self.config.rate - self.fim_spm_rate = self.config.spm_rate - self.fragment_fim_rate = self.config.fragment_rate - split_sample = self.config.split_sample - self.no_fim_prefix = self.config.no_prefix - if split_sample: - fim_split_sample_ids = self.config.tokenizer._tokenizer.tokens_to_ids(split_sample) + self.fim_rate = self.config.fim_rate + self.fim_spm_rate = self.config.fim_spm_rate + self.fragment_fim_rate = self.config.fim_fragment_rate + fim_split_sample = self.config.fim_split_sample + self.no_fim_prefix = self.config.fim_no_prefix + if fim_split_sample: + fim_split_sample_ids = self.config.tokenizer._tokenizer.tokens_to_ids(fim_split_sample) assert isinstance(fim_split_sample_ids, int) or len(fim_split_sample_ids) == 1 self.fim_split_sample = ( fim_split_sample_ids diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 2b1a72c042..6b602d3324 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -205,12 +205,12 @@ def core_gpt_dataset_config_from_args(args): } data_args.update( { - "rate": args.fim_rate, - "spm_rate": args.fim_spm_rate, - "extra_tokens": extra_tokens, - "split_sample": args.fim_split_sample, - "fragment_rate": args.fim_fragment_rate, - "no_prefix": args.fim_no_prefix, + "fim_rate": args.fim_rate, + "fim_spm_rate": args.fim_spm_rate, + "fim_extra_tokens": extra_tokens, + "fim_split_sample": args.fim_split_sample, + "fim_fragment_rate": args.fim_fragment_rate, + "fim_no_prefix": args.fim_no_prefix, } ) return GPTFIMDatasetConfig(**data_args) diff --git a/tests/unit_tests/data/test_fim_dataset.py b/tests/unit_tests/data/test_fim_dataset.py index 8939a5cd3e..7022a4b5fa 100644 --- a/tests/unit_tests/data/test_fim_dataset.py +++ b/tests/unit_tests/data/test_fim_dataset.py @@ -47,11 +47,11 @@ def test_fim_gpt_dataset(spm_rate, split_sample): reset_position_ids=True, reset_attention_mask=True, eod_mask_loss=True, - extra_tokens=extra_tokens, - rate=rate, - spm_rate=spm_rate, - fragment_rate=fragment_rate, - split_sample=split_sample, + fim_extra_tokens=extra_tokens, + fim_rate=rate, + fim_spm_rate=spm_rate, + fim_fragment_rate=fragment_rate, + fim_split_sample=split_sample, ) datasets = BlendedMegatronDatasetBuilder( From 2f4e2fa5f467a3e17670d296cf11b7fe126e8e53 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Thu, 20 Nov 2025 15:29:45 -0800 Subject: [PATCH 22/22] fix typo Signed-off-by: dimapihtar --- megatron/training/datasets/fim_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training/datasets/fim_dataset.py b/megatron/training/datasets/fim_dataset.py index 2d686d47ef..730b7e033a 100644 --- a/megatron/training/datasets/fim_dataset.py +++ b/megatron/training/datasets/fim_dataset.py @@ -84,7 +84,7 @@ def __init__( self.fim_split_sample = None # get extra tokens ids - fim_tokens = self.config.extra_tokens + fim_tokens = self.config.fim_extra_tokens fim_tokens = [ fim_tokens["prefix"], fim_tokens["middle"],