Skip to content

Commit ce14058

Browse files
mamtsingquic-mamta
authored andcommitted
Update logging_utils and log for zero rank
Signed-off-by: Mamta Singh <[email protected]>
1 parent 7bf2b24 commit ce14058

File tree

7 files changed

+121
-114
lines changed

7 files changed

+121
-114
lines changed

QEfficient/cloud/finetune.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,22 @@
3232
get_custom_data_collator,
3333
get_preprocessed_dataset,
3434
)
35-
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
35+
from QEfficient.finetune.utils.train_utils import (
36+
get_longest_seq_length,
37+
print_model_size,
38+
print_trainable_parameters,
39+
train,
40+
)
3641
from QEfficient.utils._utils import login_and_download_hf_lm
37-
from QEfficient.utils.logging_utils import ft_logger as logger
42+
from QEfficient.utils.logging_utils import logger
43+
44+
logger.setLevel(logging.INFO)
3845

3946
# Try importing QAIC-specific module, proceed without it if unavailable
4047
try:
4148
import torch_qaic # noqa: F401
4249
except ImportError as e:
43-
logger.warning(f"{e}. Moving ahead without these qaic modules.")
44-
45-
logger.setLevel(logging.INFO)
50+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
4651

4752

4853
# Suppress all warnings
@@ -121,7 +126,7 @@ def load_model_and_tokenizer(
121126
)
122127

123128
if not hasattr(model, "base_model_prefix"):
124-
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
129+
logger.raise_runtimeerror("Given huggingface model does not have 'base_model_prefix' attribute.")
125130

126131
for param in getattr(model, model.base_model_prefix).parameters():
127132
param.requires_grad = False
@@ -146,7 +151,7 @@ def load_model_and_tokenizer(
146151
# If there is a mismatch between tokenizer vocab size and embedding matrix,
147152
# throw a warning and then expand the embedding matrix
148153
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
149-
logger.warning("Resizing the embedding matrix to match the tokenizer vocab size.")
154+
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logger.WARNING)
150155
model.resize_token_embeddings(len(tokenizer))
151156

152157
# FIXME (Meet): Cover below line inside the logger once it is implemented.
@@ -162,7 +167,9 @@ def load_model_and_tokenizer(
162167
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
163168
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
164169
else:
165-
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
170+
logger.raise_runtimeerror(
171+
"Given model doesn't support gradient checkpointing. Please disable it and run it."
172+
)
166173

167174
model = apply_peft(model, train_config, peft_config_file, **kwargs)
168175

@@ -197,7 +204,7 @@ def apply_peft(
197204
else:
198205
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
199206
model = get_peft_model(model, peft_config)
200-
model.print_trainable_parameters()
207+
print_trainable_parameters(model)
201208

202209
return model
203210

@@ -222,7 +229,7 @@ def setup_dataloaders(
222229
- Length of longest sequence in the dataset.
223230
224231
Raises:
225-
ValueError: If validation is enabled but the validation set is too small.
232+
RuntimeError: If validation is enabled but the validation set is too small.
226233
227234
Notes:
228235
- Applies a custom data collator if provided by get_custom_data_collator.
@@ -246,12 +253,12 @@ def setup_dataloaders(
246253
# )
247254
##
248255
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
249-
logger.info(f"length of dataset_train = {len(dataset_train)}")
256+
logger.log_rank_zero(f"Length of dataset_train = {len(dataset_train)}")
250257

251258
# FIXME (Meet): Add custom data collator registration from the outside by the user.
252259
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
253260
if custom_data_collator:
254-
logger.info("custom_data_collator is used")
261+
logger.log_rank_zero("Custom_data_collator is used")
255262
train_dl_kwargs["collate_fn"] = custom_data_collator
256263

257264
# Create DataLoaders for the training and validation dataset
@@ -261,7 +268,7 @@ def setup_dataloaders(
261268
pin_memory=True,
262269
**train_dl_kwargs,
263270
)
264-
logger.info(f"Num of Training Set Batches loaded = {len(train_dataloader)}")
271+
logger.log_rank_zero(f"Number of Training Set Batches loaded = {len(train_dataloader)}")
265272

266273
eval_dataloader = None
267274
if train_config.run_validation:
@@ -281,11 +288,11 @@ def setup_dataloaders(
281288
**val_dl_kwargs,
282289
)
283290
if len(eval_dataloader) == 0:
284-
raise ValueError(
291+
logger.raise_runtimeerror(
285292
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
286293
)
287294
else:
288-
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
295+
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
289296

290297
longest_seq_length, _ = get_longest_seq_length(
291298
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
@@ -329,7 +336,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
329336

330337
# Create DataLoaders for the training and validation dataset
331338
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
332-
logger.info(
339+
logger.log_rank_zero(
333340
f"The longest sequence length in the train data is {longest_seq_length}, "
334341
f"passed context length is {train_config.context_length} and overall model's context length is "
335342
f"{model.config.max_position_embeddings}"
@@ -340,7 +347,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
340347
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
341348
if train_config.enable_ddp:
342349
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
343-
results = train(
350+
_ = train(
344351
model,
345352
tokenizer,
346353
train_dataloader,
@@ -352,7 +359,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
352359
)
353360
if train_config.enable_ddp:
354361
dist.destroy_process_group()
355-
return results
362+
return
356363

357364

358365
if __name__ == "__main__":

QEfficient/finetune/configs/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,5 @@ class TrainConfig:
105105
grad_scaler: bool = True
106106
dump_root_dir: str = "meta-llama-samsum-mismatches/step_"
107107
opByOpVerifier: bool = False
108+
109+
dump_logs: bool = True

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import importlib
99
from pathlib import Path
1010

11-
from QEfficient.utils.logging_utils import ft_logger as logger
11+
from QEfficient.utils.logging_utils import logger
1212

1313

1414
def load_module_from_py_file(py_file: str) -> object:
@@ -32,20 +32,19 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
3232
module_path, func_name = dataset_config.file, "get_custom_dataset"
3333

3434
if not module_path.endswith(".py"):
35-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
35+
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
3636

3737
module_path = Path(module_path)
3838
if not module_path.is_file():
39-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
39+
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
4040

4141
module = load_module_from_py_file(module_path.as_posix())
4242
try:
4343
return getattr(module, func_name)(dataset_config, tokenizer, split)
44-
except AttributeError as e:
45-
logger.error(
44+
except AttributeError:
45+
logger.raise_runtimeerror(
4646
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
4747
)
48-
raise e
4948

5049

5150
def get_data_collator(dataset_processer, dataset_config):
@@ -55,11 +54,11 @@ def get_data_collator(dataset_processer, dataset_config):
5554
module_path, func_name = dataset_config.file, "get_data_collator"
5655

5756
if not module_path.endswith(".py"):
58-
raise ValueError(f"Dataset file {module_path} is not a .py file.")
57+
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
5958

6059
module_path = Path(module_path)
6160
if not module_path.is_file():
62-
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
61+
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
6362

6463
module = load_module_from_py_file(module_path.as_posix())
6564
try:

QEfficient/finetune/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ def main(**kwargs):
109109
pin_memory=True,
110110
**val_dl_kwargs,
111111
)
112-
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
112+
logger.log_rank_zero(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
113113
if len(eval_dataloader) == 0:
114114
raise ValueError(
115115
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
116116
)
117117
else:
118-
logger.info(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
118+
logger.log_rank_zero(f"Num of Validation Set Batches loaded = {len(eval_dataloader)}")
119119

120120
model.to(device)
121121
_ = evaluation(model, train_config, eval_dataloader, None, tokenizer, device)

QEfficient/finetune/utils/train_utils.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22-
from QEfficient.utils.logging_utils import ft_logger as logger
22+
from QEfficient.utils.logging_utils import logger
2323

2424
try:
2525
import torch_qaic # noqa: F401
@@ -28,7 +28,7 @@
2828
import torch_qaic.utils as qaic_utils # noqa: F401
2929
from torch.qaic.amp import GradScaler as QAicGradScaler
3030
except ImportError as e:
31-
logger.warning(f"{e}. Moving ahead without these qaic modules.")
31+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
3232

3333
from torch.amp import GradScaler
3434

@@ -110,22 +110,21 @@ def train(
110110
# Start the training loop
111111
for epoch in range(train_config.num_epochs):
112112
if loss_0_counter.item() == train_config.convergence_counter:
113-
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
114-
logger.info(
115-
f"Skipping epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
116-
)
117-
break
113+
logger.log_rank_zero(
114+
f"Skipping epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
115+
)
116+
break
118117

119118
if train_config.use_peft and train_config.from_peft_checkpoint:
120119
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
121120
if epoch < intermediate_epoch:
122-
logger.info(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
121+
logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
123122
# to bring the count of train_step in sync with where it left off
124123
total_train_steps += len(train_dataloader)
125124
continue
126125

127-
logger.info(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
128-
logger.info(f"train_config.max_train_step: {train_config.max_train_step}")
126+
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
127+
logger.log_rank_zero(f"train_config.max_train_step: {train_config.max_train_step}")
129128
# stop when the maximum number of training steps is reached
130129
if max_steps_reached:
131130
break
@@ -152,7 +151,7 @@ def train(
152151
# to bring the count of train_step in sync with where it left off
153152
if epoch == intermediate_epoch and step == 0:
154153
total_train_steps += intermediate_step
155-
logger.info(
154+
logger.log_rank_zero(
156155
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
157156
)
158157
if epoch == intermediate_epoch and step < intermediate_step:
@@ -264,12 +263,11 @@ def train(
264263
val_step_metric,
265264
val_metric,
266265
)
267-
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
268-
if loss_0_counter.item() == train_config.convergence_counter:
269-
logger.info(
270-
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps.Hence,stopping the fine tuning."
271-
)
272-
break
266+
if loss_0_counter.item() == train_config.convergence_counter:
267+
logger.log_rank_zero(
268+
f"Loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps.Hence,stopping the fine tuning."
269+
)
270+
break
273271

274272
pbar.close()
275273
epoch_end_time = time.perf_counter() - epoch_start_time
@@ -328,15 +326,15 @@ def train(
328326
if train_config.run_validation:
329327
if eval_epoch_loss < best_val_loss:
330328
best_val_loss = eval_epoch_loss
331-
logger.info(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
329+
logger.log_rank_zero(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
332330
val_loss.append(float(eval_epoch_loss))
333331
val_metric.append(float(eval_metric))
334332
if train_config.task_type == "seq_classification":
335-
logger.info(
333+
logger.log_rank_zero(
336334
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
337335
)
338336
else:
339-
logger.info(
337+
logger.log_rank_zero(
340338
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
341339
)
342340

@@ -440,7 +438,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
440438
eval_metric = torch.exp(eval_epoch_loss)
441439

442440
# Print evaluation metrics
443-
logger.info(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
441+
logger.log_rank_zero(f"{eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
444442

445443
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
446444

@@ -467,12 +465,23 @@ def print_model_size(model, config) -> None:
467465
468466
Args:
469467
model: The PyTorch model.
470-
model_name (str): Name of the model.
468+
config : Config of the model.
471469
"""
472-
473-
logger.info(f"Model : {config.model_name}")
474470
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
475-
logger.info(f"{config.model_name} has {total_params / 1e6} Million params\n")
471+
logger.log_rank_zero(f"{config.model_name} has {total_params / 1e6} Million params.")
472+
473+
474+
def print_trainable_parameters(model) -> None:
475+
"""
476+
Print the number of trainable parameters, all params and percentage of trainablke params.
477+
478+
Args:
479+
model: The PyTorch model.
480+
"""
481+
trainable_params, all_param = model.get_nb_trainable_parameters()
482+
logger.log_rank_zero(
483+
f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
484+
)
476485

477486

478487
def save_to_json(

QEfficient/utils/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DownloadRetryLimitExceeded(Exception):
3636

3737

3838
def login_and_download_hf_lm(model_name, *args, **kwargs):
39-
logger.info(f"loading HuggingFace model for {model_name}")
39+
logger.log_rank_zero(f"loading HuggingFace model for {model_name}")
4040
hf_token = kwargs.pop("hf_token", None)
4141
cache_dir = kwargs.pop("cache_dir", None)
4242
if hf_token is not None:

0 commit comments

Comments
 (0)