1+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
13import os
24import random
35import sys
68from types import ModuleType , SimpleNamespace
79from typing import Any , Dict
810
9- import nltk
10- import pytest
11-
1211try :
1312 import boto3
1413 import botocore .exceptions as exceptions
1514except ModuleNotFoundError :
15+ # Create mock msc module
1616 boto3 = ModuleType ("boto3" )
17- sys .modules [boto3 .__name__ ] = boto3
17+
18+ # Create mock types submodule
1819 exceptions = ModuleType ("botocore.exceptions" )
20+
21+ # Register the mock module in sys.modules
22+ sys .modules [boto3 .__name__ ] = boto3
1923 sys .modules [exceptions .__name__ ] = exceptions
2024
2125try :
@@ -43,6 +47,8 @@ def __init__(self, offset: int, size: int):
4347 sys .modules [msc .__name__ ] = msc
4448 sys .modules [types_module .__name__ ] = types_module
4549
50+ import torch
51+
4652from megatron .core .datasets .indexed_dataset import (
4753 IndexedDataset ,
4854 ObjectStorageConfig ,
@@ -58,9 +64,11 @@ def __init__(self, offset: int, size: int):
5864 gpt2_merge ,
5965 gpt2_vocab ,
6066)
67+ from tests .unit_tests .test_utilities import Utils
68+
6169
6270##
63- # Overload client from boto3
71+ # Mock boto3
6472##
6573
6674
@@ -72,7 +80,8 @@ def __init__(self, *args: Any) -> None:
7280
7381 def download_file (self , Bucket : str , Key : str , Filename : str ) -> None :
7482 os .makedirs (os .path .dirname (Filename ), exist_ok = True )
75- os .system (f"cp { os .path .join ('/' , Bucket , Key )} { Filename } " )
83+ remote_path = os .path .join ("/" , Bucket , Key )
84+ os .system (f"cp { remote_path } { Filename } " )
7685 assert os .path .exists (Filename )
7786
7887 def upload_file (self , Filename : str , Bucket : str , Key : str ) -> None :
@@ -104,27 +113,28 @@ def close(self) -> None:
104113
105114
106115##
107- # Overload ClientError from botocore.exceptions
116+ # Mock botocore.exceptions
108117##
109118
110119
111120class _LocalClientError (Exception ):
112- """ " Local test client error"""
121+ """Local test client error"""
113122
114123 pass
115124
116125
117126setattr (exceptions , "ClientError" , _LocalClientError )
118127
119128##
120- # Mock multistorageclient module
129+ # Mock msc.open, msc.download_file, msc.resolve_storage_client
121130##
122131
123132
124133def _msc_download_file (remote_path , local_path ):
125- remote_path = remote_path .removeprefix (MSC_PREFIX + "default" )
134+ remote_path = os . path . join ( "/" , remote_path .removeprefix (MSC_PREFIX ) )
126135 os .makedirs (os .path .dirname (local_path ), exist_ok = True )
127136 os .system (f"cp { remote_path } { local_path } " )
137+ assert os .path .exists (local_path )
128138
129139
130140def _msc_resolve_storage_client (path ):
@@ -134,28 +144,29 @@ def read(self, path, byte_range):
134144 f .seek (byte_range .offset )
135145 return f .read (byte_range .size )
136146
137- return StorageClient (), path .removeprefix ( MSC_PREFIX + "default" )
147+ return StorageClient (), os . path .join ( "/" , path . removeprefix ( MSC_PREFIX ) )
138148
139149
140150setattr (msc , "open" , open )
141151setattr (msc , "download_file" , _msc_download_file )
142152setattr (msc , "resolve_storage_client" , _msc_resolve_storage_client )
143153
144154
145- @pytest .mark .flaky
146- @pytest .mark .flaky_in_dev
147155def test_bin_reader ():
148- with tempfile . TemporaryDirectory () as temp_dir :
149- # set the default nltk data path
150- os . environ [ "NLTK_DATA" ] = os . path . join ( temp_dir , "nltk_data" )
151- nltk . data . path . append ( os . environ [ "NLTK_DATA" ])
156+ if torch . distributed . is_available () :
157+ Utils . initialize_distributed ()
158+ if torch . distributed . get_rank () != 0 :
159+ return
152160
161+ with tempfile .TemporaryDirectory () as temp_dir :
153162 path_to_raws = os .path .join (temp_dir , "sample_raws" )
154163 path_to_data = os .path .join (temp_dir , "sample_data" )
155- path_to_object_storage_cache = os .path .join (temp_dir , "object_storage_cache" )
164+ path_to_object_storage_cache_msc = os .path .join (temp_dir , "object_storage_cache_msc" )
165+ path_to_object_storage_cache_s3 = os .path .join (temp_dir , "object_storage_cache_s3" )
156166 os .mkdir (path_to_raws )
157167 os .mkdir (path_to_data )
158- os .mkdir (path_to_object_storage_cache )
168+ os .mkdir (path_to_object_storage_cache_msc )
169+ os .mkdir (path_to_object_storage_cache_s3 )
159170
160171 # create the dummy resources
161172 dummy_jsonl (path_to_raws )
@@ -195,27 +206,26 @@ def test_bin_reader():
195206 assert isinstance (indexed_dataset_mmap .bin_reader , _MMapBinReader )
196207
197208 indexed_dataset_msc = IndexedDataset (
198- MSC_PREFIX + "default" + prefix , # use the default profile to access the filesystem
209+ MSC_PREFIX + prefix . lstrip ( "/" ),
199210 multimodal = False ,
200211 mmap = False ,
201212 object_storage_config = ObjectStorageConfig (
202- path_to_idx_cache = path_to_object_storage_cache
213+ path_to_idx_cache = path_to_object_storage_cache_msc
203214 ),
204215 )
205216 assert isinstance (indexed_dataset_msc .bin_reader , _MultiStorageClientBinReader )
206217 assert len (indexed_dataset_msc ) == len (indexed_dataset_file )
207218 assert len (indexed_dataset_msc ) == len (indexed_dataset_mmap )
208219
209220 indexed_dataset_s3 = IndexedDataset (
210- S3_PREFIX + prefix ,
221+ S3_PREFIX + prefix . lstrip ( "/" ) ,
211222 multimodal = False ,
212223 mmap = False ,
213224 object_storage_config = ObjectStorageConfig (
214- path_to_idx_cache = path_to_object_storage_cache
225+ path_to_idx_cache = path_to_object_storage_cache_s3
215226 ),
216227 )
217228 assert isinstance (indexed_dataset_s3 .bin_reader , _S3BinReader )
218-
219229 assert len (indexed_dataset_s3 ) == len (indexed_dataset_file )
220230 assert len (indexed_dataset_s3 ) == len (indexed_dataset_mmap )
221231
@@ -226,6 +236,7 @@ def test_bin_reader():
226236 for idx in indices :
227237 assert (indexed_dataset_s3 [idx ] == indexed_dataset_file [idx ]).all ()
228238 assert (indexed_dataset_s3 [idx ] == indexed_dataset_mmap [idx ]).all ()
239+ assert (indexed_dataset_s3 [idx ] == indexed_dataset_msc [idx ]).all ()
229240
230241
231242if __name__ == "__main__" :
0 commit comments