diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 3eddcd7f..e99d8d93 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -114,6 +114,7 @@ def fit( finetuned_ckpt_name: str = "finetuned-ckpt", callbacks: list["TrainerCallback"] | None = None, remove_printer_callback: bool = False, + disable_data_parallel: bool = True, **extra_trainer_kwargs, ) -> "Chronos2Pipeline": """ @@ -158,6 +159,8 @@ def fit( A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer` remove_printer_callback If True, all instances of `PrinterCallback` are removed from callbacks + disable_data_parallel + If True, ensures that DataParallel is disabled and training happens on a single GPU **extra_trainer_kwargs Extra kwargs are directly forwarded to `TrainingArguments` @@ -319,6 +322,11 @@ def fit( training_args = TrainingArguments(**training_kwargs) + if disable_data_parallel and not use_cpu: + # This is a hack to disable the default `transformers` behavior of using DataParallel + training_args._n_gpu = 1 + assert training_args.n_gpu == 1 # Ensure that the hack worked + trainer = Chronos2Trainer( model=model, args=training_args,