32
32
get_custom_data_collator ,
33
33
get_preprocessed_dataset ,
34
34
)
35
- from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
35
+ from QEfficient .finetune .utils .train_utils import (
36
+ get_longest_seq_length ,
37
+ print_model_size ,
38
+ print_trainable_parameters ,
39
+ train ,
40
+ )
36
41
from QEfficient .utils ._utils import login_and_download_hf_lm
37
- from QEfficient .utils .logging_utils import ft_logger as logger
42
+ from QEfficient .utils .logging_utils import logger
43
+
44
+ logger .setLevel (logging .INFO )
38
45
39
46
# Try importing QAIC-specific module, proceed without it if unavailable
40
47
try :
41
48
import torch_qaic # noqa: F401
42
49
except ImportError as e :
43
- logger .warning (f"{ e } . Moving ahead without these qaic modules." )
44
-
45
- logger .setLevel (logging .INFO )
50
+ logger .log_rank_zero (f"{ e } . Moving ahead without these qaic modules." )
46
51
47
52
48
53
# Suppress all warnings
@@ -121,7 +126,7 @@ def load_model_and_tokenizer(
121
126
)
122
127
123
128
if not hasattr (model , "base_model_prefix" ):
124
- raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
129
+ logger . raise_runtimeerror ("Given huggingface model does not have 'base_model_prefix' attribute." )
125
130
126
131
for param in getattr (model , model .base_model_prefix ).parameters ():
127
132
param .requires_grad = False
@@ -146,7 +151,7 @@ def load_model_and_tokenizer(
146
151
# If there is a mismatch between tokenizer vocab size and embedding matrix,
147
152
# throw a warning and then expand the embedding matrix
148
153
if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
149
- logger .warning ("Resizing the embedding matrix to match the tokenizer vocab size." )
154
+ logger .log_rank_zero ("Resizing the embedding matrix to match the tokenizer vocab size." , logger . WARNING )
150
155
model .resize_token_embeddings (len (tokenizer ))
151
156
152
157
# FIXME (Meet): Cover below line inside the logger once it is implemented.
@@ -162,7 +167,9 @@ def load_model_and_tokenizer(
162
167
if hasattr (model , "supports_gradient_checkpointing" ) and model .supports_gradient_checkpointing :
163
168
model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = {"preserve_rng_state" : False })
164
169
else :
165
- raise RuntimeError ("Given model doesn't support gradient checkpointing. Please disable it and run it." )
170
+ logger .raise_runtimeerror (
171
+ "Given model doesn't support gradient checkpointing. Please disable it and run it."
172
+ )
166
173
167
174
model = apply_peft (model , train_config , peft_config_file , ** kwargs )
168
175
@@ -197,7 +204,7 @@ def apply_peft(
197
204
else :
198
205
peft_config = generate_peft_config (train_config , peft_config_file , ** kwargs )
199
206
model = get_peft_model (model , peft_config )
200
- model . print_trainable_parameters ()
207
+ print_trainable_parameters (model )
201
208
202
209
return model
203
210
@@ -222,7 +229,7 @@ def setup_dataloaders(
222
229
- Length of longest sequence in the dataset.
223
230
224
231
Raises:
225
- ValueError : If validation is enabled but the validation set is too small.
232
+ RuntimeError : If validation is enabled but the validation set is too small.
226
233
227
234
Notes:
228
235
- Applies a custom data collator if provided by get_custom_data_collator.
@@ -246,12 +253,12 @@ def setup_dataloaders(
246
253
# )
247
254
##
248
255
train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , dataset_processer , "train" )
249
- logger .info (f"length of dataset_train = { len (dataset_train )} " )
256
+ logger .log_rank_zero (f"Length of dataset_train = { len (dataset_train )} " )
250
257
251
258
# FIXME (Meet): Add custom data collator registration from the outside by the user.
252
259
custom_data_collator = get_custom_data_collator (dataset_processer , dataset_config )
253
260
if custom_data_collator :
254
- logger .info ( "custom_data_collator is used" )
261
+ logger .log_rank_zero ( "Custom_data_collator is used" )
255
262
train_dl_kwargs ["collate_fn" ] = custom_data_collator
256
263
257
264
# Create DataLoaders for the training and validation dataset
@@ -261,7 +268,7 @@ def setup_dataloaders(
261
268
pin_memory = True ,
262
269
** train_dl_kwargs ,
263
270
)
264
- logger .info (f"Num of Training Set Batches loaded = { len (train_dataloader )} " )
271
+ logger .log_rank_zero (f"Number of Training Set Batches loaded = { len (train_dataloader )} " )
265
272
266
273
eval_dataloader = None
267
274
if train_config .run_validation :
@@ -281,11 +288,11 @@ def setup_dataloaders(
281
288
** val_dl_kwargs ,
282
289
)
283
290
if len (eval_dataloader ) == 0 :
284
- raise ValueError (
291
+ logger . raise_runtimeerror (
285
292
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
286
293
)
287
294
else :
288
- logger .info (f"Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
295
+ logger .log_rank_zero (f"Number of Validation Set Batches loaded = { len (eval_dataloader )} " )
289
296
290
297
longest_seq_length , _ = get_longest_seq_length (
291
298
torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
@@ -329,7 +336,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
329
336
330
337
# Create DataLoaders for the training and validation dataset
331
338
train_dataloader , eval_dataloader , longest_seq_length = setup_dataloaders (train_config , dataset_config , tokenizer )
332
- logger .info (
339
+ logger .log_rank_zero (
333
340
f"The longest sequence length in the train data is { longest_seq_length } , "
334
341
f"passed context length is { train_config .context_length } and overall model's context length is "
335
342
f"{ model .config .max_position_embeddings } "
@@ -340,7 +347,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
340
347
scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
341
348
if train_config .enable_ddp :
342
349
model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
343
- results = train (
350
+ _ = train (
344
351
model ,
345
352
tokenizer ,
346
353
train_dataloader ,
@@ -352,7 +359,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
352
359
)
353
360
if train_config .enable_ddp :
354
361
dist .destroy_process_group ()
355
- return results
362
+ return
356
363
357
364
358
365
if __name__ == "__main__" :
0 commit comments