Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/chronos/chronos2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def fit(
finetuned_ckpt_name: str = "finetuned-ckpt",
callbacks: list["TrainerCallback"] | None = None,
remove_printer_callback: bool = False,
disable_data_parallel: bool = True,
**extra_trainer_kwargs,
) -> "Chronos2Pipeline":
"""
Expand Down Expand Up @@ -158,6 +159,8 @@ def fit(
A list of `TrainerCallback`s which will be forwarded to the HuggingFace `Trainer`
remove_printer_callback
If True, all instances of `PrinterCallback` are removed from callbacks
disable_data_parallel
If True, ensures that DataParallel is disabled and training happens on a single GPU
**extra_trainer_kwargs
Extra kwargs are directly forwarded to `TrainingArguments`

Expand Down Expand Up @@ -319,6 +322,11 @@ def fit(

training_args = TrainingArguments(**training_kwargs)

if disable_data_parallel and not use_cpu:
# This is a hack to disable the default `transformers` behavior of using DataParallel
training_args._n_gpu = 1
assert training_args.n_gpu == 1 # Ensure that the hack worked
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some branches where this is set to 0 (e.g. on the CPU)
https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/training_args.py#L1778
Are we sure we don't break it? Should we instead set either

training_args._n_gpu = min(1, training_args._n_gpu)

or

if disable_data_parallel and torch.cuda.device_count() > 1:
    training_args._n_gpu = min(1, training_args._n_gpu)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a guard not use_cpu. Do you think is good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me


trainer = Chronos2Trainer(
model=model,
args=training_args,
Expand Down