66from types import ModuleType , SimpleNamespace
77from typing import Any , Dict
88
9- import nltk
10- import pytest
11-
129try :
1310 import boto3
1411 import botocore .exceptions as exceptions
1512except ModuleNotFoundError :
13+ # Create mock msc module
1614 boto3 = ModuleType ("boto3" )
17- sys .modules [boto3 .__name__ ] = boto3
15+
16+ # Create mock types submodule
1817 exceptions = ModuleType ("botocore.exceptions" )
18+
19+ # Register the mock module in sys.modules
20+ sys .modules [boto3 .__name__ ] = boto3
1921 sys .modules [exceptions .__name__ ] = exceptions
2022
2123try :
@@ -43,6 +45,8 @@ def __init__(self, offset: int, size: int):
4345 sys .modules [msc .__name__ ] = msc
4446 sys .modules [types_module .__name__ ] = types_module
4547
48+ import torch
49+
4650from megatron .core .datasets .indexed_dataset import (
4751 IndexedDataset ,
4852 ObjectStorageConfig ,
@@ -58,9 +62,11 @@ def __init__(self, offset: int, size: int):
5862 gpt2_merge ,
5963 gpt2_vocab ,
6064)
65+ from tests .unit_tests .test_utilities import Utils
66+
6167
6268##
63- # Overload client from boto3
69+ # Mock boto3
6470##
6571
6672
@@ -72,7 +78,8 @@ def __init__(self, *args: Any) -> None:
7278
7379 def download_file (self , Bucket : str , Key : str , Filename : str ) -> None :
7480 os .makedirs (os .path .dirname (Filename ), exist_ok = True )
75- os .system (f"cp { os .path .join ('/' , Bucket , Key )} { Filename } " )
81+ remote_path = os .path .join ("/" , Bucket , Key )
82+ os .system (f"cp { remote_path } { Filename } " )
7683 assert os .path .exists (Filename )
7784
7885 def upload_file (self , Filename : str , Bucket : str , Key : str ) -> None :
@@ -104,27 +111,28 @@ def close(self) -> None:
104111
105112
106113##
107- # Overload ClientError from botocore.exceptions
114+ # Mock botocore.exceptions
108115##
109116
110117
111118class _LocalClientError (Exception ):
112- """ " Local test client error"""
119+ """Local test client error"""
113120
114121 pass
115122
116123
117124setattr (exceptions , "ClientError" , _LocalClientError )
118125
119126##
120- # Mock multistorageclient module
127+ # Mock msc.open, msc.download_file, msc.resolve_storage_client
121128##
122129
123130
124131def _msc_download_file (remote_path , local_path ):
125- remote_path = remote_path .removeprefix (MSC_PREFIX + "default" )
132+ remote_path = os . path . join ( "/" , remote_path .removeprefix (MSC_PREFIX ) )
126133 os .makedirs (os .path .dirname (local_path ), exist_ok = True )
127134 os .system (f"cp { remote_path } { local_path } " )
135+ assert os .path .exists (local_path )
128136
129137
130138def _msc_resolve_storage_client (path ):
@@ -134,28 +142,29 @@ def read(self, path, byte_range):
134142 f .seek (byte_range .offset )
135143 return f .read (byte_range .size )
136144
137- return StorageClient (), path .removeprefix ( MSC_PREFIX + "default" )
145+ return StorageClient (), os . path .join ( "/" , path . removeprefix ( MSC_PREFIX ) )
138146
139147
140148setattr (msc , "open" , open )
141149setattr (msc , "download_file" , _msc_download_file )
142150setattr (msc , "resolve_storage_client" , _msc_resolve_storage_client )
143151
144152
145- @pytest .mark .flaky
146- @pytest .mark .flaky_in_dev
147153def test_bin_reader ():
148- with tempfile . TemporaryDirectory () as temp_dir :
149- # set the default nltk data path
150- os . environ [ "NLTK_DATA" ] = os . path . join ( temp_dir , "nltk_data" )
151- nltk . data . path . append ( os . environ [ "NLTK_DATA" ])
154+ if torch . distributed . is_available () :
155+ Utils . initialize_distributed ()
156+ if torch . distributed . get_rank () != 0 :
157+ return
152158
159+ with tempfile .TemporaryDirectory () as temp_dir :
153160 path_to_raws = os .path .join (temp_dir , "sample_raws" )
154161 path_to_data = os .path .join (temp_dir , "sample_data" )
155- path_to_object_storage_cache = os .path .join (temp_dir , "object_storage_cache" )
162+ path_to_object_storage_cache_msc = os .path .join (temp_dir , "object_storage_cache_msc" )
163+ path_to_object_storage_cache_s3 = os .path .join (temp_dir , "object_storage_cache_s3" )
156164 os .mkdir (path_to_raws )
157165 os .mkdir (path_to_data )
158- os .mkdir (path_to_object_storage_cache )
166+ os .mkdir (path_to_object_storage_cache_msc )
167+ os .mkdir (path_to_object_storage_cache_s3 )
159168
160169 # create the dummy resources
161170 dummy_jsonl (path_to_raws )
@@ -195,27 +204,26 @@ def test_bin_reader():
195204 assert isinstance (indexed_dataset_mmap .bin_reader , _MMapBinReader )
196205
197206 indexed_dataset_msc = IndexedDataset (
198- MSC_PREFIX + "default" + prefix , # use the default profile to access the filesystem
207+ MSC_PREFIX + prefix . lstrip ( "/" ),
199208 multimodal = False ,
200209 mmap = False ,
201210 object_storage_config = ObjectStorageConfig (
202- path_to_idx_cache = path_to_object_storage_cache
211+ path_to_idx_cache = path_to_object_storage_cache_msc
203212 ),
204213 )
205214 assert isinstance (indexed_dataset_msc .bin_reader , _MultiStorageClientBinReader )
206215 assert len (indexed_dataset_msc ) == len (indexed_dataset_file )
207216 assert len (indexed_dataset_msc ) == len (indexed_dataset_mmap )
208217
209218 indexed_dataset_s3 = IndexedDataset (
210- S3_PREFIX + prefix ,
219+ S3_PREFIX + prefix . lstrip ( "/" ) ,
211220 multimodal = False ,
212221 mmap = False ,
213222 object_storage_config = ObjectStorageConfig (
214- path_to_idx_cache = path_to_object_storage_cache
223+ path_to_idx_cache = path_to_object_storage_cache_s3
215224 ),
216225 )
217226 assert isinstance (indexed_dataset_s3 .bin_reader , _S3BinReader )
218-
219227 assert len (indexed_dataset_s3 ) == len (indexed_dataset_file )
220228 assert len (indexed_dataset_s3 ) == len (indexed_dataset_mmap )
221229
@@ -226,6 +234,7 @@ def test_bin_reader():
226234 for idx in indices :
227235 assert (indexed_dataset_s3 [idx ] == indexed_dataset_file [idx ]).all ()
228236 assert (indexed_dataset_s3 [idx ] == indexed_dataset_mmap [idx ]).all ()
237+ assert (indexed_dataset_s3 [idx ] == indexed_dataset_msc [idx ]).all ()
229238
230239
231240if __name__ == "__main__" :
0 commit comments