Skip to content

Commit 853f207

Browse files
btabacopybara-github
authored andcommitted
Add sps to training metric logger.
PiperOrigin-RevId: 835120429 Change-Id: I3e339bf360aa3d3600d7cb323879e6558a7a3ddc
1 parent d96ab73 commit 853f207

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

brax/training/logger.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import collections
1818
import logging
19+
import time
1920
from jax import numpy as jnp
2021
import numpy as np
2122

@@ -38,6 +39,7 @@ def __init__(
3839
self._last_log_steps = 0
3940
self._log_count = 0
4041
self._progress_fn = progress_fn
42+
self._last_log_time = time.time()
4143

4244
def update_episode_metrics(self, episode_metrics, dones, train_metrics):
4345
self._num_steps += np.prod(dones.shape)
@@ -54,10 +56,16 @@ def update_episode_metrics(self, episode_metrics, dones, train_metrics):
5456
def log_metrics(self, pad=35):
5557
"""Log metrics to console."""
5658
self._log_count += 1
59+
now = time.time()
60+
steps_per_second = (self._num_steps - self._last_log_steps) / (
61+
now - self._last_log_time + 1e-8
62+
)
63+
self._last_log_time = now
5764
log_string = (
5865
f"\n{'Steps':>{pad}} Env: {self._num_steps} Log: {self._log_count}\n"
5966
)
60-
mean_metrics = {}
67+
mean_metrics = {'sps': steps_per_second}
68+
log_string += f"{'Steps per second:':>{pad}} {steps_per_second:.0f}\n"
6169
for metric_name in self._ep_metrics_buffer:
6270
mean_metrics[metric_name] = np.mean(self._ep_metrics_buffer[metric_name])
6371
log_string += (

0 commit comments

Comments
 (0)