1- import argparse
1+ #!/usr/bin/env python3
2+
23import datetime
34import os
45import pprint
56import sys
67
78import numpy as np
89import torch
10+ from sensai .util import logging
911
1012from tianshou .algorithm import C51
1113from tianshou .algorithm .algorithm_base import Algorithm
1719from tianshou .highlevel .logger import LoggerFactoryDefault
1820from tianshou .trainer import OffPolicyTrainerParams
1921
22+ log = logging .getLogger (__name__ )
23+
24+
25+ def main (
26+ task : str = "PongNoFrameskip-v4" ,
27+ seed : int = 0 ,
28+ scale_obs : int = 0 ,
29+ eps_test : float = 0.005 ,
30+ eps_train : float = 1.0 ,
31+ eps_train_final : float = 0.05 ,
32+ buffer_size : int = 100000 ,
33+ lr : float = 0.0001 ,
34+ gamma : float = 0.99 ,
35+ num_atoms : int = 51 ,
36+ v_min : float = - 10.0 ,
37+ v_max : float = 10.0 ,
38+ n_step : int = 3 ,
39+ target_update_freq : int = 500 ,
40+ epoch : int = 100 ,
41+ epoch_num_steps : int = 100000 ,
42+ collection_step_num_env_steps : int = 10 ,
43+ update_per_step : float = 0.1 ,
44+ batch_size : int = 32 ,
45+ num_train_envs : int = 10 ,
46+ num_test_envs : int = 10 ,
47+ persistence_base_dir : str = "log" ,
48+ render : float = 0.0 ,
49+ device : str | None = None ,
50+ frames_stack : int = 4 ,
51+ resume_path : str | None = None ,
52+ resume_id : str | None = None ,
53+ logger_type : str = "tensorboard" ,
54+ wandb_project : str = "atari.benchmark" ,
55+ watch : bool = False ,
56+ save_buffer_name : str | None = None ,
57+ ) -> None :
58+ # Set defaults for mutable arguments
59+ if device is None :
60+ device = "cuda" if torch .cuda .is_available () else "cpu"
61+
62+ # Get all local variables as config (excluding internal/temporary ones)
63+ params_log_info = locals ()
64+ log .info (f"Starting training with config:\n { params_log_info } " )
2065
21- def get_args () -> argparse .Namespace :
22- parser = argparse .ArgumentParser ()
23- parser .add_argument ("--task" , type = str , default = "PongNoFrameskip-v4" )
24- parser .add_argument ("--seed" , type = int , default = 0 )
25- parser .add_argument ("--scale_obs" , type = int , default = 0 )
26- parser .add_argument ("--eps_test" , type = float , default = 0.005 )
27- parser .add_argument ("--eps_train" , type = float , default = 1.0 )
28- parser .add_argument ("--eps_train_final" , type = float , default = 0.05 )
29- parser .add_argument ("--buffer_size" , type = int , default = 100000 )
30- parser .add_argument ("--lr" , type = float , default = 0.0001 )
31- parser .add_argument ("--gamma" , type = float , default = 0.99 )
32- parser .add_argument ("--num_atoms" , type = int , default = 51 )
33- parser .add_argument ("--v_min" , type = float , default = - 10.0 )
34- parser .add_argument ("--v_max" , type = float , default = 10.0 )
35- parser .add_argument ("--n_step" , type = int , default = 3 )
36- parser .add_argument ("--target_update_freq" , type = int , default = 500 )
37- parser .add_argument ("--epoch" , type = int , default = 100 )
38- parser .add_argument ("--epoch_num_steps" , type = int , default = 100000 )
39- parser .add_argument ("--collection_step_num_env_steps" , type = int , default = 10 )
40- parser .add_argument ("--update_per_step" , type = float , default = 0.1 )
41- parser .add_argument ("--batch_size" , type = int , default = 32 )
42- parser .add_argument ("--num_train_envs" , type = int , default = 10 )
43- parser .add_argument ("--num_test_envs" , type = int , default = 10 )
44- parser .add_argument ("--logdir" , type = str , default = "log" )
45- parser .add_argument ("--render" , type = float , default = 0.0 )
46- parser .add_argument (
47- "--device" ,
48- type = str ,
49- default = "cuda" if torch .cuda .is_available () else "cpu" ,
50- )
51- parser .add_argument ("--frames_stack" , type = int , default = 4 )
52- parser .add_argument ("--resume_path" , type = str , default = None )
53- parser .add_argument ("--resume_id" , type = str , default = None )
54- parser .add_argument (
55- "--logger" ,
56- type = str ,
57- default = "tensorboard" ,
58- choices = ["tensorboard" , "wandb" ],
59- )
60- parser .add_argument ("--wandb_project" , type = str , default = "atari.benchmark" )
61- parser .add_argument (
62- "--watch" ,
63- default = False ,
64- action = "store_true" ,
65- help = "watch the play of pre-trained policy only" ,
66- )
67- parser .add_argument ("--save_buffer_name" , type = str , default = None )
68- return parser .parse_args ()
69-
70-
71- def main (args : argparse .Namespace = get_args ()) -> None :
7266 env , train_envs , test_envs = make_atari_env (
73- args . task ,
74- args . seed ,
75- args . num_train_envs ,
76- args . num_test_envs ,
77- scale = args . scale_obs ,
78- frame_stack = args . frames_stack ,
67+ task ,
68+ seed ,
69+ num_train_envs ,
70+ num_test_envs ,
71+ scale = scale_obs ,
72+ frame_stack = frames_stack ,
7973 )
80- args . state_shape = env .observation_space .shape or env .observation_space .n # type: ignore
81- args . action_shape = env .action_space .shape or env .action_space .n # type: ignore
74+ state_shape = env .observation_space .shape or env .observation_space .n # type: ignore
75+ action_shape = env .action_space .shape or env .action_space .n # type: ignore
8276 # should be N_FRAMES x H x W
83- print ( "Observations shape:" , args . state_shape )
84- print ( "Actions shape:" , args . action_shape )
77+ log . info ( f "Observations shape: { state_shape } " )
78+ log . info ( f "Actions shape: { action_shape } " )
8579 # seed
86- np .random .seed (args . seed )
87- torch .manual_seed (args . seed )
80+ np .random .seed (seed )
81+ torch .manual_seed (seed )
8882
8983 # define model
90- c , h , w = args . state_shape
91- net = C51Net (c = c , h = h , w = w , action_shape = args . action_shape , num_atoms = args . num_atoms )
84+ c , h , w = state_shape
85+ net = C51Net (c = c , h = h , w = w , action_shape = action_shape , num_atoms = num_atoms )
9286
9387 # define policy and algorithm
94- optim = AdamOptimizerFactory (lr = args . lr )
88+ optim = AdamOptimizerFactory (lr = lr )
9589 policy = C51Policy (
9690 model = net ,
9791 action_space = env .action_space ,
98- num_atoms = args . num_atoms ,
99- v_min = args . v_min ,
100- v_max = args . v_max ,
101- eps_training = args . eps_train ,
102- eps_inference = args . eps_test ,
92+ num_atoms = num_atoms ,
93+ v_min = v_min ,
94+ v_max = v_max ,
95+ eps_training = eps_train ,
96+ eps_inference = eps_test ,
10397 )
10498 algorithm : C51 = C51 (
10599 policy = policy ,
106100 optim = optim ,
107- gamma = args . gamma ,
108- n_step_return_horizon = args . n_step ,
109- target_update_freq = args . target_update_freq ,
110- ).to (args . device )
101+ gamma = gamma ,
102+ n_step_return_horizon = n_step ,
103+ target_update_freq = target_update_freq ,
104+ ).to (device )
111105
112106 # load a previous model
113- if args . resume_path :
114- algorithm .load_state_dict (torch .load (args . resume_path , map_location = args . device ))
115- print ( "Loaded agent from: " , args . resume_path )
107+ if resume_path :
108+ algorithm .load_state_dict (torch .load (resume_path , map_location = device ))
109+ log . info ( f "Loaded agent from: { resume_path } " )
116110
117111 # replay buffer: `save_last_obs` and `stack_num` can be removed together
118112 # when you have enough RAM
119113 buffer = VectorReplayBuffer (
120- args . buffer_size ,
114+ buffer_size ,
121115 buffer_num = len (train_envs ),
122116 ignore_obs_next = True ,
123117 save_only_last_obs = True ,
124- stack_num = args . frames_stack ,
118+ stack_num = frames_stack ,
125119 )
126120
127121 # collectors
@@ -130,23 +124,23 @@ def main(args: argparse.Namespace = get_args()) -> None:
130124
131125 # log
132126 now = datetime .datetime .now ().strftime ("%y%m%d-%H%M%S" )
133- args . algo_name = "c51"
134- log_name = os .path .join (args . task , args . algo_name , str (args . seed ), now )
135- log_path = os .path .join (args . logdir , log_name )
127+ algo_name = "c51"
128+ log_name = os .path .join (task , algo_name , str (seed ), now )
129+ log_path = os .path .join (persistence_base_dir , log_name )
136130
137131 # logger
138132 logger_factory = LoggerFactoryDefault ()
139- if args . logger == "wandb" :
133+ if logger_type == "wandb" :
140134 logger_factory .logger_type = "wandb"
141- logger_factory .wandb_project = args . wandb_project
135+ logger_factory .wandb_project = wandb_project
142136 else :
143137 logger_factory .logger_type = "tensorboard"
144138
145139 logger = logger_factory .create_logger (
146140 log_dir = log_path ,
147141 experiment_name = log_name ,
148- run_id = args . resume_id ,
149- config_dict = vars ( args ) ,
142+ run_id = resume_id ,
143+ config_dict = params_log_info ,
150144 )
151145
152146 def save_best_fn (policy : Algorithm ) -> None :
@@ -155,74 +149,74 @@ def save_best_fn(policy: Algorithm) -> None:
155149 def stop_fn (mean_rewards : float ) -> bool :
156150 if env .spec .reward_threshold : # type: ignore
157151 return mean_rewards >= env .spec .reward_threshold # type: ignore
158- if "Pong" in args . task :
152+ if "Pong" in task :
159153 return mean_rewards >= 20
160154 return False
161155
162156 def train_fn (epoch : int , env_step : int ) -> None :
163157 # nature DQN setting, linear decay in the first 1M steps
164158 if env_step <= 1e6 :
165- eps = args . eps_train - env_step / 1e6 * (args . eps_train - args . eps_train_final )
159+ eps = eps_train - env_step / 1e6 * (eps_train - eps_train_final )
166160 else :
167- eps = args . eps_train_final
161+ eps = eps_train_final
168162 policy .set_eps_training (eps )
169163 if env_step % 1000 == 0 :
170164 logger .write ("train/env_step" , env_step , {"train/eps" : eps })
171165
172- def watch () -> None :
173- print ("Setup test envs ..." )
174- test_envs .seed (args . seed )
175- if args . save_buffer_name :
176- print (f"Generate buffer with size { args . buffer_size } " )
166+ def watch_fn () -> None :
167+ log . info ("Setup test envs ..." )
168+ test_envs .seed (seed )
169+ if save_buffer_name :
170+ log . info (f"Generate buffer with size { buffer_size } " )
177171 buffer = VectorReplayBuffer (
178- args . buffer_size ,
172+ buffer_size ,
179173 buffer_num = len (test_envs ),
180174 ignore_obs_next = True ,
181175 save_only_last_obs = True ,
182- stack_num = args . frames_stack ,
176+ stack_num = frames_stack ,
183177 )
184178 collector = Collector [CollectStats ](
185179 algorithm , test_envs , buffer , exploration_noise = True
186180 )
187- result = collector .collect (n_step = args . buffer_size , reset_before_collect = True )
188- print (f"Save buffer into { args . save_buffer_name } " )
181+ result = collector .collect (n_step = buffer_size , reset_before_collect = True )
182+ log . info (f"Save buffer into { save_buffer_name } " )
189183 # Unfortunately, pickle will cause oom with 1M buffer size
190- buffer .save_hdf5 (args . save_buffer_name )
184+ buffer .save_hdf5 (save_buffer_name )
191185 else :
192- print ("Testing agent ..." )
186+ log . info ("Testing agent ..." )
193187 test_collector .reset ()
194- result = test_collector .collect (n_episode = args . num_test_envs , render = args . render )
188+ result = test_collector .collect (n_episode = num_test_envs , render = render )
195189 result .pprint_asdict ()
196190
197- if args . watch :
198- watch ()
191+ if watch :
192+ watch_fn ()
199193 sys .exit (0 )
200194
201195 # test train_collector and start filling replay buffer
202196 train_collector .reset ()
203- train_collector .collect (n_step = args . batch_size * args . num_train_envs )
197+ train_collector .collect (n_step = batch_size * num_train_envs )
204198 # trainer
205199 result = algorithm .run_training (
206200 OffPolicyTrainerParams (
207201 train_collector = train_collector ,
208202 test_collector = test_collector ,
209- max_epochs = args . epoch ,
210- epoch_num_steps = args . epoch_num_steps ,
211- collection_step_num_env_steps = args . collection_step_num_env_steps ,
212- test_step_num_episodes = args . num_test_envs ,
213- batch_size = args . batch_size ,
203+ max_epochs = epoch ,
204+ epoch_num_steps = epoch_num_steps ,
205+ collection_step_num_env_steps = collection_step_num_env_steps ,
206+ test_step_num_episodes = num_test_envs ,
207+ batch_size = batch_size ,
214208 train_fn = train_fn ,
215209 stop_fn = stop_fn ,
216210 save_best_fn = save_best_fn ,
217211 logger = logger ,
218- update_step_num_gradient_steps_per_sample = args . update_per_step ,
212+ update_step_num_gradient_steps_per_sample = update_per_step ,
219213 test_in_train = False ,
220214 )
221215 )
222216
223217 pprint .pprint (result )
224- watch ()
218+ watch_fn ()
225219
226220
227221if __name__ == "__main__" :
228- main ( get_args () )
222+ result = logging . run_cli ( main , level = logging . INFO )
0 commit comments