Skip to content

Commit 5752802

Browse files
committed
separation of concerns: set_seed in another function
Signed-off-by: Jack Luar <[email protected]>
1 parent 4e81220 commit 5752802

File tree

2 files changed

+62
-39
lines changed

2 files changed

+62
-39
lines changed

tools/AutoTuner/src/autotuner/distributed.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,12 @@
6363
import json
6464
import os
6565
import sys
66-
import random
6766
from itertools import product
6867
from uuid import uuid4 as uuid
6968
from collections import namedtuple
7069
from multiprocessing import cpu_count
7170

7271
import numpy as np
73-
import torch
7472

7573
import ray
7674
from ray import tune
@@ -94,6 +92,7 @@
9492
prepare_ray_server,
9593
CONSTRAINTS_SDC,
9694
FASTROUTE_TCL,
95+
set_seed,
9796
)
9897

9998
# Name of the final metric
@@ -449,19 +448,7 @@ def set_algorithm(
449448
"""
450449
Configure search algorithm.
451450
"""
452-
# Pre-set seed if user sets seed to 0
453-
if seed == 0:
454-
print(
455-
"Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
456-
)
457-
if input().lower() != "y":
458-
sys.exit(0)
459-
seed = None
460-
else:
461-
torch.manual_seed(seed)
462-
np.random.seed(seed)
463-
random.seed(seed)
464-
451+
set_seed(seed)
465452
if algorithm_name == "hyperopt":
466453
algorithm = HyperOptSearch(
467454
points_to_evaluate=best_params,

tools/AutoTuner/src/autotuner/utils.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import sys
4343
import uuid
4444
import time
45-
from multiprocessing import cpu_count
4645
from datetime import datetime
4746

4847
import numpy as np
@@ -71,6 +70,56 @@
7170
DATE = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
7271

7372

73+
# --- General utilities ---
74+
def run_command(
75+
args, cmd, timeout=None, stderr_file=None, stdout_file=None, fail_fast=False
76+
):
77+
"""
78+
Wrapper for subprocess.run
79+
Allows to run shell command, control print and exceptions.
80+
"""
81+
process = subprocess.run(
82+
cmd, timeout=timeout, capture_output=True, text=True, check=False, shell=True
83+
)
84+
if stderr_file is not None and process.stderr != "":
85+
with open(stderr_file, "a") as file:
86+
file.write(f"\n\n{cmd}\n{process.stderr}")
87+
if stdout_file is not None and process.stdout != "":
88+
with open(stdout_file, "a") as file:
89+
file.write(f"\n\n{cmd}\n{process.stdout}")
90+
if args.verbose >= 1:
91+
print(process.stderr)
92+
if args.verbose >= 2:
93+
print(process.stdout)
94+
95+
if fail_fast and process.returncode != 0:
96+
raise RuntimeError
97+
98+
99+
def set_seed(seed: int):
100+
"""Set seed for reproducibility."""
101+
import torch
102+
import random
103+
104+
# TODO: shift seed validation into validate_args during parse_arguments
105+
# Pre-set seed if user sets seed to 0
106+
if seed == 0:
107+
print(
108+
"Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
109+
)
110+
if input().lower() != "y":
111+
sys.exit(0)
112+
seed = None
113+
else:
114+
torch.manual_seed(seed)
115+
np.random.seed(seed)
116+
random.seed(seed)
117+
118+
119+
# --- End General utilities ---
120+
121+
122+
# --- OpenROAD: write file utilities ---
74123
def write_sdc(variables, path, sdc_original, constraints_sdc):
75124
"""
76125
Create a SDC file with parameters for current tuning iteration.
@@ -160,6 +209,10 @@ def write_fast_route(variables, path, platform, fr_original, fastroute_tcl):
160209
return file_name
161210

162211

212+
# --- End OpenROAD: write file utilities ---
213+
214+
215+
# --- OpenROAD: parse utilities ---
163216
def parse_flow_variables(base_dir, platform):
164217
"""
165218
Parse the flow variables from source
@@ -262,31 +315,10 @@ def parse_config(
262315
return options
263316

264317

265-
def run_command(
266-
args, cmd, timeout=None, stderr_file=None, stdout_file=None, fail_fast=False
267-
):
268-
"""
269-
Wrapper for subprocess.run
270-
Allows to run shell command, control print and exceptions.
271-
"""
272-
process = subprocess.run(
273-
cmd, timeout=timeout, capture_output=True, text=True, check=False, shell=True
274-
)
275-
if stderr_file is not None and process.stderr != "":
276-
with open(stderr_file, "a") as file:
277-
file.write(f"\n\n{cmd}\n{process.stderr}")
278-
if stdout_file is not None and process.stdout != "":
279-
with open(stdout_file, "a") as file:
280-
file.write(f"\n\n{cmd}\n{process.stdout}")
281-
if args.verbose >= 1:
282-
print(process.stderr)
283-
if args.verbose >= 2:
284-
print(process.stdout)
285-
286-
if fail_fast and process.returncode != 0:
287-
raise RuntimeError
318+
# --- End OpenROAD: parse utilities ---
288319

289320

321+
# --- OpenROAD specific functions ---
290322
def openroad(
291323
args,
292324
base_dir,
@@ -602,6 +634,7 @@ def prepare_ray_server(args):
602634
return local_dir, orfs_flow_dir, install_path
603635

604636

637+
# --- Ray: OpenROAD wrapper utilities ---
605638
@ray.remote
606639
def openroad_distributed(
607640
args,
@@ -645,3 +678,6 @@ def consumer(queue):
645678
print(f"[INFO TUN-0007] Scheduling run for parameter {name}.")
646679
ray.get(openroad_distributed.remote(*next_item))
647680
print(f"[INFO TUN-0008] Finished run for parameter {name}.")
681+
682+
683+
# --- End Ray: OpenROAD wrapper utilities ---

0 commit comments

Comments
 (0)