diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index b0a7c1d5..9c651ed8 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -6,8 +6,9 @@ """Definitions for neural-network components for RL-agents.""" from .actor_critic import ActorCritic +from .actor_critic_conv2d import ActorCriticConv2d from .actor_critic_recurrent import ActorCriticRecurrent from .normalizer import EmpiricalNormalization from .rnd import RandomNetworkDistillation -__all__ = ["ActorCritic", "ActorCriticRecurrent", "EmpiricalNormalization", "RandomNetworkDistillation"] +__all__ = ["ActorCritic", "ActorCriticConv2d", "ActorCriticRecurrent", "EmpiricalNormalization", "RandomNetworkDistillation"] diff --git a/rsl_rl/modules/actor_critic_conv2d.py b/rsl_rl/modules/actor_critic_conv2d.py new file mode 100755 index 00000000..dabb51bf --- /dev/null +++ b/rsl_rl/modules/actor_critic_conv2d.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import torch +import torch.nn as nn +from torch.distributions import Normal + +from rsl_rl.utils import resolve_nn_activation + + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(channels) + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out += residual + out = self.relu(out) + return out + + +class ConvolutionalNetwork(nn.Module): + def __init__( + self, + proprio_input_dim, + output_dim, + image_input_shape, + conv_layers_params, + hidden_dims, + activation_fn, + conv_linear_output_size, + ): + super().__init__() + + self.image_input_shape = image_input_shape # (C, H, W) + self.image_obs_size = torch.prod(torch.tensor(self.image_input_shape)).item() + self.proprio_obs_size = proprio_input_dim + self.input_dim = self.proprio_obs_size + self.image_obs_size + self.activation_fn = activation_fn + + # Build conv network and get its output size + self.conv_net = self.build_conv_net(conv_layers_params) + with torch.no_grad(): + dummy_image = torch.zeros(1, *self.image_input_shape) + conv_output = self.conv_net(dummy_image) + self.image_feature_size = conv_output.view(1, -1).shape[1] + + # Build the connection layers between conv net and mlp + self.conv_linear = nn.Linear(self.image_feature_size, conv_linear_output_size) + self.layernorm = nn.LayerNorm(conv_linear_output_size) + + # Build the mlp + self.mlp = nn.Sequential( + nn.Linear(self.proprio_obs_size + conv_linear_output_size, hidden_dims[0]), + self.activation_fn, + *[ + layer + for dim in zip(hidden_dims[:-1], hidden_dims[1:]) + for layer in (nn.Linear(dim[0], dim[1]), self.activation_fn) + ], + nn.Linear(hidden_dims[-1], output_dim), + ) + + # Initialize the weights + self._initialize_weights() + + def build_conv_net(self, conv_layers_params): + layers = [] + in_channels = self.image_input_shape[0] + for idx, params in enumerate(conv_layers_params[:-1]): + layers.extend([ + nn.Conv2d( + in_channels, + params["out_channels"], + kernel_size=params.get("kernel_size", 3), + stride=params.get("stride", 1), + padding=params.get("padding", 0), + ), + nn.BatchNorm2d(params["out_channels"]), + nn.ReLU(inplace=True), + ResidualBlock(params["out_channels"]) if idx > 0 else nn.Identity(), + ]) + in_channels = params["out_channels"] + last_params = conv_layers_params[-1] + layers.append( + nn.Conv2d( + in_channels, + last_params["out_channels"], + kernel_size=last_params.get("kernel_size", 3), + stride=last_params.get("stride", 1), + padding=last_params.get("padding", 0), + ) + ) + layers.append(nn.BatchNorm2d(last_params["out_channels"])) + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.conv_net.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + nn.init.kaiming_normal_(self.conv_linear.weight, mode="fan_out", nonlinearity="tanh") + nn.init.constant_(self.conv_linear.bias, 0) + nn.init.constant_(self.layernorm.weight, 1.0) + nn.init.constant_(self.layernorm.bias, 0.0) + + for layer in self.mlp: + if isinstance(layer, nn.Linear): + nn.init.orthogonal_(layer.weight, gain=0.01) + nn.init.zeros_(layer.bias) if layer.bias is not None else None + + def forward(self, observations): + proprio_obs = observations[:, : -self.image_obs_size] + image_obs = observations[:, -self.image_obs_size :] + + batch_size = image_obs.size(0) + image = image_obs.view(batch_size, *self.image_input_shape) + + conv_features = self.conv_net(image) + flattened_conv_features = conv_features.view(batch_size, -1) + normalized_conv_output = self.layernorm(self.conv_linear(flattened_conv_features)) + combined_input = torch.cat([proprio_obs, normalized_conv_output], dim=1) + output = self.mlp(combined_input) + return output + + +class ActorCriticConv2d(nn.Module): + is_recurrent = False + + def __init__( + self, + num_actor_obs, + num_critic_obs, + num_actions, + image_input_shape, + conv_layers_params, + conv_linear_output_size, + actor_hidden_dims, + critic_hidden_dims, + activation="elu", + init_noise_std=1.0, + **kwargs, + ): + super().__init__() + + self.image_input_shape = image_input_shape # (C, H, W) + self.activation_fn = resolve_nn_activation(activation) + + self.actor = ConvolutionalNetwork( + proprio_input_dim=num_actor_obs, + output_dim=num_actions, + image_input_shape=image_input_shape, + conv_layers_params=conv_layers_params, + hidden_dims=actor_hidden_dims, + activation_fn=self.activation_fn, + conv_linear_output_size=conv_linear_output_size, + ) + + self.critic = ConvolutionalNetwork( + proprio_input_dim=num_critic_obs, + output_dim=1, + image_input_shape=image_input_shape, + conv_layers_params=conv_layers_params, + hidden_dims=critic_hidden_dims, + activation_fn=self.activation_fn, + conv_linear_output_size=conv_linear_output_size, + ) + + print(f"Modified Actor Network: {self.actor}") + print(f"Modified Critic Network: {self.critic}") + + # Action noise + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + # Action distribution (populated in update_distribution) + self.distribution = None + # disable args validation for speedup + Normal.set_default_validate_args(False) + + def reset(self, dones=None): + pass + + def forward(self): + raise NotImplementedError + + @property + def action_mean(self): + return self.distribution.mean + + @property + def action_std(self): + return self.distribution.stddev + + @property + def entropy(self): + return self.distribution.entropy().sum(dim=-1) + + def update_distribution(self, observations): + mean = self.actor(observations) + self.distribution = Normal(mean, self.std) + + def act(self, observations, **kwargs): + self.update_distribution(observations) + return self.distribution.sample() + + def get_actions_log_prob(self, actions): + return self.distribution.log_prob(actions).sum(dim=-1) + + def act_inference(self, observations): + actions_mean = self.actor(observations) + return actions_mean + + def evaluate(self, critic_observations, **kwargs): + value = self.critic(critic_observations) + return value diff --git a/rsl_rl/runners/__init__.py b/rsl_rl/runners/__init__.py index e1713b2c..1534181e 100644 --- a/rsl_rl/runners/__init__.py +++ b/rsl_rl/runners/__init__.py @@ -6,5 +6,6 @@ """Implementation of runners for environment-agent interaction.""" from .on_policy_runner import OnPolicyRunner +from .on_policy_runner_conv2d import OnPolicyRunnerConv2d -__all__ = ["OnPolicyRunner"] +__all__ = ["OnPolicyRunner", "OnPolicyRunnerConv2d"] diff --git a/rsl_rl/runners/on_policy_runner_conv2d.py b/rsl_rl/runners/on_policy_runner_conv2d.py new file mode 100755 index 00000000..a8df5cde --- /dev/null +++ b/rsl_rl/runners/on_policy_runner_conv2d.py @@ -0,0 +1,228 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +import time +import torch +from collections import deque + +import rsl_rl +from rsl_rl.algorithms import PPO +from rsl_rl.env import VecEnv +from rsl_rl.modules import ActorCriticConv2d, EmpiricalNormalization +from rsl_rl.runners import OnPolicyRunner +from rsl_rl.utils import store_code_state + + +class OnPolicyRunnerConv2d(OnPolicyRunner): + """Custom on-policy runner for training and evaluation with convolutional actor-critic.""" + + def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): + self.cfg = train_cfg + self.alg_cfg = train_cfg["algorithm"] + self.policy_cfg = train_cfg["policy"] + self.device = device + self.env = env + obs, extras = self.env.get_observations() + num_obs = obs.shape[1] + if "critic" in extras["observations"]: + num_critic_obs = extras["observations"]["critic"].shape[1] + else: + num_critic_obs = num_obs + # Convert from [N, H, W, C] to [C, H, W] + input_image_shape = extras["observations"]["sensor"].permute(0, 3, 1, 2).shape[1:] + num_image_obs = torch.prod(torch.tensor(input_image_shape)).item() + + # init the actor-critic networks + actor_critic: ActorCriticConv2d = ActorCriticConv2d( + num_obs, num_critic_obs, self.env.num_actions, input_image_shape, **self.policy_cfg + ).to(self.device) + + # init the ppo algorithm + alg_class = eval(self.alg_cfg.pop("class_name")) # PPO + self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg) + self.num_steps_per_env = self.cfg["num_steps_per_env"] + self.save_interval = self.cfg["save_interval"] + self.empirical_normalization = self.cfg["empirical_normalization"] + if self.empirical_normalization: + self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) + self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) + else: + self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization + self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization + + # init storage and model + self.alg.init_storage( + self.env.num_envs, + self.num_steps_per_env, + [num_obs + num_image_obs], + [num_critic_obs + num_image_obs], + [self.env.num_actions], + ) + + # Log + self.log_dir = log_dir + self.writer = None + self.tot_timesteps = 0 + self.tot_time = 0 + self.current_learning_iteration = 0 + self.git_status_repos = [rsl_rl.__file__] + + def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901 + # initialize writer + if self.log_dir is not None and self.writer is None: + # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard. + self.logger_type = self.cfg.get("logger", "tensorboard") + self.logger_type = self.logger_type.lower() + + if self.logger_type == "neptune": + from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter + + self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) + self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) + elif self.logger_type == "wandb": + from rsl_rl.utils.wandb_utils import WandbSummaryWriter + + self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) + self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) + elif self.logger_type == "tensorboard": + from torch.utils.tensorboard import SummaryWriter + + self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + else: + raise ValueError("Logger type not found. Please choose 'neptune', 'wandb' or 'tensorboard'.") + + # randomize initial episode lengths (for exploration) + if init_at_random_ep_len: + self.env.episode_length_buf = torch.randint_like( + self.env.episode_length_buf, high=int(self.env.max_episode_length) + ) + + # start learning + obs, extras = self.env.get_observations() + critic_obs = extras["observations"].get("critic", obs) + + image_obs = extras["observations"]["sensor"].permute(0, 3, 1, 2).flatten(start_dim=1) + + obs = torch.cat([obs, image_obs], dim=1) + critic_obs = torch.cat([critic_obs, image_obs], dim=1) + obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) + + self.train_mode() # switch to train mode (for dropout for example) + + # Book keeping + ep_infos = [] + rewbuffer = deque(maxlen=100) + lenbuffer = deque(maxlen=100) + cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + # create buffers for logging extrinsic and intrinsic rewards + if self.alg.rnd: + erewbuffer = deque(maxlen=100) + irewbuffer = deque(maxlen=100) + cur_ereward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + cur_ireward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + + start_iter = self.current_learning_iteration + tot_iter = start_iter + num_learning_iterations + for it in range(start_iter, tot_iter): + start = time.time() + # Rollout + with torch.inference_mode(): + for _ in range(self.num_steps_per_env): + # Sample actions from policy + actions = self.alg.act(obs, critic_obs) + # Step environment + obs, rewards, dones, infos = self.env.step(actions.to(self.env.device)) + # Move to the agent device + obs, rewards, dones = obs.to(self.device), rewards.to(self.device), dones.to(self.device) + + # Normalize observations + obs = self.obs_normalizer(obs) + # Extract critic observations and normalize + if "critic" in infos["observations"]: + critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"].to(self.device)) + else: + critic_obs = obs + + # Concatenate image observations with proprioceptive observations + image_obs = infos["observations"]["sensor"].permute(0, 3, 1, 2).flatten(start_dim=1).to(self.device) + obs = torch.cat([obs, image_obs], dim=1) + critic_obs = torch.cat([critic_obs, image_obs], dim=1) + + # Process env step and store in buffer + self.alg.process_env_step(rewards, dones, infos) + + # Intrinsic rewards (extracted here only for logging)! + intrinsic_rewards = self.alg.intrinsic_rewards if self.alg.rnd else None + + if self.log_dir is not None: + # Book keeping + if "episode" in infos: + ep_infos.append(infos["episode"]) + elif "log" in infos: + ep_infos.append(infos["log"]) + # Update rewards + if self.alg.rnd: + cur_ereward_sum += rewards + cur_ireward_sum += intrinsic_rewards # type: ignore + cur_reward_sum += rewards + intrinsic_rewards + else: + cur_reward_sum += rewards + # Update episode length + cur_episode_length += 1 + # Clear data for completed episodes + # -- common + new_ids = (dones > 0).nonzero(as_tuple=False) + rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) + lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) + cur_reward_sum[new_ids] = 0 + cur_episode_length[new_ids] = 0 + # -- intrinsic and extrinsic rewards + if self.alg.rnd: + erewbuffer.extend(cur_ereward_sum[new_ids][:, 0].cpu().numpy().tolist()) + irewbuffer.extend(cur_ireward_sum[new_ids][:, 0].cpu().numpy().tolist()) + cur_ereward_sum[new_ids] = 0 + cur_ireward_sum[new_ids] = 0 + + stop = time.time() + collection_time = stop - start + + # Learning step + start = stop + self.alg.compute_returns(critic_obs) + + # Update policy + # Note: we keep arguments here since locals() loads them + mean_value_loss, mean_surrogate_loss, mean_entropy, mean_rnd_loss, mean_symmetry_loss = self.alg.update() + stop = time.time() + learn_time = stop - start + self.current_learning_iteration = it + + # Logging info and save checkpoint + if self.log_dir is not None: + # Log information + self.log(locals()) + # Save model + if it % self.save_interval == 0: + self.save(os.path.join(self.log_dir, f"model_{it}.pt")) + + # Clear episode infos + ep_infos.clear() + + # Save code state + if it == start_iter: + # obtain all the diff files + git_file_paths = store_code_state(self.log_dir, self.git_status_repos) + # if possible store them to wandb + if self.logger_type in ["wandb", "neptune"] and git_file_paths: + for path in git_file_paths: + self.writer.save_file(path) + + # Save the final model after training + if self.log_dir is not None: + self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))