Skip to content

Commit 46f9bb6

Browse files
committed
initial working implementation
1 parent 60f8423 commit 46f9bb6

File tree

4 files changed

+628
-2
lines changed

4 files changed

+628
-2
lines changed

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@ authors = [
99
requires-python = ">=3.10"
1010
dependencies = [
1111
"papermill>=2.6.0",
12+
"typer>=0.16.0",
1213
]
1314

1415
[project.scripts]
15-
millrun = "millrun:main"
16+
millrun = "millrun.cli:app"
1617

1718
[build-system]
1819
requires = ["flit_core >=3.2,<4"]
1920
build-backend = "flit_core.buildapi"
21+
22+
[dependency-groups]
23+
dev = [
24+
"ipykernel>=6.29.5",
25+
]

src/millrun/cli.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import json
2+
from typing import Optional, Any
3+
from typing_extensions import Annotated
4+
import pathlib
5+
6+
import typer
7+
from .millrun import execute_batch
8+
9+
DATA_STORE_FILE = pathlib.Path(__file__).parents[0] / "DATA_STORE.json"
10+
ROW_SELECTIONS: list[int] = []
11+
DATA_STORE: dict = {}
12+
13+
def _parse_json(filepath: str) -> dict:
14+
with open(filepath, 'r') as file:
15+
return json.load(file)
16+
17+
APP_INTRO = typer.style(
18+
"""
19+
AISC sections database W-section selection tool (2023-05-28)
20+
""",
21+
fg=typer.colors.BRIGHT_YELLOW,
22+
bold=True,
23+
)
24+
25+
app = typer.Typer(
26+
add_completion=False,
27+
no_args_is_help=True,
28+
help=APP_INTRO,
29+
)
30+
31+
32+
@app.command(
33+
name='run',
34+
help="PROOFExecutes a notebook or directory of notebooks using the provided bulk parameters JSON file",
35+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
36+
)
37+
def run(
38+
notebook_dir_or_file: Annotated[str, typer.Argument(
39+
help="Path to a notebook file or a directory containing notebooks.")
40+
],
41+
params: Annotated[str, typer.Argument(
42+
help=("JSON file that contains parameters for notebook execution. "
43+
"Can either be a 'list of dict' or 'dict of list'."),
44+
callback=lambda value: _parse_json(value),
45+
)
46+
],
47+
output_dir: Annotated[Optional[str], typer.Option(
48+
help=("Directory to place output files into. If not provided"
49+
" the file directory will be used."),
50+
)
51+
] = None,
52+
prepend: Annotated[Optional[str], typer.Option(
53+
help=("Prepend components to use on output filename."
54+
"Can use dict keys from 'params' which will be evaluated."
55+
"(Comma-separated values)."),
56+
callback=lambda x: x.split(",") if x else None
57+
)
58+
] = None,
59+
append: Annotated[Optional[str], typer.Option(
60+
help=("Append components to use on output filename."
61+
"Can use dict keys from 'params' which will be evaluated."
62+
"(Comma-separated values)."),
63+
callback=lambda x: x.split(",") if x else None
64+
)
65+
] = None,
66+
recursive: bool = False,
67+
exclude_glob_pattern: Optional[str] = None,
68+
include_glob_pattern: Optional[str] = None,
69+
):
70+
if output_dir is not None:
71+
output_dir = pathlib.Path(output_dir)
72+
else:
73+
output_dir = pathlib.Path.cwd()
74+
try:
75+
execute_batch(
76+
notebook_dir_or_file,
77+
params,
78+
output_dir,
79+
prepend,
80+
append,
81+
# recursive,
82+
# exclude_glob_pattern,
83+
# include_glob_pattern,
84+
# **kwargs
85+
)
86+
except Exception as e:
87+
print(f"Error! {e.msg}")
88+
89+
90+
if __name__ == "__main__":
91+
app()

src/millrun/millrun.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import pathlib
2+
import json
3+
from typing import Optional, Any
4+
import papermill as pm
5+
6+
7+
def execute_batch(
8+
notebook_dir_or_file: pathlib.Path | str,
9+
bulk_params: list | dict,
10+
output_dir: Optional[pathlib.Path | str] = None,
11+
output_prepend_components: Optional[list[str]] = None,
12+
output_append_components: Optional[list[str]] = None,
13+
recursive: bool = False,
14+
exclude_glob_pattern: Optional[str] = None,
15+
include_glob_pattern: Optional[str] = None,
16+
**kwargs,
17+
) -> list[pathlib.Path] | None:
18+
"""
19+
Executes the notebooks contained in the notebook_dir_or_file using the parameters in
20+
'bulk_params'.
21+
22+
'notebook_dir_or_file': If a directory, then will execute all of the notebooks
23+
within the directory
24+
'bulk_params':
25+
Either a dict in the following format:
26+
27+
{
28+
"key1": ["list", "of", "values"], # All lists of values must be same length
29+
"key2": ["list", "of", "values"],
30+
"key3": ["list", "of", "values"],
31+
...
32+
},
33+
34+
-or-
35+
36+
A list in the following format:
37+
38+
[
39+
{"key1": "value", "key2": "value"}, # All keys
40+
{"key1": "value", "key2": "value"},
41+
{"key1": "value", "key2": "value"},
42+
...
43+
]
44+
'output_dir': The directory for all output files. If None,
45+
files will be output in the same directory as the source file.
46+
'output_prepend_components': A list of str representing the keys used
47+
in 'bulk_params'. These keys will be used to retrieve the value
48+
for the key in each iteration of 'bulk_params' and they will
49+
be used to name the output file be prepending them to the
50+
original filename. If a key is not found then the key will
51+
be interpreted as a str literal and will be added asis.
52+
'output_append_components': Same as the prepend components but
53+
these components will be used at the end of the original filename.
54+
'recursive': If True, and if 'notebook_dir_or_file' is a directory,
55+
then will execute notebooks within all sub-directories.
56+
'exclude_glob_pattern': A glob-style pattern of files to exclude.
57+
If None, then all files willbe included.
58+
'include_glob_pattern': A glob-style pattern of files to include. If
59+
None then all files will be included.
60+
'**kwargs': Passed on to papermill.execute_notebook
61+
"""
62+
notebook_dir_or_file = pathlib.Path(notebook_dir_or_file)
63+
output_dir = pathlib.Path(output_dir)
64+
if isinstance(bulk_params, dict):
65+
unequal_lengths = check_unequal_value_lengths(bulk_params)
66+
if unequal_lengths:
67+
raise ValueError(
68+
f"All lists in the bulk_params dict must be of equal length.\n"
69+
f"The following keys have unequal length: {unequal_lengths}"
70+
)
71+
bulk_params_list = convert_bulk_params_to_list(bulk_params)
72+
else:
73+
bulk_params_list = bulk_params
74+
75+
notebook_dir = None
76+
notebook_filename = None
77+
if notebook_dir_or_file.is_dir():
78+
notebook_dir = notebook_dir_or_file
79+
else:
80+
notebook_dir = notebook_dir_or_file.parent.resolve()
81+
notebook_filename = notebook_dir_or_file.name
82+
83+
if output_dir is None:
84+
output_dir = notebook_dir
85+
if not output_dir.exists():
86+
output_dir.mkdir(exist_ok=True, parents=True)
87+
88+
if notebook_filename is not None:
89+
for notebook_params in bulk_params_list:
90+
output_name = get_output_name(
91+
notebook_filename,
92+
output_prepend_components,
93+
output_append_components,
94+
notebook_params
95+
)
96+
notebook_file = notebook_dir_or_file
97+
pm.execute_notebook(
98+
notebook_dir / notebook_file,
99+
output_path=output_dir / output_name,
100+
parameters=notebook_params,
101+
**kwargs
102+
)
103+
else:
104+
glob_method = notebook_dir.glob
105+
if recursive:
106+
glob_method = notebook_dir.rglob
107+
108+
excluded_paths = set()
109+
if exclude_glob_pattern is not None:
110+
excluded_paths = set(glob_method(exclude_glob_pattern))
111+
112+
glob_pattern = include_glob_pattern
113+
if include_glob_pattern is None:
114+
glob_pattern = "*.ipynb"
115+
included_paths = set(glob_method(glob_pattern))
116+
117+
notebook_paths = included_paths - excluded_paths
118+
119+
for notebook_path in notebook_paths:
120+
print(f"{notebook_path=}")
121+
for notebook_params in bulk_params_list:
122+
output_name_template = get_output_name(
123+
notebook_path.name,
124+
output_prepend_components,
125+
output_append_components,
126+
notebook_params
127+
)
128+
output_name = output_name_template.format(nb=notebook_params)
129+
notebook_file = notebook_dir_or_file
130+
pm.execute_notebook(
131+
notebook_path,
132+
output_path=output_dir / output_name,
133+
parameters=notebook_params,
134+
**kwargs
135+
)
136+
137+
def check_unequal_value_lengths(bulk_params: dict[str, list]) -> bool | dict:
138+
"""
139+
Returns False if all list values are equal length. Returns a dict
140+
of value lengths otherwise.
141+
"""
142+
acc = {}
143+
for k, v in bulk_params.items():
144+
try:
145+
acc.update({k: len(v)})
146+
except TypeError:
147+
raise ValueError(f"The values of the bulk_param keys must be lists, not: '{k: v}'")
148+
all_values = set(acc.values())
149+
if len(all_values) == 1:
150+
return False
151+
else:
152+
return acc
153+
154+
155+
def convert_bulk_params_to_list(bulk_params: dict[str, list]):
156+
"""
157+
Converts a dict of lists into a list of dicts.
158+
"""
159+
iter_length = len(list(bulk_params.values())[0])
160+
bulk_params_list = []
161+
for idx in range(iter_length):
162+
inner_acc = {}
163+
for parameter_name, parameter_values in bulk_params.items():
164+
inner_acc.update({parameter_name: parameter_values[idx]})
165+
bulk_params_list.append(inner_acc)
166+
return bulk_params_list
167+
168+
169+
def get_output_name(
170+
notebook_filename: str,
171+
output_prepend_components: list[str] | None,
172+
output_append_components: list[str] | None,
173+
notebook_params: dict[str, Any]
174+
) -> str:
175+
"""
176+
Returns the output name given the included components.
177+
"""
178+
if output_prepend_components is None:
179+
output_prepend_components = []
180+
if output_append_components is None:
181+
output_append_components = []
182+
prepends = [notebook_params[comp] for comp in output_prepend_components]
183+
appends = [notebook_params[comp] for comp in output_append_components]
184+
prepend_str = "-".join(prepends)
185+
append_str = "-".join(appends)
186+
notebook_filename = pathlib.Path(notebook_filename)
187+
return "-".join([elem for elem in [prepend_str, notebook_filename.stem, append_str] if elem]) + notebook_filename.suffix

0 commit comments

Comments
 (0)