11#!/usr/bin/env python3
22
3- import argparse
43import datetime
54import os
65import pprint
6+ from typing import Literal
77
88import numpy as np
99import torch
1010from mujoco_env import make_mujoco_env
11+ from sensai .util import logging
1112from torch import nn
1213from torch .distributions import Distribution , Independent , Normal
1314
2122from tianshou .utils .net .common import ActorCritic , Net
2223from tianshou .utils .net .continuous import ContinuousActorProbabilistic , ContinuousCritic
2324
25+ log = logging .getLogger (__name__ )
26+
27+
28+ def main (
29+ task : str = "Ant-v4" ,
30+ persistence_base_dir : str = "log" ,
31+ seed : int = 0 ,
32+ buffer_size : int = 4096 ,
33+ hidden_sizes : list | None = None ,
34+ lr : float = 7e-4 ,
35+ gamma : float = 0.99 ,
36+ epoch : int = 100 ,
37+ epoch_num_steps : int = 30000 ,
38+ collection_step_num_env_steps : int = 80 ,
39+ update_step_num_repetitions : int = 1 ,
40+ batch_size : int | None = None ,
41+ num_train_envs : int = 16 ,
42+ num_test_envs : int = 10 ,
43+ return_scaling : bool = True ,
44+ vf_coef : float = 0.5 ,
45+ ent_coef : float = 0.01 ,
46+ gae_lambda : float = 0.95 ,
47+ action_bound_method : Literal ["clip" , "tanh" ] | None = "clip" ,
48+ lr_decay : bool = True ,
49+ max_grad_norm : float = 0.5 ,
50+ render : float = 0.0 ,
51+ device : str | None = None ,
52+ resume_path : str | None = None ,
53+ resume_id : str | None = None ,
54+ logger_type : str = "tensorboard" ,
55+ wandb_project : str = "mujoco.benchmark" ,
56+ watch : bool = False ,
57+ ) -> None :
58+ # Set defaults for mutable arguments
59+ if hidden_sizes is None :
60+ hidden_sizes = [64 , 64 ]
61+ if device is None :
62+ device = "cuda" if torch .cuda .is_available () else "cpu"
63+
64+ # Get all local variables as config (excluding internal/temporary ones)
65+ params_log_info = locals ()
66+ log .info (f"Starting training with config:\n { params_log_info } " )
2467
25- def get_args () -> argparse .Namespace :
26- parser = argparse .ArgumentParser ()
27- parser .add_argument ("--task" , type = str , default = "Ant-v4" )
28- parser .add_argument ("--seed" , type = int , default = 0 )
29- parser .add_argument ("--buffer_size" , type = int , default = 4096 )
30- parser .add_argument ("--hidden_sizes" , type = int , nargs = "*" , default = [64 , 64 ])
31- parser .add_argument ("--lr" , type = float , default = 7e-4 )
32- parser .add_argument ("--gamma" , type = float , default = 0.99 )
33- parser .add_argument ("--epoch" , type = int , default = 100 )
34- parser .add_argument ("--epoch_num_steps" , type = int , default = 30000 )
35- parser .add_argument ("--collection_step_num_env_steps" , type = int , default = 80 )
36- parser .add_argument ("--update_step_num_repetitions" , type = int , default = 1 )
37- # batch-size >> step-per-collect means calculating all data in one singe forward.
38- parser .add_argument ("--batch_size" , type = int , default = None )
39- parser .add_argument ("--num_train_envs" , type = int , default = 16 )
40- parser .add_argument ("--num_test_envs" , type = int , default = 10 )
41- # a2c special
42- parser .add_argument ("--return_scaling" , type = int , default = True )
43- parser .add_argument ("--vf_coef" , type = float , default = 0.5 )
44- parser .add_argument ("--ent_coef" , type = float , default = 0.01 )
45- parser .add_argument ("--gae_lambda" , type = float , default = 0.95 )
46- parser .add_argument ("--bound_action_method" , type = str , default = "clip" )
47- parser .add_argument ("--lr_decay" , type = int , default = True )
48- parser .add_argument ("--max_grad_norm" , type = float , default = 0.5 )
49- parser .add_argument ("--logdir" , type = str , default = "log" )
50- parser .add_argument ("--render" , type = float , default = 0.0 )
51- parser .add_argument (
52- "--device" ,
53- type = str ,
54- default = "cuda" if torch .cuda .is_available () else "cpu" ,
55- )
56- parser .add_argument ("--resume_path" , type = str , default = None )
57- parser .add_argument ("--resume_id" , type = str , default = None )
58- parser .add_argument (
59- "--logger" ,
60- type = str ,
61- default = "tensorboard" ,
62- choices = ["tensorboard" , "wandb" ],
63- )
64- parser .add_argument ("--wandb_project" , type = str , default = "mujoco.benchmark" )
65- parser .add_argument (
66- "--watch" ,
67- default = False ,
68- action = "store_true" ,
69- help = "watch the play of pre-trained policy only" ,
70- )
71- return parser .parse_args ()
72-
73-
74- def main (args : argparse .Namespace = get_args ()) -> None :
7568 env , train_envs , test_envs = make_mujoco_env (
76- args . task ,
77- args . seed ,
78- args . num_train_envs ,
79- args . num_test_envs ,
69+ task ,
70+ seed ,
71+ num_train_envs ,
72+ num_test_envs ,
8073 obs_norm = True ,
8174 )
82- args . state_shape = env .observation_space .shape or env .observation_space .n
83- args . action_shape = env .action_space .shape or env .action_space .n
84- args . max_action = env .action_space .high [0 ]
85- print ( "Observations shape:" , args . state_shape )
86- print ( "Actions shape:" , args . action_shape )
87- print ( "Action range:" , np .min (env .action_space .low ), np .max (env .action_space .high ))
75+ state_shape = env .observation_space .shape or env .observation_space .n
76+ action_shape = env .action_space .shape or env .action_space .n
77+ max_action = env .action_space .high [0 ]
78+ log . info ( f "Observations shape: { state_shape } " )
79+ log . info ( f "Actions shape: { action_shape } " )
80+ log . info ( f "Action range: { np .min (env .action_space .low )} , { np .max (env .action_space .high )} " )
8881 # seed
89- np .random .seed (args . seed )
90- torch .manual_seed (args . seed )
82+ np .random .seed (seed )
83+ torch .manual_seed (seed )
9184 # model
9285 net_a = Net (
93- state_shape = args . state_shape ,
94- hidden_sizes = args . hidden_sizes ,
86+ state_shape = state_shape ,
87+ hidden_sizes = hidden_sizes ,
9588 activation = nn .Tanh ,
9689 )
9790 actor = ContinuousActorProbabilistic (
9891 preprocess_net = net_a ,
99- action_shape = args . action_shape ,
92+ action_shape = action_shape ,
10093 unbounded = True ,
101- ).to (args . device )
94+ ).to (device )
10295 net_c = Net (
103- state_shape = args . state_shape ,
104- hidden_sizes = args . hidden_sizes ,
96+ state_shape = state_shape ,
97+ hidden_sizes = hidden_sizes ,
10598 activation = nn .Tanh ,
10699 )
107- critic = ContinuousCritic (preprocess_net = net_c ).to (args . device )
100+ critic = ContinuousCritic (preprocess_net = net_c ).to (device )
108101 actor_critic = ActorCritic (actor , critic )
109102
110103 torch .nn .init .constant_ (actor .sigma_param , - 0.5 )
@@ -122,17 +115,17 @@ def main(args: argparse.Namespace = get_args()) -> None:
122115 m .weight .data .copy_ (0.01 * m .weight .data )
123116
124117 optim = RMSpropOptimizerFactory (
125- lr = args . lr ,
118+ lr = lr ,
126119 eps = 1e-5 ,
127120 alpha = 0.99 ,
128121 )
129122
130- if args . lr_decay :
123+ if lr_decay :
131124 optim .with_lr_scheduler_factory (
132125 LRSchedulerFactoryLinear (
133- max_epochs = args . epoch ,
134- epoch_num_steps = args . epoch_num_steps ,
135- collection_step_num_env_steps = args . collection_step_num_env_steps ,
126+ max_epochs = epoch ,
127+ epoch_num_steps = epoch_num_steps ,
128+ collection_step_num_env_steps = collection_step_num_env_steps ,
136129 )
137130 )
138131
@@ -144,75 +137,75 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
144137 actor = actor ,
145138 dist_fn = dist ,
146139 action_scaling = True ,
147- action_bound_method = args . bound_action_method ,
140+ action_bound_method = action_bound_method ,
148141 action_space = env .action_space ,
149142 )
150143 algorithm : A2C = A2C (
151144 policy = policy ,
152145 critic = critic ,
153146 optim = optim ,
154- gamma = args . gamma ,
155- gae_lambda = args . gae_lambda ,
156- max_grad_norm = args . max_grad_norm ,
157- vf_coef = args . vf_coef ,
158- ent_coef = args . ent_coef ,
159- return_scaling = args . return_scaling ,
147+ gamma = gamma ,
148+ gae_lambda = gae_lambda ,
149+ max_grad_norm = max_grad_norm ,
150+ vf_coef = vf_coef ,
151+ ent_coef = ent_coef ,
152+ return_scaling = return_scaling ,
160153 )
161154
162155 # load a previous policy
163- if args . resume_path :
164- ckpt = torch .load (args . resume_path , map_location = args . device )
156+ if resume_path :
157+ ckpt = torch .load (resume_path , map_location = device )
165158 algorithm .load_state_dict (ckpt ["model" ])
166159 train_envs .set_obs_rms (ckpt ["obs_rms" ])
167160 test_envs .set_obs_rms (ckpt ["obs_rms" ])
168- print ("Loaded agent from: " , args . resume_path )
161+ print ("Loaded agent from: " , resume_path )
169162
170163 # collector
171164 buffer : VectorReplayBuffer | ReplayBuffer
172- if args . num_train_envs > 1 :
173- buffer = VectorReplayBuffer (args . buffer_size , len (train_envs ))
165+ if num_train_envs > 1 :
166+ buffer = VectorReplayBuffer (buffer_size , len (train_envs ))
174167 else :
175- buffer = ReplayBuffer (args . buffer_size )
168+ buffer = ReplayBuffer (buffer_size )
176169 train_collector = Collector [CollectStats ](algorithm , train_envs , buffer , exploration_noise = True )
177170 test_collector = Collector [CollectStats ](algorithm , test_envs )
178171
179172 # log
180173 now = datetime .datetime .now ().strftime ("%y%m%d-%H%M%S" )
181- args . algo_name = "a2c"
182- log_name = os .path .join (args . task , args . algo_name , str (args . seed ), now )
183- log_path = os .path .join (args . logdir , log_name )
174+ algo_name = "a2c"
175+ log_name = os .path .join (task , algo_name , str (seed ), now )
176+ log_path = os .path .join (persistence_base_dir , log_name )
184177
185178 # logger
186179 logger_factory = LoggerFactoryDefault ()
187- if args . logger == "wandb" :
180+ if logger_type == "wandb" :
188181 logger_factory .logger_type = "wandb"
189- logger_factory .wandb_project = args . wandb_project
182+ logger_factory .wandb_project = wandb_project
190183 else :
191184 logger_factory .logger_type = "tensorboard"
192185
193186 logger = logger_factory .create_logger (
194187 log_dir = log_path ,
195188 experiment_name = log_name ,
196- run_id = args . resume_id ,
197- config_dict = vars ( args ) ,
189+ run_id = resume_id ,
190+ config_dict = params_log_info ,
198191 )
199192
200193 def save_best_fn (policy : Algorithm ) -> None :
201194 state = {"model" : policy .state_dict (), "obs_rms" : train_envs .get_obs_rms ()}
202195 torch .save (state , os .path .join (log_path , "policy.pth" ))
203196
204- if not args . watch :
197+ if not watch :
205198 # train
206199 result = algorithm .run_training (
207200 OnPolicyTrainerParams (
208201 train_collector = train_collector ,
209202 test_collector = test_collector ,
210- max_epochs = args . epoch ,
211- epoch_num_steps = args . epoch_num_steps ,
212- update_step_num_repetitions = args . update_step_num_repetitions ,
213- test_step_num_episodes = args . num_test_envs ,
214- batch_size = args . batch_size ,
215- collection_step_num_env_steps = args . collection_step_num_env_steps ,
203+ max_epochs = epoch ,
204+ epoch_num_steps = epoch_num_steps ,
205+ update_step_num_repetitions = update_step_num_repetitions ,
206+ test_step_num_episodes = num_test_envs ,
207+ batch_size = batch_size ,
208+ collection_step_num_env_steps = collection_step_num_env_steps ,
216209 save_best_fn = save_best_fn ,
217210 logger = logger ,
218211 test_in_train = False ,
@@ -221,11 +214,11 @@ def save_best_fn(policy: Algorithm) -> None:
221214 pprint .pprint (result )
222215
223216 # Let's watch its performance!
224- test_envs .seed (args . seed )
217+ test_envs .seed (seed )
225218 test_collector .reset ()
226- collector_stats = test_collector .collect (n_episode = args . num_test_envs , render = args . render )
219+ collector_stats = test_collector .collect (n_episode = num_test_envs , render = render )
227220 print (collector_stats )
228221
229222
230223if __name__ == "__main__" :
231- main ( )
224+ result = logging . run_cli ( main , level = logging . INFO )
0 commit comments