11# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22
3+ import pytest
34import torch
45
56from megatron .core .datasets .blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
910from tests .unit_tests .test_utilities import Utils
1011
1112
12- def test_fim_gpt_dataset ():
13+ @pytest .mark .parametrize ("spm_rate" , [0.0 , 1.0 ])
14+ @pytest .mark .parametrize ("split_sample" , [None , "python" ])
15+ def test_fim_gpt_dataset (spm_rate , split_sample ):
1316 if torch .distributed .is_available ():
1417 Utils .initialize_distributed ()
1518 if torch .distributed .get_rank () == 0 :
@@ -19,14 +22,22 @@ def test_fim_gpt_dataset():
1922 compile_helpers ()
2023
2124 tokenizer = MegatronTokenizer .from_pretrained (
22- metadata_path = {"library" : "null" }, vocab_size = 131072
25+ tokenizer_path = "/opt/data/tokenizers/huggingface" ,
26+ metadata_path = {"library" : "huggingface" },
27+ additional_special_tokens = ["<prefix>" , "<middle>" , "<suffix>" , "<pad>" , "<eod>" ],
28+ include_special_tokens = True ,
2329 )
24- blend = get_blend_from_list (["/opt/data/datasets/train/test_text_document" ])
25- extra_tokens = {"prefix" : "777" , "middle" : "888" , "suffix" : "999" , "pad" : "666" , "eod" : "000" }
26- seq_length = 8
27- rate = 0.2
28- spm_rate = 0.2
29- fragment_rate = 0.5
30+ blend = get_blend_from_list (["/home/data/fim/fim_text_document" ])
31+ extra_tokens = {
32+ "prefix" : "<prefix>" ,
33+ "middle" : "<middle>" ,
34+ "suffix" : "<suffix>" ,
35+ "pad" : "<pad>" ,
36+ "eod" : "<eod>" ,
37+ }
38+ seq_length = 32
39+ rate = 1.0
40+ fragment_rate = 1.0
3041 config = GPTFIMDatasetConfig (
3142 blend = blend ,
3243 random_seed = 1234 ,
@@ -40,18 +51,36 @@ def test_fim_gpt_dataset():
4051 rate = rate ,
4152 spm_rate = spm_rate ,
4253 fragment_rate = fragment_rate ,
43- no_prefix = "111214" ,
54+ split_sample = split_sample ,
4455 )
4556
4657 datasets = BlendedMegatronDatasetBuilder (
4758 GPTFIMDataset , [10 , 10 , 10 ], lambda : True , config
4859 ).build ()
4960
61+ prefix_id = tokenizer .tokenize ("<prefix>" )[1 ]
62+ suffix_id = tokenizer .tokenize ("<suffix>" )[1 ]
63+ middle_id = tokenizer .tokenize ("<middle>" )[1 ]
64+
5065 dataset = datasets [0 ]
5166 assert dataset .fim_rate == rate
5267 assert dataset .fim_spm_rate == spm_rate
53- assert dataset .fragment_fim_rate == 0.5
54- assert dataset [0 ]["tokens" ].tolist () == [343 , 54365900 , 77 , 131072 , 111214 , 343 , 54365900 , 77 ]
68+ assert dataset .fragment_fim_rate == fragment_rate
69+
70+ tokens = dataset [0 ]["tokens" ].tolist ()
71+ if split_sample :
72+ split_sample_id = tokenizer .tokenize (split_sample )[1 ]
73+ split_sample_index = tokens .index (split_sample_id )
74+ assert prefix_id == tokens [split_sample_index + 1 ]
75+ if spm_rate == 0.0 :
76+ assert prefix_id == tokens [0 ]
77+ assert suffix_id in tokens
78+ assert middle_id in tokens
79+ assert tokens .index (suffix_id ) < tokens .index (middle_id )
80+ else :
81+ assert prefix_id == tokens [0 ]
82+ assert suffix_id == tokens [1 ]
83+ assert middle_id in tokens
5584
5685
5786if __name__ == "__main__" :
0 commit comments