diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index d396d5b5..6445a20b 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -352,6 +352,15 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " + # Upload video to wandb + if self.logger_type == "wandb" and not self.disable_logs: + # use video_fps from cfg if available or default to 30 + if "video_fps" in self.cfg: + video_fps = self.cfg["video_fps"] + else: + video_fps = 30 + self.writer.add_video_files(self.log_dir, fps=video_fps, step=locs["it"]) + if len(locs["rewbuffer"]) > 0: log_string = ( f"""{'#' * width}\n""" diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 243e82d4..e4eec795 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -5,7 +5,7 @@ from __future__ import annotations -import os +import os, pathlib, json from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter @@ -44,6 +44,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg): "Train/mean_reward/time": "Train/mean_reward_time", "Train/mean_episode_length/time": "Train/mean_episode_length_time", } + self.video_files = [] def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): wandb.config.update({"runner_cfg": runner_cfg}) @@ -76,6 +77,22 @@ def save_model(self, model_path, iter): def save_file(self, path, iter=None): wandb.save(path, base_path=os.path.dirname(path)) + def add_video_files(self, log_dir: str, step: int, fps: int = 30): + # Check if there are video files in the video directory + if os.path.exists(log_dir): + # append the new video files to the existing list + for root, dirs, files in os.walk(log_dir): + for video_file in files: + if video_file.endswith(".mp4") and video_file not in self.video_files: + self.video_files.append(video_file) + # add the new video file to wandb only if video file is not updating + video_path = os.path.join(root, video_file) + wandb.log( + {"Video": wandb.Video(video_path, fps=fps, format="mp4")}, + step = step + ) + + """ Private methods. """