Skip to content

Commit a2d8c80

Browse files
yueshen2016Chen-Han Yu
authored andcommitted
ADLR/megatron-lm!4169 - [OMNIML-2921] GPT-OSS Modelopt support
Co-authored-by: Chen-Han Yu <[email protected]>
1 parent 4666de7 commit a2d8c80

File tree

6 files changed

+150
-17
lines changed

6 files changed

+150
-17
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
3+
if [ -z ${HF_MODEL_CKPT} ]; then
4+
HF_MODEL_CKPT=openai/gpt-oss-20b
5+
TOKENIZER_MODEL=openai/gpt-oss-20b
6+
else
7+
TOKENIZER_MODEL=${HF_MODEL_CKPT}
8+
fi
9+
10+
# WAR: enable-gpt-oss is a temporary workaround for using the default GPT-OSS config
11+
MODEL_ARGS=" \
12+
--save-interval 100000 \
13+
--micro-batch-size 1 \
14+
--bf16 \
15+
--no-masked-softmax-fusion \
16+
--untie-embeddings-and-output-weights \
17+
--no-rope-fusion \
18+
--normalization RMSNorm \
19+
--num-layers 36 \
20+
--hidden-size 2880 \
21+
--ffn-hidden-size 2880 \
22+
--num-attention-heads 64 \
23+
--group-query-attention \
24+
--num-query-groups 8 \
25+
--kv-channels 64 \
26+
--num-experts 128 \
27+
--moe-ffn-hidden-size 2880 \
28+
--moe-router-dtype fp32 \
29+
--moe-router-topk 4 \
30+
--moe-aux-loss-coeff 0.0 \
31+
--moe-token-dispatcher-type alltoall \
32+
--moe-router-score-function softmax \
33+
--moe-router-load-balancing-type aux_loss \
34+
--seq-length 4096 \
35+
--max-position-embeddings 40960 \
36+
--tokenizer-type HuggingFaceTokenizer \
37+
--make-vocab-size-divisible-by 128 \
38+
--use-mcore-models \
39+
--rotary-percent 1.0 \
40+
--rotary-base 150000 \
41+
--no-bias-gelu-fusion \
42+
--sequence-parallel \
43+
--export-force-local-attention \
44+
--no-bias-dropout-fusion \
45+
--padded-vocab-size 201088 \
46+
--quick-geglu \
47+
--glu-linear-offset 1.0 \
48+
--softmax-type learnable \
49+
--window-attn-skip-freq 2 \
50+
--enable-gpt-oss \
51+
--activation-func-clamp-value 7.0 \
52+
--window-size 128,0 \
53+
"
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
3+
if [ -z ${HF_MODEL_CKPT} ]; then
4+
HF_MODEL_CKPT=openai/gpt-oss-20b
5+
TOKENIZER_MODEL=openai/gpt-oss-20b
6+
else
7+
TOKENIZER_MODEL=${HF_MODEL_CKPT}
8+
fi
9+
10+
# WAR: enable-gpt-oss is a temporary workaround for using the default GPT-OSS config
11+
MODEL_ARGS=" \
12+
--save-interval 100000 \
13+
--micro-batch-size 1 \
14+
--bf16 \
15+
--no-masked-softmax-fusion \
16+
--untie-embeddings-and-output-weights \
17+
--no-rope-fusion \
18+
--normalization RMSNorm \
19+
--num-layers 24 \
20+
--hidden-size 2880 \
21+
--ffn-hidden-size 2880 \
22+
--num-attention-heads 64 \
23+
--group-query-attention \
24+
--num-query-groups 8 \
25+
--kv-channels 64 \
26+
--num-experts 32 \
27+
--moe-ffn-hidden-size 2880 \
28+
--moe-router-dtype fp32 \
29+
--moe-router-topk 4 \
30+
--moe-aux-loss-coeff 0.0 \
31+
--moe-token-dispatcher-type alltoall \
32+
--moe-router-score-function softmax \
33+
--moe-router-load-balancing-type aux_loss \
34+
--seq-length 4096 \
35+
--max-position-embeddings 40960 \
36+
--tokenizer-type HuggingFaceTokenizer \
37+
--make-vocab-size-divisible-by 128 \
38+
--use-mcore-models \
39+
--rotary-percent 1.0 \
40+
--rotary-base 150000 \
41+
--no-bias-gelu-fusion \
42+
--sequence-parallel \
43+
--export-force-local-attention \
44+
--no-bias-dropout-fusion \
45+
--padded-vocab-size 201088 \
46+
--quick-geglu \
47+
--glu-linear-offset 1.0 \
48+
--softmax-type learnable \
49+
--window-attn-skip-freq 2 \
50+
--enable-gpt-oss \
51+
--activation-func-clamp-value 7.0 \
52+
--window-size 128,0 \
53+
"

examples/post_training/modelopt/generate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ def get_conversations(example):
150150
input_ids = tokenizer.apply_chat_template(
151151
new_conversations, return_tensors="pt", add_generation_prompt=True
152152
)
153-
output_ids = simple_generate(
154-
unwrapped_model, input_ids.cuda(), osl=args.osl, disable_tqdm=args.disable_tqdm
155-
)
153+
with torch.no_grad():
154+
output_ids = simple_generate(
155+
unwrapped_model, input_ids.cuda(), osl=args.osl, disable_tqdm=args.disable_tqdm
156+
)
156157
output_texts = tokenizer.batch_decode(output_ids)[0]
157158
print_rank_0("{}".format(output_texts))
158159
new_conversations.append({"role": "assistant", "content": output_texts})

examples/post_training/modelopt/mmlu.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def add_mmlu_args(parser):
2828
group.add_argument("--disable-tqdm", action="store_true", help="Disable tqdm.")
2929
group.add_argument("--fraction", type=float, default=1.0, help="Fraction of dataset to use.")
3030
group.add_argument("--lower-bound", type=float, default=None)
31+
group.add_argument("--no-subject-prompt", action="store_true", help="Use empty prompt instead of subject-based prompt.")
3132
add_modelopt_args(parser)
3233
return parser
3334

@@ -101,17 +102,20 @@ def format_example(example, include_answer: bool = True):
101102
for choice, answer in zip(["A", "B", "C", "D"], example["choices"]):
102103
prompt += "\n{}. {}".format(choice, answer)
103104
if include_answer:
104-
prompt += "Answer: {}\n\n".format(example["answer"])
105+
prompt += "\nAnswer: {}\n\n".format(["A", "B", "C", "D"][example["answer"]])
105106
else:
106107
prompt += "\nAnswer:"
107108
return prompt
108109

109110

110-
def generate_prompt(test_example, dev_examples, few_shots=0):
111+
def generate_prompt(test_example, dev_examples, few_shots=0, no_subject_prompt=False):
111112
"""Generating few-shot prompts."""
112-
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
113-
" ".join(test_example["subject"].split("_"))
114-
)
113+
if no_subject_prompt:
114+
prompt = ""
115+
else:
116+
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
117+
" ".join(test_example["subject"].split("_"))
118+
)
115119
for i in range(few_shots):
116120
prompt += format_example(dev_examples[i])
117121
prompt += format_example(test_example, include_answer=False)
@@ -147,11 +151,6 @@ def generate_prompt(test_example, dev_examples, few_shots=0):
147151
model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)
148152
report_current_memory_info()
149153

150-
# Materialize the model from meta device to gpu before loading the checkpoint.
151-
unwrapped_model = unwrap_model(model)[0]
152-
unwrapped_model.to_empty(device="cuda")
153-
report_current_memory_info()
154-
155154
disable_tqdm = args.disable_tqdm or torch.distributed.get_rank() > 0
156155

157156
tokenizer = get_tokenizer()._tokenizer
@@ -160,6 +159,9 @@ def generate_prompt(test_example, dev_examples, few_shots=0):
160159
load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
161160
print_rank_0("Done loading checkpoint")
162161

162+
unwrapped_model = unwrap_model(model)[0]
163+
unwrapped_model.eval()
164+
163165
all_subjects = get_all_subjects()
164166

165167
all_correct = {}
@@ -172,12 +174,13 @@ def generate_prompt(test_example, dev_examples, few_shots=0):
172174
for idx, test_example in enumerate(test_data):
173175
if idx > args.fraction * len(test_data):
174176
break
175-
prompt = generate_prompt(test_example, dev_data, few_shots=0)
177+
prompt = generate_prompt(test_example, dev_data, few_shots=0, no_subject_prompt=args.no_subject_prompt)
176178
label = ["A", "B", "C", "D"][test_example["answer"]]
177179
tokens = tokenizer(prompt, return_tensors="pt")
178-
generated_ids = simple_generate(
179-
unwrapped_model, tokens.input_ids.cuda(), osl=2, disable_tqdm=disable_tqdm
180-
)
180+
with torch.no_grad():
181+
generated_ids = simple_generate(
182+
unwrapped_model, tokens.input_ids.cuda(), osl=2, disable_tqdm=disable_tqdm
183+
)
181184
predict = tokenizer.batch_decode(generated_ids)[0].strip()
182185
correct += [True] if predict.startswith(label) else [False]
183186
all_correct[subject] = correct

megatron/post_training/arguments.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,13 @@ def add_modelopt_args(parser):
122122
action="store_true",
123123
help='Will be set automatically when loading a ModelOpt checkpoint.',
124124
)
125+
126+
# GPT-OSS YaRN RoPE support
127+
group.add_argument(
128+
'--enable-gpt-oss',
129+
action="store_true",
130+
help='Enable GPT-OSS mode with YaRN RoPE configuration. When enabled, automatically '
131+
'configures all YaRN parameters with GPT-OSS defaults.',
132+
)
125133

126134
return parser

megatron/post_training/model_provider.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,21 @@ def model_provider(pre_process=True, post_process=True, parallel_output=True) ->
151151
# ModelOpt by default assumes none homogenous layers. This affect the storage format of the sharded checkpoint.
152152
config = core_transformer_config_from_args(args)
153153

154+
# Handle GPT-OSS mode with YaRN RoPE configuration
155+
if hasattr(args, 'enable_gpt_oss') and args.enable_gpt_oss:
156+
print_rank_0("GPT-OSS mode enabled: Configuring YaRN RoPE parameters")
157+
158+
# Set GPT-OSS YaRN values directly on the config
159+
# These defaults are based on Huggingface GPT-OSS configurations
160+
config.position_embedding_type = "yarn"
161+
config.yarn_rotary_scaling_factor = 32.0
162+
config.yarn_original_max_position_embeddings = 131072
163+
config.yarn_beta_fast = 32.0
164+
config.yarn_beta_slow = 1.0
165+
config.yarn_mscale = 1.0
166+
config.yarn_mscale_all_dim = 0.0
167+
config.yarn_correction_range_round_to_int = False
168+
154169
if args.use_legacy_models:
155170
raise ValueError(
156171
"ModelOpt integration only support MCore models. Use --use-mcore-modules instead."

0 commit comments

Comments
 (0)