Skip to content

Commit 4619fd0

Browse files
committed
AR compute should include base token
1 parent 2ad45b7 commit 4619fd0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

megatron/post_training/non_loss_data_func.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def report_draft_acceptance_length(model, osl: int = 64, draft_steps: int = 7):
4040
total_steps += steps
4141
if torch.distributed.get_rank() == 0:
4242
al = actual_osl / steps
43-
ar = al / (draft_steps + parallel_draft_step - 1)
43+
ar = al / (draft_steps + parallel_draft_step)
4444
print(
4545
"Rank {:3}/{:3} {:12} AL {:.1f} AR {:.2f} STEPS {:5}/{:5} DRAFT {:2} PARALLEL {:2}".format(
4646
torch.distributed.get_rank(),
@@ -57,7 +57,7 @@ def report_draft_acceptance_length(model, osl: int = 64, draft_steps: int = 7):
5757
)
5858
if torch.distributed.get_rank() == 0:
5959
al = total_osl / total_steps
60-
ar = al / (draft_steps + parallel_draft_step - 1)
60+
ar = al / (draft_steps + parallel_draft_step)
6161
print(
6262
"Rank {:3}/{:3} {:12} AL {:.1f} AR {:.2f} STEPS {:5}/{:5} DRAFT {:2} PARALLEL {:2}".format(
6363
torch.distributed.get_rank(),

0 commit comments

Comments
 (0)