Skip to content

Commit 8cf6576

Browse files
committed
Refactored atari low-level examples to use jsonargparse
1 parent b47ca52 commit 8cf6576

File tree

8 files changed

+954
-1037
lines changed

8 files changed

+954
-1037
lines changed

examples/atari/atari_c51.py

Lines changed: 108 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import argparse
1+
#!/usr/bin/env python3
2+
23
import datetime
34
import os
45
import pprint
56
import sys
67

78
import numpy as np
89
import torch
10+
from sensai.util import logging
911

1012
from tianshou.algorithm import C51
1113
from tianshou.algorithm.algorithm_base import Algorithm
@@ -17,111 +19,103 @@
1719
from tianshou.highlevel.logger import LoggerFactoryDefault
1820
from tianshou.trainer import OffPolicyTrainerParams
1921

22+
log = logging.getLogger(__name__)
23+
24+
25+
def main(
26+
task: str = "PongNoFrameskip-v4",
27+
seed: int = 0,
28+
scale_obs: int = 0,
29+
eps_test: float = 0.005,
30+
eps_train: float = 1.0,
31+
eps_train_final: float = 0.05,
32+
buffer_size: int = 100000,
33+
lr: float = 0.0001,
34+
gamma: float = 0.99,
35+
num_atoms: int = 51,
36+
v_min: float = -10.0,
37+
v_max: float = 10.0,
38+
n_step: int = 3,
39+
target_update_freq: int = 500,
40+
epoch: int = 100,
41+
epoch_num_steps: int = 100000,
42+
collection_step_num_env_steps: int = 10,
43+
update_per_step: float = 0.1,
44+
batch_size: int = 32,
45+
num_train_envs: int = 10,
46+
num_test_envs: int = 10,
47+
persistence_base_dir: str = "log",
48+
render: float = 0.0,
49+
device: str | None = None,
50+
frames_stack: int = 4,
51+
resume_path: str | None = None,
52+
resume_id: str | None = None,
53+
logger_type: str = "tensorboard",
54+
wandb_project: str = "atari.benchmark",
55+
watch: bool = False,
56+
save_buffer_name: str | None = None,
57+
) -> None:
58+
# Set defaults for mutable arguments
59+
if device is None:
60+
device = "cuda" if torch.cuda.is_available() else "cpu"
61+
62+
# Get all local variables as config (excluding internal/temporary ones)
63+
params_log_info = locals()
64+
log.info(f"Starting training with config:\n{params_log_info}")
2065

21-
def get_args() -> argparse.Namespace:
22-
parser = argparse.ArgumentParser()
23-
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
24-
parser.add_argument("--seed", type=int, default=0)
25-
parser.add_argument("--scale_obs", type=int, default=0)
26-
parser.add_argument("--eps_test", type=float, default=0.005)
27-
parser.add_argument("--eps_train", type=float, default=1.0)
28-
parser.add_argument("--eps_train_final", type=float, default=0.05)
29-
parser.add_argument("--buffer_size", type=int, default=100000)
30-
parser.add_argument("--lr", type=float, default=0.0001)
31-
parser.add_argument("--gamma", type=float, default=0.99)
32-
parser.add_argument("--num_atoms", type=int, default=51)
33-
parser.add_argument("--v_min", type=float, default=-10.0)
34-
parser.add_argument("--v_max", type=float, default=10.0)
35-
parser.add_argument("--n_step", type=int, default=3)
36-
parser.add_argument("--target_update_freq", type=int, default=500)
37-
parser.add_argument("--epoch", type=int, default=100)
38-
parser.add_argument("--epoch_num_steps", type=int, default=100000)
39-
parser.add_argument("--collection_step_num_env_steps", type=int, default=10)
40-
parser.add_argument("--update_per_step", type=float, default=0.1)
41-
parser.add_argument("--batch_size", type=int, default=32)
42-
parser.add_argument("--num_train_envs", type=int, default=10)
43-
parser.add_argument("--num_test_envs", type=int, default=10)
44-
parser.add_argument("--logdir", type=str, default="log")
45-
parser.add_argument("--render", type=float, default=0.0)
46-
parser.add_argument(
47-
"--device",
48-
type=str,
49-
default="cuda" if torch.cuda.is_available() else "cpu",
50-
)
51-
parser.add_argument("--frames_stack", type=int, default=4)
52-
parser.add_argument("--resume_path", type=str, default=None)
53-
parser.add_argument("--resume_id", type=str, default=None)
54-
parser.add_argument(
55-
"--logger",
56-
type=str,
57-
default="tensorboard",
58-
choices=["tensorboard", "wandb"],
59-
)
60-
parser.add_argument("--wandb_project", type=str, default="atari.benchmark")
61-
parser.add_argument(
62-
"--watch",
63-
default=False,
64-
action="store_true",
65-
help="watch the play of pre-trained policy only",
66-
)
67-
parser.add_argument("--save_buffer_name", type=str, default=None)
68-
return parser.parse_args()
69-
70-
71-
def main(args: argparse.Namespace = get_args()) -> None:
7266
env, train_envs, test_envs = make_atari_env(
73-
args.task,
74-
args.seed,
75-
args.num_train_envs,
76-
args.num_test_envs,
77-
scale=args.scale_obs,
78-
frame_stack=args.frames_stack,
67+
task,
68+
seed,
69+
num_train_envs,
70+
num_test_envs,
71+
scale=scale_obs,
72+
frame_stack=frames_stack,
7973
)
80-
args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
81-
args.action_shape = env.action_space.shape or env.action_space.n # type: ignore
74+
state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
75+
action_shape = env.action_space.shape or env.action_space.n # type: ignore
8276
# should be N_FRAMES x H x W
83-
print("Observations shape:", args.state_shape)
84-
print("Actions shape:", args.action_shape)
77+
log.info(f"Observations shape: {state_shape}")
78+
log.info(f"Actions shape: {action_shape}")
8579
# seed
86-
np.random.seed(args.seed)
87-
torch.manual_seed(args.seed)
80+
np.random.seed(seed)
81+
torch.manual_seed(seed)
8882

8983
# define model
90-
c, h, w = args.state_shape
91-
net = C51Net(c=c, h=h, w=w, action_shape=args.action_shape, num_atoms=args.num_atoms)
84+
c, h, w = state_shape
85+
net = C51Net(c=c, h=h, w=w, action_shape=action_shape, num_atoms=num_atoms)
9286

9387
# define policy and algorithm
94-
optim = AdamOptimizerFactory(lr=args.lr)
88+
optim = AdamOptimizerFactory(lr=lr)
9589
policy = C51Policy(
9690
model=net,
9791
action_space=env.action_space,
98-
num_atoms=args.num_atoms,
99-
v_min=args.v_min,
100-
v_max=args.v_max,
101-
eps_training=args.eps_train,
102-
eps_inference=args.eps_test,
92+
num_atoms=num_atoms,
93+
v_min=v_min,
94+
v_max=v_max,
95+
eps_training=eps_train,
96+
eps_inference=eps_test,
10397
)
10498
algorithm: C51 = C51(
10599
policy=policy,
106100
optim=optim,
107-
gamma=args.gamma,
108-
n_step_return_horizon=args.n_step,
109-
target_update_freq=args.target_update_freq,
110-
).to(args.device)
101+
gamma=gamma,
102+
n_step_return_horizon=n_step,
103+
target_update_freq=target_update_freq,
104+
).to(device)
111105

112106
# load a previous model
113-
if args.resume_path:
114-
algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device))
115-
print("Loaded agent from: ", args.resume_path)
107+
if resume_path:
108+
algorithm.load_state_dict(torch.load(resume_path, map_location=device))
109+
log.info(f"Loaded agent from: {resume_path}")
116110

117111
# replay buffer: `save_last_obs` and `stack_num` can be removed together
118112
# when you have enough RAM
119113
buffer = VectorReplayBuffer(
120-
args.buffer_size,
114+
buffer_size,
121115
buffer_num=len(train_envs),
122116
ignore_obs_next=True,
123117
save_only_last_obs=True,
124-
stack_num=args.frames_stack,
118+
stack_num=frames_stack,
125119
)
126120

127121
# collectors
@@ -130,23 +124,23 @@ def main(args: argparse.Namespace = get_args()) -> None:
130124

131125
# log
132126
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
133-
args.algo_name = "c51"
134-
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
135-
log_path = os.path.join(args.logdir, log_name)
127+
algo_name = "c51"
128+
log_name = os.path.join(task, algo_name, str(seed), now)
129+
log_path = os.path.join(persistence_base_dir, log_name)
136130

137131
# logger
138132
logger_factory = LoggerFactoryDefault()
139-
if args.logger == "wandb":
133+
if logger_type == "wandb":
140134
logger_factory.logger_type = "wandb"
141-
logger_factory.wandb_project = args.wandb_project
135+
logger_factory.wandb_project = wandb_project
142136
else:
143137
logger_factory.logger_type = "tensorboard"
144138

145139
logger = logger_factory.create_logger(
146140
log_dir=log_path,
147141
experiment_name=log_name,
148-
run_id=args.resume_id,
149-
config_dict=vars(args),
142+
run_id=resume_id,
143+
config_dict=params_log_info,
150144
)
151145

152146
def save_best_fn(policy: Algorithm) -> None:
@@ -155,74 +149,74 @@ def save_best_fn(policy: Algorithm) -> None:
155149
def stop_fn(mean_rewards: float) -> bool:
156150
if env.spec.reward_threshold: # type: ignore
157151
return mean_rewards >= env.spec.reward_threshold # type: ignore
158-
if "Pong" in args.task:
152+
if "Pong" in task:
159153
return mean_rewards >= 20
160154
return False
161155

162156
def train_fn(epoch: int, env_step: int) -> None:
163157
# nature DQN setting, linear decay in the first 1M steps
164158
if env_step <= 1e6:
165-
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
159+
eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final)
166160
else:
167-
eps = args.eps_train_final
161+
eps = eps_train_final
168162
policy.set_eps_training(eps)
169163
if env_step % 1000 == 0:
170164
logger.write("train/env_step", env_step, {"train/eps": eps})
171165

172-
def watch() -> None:
173-
print("Setup test envs ...")
174-
test_envs.seed(args.seed)
175-
if args.save_buffer_name:
176-
print(f"Generate buffer with size {args.buffer_size}")
166+
def watch_fn() -> None:
167+
log.info("Setup test envs ...")
168+
test_envs.seed(seed)
169+
if save_buffer_name:
170+
log.info(f"Generate buffer with size {buffer_size}")
177171
buffer = VectorReplayBuffer(
178-
args.buffer_size,
172+
buffer_size,
179173
buffer_num=len(test_envs),
180174
ignore_obs_next=True,
181175
save_only_last_obs=True,
182-
stack_num=args.frames_stack,
176+
stack_num=frames_stack,
183177
)
184178
collector = Collector[CollectStats](
185179
algorithm, test_envs, buffer, exploration_noise=True
186180
)
187-
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
188-
print(f"Save buffer into {args.save_buffer_name}")
181+
result = collector.collect(n_step=buffer_size, reset_before_collect=True)
182+
log.info(f"Save buffer into {save_buffer_name}")
189183
# Unfortunately, pickle will cause oom with 1M buffer size
190-
buffer.save_hdf5(args.save_buffer_name)
184+
buffer.save_hdf5(save_buffer_name)
191185
else:
192-
print("Testing agent ...")
186+
log.info("Testing agent ...")
193187
test_collector.reset()
194-
result = test_collector.collect(n_episode=args.num_test_envs, render=args.render)
188+
result = test_collector.collect(n_episode=num_test_envs, render=render)
195189
result.pprint_asdict()
196190

197-
if args.watch:
198-
watch()
191+
if watch:
192+
watch_fn()
199193
sys.exit(0)
200194

201195
# test train_collector and start filling replay buffer
202196
train_collector.reset()
203-
train_collector.collect(n_step=args.batch_size * args.num_train_envs)
197+
train_collector.collect(n_step=batch_size * num_train_envs)
204198
# trainer
205199
result = algorithm.run_training(
206200
OffPolicyTrainerParams(
207201
train_collector=train_collector,
208202
test_collector=test_collector,
209-
max_epochs=args.epoch,
210-
epoch_num_steps=args.epoch_num_steps,
211-
collection_step_num_env_steps=args.collection_step_num_env_steps,
212-
test_step_num_episodes=args.num_test_envs,
213-
batch_size=args.batch_size,
203+
max_epochs=epoch,
204+
epoch_num_steps=epoch_num_steps,
205+
collection_step_num_env_steps=collection_step_num_env_steps,
206+
test_step_num_episodes=num_test_envs,
207+
batch_size=batch_size,
214208
train_fn=train_fn,
215209
stop_fn=stop_fn,
216210
save_best_fn=save_best_fn,
217211
logger=logger,
218-
update_step_num_gradient_steps_per_sample=args.update_per_step,
212+
update_step_num_gradient_steps_per_sample=update_per_step,
219213
test_in_train=False,
220214
)
221215
)
222216

223217
pprint.pprint(result)
224-
watch()
218+
watch_fn()
225219

226220

227221
if __name__ == "__main__":
228-
main(get_args())
222+
result = logging.run_cli(main, level=logging.INFO)

0 commit comments

Comments
 (0)