Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/tokenizers/text/libraries/null_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def ids_to_text(self, ids):
text = [str(x) for x in ids]
return ' '.join(text)

def tokens_to_ids(self, tokens):
"""Converts tokens to ids."""
return [int(x) for x in tokens]

def ids_to_tokens(self, ids):
"""Converts ids to tokens."""
return [str(x) for x in ids]

def offsets(self, ids: list[int], text: str) -> list[int]:
"""Returns offsets."""
offsets, start_idx = [], 0
Expand Down
35 changes: 35 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,20 @@ def validate_args(args, defaults={}):
any([args.train_data_path, args.valid_data_path, args.test_data_path]) \
<= 1, "A single data source must be provided in training mode, else None"

if args.fim_data:
extra_tokens = [
args.fim_prefix_token,
args.fim_middle_token,
args.fim_suffix_token,
args.fim_pad_token,
args.fim_eod_token,
]
assert not args.mock_data, "Mock dataset is not supported with FIM dataset."
assert not args.legacy_tokenizer, "FIM dataset is not supported with legacy tokenizers."
assert args.fim_rate, "--fim-rate should be specified."
assert args.fim_spm_rate, "--fim-spm-rate should be specified."
assert all(token is not None for token in extra_tokens), "FIM extra tokens should be specified."

# Deterministic mode
if args.deterministic_mode:
assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode."
Expand Down Expand Up @@ -2915,6 +2929,27 @@ def _add_data_args(parser):
'If instead this argument is set, the training flow will treat all tokens '
'that share the same id as the pad token as true pad tokens, potentially '
'causing severe training instability.')
group.add_argument('--fim-data', action='store_true', help='Whether to use the FIM dataset.')
group.add_argument('--fim-rate', type=float, default=0.5,
help='Probability to convert a training sample into a FIM format.')
group.add_argument('--fim-spm-rate', type=float, default=0.5,
help='Probability that the a FIM sample uses the SPM format over the PSM format.')
group.add_argument('--fim-split-sample', type=str, default=None,
help='String around which to split the sample for FIM.')
group.add_argument('--fim-fragment-rate', type=float, default=None,
help='Rate of FIM on each fragment when --fim-split-sample is not None.')
group.add_argument('--fim-no-prefix', type=str, default=None,
help='Do not apply FIM to fragments that start with this prefix')
group.add_argument('--fim-prefix-token', type=str, default='<fim_prefix>',
help='FIM prefix token')
group.add_argument('--fim-middle-token', type=str, default='<fim_middle>',
help='FIM middle token')
group.add_argument('--fim-suffix-token', type=str, default='<fim_suffix>',
help='FIM suffix token')
group.add_argument('--fim-pad-token', type=str, default='<fim_pad>',
help='FIM PAD token')
group.add_argument('--fim-eod-token', type=str, default='<|endoftext|>',
help='FIM EOD token')
return parser


Expand Down
34 changes: 34 additions & 0 deletions megatron/training/datasets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Data Pipeline

## FIM dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not necessarily belong to this PR, but we should at least note in the readme that in order to use FIM training, your pretrain dataset needs to be preprocessed with the special tokens. We might need to add support for data preprocessing script as well ( could be separate PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, let's add it later i na separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah something like:

fim_tokens["prefix"], fim_tokens["middle"], fim_tokens["suffix"], fim_tokens["pad"], fim_tokens["eod"]

These tokens must exist in the tokenizer vocab.

The dataset must be pre-tokenized with those tokens.


`GPTFIMDataset` extends Megatron-Core’s `GPTDataset` to support **Fill-in-the-Middle (FIM)** data augmentation.
It probabilistically converts samples into FIM format using configurable rates, with support for both PSM and SPM patterns, fragment-level splitting, and length-preserving output.

`GPTFIMDatasetConfig` provides the configuration needed to enable this behavior.
`GPTFIMDatasetConfig` configuration object extending `GPTDatasetConfig` to enable FIM preprocessing.

**Attributes**

- `rate`: Probability of converting a sample into a FIM example. A value of `1.0` means FIM is always applied. a value of `0.0` means FIM is never applied.
- `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). The remaining probability (`1 - spm_rate`) selects the PSM (prefix-suffix-middle) pattern instead. For example, if `spm_rate = 0.3`: 30% SPM, 70% PSM.
- `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}.
- `split_sample`: Optional token around which samples are split before applying FIM. If provided, the input sequence is divided at every occurrence of this token, and FIM is applied independently to each fragment. `A B C <SPLI_SAMPLE> D E F <SPLIT_SAMPLE> G H` -> `FIM(Fragment 1) <SPLI_SAMPLE> FIM(Fragment 2) <SPLI_SAMPLE> FIM(Fragment 3)`.
- `fragment_rate`: Probability of applying FIM to each fragment when split_sample is used.
- `no_prefix`: If the decoded sequence starts with this prefix, FIM is skipped.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fim_skip_prefix is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd leave it as it is since it's implemntation from NeMo1

`GPTFIMDataset` dataset class that loads token sequences from an `IndexedDataset` and applies FIM transformations before returning each sample.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how will seq length change in this case? in many cases IndexedDataset has constant seq length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq length will not be changed


**PSM Format**
```
[prefix_tok] prefix [suffix_tok] suffix [middle_tok] middle
```

**SPM Format**
```
[prefix_tok, suffix_tok] suffix [middle_tok] prefix middle
```

**Special cases:**

- If the sequence starts with no_prefix, FIM is skipped.
- If FIM is not applied, the sample is returned unchanged.
Loading
Loading