Skip to content

Commit 7994405

Browse files
authored
add FIM dataset support (#2291)
Signed-off-by: dimapihtar <[email protected]>
1 parent 29a810e commit 7994405

File tree

9 files changed

+866
-20
lines changed

9 files changed

+866
-20
lines changed

megatron/core/tokenizers/text/libraries/null_tokenizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def ids_to_text(self, ids):
2525
text = [str(x) for x in ids]
2626
return ' '.join(text)
2727

28+
def tokens_to_ids(self, tokens):
29+
"""Converts tokens to ids."""
30+
return [int(x) for x in tokens]
31+
32+
def ids_to_tokens(self, ids):
33+
"""Converts ids to tokens."""
34+
return [str(x) for x in ids]
35+
2836
def offsets(self, ids: list[int], text: str) -> list[int]:
2937
"""Returns offsets."""
3038
offsets, start_idx = [], 0

megatron/training/arguments.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,20 @@ def validate_args(args, defaults={}):
10671067
any([args.train_data_path, args.valid_data_path, args.test_data_path]) \
10681068
<= 1, "A single data source must be provided in training mode, else None"
10691069

1070+
if args.fim_data:
1071+
extra_tokens = [
1072+
args.fim_prefix_token,
1073+
args.fim_middle_token,
1074+
args.fim_suffix_token,
1075+
args.fim_pad_token,
1076+
args.fim_eod_token,
1077+
]
1078+
assert not args.mock_data, "Mock dataset is not supported with FIM dataset."
1079+
assert not args.legacy_tokenizer, "FIM dataset is not supported with legacy tokenizers."
1080+
assert args.fim_rate, "--fim-rate should be specified."
1081+
assert args.fim_spm_rate, "--fim-spm-rate should be specified."
1082+
assert all(token is not None for token in extra_tokens), "FIM extra tokens should be specified."
1083+
10701084
# Deterministic mode
10711085
if args.deterministic_mode:
10721086
assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode."
@@ -2915,6 +2929,27 @@ def _add_data_args(parser):
29152929
'If instead this argument is set, the training flow will treat all tokens '
29162930
'that share the same id as the pad token as true pad tokens, potentially '
29172931
'causing severe training instability.')
2932+
group.add_argument('--fim-data', action='store_true', help='Whether to use the FIM dataset.')
2933+
group.add_argument('--fim-rate', type=float, default=0.5,
2934+
help='Probability to convert a training sample into a FIM format.')
2935+
group.add_argument('--fim-spm-rate', type=float, default=0.5,
2936+
help='Probability that the a FIM sample uses the SPM format over the PSM format.')
2937+
group.add_argument('--fim-split-sample', type=str, default=None,
2938+
help='String around which to split the sample for FIM.')
2939+
group.add_argument('--fim-fragment-rate', type=float, default=None,
2940+
help='Rate of FIM on each fragment when --fim-split-sample is not None.')
2941+
group.add_argument('--fim-no-prefix', type=str, default=None,
2942+
help='Do not apply FIM to fragments that start with this prefix')
2943+
group.add_argument('--fim-prefix-token', type=str, default='<fim_prefix>',
2944+
help='FIM prefix token')
2945+
group.add_argument('--fim-middle-token', type=str, default='<fim_middle>',
2946+
help='FIM middle token')
2947+
group.add_argument('--fim-suffix-token', type=str, default='<fim_suffix>',
2948+
help='FIM suffix token')
2949+
group.add_argument('--fim-pad-token', type=str, default='<fim_pad>',
2950+
help='FIM PAD token')
2951+
group.add_argument('--fim-eod-token', type=str, default='<|endoftext|>',
2952+
help='FIM EOD token')
29182953
return parser
29192954

29202955

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Data Pipeline
2+
3+
## FIM dataset
4+
5+
`GPTFIMDataset` extends Megatron-Core’s `GPTDataset` to support **Fill-in-the-Middle (FIM)** data augmentation.
6+
It probabilistically converts samples into FIM format using configurable rates, with support for both PSM and SPM patterns, fragment-level splitting, and length-preserving output.
7+
8+
`GPTFIMDatasetConfig` provides the configuration needed to enable this behavior.
9+
`GPTFIMDatasetConfig` configuration object extending `GPTDatasetConfig` to enable FIM preprocessing.
10+
11+
**Attributes**
12+
13+
- `rate`: Probability of converting a sample into a FIM example. A value of `1.0` means FIM is always applied. a value of `0.0` means FIM is never applied.
14+
- `spm_rate`: Probability of using the SPM FIM pattern (vs PSM). The remaining probability (`1 - spm_rate`) selects the PSM (prefix-suffix-middle) pattern instead. For example, if `spm_rate = 0.3`: 30% SPM, 70% PSM.
15+
- `extra_tokens`: Dictionary containing the FIM special tokens: {"prefix", "middle", "suffix", "pad", "eod"}.
16+
- `split_sample`: Optional token around which samples are split before applying FIM. If provided, the input sequence is divided at every occurrence of this token, and FIM is applied independently to each fragment. `A B C <SPLI_SAMPLE> D E F <SPLIT_SAMPLE> G H` -> `FIM(Fragment 1) <SPLI_SAMPLE> FIM(Fragment 2) <SPLI_SAMPLE> FIM(Fragment 3)`.
17+
- `fragment_rate`: Probability of applying FIM to each fragment when split_sample is used.
18+
- `no_prefix`: If the decoded sequence starts with this prefix, FIM is skipped.
19+
`GPTFIMDataset` dataset class that loads token sequences from an `IndexedDataset` and applies FIM transformations before returning each sample.
20+
21+
**PSM Format**
22+
```
23+
[prefix_tok] prefix [suffix_tok] suffix [middle_tok] middle
24+
```
25+
26+
**SPM Format**
27+
```
28+
[prefix_tok, suffix_tok] suffix [middle_tok] prefix middle
29+
```
30+
31+
**Special cases:**
32+
33+
- If the sequence starts with no_prefix, FIM is skipped.
34+
- If FIM is not applied, the sample is returned unchanged.

0 commit comments

Comments
 (0)