Skip to content

Commit 541b254

Browse files
authored
log total tokens in esm2 recipe (#1309)
In comparing THD and BSHD runs, it's helpful to have a view of the total batch size (in terms of number of tokens). This logs the summed tokens per batch to wandb Signed-off-by: Peter St. John <[email protected]>
1 parent a4d56c0 commit 541b254

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

bionemo-recipes/recipes/esm2_native_te/perf_logger.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
5858
"train/step_time": torchmetrics.MeanMetric(),
5959
"train/tokens_per_second": torchmetrics.MeanMetric(),
6060
"train/unpadded_tokens_per_second": torchmetrics.MeanMetric(),
61+
"train/total_unpadded_tokens_per_batch": torchmetrics.SumMetric(),
6162
"train/perplexity": torchmetrics.text.Perplexity(ignore_index=-100),
6263
"train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(),
6364
"train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(),
@@ -103,6 +104,7 @@ def log_step(
103104
self.metrics["train/step_time"].update(step_time)
104105
self.metrics["train/tokens_per_second"].update(num_tokens / step_time)
105106
self.metrics["train/unpadded_tokens_per_second"].update(num_unpadded_tokens / step_time)
107+
self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency)
106108

107109
# Handle sequence packing for torchmetrics calculation.
108110
if outputs.logits.dim() < 3:

0 commit comments

Comments
 (0)