Skip to content

Commit 7bc9923

Browse files
NouamaneTazic8efeliebakgrewalskKabir Grewal
authored
SmolLM3 training 🚀 (#376)
* can only merge to main from dev * Fix UnBoundLocalError in `clm_collator.py` (#339) * Update clm_collator.py * can only merge to main from dev (#348) --------- Co-authored-by: Nouamane Tazi <[email protected]> * fix init and init scaling factor and run evals in background (#349) * InitScalingMethod * InitScalingMethod * run evals in background (#352) * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> * [Feature] Implement CUDA event-based timing for improved GPU performa… (#346) * [Feature] Implement CUDA event-based timing for improved GPU performance measurement * can only merge to main from dev (#348) * Fix timer decorator logic: Support both CPU and CUDA timers and update docs * Fix timer decorator logic: support both CPU and CUDA; update docs --------- Co-authored-by: Kabir Grewal <[email protected]> Co-authored-by: Nouamane Tazi <[email protected]> Co-authored-by: Kabir Grewal <[email protected]> * amend previous pr (#354) * MoE without token dropping (#355) * can only merge to main from dev (#348) * move moe from qwen modeling to src/nn * add groupedmlp * add token permute and unpermute * fix num_tokens_per_expert counting < num_experts * fix init and init scaling factor and run evals in background (#353) * can only merge to main from dev * Fix UnBoundLocalError in `clm_collator.py` (#339) * Update clm_collator.py * can only merge to main from dev (#348) --------- Co-authored-by: Nouamane Tazi <[email protected]> * fix init and init scaling factor and run evals in background (#349) * InitScalingMethod * InitScalingMethod * run evals in background (#352) * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> --------- Co-authored-by: Connector Switch <[email protected]> Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> * inference qwen moe seems to work inference seems good rn * update readme * fix router's weight initialization and wrong hidden size for non-moe mlp in qwen * add source for router weight and router logits in float32 * fixes * . * . * add parametrize grouped mlp in column and row linear * add logging per-param grad norm * fix conversation fail due to buffer on cpu * config_qwen * . * . * fix moe convert config --------- Co-authored-by: Nouamane Tazi <[email protected]> Co-authored-by: Connector Switch <[email protected]> Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> Co-authored-by: zzhhjjj <[email protected]> Co-authored-by: nouamanetazi <[email protected]> * Nouamane/lighteval (#356) * InitScalingMethod * InitScalingMethod * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes * add requeue * add wandb with lighteval and fix eval interval * fix this little space :( * folder_path should always have s3 when using s3 (fix consumed tokens issue) * config qwen * . --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> * SmoLM3 training 🚀 (#375) * InitScalingMethod * InitScalingMethod * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes * add requeue * add wandb with lighteval and fix eval interval * fix this little space :( * folder_path should always have s3 when using s3 (fix consumed tokens issue) * fix resuming with new data mixture * offsets must be in samples not tokens * sanity check local files when dataset_read_path * better error for new stage * rmsnorm * sliding window * causal SWA * Revert "rmsnorm" This reverts commit 17dad0a. * rope_seq_len_interpolation_factor * logmixin for intermediate tensors + CP + consumed_token shenanigans when resuming training (#365) * logmixin * context parallelism (llama3 ring attn) + consumed_token shenanigans (#366) * training works * llama3 ring attn * llama3 ring attn * llama3 ring attn * fix position_ids (make them global) * rope_seq_len_interpolation_factor assert * . * . * fix rope and cp_pg * fixed consumed_tokens log --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> --------- Co-authored-by: Connector Switch <[email protected]> Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> Co-authored-by: grewalsk <[email protected]> Co-authored-by: Kabir Grewal <[email protected]> Co-authored-by: Kabir Grewal <[email protected]> Co-authored-by: XλRI-U5 <[email protected]> Co-authored-by: zzhhjjj <[email protected]>
1 parent c737f00 commit 7bc9923

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2637
-319
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config-
100100
The model will be saved in the `checkpoints` directory as specified in the config file.
101101

102102
> [!NOTE]
103-
> You can use `examples/config_tiny_llama.py` to generate your own training config
103+
> You can use `examples/config_tiny_llama.py` to generate your own training config
104104
105105
For detailed instructions on training your first model, check out our [Your First Training guide](docs/your-first-training.md). For multi-node training with Slurm, see our [Multi-Node Training guide](docs/multi-node-training.md).
106106

@@ -175,6 +175,7 @@ We currently support the following features:
175175
- [x] Custom module checkpointing for large models
176176
- [x] Spectral µTransfer parametrization for scaling up neural networks
177177
- [x] Mamba example
178+
- [x] CUDA event-based timing for accurate GPU performance measurement
178179

179180
And we have on our roadmap:
180181
- [ ] FP8 training

docs/cuda_event_timing.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# CUDA Event-Based Timing in Nanotron
2+
3+
## Overview
4+
5+
Nanotron now uses CUDA events for timing GPU operations instead of CPU-based timing with `time.time()`. This change provides several benefits:
6+
7+
1. **More accurate measurement of GPU execution time**: CUDA events are recorded directly on the GPU timeline, providing more precise timing of GPU operations.
8+
2. **Reduced need for explicit CUDA synchronization**: CPU-based timing requires synchronization between CPU and GPU to get accurate measurements, which can introduce overhead and affect performance.
9+
3. **Lower overhead**: CUDA event-based timing has minimal impact on the execution of GPU operations.
10+
4. **Better performance monitoring**: More accurate timing leads to better performance analysis and optimization.
11+
12+
## Implementation Details
13+
14+
The implementation uses `torch.cuda.Event` with `enable_timing=True` to create start and end events that are recorded on the GPU timeline. The elapsed time is then calculated using `start_event.elapsed_time(end_event)`, which returns the time in milliseconds.
15+
16+
### Key Changes
17+
18+
1. **Default Timer Type**: The default timer type in `nanotron/src/nanotron/logging/timers.py` has been changed from `TimerType.CPU` to `TimerType.CUDA`.
19+
20+
2. **Iteration Timing**: The iteration timing in `trainer.py` now uses CUDA events instead of `time.time()`.
21+
22+
3. **Synchronization Control**: By default, CUDA event-based timers do not force synchronization unless explicitly requested with `cuda_sync=True`.
23+
24+
## Usage
25+
26+
### Basic Usage
27+
28+
```python
29+
# Create and use a CUDA timer (default)
30+
with nanotron_timer("my_operation"):
31+
# Your GPU operation here
32+
...
33+
34+
# Explicitly specify CUDA timing
35+
with nanotron_timer("my_operation", timer_type="cuda"):
36+
# Your GPU operation here
37+
...
38+
39+
# For CPU-only operations, you can still use CPU-based timing
40+
with nanotron_timer("cpu_operation", timer_type="cpu"):
41+
# Your CPU operation here
42+
...
43+
44+
# As a decorator with default CUDA timing
45+
@nanotron_timer
46+
def my_function():
47+
# Your GPU operation here
48+
...
49+
50+
# As a decorator with custom name
51+
@nanotron_timer("custom_name")
52+
def my_function():
53+
# Your GPU operation here
54+
...
55+
56+
# As a decorator with CPU timing
57+
@nanotron_timer(timer_type=TimerType.CPU)
58+
def my_cpu_function():
59+
# Your CPU operation here
60+
...
61+
```
62+
63+
### Advanced Usage
64+
65+
```python
66+
# Start and end a timer manually
67+
timer = nanotron_timer("my_operation")
68+
timer.start()
69+
# Your operation here
70+
timer.end()
71+
72+
# Get the elapsed time in seconds
73+
elapsed_time = timer.elapsed
74+
75+
# Get the total time across all calls
76+
total_time = timer.total_time
77+
78+
# Get the average time per call
79+
avg_time = timer.average_time
80+
```
81+
82+
## Considerations
83+
84+
1. **Synchronization**: By default, CUDA event-based timers do not force synchronization to avoid overhead. If you need more accurate timing at the cost of performance, you can set `cuda_sync=True`.
85+
86+
2. **Units**: CUDA events measure time in milliseconds, but the timer API converts this to seconds for consistency with the previous CPU-based timing.
87+
88+
3. **Fallback**: If CUDA is not available, the timer will automatically fall back to CPU-based timing.
89+
90+
## Performance Impact
91+
92+
Using CUDA events for timing instead of CPU-based timing with synchronization can significantly reduce overhead, especially in distributed training scenarios with thousands of GPUs.

examples/config_qwen.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"410m": (24, 1024, 16, 16, 4096), # ~410M params
3131
# Small to medium models
3232
"1b": (16, 2048, 16, 16, 5632), # ~1B params
33-
"3b": (28, 2048, 16, 2, 11008), # ~3B params
33+
"3b": (36, 2048, 16, 4, 11008), # ~3B params
3434
# Standard sizes
3535
"7b": (32, 4096, 32, 32, 11008), # ~7B params
3636
"13b": (40, 5120, 40, 40, 13824), # ~13B params
@@ -47,7 +47,7 @@ def get_args():
4747
parser.add_argument(
4848
"--model",
4949
choices=MODEL_SIZES.keys(),
50-
default="custom",
50+
default="3b",
5151
help="Model size to generate config for (e.g., 7b, 13b)",
5252
)
5353
parser.add_argument(
@@ -76,6 +76,10 @@ def get_args():
7676
tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size")
7777
tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica")
7878

79+
# checkpoints
80+
checkpoints_group = parser.add_argument_group("checkpoints")
81+
checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval")
82+
7983
args = parser.parse_args()
8084
return args
8185

@@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config:
108112
is_qwen2_config=True,
109113
pad_token_id=None,
110114
_attn_implementation="flash_attention_2",
111-
sliding_window_size=20,
115+
_use_doc_masking=True,
112116
)
113117

114118

@@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str:
154158

155159
def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config:
156160
learning_rate = LRSchedulerArgs(
157-
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
161+
learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0
158162
)
159163
parallelism = ParallelismArgs(
160164
dp=args.dp,
@@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
175179
)
176180
optimizer = OptimizerArgs(
177181
zero_stage=args.zero,
178-
weight_decay=0.01,
182+
weight_decay=0.1,
179183
clip_grad=1.0,
180184
accumulate_grad_in_fp32=True,
181185
learning_rate_scheduler=learning_rate,
@@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
192196

193197
return Config(
194198
general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity),
195-
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
199+
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save),
196200
parallelism=parallelism,
197201
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
198202
# tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"),
@@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
219223
world_size = args.dp * args.tp * args.pp * args.cp
220224
if world_size <= 8:
221225
print(
222-
f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
226+
f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
223227
)
228+
print("You can also use environment variables for more debugging:")
229+
print(" - ENABLE_TIMERS=1: Enable detailed timing information")
230+
print(" - DEBUG_CPU=1: Log CPU and memory usage statistics")
231+
print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection")
224232
else:
225233
print("Checkout slurm_launcher.py to launch a multi-node job")

examples/config_qwen.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
checkpoints:
2-
checkpoint_interval: 10
2+
checkpoint_interval: 100000
33
checkpoints_path: checkpoints
44
checkpoints_path_is_shared_file_system: false
55
load_lr_scheduler: true
@@ -30,9 +30,9 @@ data_stages:
3030
general:
3131
benchmark_csv_path: null
3232
consumed_train_samples: null
33-
ignore_sanity_checks: false
33+
ignore_sanity_checks: true
3434
project: debug
35-
run: qwen_20250410_014907_16027793
35+
run: qwen_20250424_120835_16423158
3636
seed: 42
3737
step: null
3838
lighteval: null
@@ -45,6 +45,7 @@ model:
4545
ddp_bucket_cap_mb: 25
4646
dtype: bfloat16
4747
init_method:
48+
scaling_method: NUM_LAYERS
4849
std: 0.025
4950
make_vocab_size_divisible_by: 1
5051
model_config:
@@ -58,23 +59,23 @@ model:
5859
eos_token_id: 2
5960
flex_attention_mask: null
6061
hidden_act: silu
61-
hidden_size: 256
62+
hidden_size: 2048
6263
initializer_range: 0.02
63-
intermediate_size: 768
64+
intermediate_size: 11008
6465
is_qwen2_config: true
6566
max_position_embeddings: 4096
6667
moe_config: null
6768
no_rope_layer: null
68-
num_attention_heads: 4
69-
num_hidden_layers: 12
69+
num_attention_heads: 16
70+
num_hidden_layers: 36
7071
num_key_value_heads: 4
7172
pad_token_id: null
7273
pretraining_tp: 1
7374
rms_norm_eps: 1.0e-06
7475
rope_interleaved: false
7576
rope_scaling: null
7677
rope_theta: 10000.0
77-
sliding_window_size: 20
78+
sliding_window_size: null
7879
tie_word_embeddings: true
7980
use_cache: true
8081
vocab_size: 128256
@@ -104,11 +105,10 @@ parallelism:
104105
context_parallel_size: 1
105106
dp: 2
106107
expert_parallel_size: 1
107-
moe_layer_recompute: false
108108
pp: 1
109109
pp_engine: 1f1b
110110
recompute_layer: false
111-
tp: 1
111+
tp: 2
112112
tp_linear_async_communication: true
113113
tp_mode: REDUCE_SCATTER
114114
tp_recompute_allgather: true

examples/config_qwen_with_moe.yaml

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: /fsx/phuc/new_workspace/experiments/qwen2_moe_test
4+
checkpoints_path_is_shared_file_system: false
5+
load_lr_scheduler: true
6+
load_optimizer: true
7+
resume_checkpoint_path: null
8+
save_final_state: true
9+
save_initial_state: false
10+
data_stages:
11+
- data:
12+
dataset:
13+
dataset_folder:
14+
- /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged
15+
dataset_max_tokens: null
16+
dataset_read_path: null
17+
dataset_weights: null
18+
pad_samples_to_global_batch_size: false
19+
return_positions: true
20+
shuffle_files: false
21+
skip_in_stream: false
22+
token_size_in_bytes: 4
23+
tokenizer_name: meta-llama/Llama-3.2-1B
24+
use_old_brrr_dataloader: false
25+
vocab_size: 128256
26+
num_loading_workers: 1
27+
seed: 42
28+
name: Stable Training Stage
29+
start_training_step: 1
30+
general:
31+
benchmark_csv_path: null
32+
consumed_train_samples: null
33+
ignore_sanity_checks: false
34+
project: qwen_moe
35+
run: qwen_20250410_014907_16027793
36+
seed: 42
37+
step: null
38+
lighteval: null
39+
logging:
40+
iteration_step_info_interval: 1
41+
log_level: info
42+
log_level_replica: info
43+
metrics_logging: null
44+
model:
45+
ddp_bucket_cap_mb: 25
46+
dtype: bfloat16
47+
init_method:
48+
std: 0.025
49+
make_vocab_size_divisible_by: 1
50+
model_config:
51+
_attn_implementation: flash_attention_2
52+
_fused_rms_norm: true
53+
_fused_rotary_emb: true
54+
_use_doc_masking: true
55+
_use_qkv_packed: true
56+
attention_bias: false
57+
bos_token_id: 1
58+
eos_token_id: 2
59+
flex_attention_mask: null
60+
hidden_act: silu
61+
hidden_size: 256
62+
initializer_range: 0.02
63+
intermediate_size: 768
64+
is_qwen2_config: true
65+
max_position_embeddings: 4096
66+
moe_config: null
67+
no_rope_layer: null
68+
num_attention_heads: 4
69+
num_hidden_layers: 12
70+
num_key_value_heads: 4
71+
pad_token_id: null
72+
pretraining_tp: 1
73+
rms_norm_eps: 1.0e-06
74+
rope_interleaved: false
75+
rope_scaling: null
76+
rope_theta: 10000.0
77+
sliding_window_size: 20
78+
tie_word_embeddings: true
79+
use_cache: true
80+
vocab_size: 128256
81+
z_loss_coefficient: 0.0001
82+
z_loss_enabled: false
83+
moe_config:
84+
num_experts: 8
85+
top_k: 1
86+
enable_shared_expert: true
87+
token_dispatcher_type: alltoall
88+
optimizer:
89+
accumulate_grad_in_fp32: true
90+
clip_grad: 1.0
91+
learning_rate_scheduler:
92+
learning_rate: 0.0003
93+
lr_decay_starting_step: null
94+
lr_decay_steps: 31998
95+
lr_decay_style: cosine
96+
lr_warmup_steps: 2
97+
lr_warmup_style: linear
98+
min_decay_lr: 1.0e-05
99+
optimizer_factory:
100+
adam_beta1: 0.9
101+
adam_beta2: 0.95
102+
adam_eps: 1.0e-08
103+
name: adamW
104+
torch_adam_is_fused: true
105+
weight_decay: 0.01
106+
weight_decay_exclude_named_params: []
107+
zero_stage: 0
108+
parallelism:
109+
context_parallel_size: 1
110+
dp: 2
111+
expert_parallel_size: 1
112+
pp: 1
113+
pp_engine: 1f1b
114+
recompute_layer: false
115+
tp: 1
116+
tp_linear_async_communication: true
117+
tp_mode: REDUCE_SCATTER
118+
tp_recompute_allgather: true
119+
profiler: null
120+
s3_upload: null
121+
tokenizer:
122+
tokenizer_max_length: null
123+
tokenizer_name_or_path: meta-llama/Llama-3.2-1B
124+
tokenizer_revision: null
125+
tokens:
126+
batch_accumulation_per_replica: 1
127+
limit_test_batches: 0
128+
limit_val_batches: 0
129+
micro_batch_size: 3
130+
sequence_length: 4096
131+
train_steps: 32000
132+
val_check_interval: -1

0 commit comments

Comments
 (0)