Skip to content

Commit d41e8ce

Browse files
authored
Merge pull request #4 from microsoft/train
Adding required files and restructuring directories
2 parents 24a1b39 + 8523d5f commit d41e8ce

File tree

14 files changed

+522
-148
lines changed

14 files changed

+522
-148
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
*.ipynb
2-
*.parquet
2+
*.parquet
3+
dataset/
4+
models/
5+
local_util/

CONTRIBUTING.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Contributing
2+
3+
This project welcomes contributions and suggestions. Most contributions require you to
4+
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5+
and actually do, grant us the rights to use your contribution. For details, visit
6+
https://cla.microsoft.com.
7+
8+
When you submit a pull request, a CLA-bot will automatically determine whether you need
9+
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
10+
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
11+
12+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
13+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
14+
or contact [[email protected]](mailto:[email protected]) with any additional questions or comments.

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,15 @@ trademarks or logos is subject to and must follow
112112
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
113113
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
114114
Any use of third-party trademarks or logos are subject to those third-party's policies.
115+
116+
## Citation
117+
118+
```bibtex
119+
@inproceedings{aggarwal2025nextcoder,
120+
author = {Aggarwal, Tushar and Singh, Swayam and Awasthi, Abhijeet and Kanade, Aditya and Natarajan, Nagarajan},
121+
title = {NextCoder: Robust Adaptation of Code LMs to Diverse Code Edits},
122+
booktitle = {International Conference on Machine Learning},
123+
year = {2025},
124+
url = {https://www.microsoft.com/en-us/research/publication/nextcoder-robust-adaptation-of-code-lms-to-diverse-code-edits/},
125+
}
126+
```

src/train/README.md

Lines changed: 18 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# Model Training scripts
22

33
## Folder Structure
4-
- `ds_config.json` contains the deepspeed configuration
5-
- `general_acc.yaml` contains the accelerate configuration (might need to be modified as per desired system)
6-
- `lora.py` contains the code for training model with LoRA
7-
- `merge_lora.py` contains the code for merging trained LoRA adapters back to model for inference
8-
- `seletkt.py` contains the code for training model with our algorithm explained in our paper
9-
- `sft.py` contains the code for training model with Full Supervised Finetuning
4+
- `configs` contains the deepspeed and accelerate configurations (modifialbe as per the system)
5+
- `lora` contains the code for training model with LoRA
6+
- `seletkt` contains the code for training model with SeleKT algorithm explained in our paper
7+
- `sft` contains the code for training model with Full Supervised Finetuning
108

119
## Usgae
1210
### Preparing the dataset
@@ -23,122 +21,23 @@
2321
### Training with SFT
2422
- modify or replace the `general_acc.yaml` file as per the desired system configuration
2523
- set the `zero_optimization-stage` to `3` and `overlap_comm` to `false` in `ds_config` for better memory optimizations
26-
- Run the following command to start training
27-
```bash
28-
deepspeed sft.py \
29-
--model_name_or_path "path to pretrained LLM" \
30-
--train_data_path "path to training data" \
31-
--output_dir "path to output dir" \
32-
--num_train_epochs 3 \
33-
--model_max_length 8192 \
34-
--per_device_train_batch_size 4 \
35-
--gradient_accumulation_steps 4 \
36-
--save_strategy "epoch" \
37-
--save_steps 760 \
38-
--save_total_limit 25 \
39-
--learning_rate 1e-5 \
40-
--warmup_ratio 0.1 \
41-
--logging_steps 5 \
42-
--report_to "wandb" \
43-
--gradient_checkpointing True \
44-
--deepspeed ds_config.json \
45-
--bf16 True \
46-
--run_name "Run name for logs" \
47-
--debug True \
48-
```
49-
Update the above command as per the model
50-
- To train on conversation data by only applying loss on the response, uncomment the lines 175, 176 and 185 and run the same command with proper conversational dataset path
51-
```python
52-
response_template = "#RESPONSE\n"
53-
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
54-
55-
# Initialize trainer
56-
trainer = SFTTrainer(
57-
model=model,
58-
processing_class=tokenizer,
59-
train_dataset=dataset,
60-
args=training_config,
61-
callbacks=[Callback(flush_steps=1)],
62-
data_collator=collator, # pass the collator in the trainer
63-
)
64-
```
24+
- Add the respecitive variables like `MODEL_PATH`, `TRAIN_DATA`, `OUTPUT_DIR` etc. in the `run.sh` script and run
25+
```bash
26+
bash ./sft/run.sh
27+
```
6528

6629
### Training with LoRA
6730
- modify or replace the `general_acc.yaml` file as per the desired system configuration
68-
- set the `zero_optimization-stage` to `2` and `overlap_comm` to `false` in `ds_config` for better memory optimizations
69-
- Run the following command to start training
70-
```bash
71-
deepspeed lora.py \
72-
--model_name_or_path "path to pretrained LLM" \
73-
--train_data_path "path to training data" \
74-
--output_dir "path to output dir" \
75-
--num_train_epochs 3 \
76-
--model_max_length 8192 \
77-
--per_device_train_batch_size 4 \
78-
--gradient_accumulation_steps 4 \
79-
--save_strategy "epoch" \
80-
--save_steps 760 \
81-
--save_total_limit 25 \
82-
--learning_rate 1e-5 \
83-
--warmup_ratio 0.1 \
84-
--logging_steps 5 \
85-
--report_to "wandb" \
86-
--gradient_checkpointing True \
87-
--deepspeed ds_config.json \
88-
--bf16 True \
89-
--run_name "Run name for logs" \
90-
--debug True \
91-
```
92-
Update the above command as per the model
93-
- Put the path of output LoRA adapters inside `merge_lora.py` and run following to get the final checkpoints
94-
```bash
95-
python merge_lora.py
96-
```
31+
- set the `zero_optimization-stage` to `2` and `overlap_comm` to `false` in `ds_config`
32+
- Add the respecitive variables like `MODEL_PATH`, `TRAIN_DATA`, `OUTPUT_DIR` etc. in the `run.sh` script and run
33+
```bash
34+
bash ./lora/run.sh
35+
```
36+
>`lora/lora.py` uses `use_reentrant: True` for gradient checkpointing, and this can allow using deepspeed zero-3 optimization for large models.
9737
9838
### Training with SeleKT
9939
- modify or replace the `general_acc.yaml` file as per the desired system configuration
100-
- set the `zero_optimization-stage` to `2` and `overlap_comm` to `false` in `ds_config` for better memory optimizations
101-
- Run the following command to start training
102-
```bash
103-
accelerate launch \
104-
--config_file=general_acc.yaml \
105-
selekt.py \
106-
--model_name_or_path "path to pretrained LLM" \
107-
--base_model_path "path to pretrained LLM" \
108-
--train_data_path "path to training data" \
109-
--output_dir "path to output directory" \
110-
--num_train_epochs 3 \
111-
--model_max_length 8192 \
112-
--per_device_train_batch_size 4 \
113-
--gradient_accumulation_steps 4 \
114-
--save_strategy "steps" \
115-
--save_steps "Enter the periodicity value M for seleKT" \
116-
--save_total_limit 50 \
117-
--learning_rate 1e-5 \
118-
--warmup_ratio 0.1 \
119-
--logging_steps 5 \
120-
--report_to "wandb" \
121-
--gradient_checkpointing True \
122-
--deepspeed ds_config.json \
123-
--bf16 True \
124-
--run_name "Name for logs" \
125-
--debug True \
126-
--alpha "Enter value for desired alpha parameter for SeleKT" \
127-
```
128-
Update the above command as per the model
129-
- To train on conversation data by only applying loss on the response, uncomment the lines 291, 292 and 301 and run the same command with proper conversational dataset path
130-
```python
131-
```python
132-
response_template = "#RESPONSE\n"
133-
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
134-
135-
# Initialize trainer
136-
trainer = SFTTrainer(
137-
model=model,
138-
processing_class=tokenizer,
139-
train_dataset=dataset,
140-
args=training_config,
141-
callbacks=[Callback(flush_steps=1)],
142-
data_collator=collator, # pass the collator in the trainer
143-
)
144-
```
40+
- set the `zero_optimization-stage` to `3` and `overlap_comm` to `false` in `ds_config` for better memory optimizations
41+
- Add the respecitive variables like `MODEL_PATH`, `TRAIN_DATA`, `OUTPUT_DIR` etc. in the `run.sh` script and run
42+
```bash
43+
bash ./selekt/run.sh

src/train/SeleKT/run.sh

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/bin/bash
2+
3+
export MODEL_NAME=""
4+
export DESC=""
5+
6+
# Stage 1: Instruction Training
7+
OUTPUT_DIR_STAGE1="./output/selekt_stage1_instruction"
8+
TRAIN_DATA_STAGE1=""
9+
MODEL_PATH=""
10+
11+
# Stage 2: Conversational Training
12+
OUTPUT_DIR_STAGE2="./output/selekt_stage2_conversational"
13+
TRAIN_DATA_STAGE2=""
14+
15+
find_latest_checkpoint() {
16+
local output_dir=$1
17+
local latest_checkpoint=$(find "$output_dir" -name "checkpoint-*" -type d | sort -V | tail -1)
18+
echo "$latest_checkpoint"
19+
}
20+
21+
echo "Starting Stage 1: SeleKT Instruction Training..."
22+
echo "Model: $MODEL_PATH"
23+
echo "Training data: $TRAIN_DATA_STAGE1"
24+
echo "Output directory: $OUTPUT_DIR_STAGE1"
25+
26+
mkdir -p $OUTPUT_DIR_STAGE1
27+
28+
# Stage 1: Instruction Training
29+
accelerate launch \
30+
--config_file=../configs/general_acc.yaml \
31+
selekt.py \
32+
--model_name_or_path "$MODEL_PATH" \
33+
--train_data_path "$TRAIN_DATA_STAGE1" \
34+
--output_dir ${OUTPUT_DIR_STAGE1} \
35+
--num_train_epochs 3 \
36+
--model_max_length 16384 \
37+
--per_device_train_batch_size 1 \
38+
--gradient_accumulation_steps 4 \
39+
--save_strategy "epoch" \
40+
--save_steps 760 \
41+
--save_total_limit 25 \
42+
--learning_rate 1e-5 \
43+
--warmup_ratio 0.1 \
44+
--weight_decay 0.1 \
45+
--logging_steps 5 \
46+
--lr_scheduler_type "cosine" \
47+
--report_to "wandb" \
48+
--gradient_checkpointing True \
49+
--deepspeed ../configs/ds_config.json \
50+
--bf16 True \
51+
--run_name "${MODEL_NAME}_stage1_instruction" \
52+
--alpha 0.05 \
53+
54+
if [ $? -ne 0 ]; then
55+
echo "Error: Stage 1 training failed!"
56+
exit 1
57+
fi
58+
59+
echo "Stage 1 completed successfully!"
60+
61+
LATEST_CHECKPOINT=$(find_latest_checkpoint "$OUTPUT_DIR_STAGE1")
62+
63+
if [ -z "$LATEST_CHECKPOINT" ]; then
64+
echo "Error: No checkpoint found in $OUTPUT_DIR_STAGE1"
65+
exit 1
66+
fi
67+
68+
echo "Found latest checkpoint: $LATEST_CHECKPOINT"
69+
echo "Starting Stage 2: SeleKT Conversational Training..."
70+
echo "Model: $LATEST_CHECKPOINT"
71+
echo "Training data: $TRAIN_DATA_STAGE2"
72+
echo "Output directory: $OUTPUT_DIR_STAGE2"
73+
74+
mkdir -p $OUTPUT_DIR_STAGE2
75+
76+
# Stage 2: Conversational Training
77+
accelerate launch \
78+
--config_file=../configs/general_acc.yaml \
79+
selekt.py \
80+
--model_name_or_path "${LATEST_CHECKPOINT}" \
81+
--train_data_path "$TRAIN_DATA_STAGE2" \
82+
--output_dir ${OUTPUT_DIR_STAGE2} \
83+
--num_train_epochs 3 \
84+
--model_max_length 16384 \
85+
--per_device_train_batch_size 1 \
86+
--gradient_accumulation_steps 4 \
87+
--save_strategy "epoch" \
88+
--save_steps 760 \
89+
--save_total_limit 25 \
90+
--learning_rate 1e-5 \
91+
--warmup_ratio 0.1 \
92+
--weight_decay 0.1 \
93+
--logging_steps 5 \
94+
--lr_scheduler_type "cosine" \
95+
--report_to "wandb" \
96+
--gradient_checkpointing True \
97+
--deepspeed ../configs/ds_config.json \
98+
--bf16 True \
99+
--run_name "${MODEL_NAME}_stage2_conversational" \
100+
--alpha 0.05 \
101+
--is_conversational_training \
102+
103+
104+
# Check if stage 2 completed successfully
105+
if [ $? -ne 0 ]; then
106+
echo "Error: Stage 2 training failed!"
107+
exit 1
108+
fi
109+
110+
echo "Stage 2 training completed!"
111+
echo "Both training stages completed successfully!"
112+
echo "Final model saved in: $OUTPUT_DIR_STAGE2"

src/train/selekt.py renamed to src/train/SeleKT/selekt.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def parse_args():
7070
help="Whether to use bf16 mixed precision training")
7171
parser.add_argument("--run_name", type=str, default=None)
7272
parser.add_argument("--use_liger", type=bool, default=False)
73-
parser.add_argument("--debug", type=bool, default=False)
7473
parser.add_argument("--packing", type=bool, default=True,
7574
help="Whether to use packing for training")
76-
parser.add_argument("--alpha", type=float, default=0.05,)
75+
parser.add_argument("--alpha", type=float, default=0.05, help="Alpha value for SeleKT")
76+
parser.add_argument("--is_conversational_training", action='store_true',
77+
help="Whether to use conversational training format")
7778

7879
args, _ = parser.parse_known_args()
7980
return args
@@ -300,8 +301,10 @@ def train(args):
300301
print(f'Resuming from checkpoint: {last_checkpoint}')
301302

302303

303-
# response_template = "#RESPONSE\n"
304-
# collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
304+
collator = None
305+
if args.is_conversational_training:
306+
response_template = "#RESPONSE\n"
307+
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
305308

306309
callback = Callback(base_model_path=args.base_model_path, flush_steps=1, alpha=args.alpha)
307310
trainer = SFTTrainer(
@@ -310,7 +313,7 @@ def train(args):
310313
train_dataset=dataset,
311314
args=training_config,
312315
callbacks=[callback],
313-
# data_collator=collator,
316+
data_collator=collator,
314317
)
315318
callback.set_trainer(trainer)
316319
print(f"Starting training for epoch {args.num_train_epochs}")
File renamed without changes.
File renamed without changes.

src/train/lora.py renamed to src/train/lora/lora.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ def parse_args():
6666
help="Whether to use bf16 mixed precision training")
6767
parser.add_argument("--run_name", type=str, default=None)
6868
parser.add_argument("--use_liger", type=bool, default=False)
69-
parser.add_argument("--debug", type=bool, default=False)
7069
parser.add_argument("--packing", type=bool, default=True,
7170
help="Whether to use packing for training")
71+
parser.add_argument("--is_conversational_training", action='store_true',
72+
help="Whether to use conversational training format")
7273

7374
args, _ = parser.parse_known_args()
7475
return args
@@ -151,12 +152,13 @@ def main():
151152
output_dir=args.output_dir,
152153
report_to="none",
153154
gradient_checkpointing=args.gradient_checkpointing,
154-
gradient_checkpointing_kwargs={"use_reentrant": False},
155+
gradient_checkpointing_kwargs={"use_reentrant": True},
155156
deepspeed=args.deepspeed,
156157
dataset_num_proc=80,
157158
run_name=args.run_name,
158159
use_liger=args.use_liger,
159160
)
161+
160162
lora_config = LoraConfig(
161163
r=64,
162164
# target_modules= ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
@@ -175,13 +177,19 @@ def main():
175177

176178
dataset = setup_training_data(args, local_rank, tokenizer)
177179

180+
collator = None
181+
if args.is_conversational_training:
182+
response_template = "#RESPONSE\n"
183+
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
184+
178185
trainer = SFTTrainer(
179186
model=model,
180187
processing_class=tokenizer,
181188
train_dataset=dataset,
182189
args=training_config,
183190
peft_config=lora_config,
184191
callbacks=[Callback(flush_steps=1)],
192+
data_collator=collator
185193
)
186194

187195
print("Starting LoRA training...")

0 commit comments

Comments
 (0)