Skip to content

Commit 4d9ee92

Browse files
committed
update unit tests
Signed-off-by: dimapihtar <[email protected]>
1 parent bd9d689 commit 4d9ee92

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

tests/unit_tests/data/test_fim_dataset.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

3+
import pytest
34
import torch
45

56
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
@@ -9,7 +10,9 @@
910
from tests.unit_tests.test_utilities import Utils
1011

1112

12-
def test_fim_gpt_dataset():
13+
@pytest.mark.parametrize("spm_rate", [0.0, 1.0])
14+
@pytest.mark.parametrize("split_sample", [None, "python"])
15+
def test_fim_gpt_dataset(spm_rate, split_sample):
1316
if torch.distributed.is_available():
1417
Utils.initialize_distributed()
1518
if torch.distributed.get_rank() == 0:
@@ -19,14 +22,22 @@ def test_fim_gpt_dataset():
1922
compile_helpers()
2023

2124
tokenizer = MegatronTokenizer.from_pretrained(
22-
metadata_path={"library": "null"}, vocab_size=131072
25+
tokenizer_path="/opt/data/tokenizers/huggingface",
26+
metadata_path={"library": "huggingface"},
27+
additional_special_tokens=["<prefix>", "<middle>", "<suffix>", "<pad>", "<eod>"],
28+
include_special_tokens=True,
2329
)
24-
blend = get_blend_from_list(["/opt/data/datasets/train/test_text_document"])
25-
extra_tokens = {"prefix": "777", "middle": "888", "suffix": "999", "pad": "666", "eod": "000"}
26-
seq_length = 8
27-
rate = 0.2
28-
spm_rate = 0.2
29-
fragment_rate = 0.5
30+
blend = get_blend_from_list(["/home/data/fim/fim_text_document"])
31+
extra_tokens = {
32+
"prefix": "<prefix>",
33+
"middle": "<middle>",
34+
"suffix": "<suffix>",
35+
"pad": "<pad>",
36+
"eod": "<eod>",
37+
}
38+
seq_length = 32
39+
rate = 1.0
40+
fragment_rate = 1.0
3041
config = GPTFIMDatasetConfig(
3142
blend=blend,
3243
random_seed=1234,
@@ -40,18 +51,36 @@ def test_fim_gpt_dataset():
4051
rate=rate,
4152
spm_rate=spm_rate,
4253
fragment_rate=fragment_rate,
43-
no_prefix="111214",
54+
split_sample=split_sample,
4455
)
4556

4657
datasets = BlendedMegatronDatasetBuilder(
4758
GPTFIMDataset, [10, 10, 10], lambda: True, config
4859
).build()
4960

61+
prefix_id = tokenizer.tokenize("<prefix>")[1]
62+
suffix_id = tokenizer.tokenize("<suffix>")[1]
63+
middle_id = tokenizer.tokenize("<middle>")[1]
64+
5065
dataset = datasets[0]
5166
assert dataset.fim_rate == rate
5267
assert dataset.fim_spm_rate == spm_rate
53-
assert dataset.fragment_fim_rate == 0.5
54-
assert dataset[0]["tokens"].tolist() == [343, 54365900, 77, 131072, 111214, 343, 54365900, 77]
68+
assert dataset.fragment_fim_rate == fragment_rate
69+
70+
tokens = dataset[0]["tokens"].tolist()
71+
if split_sample:
72+
split_sample_id = tokenizer.tokenize(split_sample)[1]
73+
split_sample_index = tokens.index(split_sample_id)
74+
assert prefix_id == tokens[split_sample_index + 1]
75+
if spm_rate == 0.0:
76+
assert prefix_id == tokens[0]
77+
assert suffix_id in tokens
78+
assert middle_id in tokens
79+
assert tokens.index(suffix_id) < tokens.index(middle_id)
80+
else:
81+
assert prefix_id == tokens[0]
82+
assert suffix_id == tokens[1]
83+
assert middle_id in tokens
5584

5685

5786
if __name__ == "__main__":

0 commit comments

Comments
 (0)