1111import tensorflow as tf
1212import time
1313
14- class PBTBenchmarkExample ():
14+
15+ class PBTBenchmarkExample :
1516 """Toy PBT problem for benchmarking adaptive learning rate.
1617 The goal is to optimize this trainable's accuracy. The accuracy increases
1718 fastest at the optimal lr, which is a function of the current accuracy.
@@ -36,24 +37,23 @@ def __init__(self, lr, log_dir: str, log_interval: int, checkpoint: str):
3637 self ._log_interval = log_interval
3738 self ._lr = lr
3839
39- self ._checkpoint_file = os .path .join (checkpoint , ' training.ckpt' )
40+ self ._checkpoint_file = os .path .join (checkpoint , " training.ckpt" )
4041 if os .path .exists (self ._checkpoint_file ):
41- with open (self ._checkpoint_file , 'rb' ) as fin :
42+ with open (self ._checkpoint_file , "rb" ) as fin :
4243 checkpoint_data = pickle .load (fin )
43- self ._accuracy = checkpoint_data [' accuracy' ]
44- self ._step = checkpoint_data [' step' ]
44+ self ._accuracy = checkpoint_data [" accuracy" ]
45+ self ._step = checkpoint_data [" step" ]
4546 else :
4647 os .makedirs (checkpoint , exist_ok = True )
4748 self ._step = 1
4849 self ._accuracy = 0.0
49-
5050
5151 def save_checkpoint (self ):
52- with open (self ._checkpoint_file , 'wb' ) as fout :
53- pickle .dump ({' step' : self ._step , ' accuracy' : self ._accuracy }, fout )
52+ with open (self ._checkpoint_file , "wb" ) as fout :
53+ pickle .dump ({" step" : self ._step , " accuracy" : self ._accuracy }, fout )
5454
5555 def step (self ):
56- midpoint = 100 # lr starts decreasing after acc > midpoint
56+ midpoint = 100 # lr starts decreasing after acc > midpoint
5757 q_tolerance = 3 # penalize exceeding lr by more than this multiple
5858 noise_level = 2 # add gaussian noise to the acc increase
5959 # triangle wave:
@@ -80,32 +80,53 @@ def step(self):
8080 if not self ._writer :
8181 self ._writer = tf .summary .create_file_writer (self ._log_dir )
8282 with self ._writer .as_default ():
83- tf .summary .scalar ("Validation-accuracy" , self ._accuracy , step = self ._step )
83+ tf .summary .scalar (
84+ "Validation-accuracy" , self ._accuracy , step = self ._step
85+ )
8486 tf .summary .scalar ("lr" , self ._lr , step = self ._step )
8587 self ._writer .flush ()
8688
8789 self ._step += 1
8890
8991 def __repr__ (self ):
90- return "epoch {}:\n lr={:0.4f}\n Validation-accuracy={:0.4f}" .format (self ._step , self ._lr , self ._accuracy )
92+ return "epoch {}:\n lr={:0.4f}\n Validation-accuracy={:0.4f}" .format (
93+ self ._step , self ._lr , self ._accuracy
94+ )
9195
9296
9397if __name__ == "__main__" :
9498 # Parse CLI arguments
95- parser = argparse .ArgumentParser (description = 'PBT Basic Test' )
96- parser .add_argument ('--lr' , type = float , default = 0.0001 ,
97- help = 'learning rate (default: 0.0001)' )
98- parser .add_argument ('--epochs' , type = int , default = 20 ,
99- help = 'number of epochs to train (default: 20)' )
100- parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
101- help = 'how many batches to wait before logging training status (default: 1)' )
102- parser .add_argument ('--log-path' , type = str , default = "/var/log/katib/tfevent/" ,
103- help = 'tfevent output path (default: /var/log/katib/tfevent/)' )
104- parser .add_argument ('--checkpoint' , type = str , default = "/var/log/katib/checkpoints/" ,
105- help = 'checkpoint directory (resume and save)' )
99+ parser = argparse .ArgumentParser (description = "PBT Basic Test" )
100+ parser .add_argument (
101+ "--lr" , type = float , default = 0.0001 , help = "learning rate (default: 0.0001)"
102+ )
103+ parser .add_argument (
104+ "--epochs" , type = int , default = 20 , help = "number of epochs to train (default: 20)"
105+ )
106+ parser .add_argument (
107+ "--log-interval" ,
108+ type = int ,
109+ default = 10 ,
110+ metavar = "N" ,
111+ help = "how many batches to wait before logging training status (default: 1)" ,
112+ )
113+ parser .add_argument (
114+ "--log-path" ,
115+ type = str ,
116+ default = "/var/log/katib/tfevent/" ,
117+ help = "tfevent output path (default: /var/log/katib/tfevent/)" ,
118+ )
119+ parser .add_argument (
120+ "--checkpoint" ,
121+ type = str ,
122+ default = "/var/log/katib/checkpoints/" ,
123+ help = "checkpoint directory (resume and save)" ,
124+ )
106125 opt = parser .parse_args ()
107126
108- benchmark = PBTBenchmarkExample (opt .lr , opt .log_path , opt .log_interval , opt .checkpoint )
127+ benchmark = PBTBenchmarkExample (
128+ opt .lr , opt .log_path , opt .log_interval , opt .checkpoint
129+ )
109130 for i in range (opt .epochs ):
110131 benchmark .step ()
111132 time .sleep (0.2 )
0 commit comments