Skip to content

Commit 9f4b8c9

Browse files
committed
Refactored mujoco low-level examples to use jsonargparse
1 parent 203ea48 commit 9f4b8c9

File tree

9 files changed

+900
-975
lines changed

9 files changed

+900
-975
lines changed

examples/mujoco/mujoco_a2c.py

Lines changed: 98 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#!/usr/bin/env python3
22

3-
import argparse
43
import datetime
54
import os
65
import pprint
6+
from typing import Literal
77

88
import numpy as np
99
import torch
1010
from mujoco_env import make_mujoco_env
11+
from sensai.util import logging
1112
from torch import nn
1213
from torch.distributions import Distribution, Independent, Normal
1314

@@ -21,90 +22,82 @@
2122
from tianshou.utils.net.common import ActorCritic, Net
2223
from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic
2324

25+
log = logging.getLogger(__name__)
26+
27+
28+
def main(
29+
task: str = "Ant-v4",
30+
persistence_base_dir: str = "log",
31+
seed: int = 0,
32+
buffer_size: int = 4096,
33+
hidden_sizes: list | None = None,
34+
lr: float = 7e-4,
35+
gamma: float = 0.99,
36+
epoch: int = 100,
37+
epoch_num_steps: int = 30000,
38+
collection_step_num_env_steps: int = 80,
39+
update_step_num_repetitions: int = 1,
40+
batch_size: int | None = None,
41+
num_train_envs: int = 16,
42+
num_test_envs: int = 10,
43+
return_scaling: bool = True,
44+
vf_coef: float = 0.5,
45+
ent_coef: float = 0.01,
46+
gae_lambda: float = 0.95,
47+
action_bound_method: Literal["clip", "tanh"] | None = "clip",
48+
lr_decay: bool = True,
49+
max_grad_norm: float = 0.5,
50+
render: float = 0.0,
51+
device: str | None = None,
52+
resume_path: str | None = None,
53+
resume_id: str | None = None,
54+
logger_type: str = "tensorboard",
55+
wandb_project: str = "mujoco.benchmark",
56+
watch: bool = False,
57+
) -> None:
58+
# Set defaults for mutable arguments
59+
if hidden_sizes is None:
60+
hidden_sizes = [64, 64]
61+
if device is None:
62+
device = "cuda" if torch.cuda.is_available() else "cpu"
63+
64+
# Get all local variables as config (excluding internal/temporary ones)
65+
params_log_info = locals()
66+
log.info(f"Starting training with config:\n{params_log_info}")
2467

25-
def get_args() -> argparse.Namespace:
26-
parser = argparse.ArgumentParser()
27-
parser.add_argument("--task", type=str, default="Ant-v4")
28-
parser.add_argument("--seed", type=int, default=0)
29-
parser.add_argument("--buffer_size", type=int, default=4096)
30-
parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64])
31-
parser.add_argument("--lr", type=float, default=7e-4)
32-
parser.add_argument("--gamma", type=float, default=0.99)
33-
parser.add_argument("--epoch", type=int, default=100)
34-
parser.add_argument("--epoch_num_steps", type=int, default=30000)
35-
parser.add_argument("--collection_step_num_env_steps", type=int, default=80)
36-
parser.add_argument("--update_step_num_repetitions", type=int, default=1)
37-
# batch-size >> step-per-collect means calculating all data in one singe forward.
38-
parser.add_argument("--batch_size", type=int, default=None)
39-
parser.add_argument("--num_train_envs", type=int, default=16)
40-
parser.add_argument("--num_test_envs", type=int, default=10)
41-
# a2c special
42-
parser.add_argument("--return_scaling", type=int, default=True)
43-
parser.add_argument("--vf_coef", type=float, default=0.5)
44-
parser.add_argument("--ent_coef", type=float, default=0.01)
45-
parser.add_argument("--gae_lambda", type=float, default=0.95)
46-
parser.add_argument("--bound_action_method", type=str, default="clip")
47-
parser.add_argument("--lr_decay", type=int, default=True)
48-
parser.add_argument("--max_grad_norm", type=float, default=0.5)
49-
parser.add_argument("--logdir", type=str, default="log")
50-
parser.add_argument("--render", type=float, default=0.0)
51-
parser.add_argument(
52-
"--device",
53-
type=str,
54-
default="cuda" if torch.cuda.is_available() else "cpu",
55-
)
56-
parser.add_argument("--resume_path", type=str, default=None)
57-
parser.add_argument("--resume_id", type=str, default=None)
58-
parser.add_argument(
59-
"--logger",
60-
type=str,
61-
default="tensorboard",
62-
choices=["tensorboard", "wandb"],
63-
)
64-
parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark")
65-
parser.add_argument(
66-
"--watch",
67-
default=False,
68-
action="store_true",
69-
help="watch the play of pre-trained policy only",
70-
)
71-
return parser.parse_args()
72-
73-
74-
def main(args: argparse.Namespace = get_args()) -> None:
7568
env, train_envs, test_envs = make_mujoco_env(
76-
args.task,
77-
args.seed,
78-
args.num_train_envs,
79-
args.num_test_envs,
69+
task,
70+
seed,
71+
num_train_envs,
72+
num_test_envs,
8073
obs_norm=True,
8174
)
82-
args.state_shape = env.observation_space.shape or env.observation_space.n
83-
args.action_shape = env.action_space.shape or env.action_space.n
84-
args.max_action = env.action_space.high[0]
85-
print("Observations shape:", args.state_shape)
86-
print("Actions shape:", args.action_shape)
87-
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
75+
state_shape = env.observation_space.shape or env.observation_space.n
76+
action_shape = env.action_space.shape or env.action_space.n
77+
max_action = env.action_space.high[0]
78+
log.info(f"Observations shape: {state_shape}")
79+
log.info(f"Actions shape: {action_shape}")
80+
log.info(f"Action range: {np.min(env.action_space.low)}, {np.max(env.action_space.high)}")
8881
# seed
89-
np.random.seed(args.seed)
90-
torch.manual_seed(args.seed)
82+
np.random.seed(seed)
83+
torch.manual_seed(seed)
9184
# model
9285
net_a = Net(
93-
state_shape=args.state_shape,
94-
hidden_sizes=args.hidden_sizes,
86+
state_shape=state_shape,
87+
hidden_sizes=hidden_sizes,
9588
activation=nn.Tanh,
9689
)
9790
actor = ContinuousActorProbabilistic(
9891
preprocess_net=net_a,
99-
action_shape=args.action_shape,
92+
action_shape=action_shape,
10093
unbounded=True,
101-
).to(args.device)
94+
).to(device)
10295
net_c = Net(
103-
state_shape=args.state_shape,
104-
hidden_sizes=args.hidden_sizes,
96+
state_shape=state_shape,
97+
hidden_sizes=hidden_sizes,
10598
activation=nn.Tanh,
10699
)
107-
critic = ContinuousCritic(preprocess_net=net_c).to(args.device)
100+
critic = ContinuousCritic(preprocess_net=net_c).to(device)
108101
actor_critic = ActorCritic(actor, critic)
109102

110103
torch.nn.init.constant_(actor.sigma_param, -0.5)
@@ -122,17 +115,17 @@ def main(args: argparse.Namespace = get_args()) -> None:
122115
m.weight.data.copy_(0.01 * m.weight.data)
123116

124117
optim = RMSpropOptimizerFactory(
125-
lr=args.lr,
118+
lr=lr,
126119
eps=1e-5,
127120
alpha=0.99,
128121
)
129122

130-
if args.lr_decay:
123+
if lr_decay:
131124
optim.with_lr_scheduler_factory(
132125
LRSchedulerFactoryLinear(
133-
max_epochs=args.epoch,
134-
epoch_num_steps=args.epoch_num_steps,
135-
collection_step_num_env_steps=args.collection_step_num_env_steps,
126+
max_epochs=epoch,
127+
epoch_num_steps=epoch_num_steps,
128+
collection_step_num_env_steps=collection_step_num_env_steps,
136129
)
137130
)
138131

@@ -144,75 +137,75 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
144137
actor=actor,
145138
dist_fn=dist,
146139
action_scaling=True,
147-
action_bound_method=args.bound_action_method,
140+
action_bound_method=action_bound_method,
148141
action_space=env.action_space,
149142
)
150143
algorithm: A2C = A2C(
151144
policy=policy,
152145
critic=critic,
153146
optim=optim,
154-
gamma=args.gamma,
155-
gae_lambda=args.gae_lambda,
156-
max_grad_norm=args.max_grad_norm,
157-
vf_coef=args.vf_coef,
158-
ent_coef=args.ent_coef,
159-
return_scaling=args.return_scaling,
147+
gamma=gamma,
148+
gae_lambda=gae_lambda,
149+
max_grad_norm=max_grad_norm,
150+
vf_coef=vf_coef,
151+
ent_coef=ent_coef,
152+
return_scaling=return_scaling,
160153
)
161154

162155
# load a previous policy
163-
if args.resume_path:
164-
ckpt = torch.load(args.resume_path, map_location=args.device)
156+
if resume_path:
157+
ckpt = torch.load(resume_path, map_location=device)
165158
algorithm.load_state_dict(ckpt["model"])
166159
train_envs.set_obs_rms(ckpt["obs_rms"])
167160
test_envs.set_obs_rms(ckpt["obs_rms"])
168-
print("Loaded agent from: ", args.resume_path)
161+
print("Loaded agent from: ", resume_path)
169162

170163
# collector
171164
buffer: VectorReplayBuffer | ReplayBuffer
172-
if args.num_train_envs > 1:
173-
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
165+
if num_train_envs > 1:
166+
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
174167
else:
175-
buffer = ReplayBuffer(args.buffer_size)
168+
buffer = ReplayBuffer(buffer_size)
176169
train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True)
177170
test_collector = Collector[CollectStats](algorithm, test_envs)
178171

179172
# log
180173
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
181-
args.algo_name = "a2c"
182-
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
183-
log_path = os.path.join(args.logdir, log_name)
174+
algo_name = "a2c"
175+
log_name = os.path.join(task, algo_name, str(seed), now)
176+
log_path = os.path.join(persistence_base_dir, log_name)
184177

185178
# logger
186179
logger_factory = LoggerFactoryDefault()
187-
if args.logger == "wandb":
180+
if logger_type == "wandb":
188181
logger_factory.logger_type = "wandb"
189-
logger_factory.wandb_project = args.wandb_project
182+
logger_factory.wandb_project = wandb_project
190183
else:
191184
logger_factory.logger_type = "tensorboard"
192185

193186
logger = logger_factory.create_logger(
194187
log_dir=log_path,
195188
experiment_name=log_name,
196-
run_id=args.resume_id,
197-
config_dict=vars(args),
189+
run_id=resume_id,
190+
config_dict=params_log_info,
198191
)
199192

200193
def save_best_fn(policy: Algorithm) -> None:
201194
state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()}
202195
torch.save(state, os.path.join(log_path, "policy.pth"))
203196

204-
if not args.watch:
197+
if not watch:
205198
# train
206199
result = algorithm.run_training(
207200
OnPolicyTrainerParams(
208201
train_collector=train_collector,
209202
test_collector=test_collector,
210-
max_epochs=args.epoch,
211-
epoch_num_steps=args.epoch_num_steps,
212-
update_step_num_repetitions=args.update_step_num_repetitions,
213-
test_step_num_episodes=args.num_test_envs,
214-
batch_size=args.batch_size,
215-
collection_step_num_env_steps=args.collection_step_num_env_steps,
203+
max_epochs=epoch,
204+
epoch_num_steps=epoch_num_steps,
205+
update_step_num_repetitions=update_step_num_repetitions,
206+
test_step_num_episodes=num_test_envs,
207+
batch_size=batch_size,
208+
collection_step_num_env_steps=collection_step_num_env_steps,
216209
save_best_fn=save_best_fn,
217210
logger=logger,
218211
test_in_train=False,
@@ -221,11 +214,11 @@ def save_best_fn(policy: Algorithm) -> None:
221214
pprint.pprint(result)
222215

223216
# Let's watch its performance!
224-
test_envs.seed(args.seed)
217+
test_envs.seed(seed)
225218
test_collector.reset()
226-
collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render)
219+
collector_stats = test_collector.collect(n_episode=num_test_envs, render=render)
227220
print(collector_stats)
228221

229222

230223
if __name__ == "__main__":
231-
main()
224+
result = logging.run_cli(main, level=logging.INFO)

0 commit comments

Comments
 (0)