Skip to content

Commit 56ef25e

Browse files
committed
Added low-level API example for multiple experiments with rliable eval
1 parent 3f00231 commit 56ef25e

File tree

1 file changed

+268
-0
lines changed

1 file changed

+268
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""
2+
A low-level counterpart of `mujoco_ppo_hl_multi.py`. The directory structure of the persisted logs
3+
mirrors that of the high-level example.
4+
5+
Rollout of multiple experiments with different seeds is done manually in a for loop,
6+
the results are aggregated and evaluated with rliable.
7+
"""
8+
9+
import argparse
10+
import os
11+
from pathlib import Path
12+
13+
import numpy as np
14+
import torch
15+
from mujoco_env import make_mujoco_env
16+
from sensai.util import logging
17+
from sensai.util.logging import datetime_tag
18+
from torch import nn
19+
from torch.distributions import Distribution, Independent, Normal
20+
21+
from tianshou.algorithm import PPO
22+
from tianshou.algorithm.algorithm_base import Algorithm
23+
from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy
24+
from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear
25+
from tianshou.data import (
26+
Collector,
27+
CollectStats,
28+
ReplayBuffer,
29+
VectorReplayBuffer,
30+
)
31+
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
32+
from tianshou.highlevel.experiment import Experiment
33+
from tianshou.highlevel.logger import LoggerFactoryDefault
34+
from tianshou.trainer import OnPolicyTrainerParams
35+
from tianshou.utils.net.common import ActorCritic, Net
36+
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
37+
38+
DATETIME_TAG = datetime_tag()
39+
40+
41+
def get_args() -> argparse.Namespace:
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument("--task", type=str, default="Ant-v4")
44+
parser.add_argument("--num_experiments", type=int, default=5)
45+
parser.add_argument("--seed", type=int, default=0)
46+
parser.add_argument("--buffer_size", type=int, default=4096)
47+
parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64])
48+
parser.add_argument("--lr", type=float, default=3e-4)
49+
parser.add_argument("--gamma", type=float, default=0.99)
50+
parser.add_argument("--epoch", type=int, default=1)
51+
parser.add_argument("--epoch_num_steps", type=int, default=2000)
52+
parser.add_argument("--collection_step_num_env_steps", type=int, default=2048)
53+
parser.add_argument("--update_step_num_repetitions", type=int, default=10)
54+
parser.add_argument("--batch_size", type=int, default=64)
55+
parser.add_argument("--num_train_envs", type=int, default=8)
56+
parser.add_argument("--num_test_envs", type=int, default=10)
57+
# ppo special
58+
parser.add_argument("--return_scaling", type=int, default=True)
59+
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
60+
parser.add_argument("--vf_coef", type=float, default=0.25)
61+
parser.add_argument("--ent_coef", type=float, default=0.0)
62+
parser.add_argument("--gae_lambda", type=float, default=0.95)
63+
parser.add_argument("--bound_action_method", type=str, default="clip")
64+
parser.add_argument("--lr_decay", type=int, default=True)
65+
parser.add_argument("--max_grad_norm", type=float, default=0.5)
66+
parser.add_argument("--eps_clip", type=float, default=0.2)
67+
parser.add_argument("--dual_clip", type=float, default=None)
68+
parser.add_argument("--value_clip", type=int, default=0)
69+
parser.add_argument("--advantage_normalization", type=int, default=0)
70+
parser.add_argument("--recompute_adv", type=int, default=1)
71+
parser.add_argument("--logdir", type=str, default="log")
72+
parser.add_argument("--render", type=float, default=0.0)
73+
parser.add_argument(
74+
"--device",
75+
type=str,
76+
default="cuda" if torch.cuda.is_available() else "cpu",
77+
)
78+
parser.add_argument("--resume_path", type=str, default=None)
79+
parser.add_argument("--resume_id", type=str, default=None)
80+
parser.add_argument(
81+
"--logger",
82+
type=str,
83+
default="tensorboard",
84+
choices=["tensorboard", "wandb"],
85+
)
86+
parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark")
87+
parser.add_argument(
88+
"--watch",
89+
default=False,
90+
action="store_true",
91+
help="watch the play of pre-trained policy only",
92+
)
93+
return parser.parse_args()
94+
95+
96+
def get_persistence_dir(args: argparse.Namespace) -> str:
97+
algo_name = "ppo"
98+
log_subdir = os.path.join(
99+
args.task, f"{algo_name}_{DATETIME_TAG}", Experiment.seeding_info_str_static(args.seed)
100+
)
101+
return os.path.join(args.logdir, log_subdir)
102+
103+
104+
def main(args: argparse.Namespace = get_args()) -> None:
105+
print("Creating envs")
106+
env, train_envs, test_envs = make_mujoco_env(
107+
args.task,
108+
args.seed,
109+
args.num_train_envs,
110+
args.num_test_envs,
111+
obs_norm=True,
112+
)
113+
args.state_shape = env.observation_space.shape or env.observation_space.n
114+
args.action_shape = env.action_space.shape or env.action_space.n
115+
args.max_action = env.action_space.high[0]
116+
print("Observations shape:", args.state_shape)
117+
print("Actions shape:", args.action_shape)
118+
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
119+
# seed
120+
np.random.seed(args.seed)
121+
torch.manual_seed(args.seed)
122+
# model
123+
net_a = Net(
124+
state_shape=args.state_shape,
125+
hidden_sizes=args.hidden_sizes,
126+
activation=nn.Tanh,
127+
)
128+
actor = ContinuousActorProbabilistic(
129+
preprocess_net=net_a,
130+
action_shape=args.action_shape,
131+
unbounded=True,
132+
).to(args.device)
133+
net_c = Net(
134+
state_shape=args.state_shape,
135+
hidden_sizes=args.hidden_sizes,
136+
activation=nn.Tanh,
137+
)
138+
critic = ContinuousCritic(preprocess_net=net_c).to(args.device)
139+
actor_critic = ActorCritic(actor, critic)
140+
141+
torch.nn.init.constant_(actor.sigma_param, -0.5)
142+
for m in actor_critic.modules():
143+
if isinstance(m, torch.nn.Linear):
144+
# orthogonal initialization
145+
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
146+
torch.nn.init.zeros_(m.bias)
147+
# do last policy layer scaling, this will make initial actions have (close to)
148+
# 0 mean and std, and will help boost performances,
149+
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
150+
for m in actor.mu.modules():
151+
if isinstance(m, torch.nn.Linear):
152+
torch.nn.init.zeros_(m.bias)
153+
m.weight.data.copy_(0.01 * m.weight.data)
154+
155+
optim = AdamOptimizerFactory(lr=args.lr)
156+
157+
if args.lr_decay:
158+
optim.with_lr_scheduler_factory(
159+
LRSchedulerFactoryLinear(
160+
max_epochs=args.epoch,
161+
epoch_num_steps=args.epoch_num_steps,
162+
collection_step_num_env_steps=args.collection_step_num_env_steps,
163+
)
164+
)
165+
166+
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
167+
loc, scale = loc_scale
168+
return Independent(Normal(loc, scale), 1)
169+
170+
policy = ProbabilisticActorPolicy(
171+
actor=actor,
172+
dist_fn=dist,
173+
action_scaling=True,
174+
action_bound_method=args.bound_action_method,
175+
action_space=env.action_space,
176+
)
177+
algorithm: PPO = PPO(
178+
policy=policy,
179+
critic=critic,
180+
optim=optim,
181+
gamma=args.gamma,
182+
gae_lambda=args.gae_lambda,
183+
max_grad_norm=args.max_grad_norm,
184+
vf_coef=args.vf_coef,
185+
ent_coef=args.ent_coef,
186+
return_scaling=args.return_scaling,
187+
eps_clip=args.eps_clip,
188+
value_clip=args.value_clip,
189+
dual_clip=args.dual_clip,
190+
advantage_normalization=args.advantage_normalization,
191+
recompute_advantage=args.recompute_adv,
192+
)
193+
194+
# load a previous policy
195+
if args.resume_path:
196+
ckpt = torch.load(args.resume_path, map_location=args.device)
197+
algorithm.load_state_dict(ckpt["model"])
198+
train_envs.set_obs_rms(ckpt["obs_rms"])
199+
test_envs.set_obs_rms(ckpt["obs_rms"])
200+
print("Loaded agent from: ", args.resume_path)
201+
202+
# collector
203+
buffer: VectorReplayBuffer | ReplayBuffer
204+
if args.num_train_envs > 1:
205+
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
206+
else:
207+
buffer = ReplayBuffer(args.buffer_size)
208+
train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True)
209+
test_collector = Collector[CollectStats](algorithm, test_envs)
210+
211+
# log
212+
persistence_dir = get_persistence_dir(args)
213+
experiment_subpath = Path(persistence_dir).relative_to(args.logdir)
214+
215+
# logger
216+
logger_factory = LoggerFactoryDefault()
217+
if args.logger == "wandb":
218+
logger_factory.logger_type = "wandb"
219+
logger_factory.wandb_project = args.wandb_project
220+
else:
221+
logger_factory.logger_type = "tensorboard"
222+
223+
logger = logger_factory.create_logger(
224+
log_dir=persistence_dir,
225+
experiment_name=str(experiment_subpath),
226+
run_id=args.resume_id,
227+
config_dict=vars(args),
228+
)
229+
230+
def save_best_fn(policy: Algorithm) -> None:
231+
state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()}
232+
torch.save(state, os.path.join(persistence_dir, "policy.pth"))
233+
234+
print("Running the training")
235+
algorithm.run_training(
236+
OnPolicyTrainerParams(
237+
train_collector=train_collector,
238+
test_collector=test_collector,
239+
max_epochs=args.epoch,
240+
epoch_num_steps=args.epoch_num_steps,
241+
update_step_num_repetitions=args.update_step_num_repetitions,
242+
test_step_num_episodes=args.num_test_envs,
243+
batch_size=args.batch_size,
244+
collection_step_num_env_steps=args.collection_step_num_env_steps,
245+
save_best_fn=save_best_fn,
246+
logger=logger,
247+
test_in_train=False,
248+
)
249+
)
250+
251+
252+
if __name__ == "__main__":
253+
args = get_args()
254+
main_seed = args.seed
255+
256+
# Manual rollout of multiple experiments differing only by seed.
257+
# If desired, this can be parallelized, e.g., using joblib.
258+
# Often one doesn't gain much from parallelization on a single machine though, as each experiment is already
259+
# using multiple cores (parallelized rollouts)
260+
for i in range(args.num_experiments):
261+
print(f"Running experiment {i+1}/{args.num_experiments} with seed {main_seed + i}")
262+
args.seed = main_seed + i
263+
logging.run_main(lambda: main(args=args), level=logging.INFO)
264+
265+
# Evaluate the results with rliable
266+
persistence_dir_all_seeds = str(Path(get_persistence_dir(args)).parent)
267+
rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir_all_seeds)
268+
rliable_result.eval_results(save_plots=True)

0 commit comments

Comments
 (0)