Skip to content

Commit 895a806

Browse files
committed
Multi-processing implementation
1 parent 46f9bb6 commit 895a806

File tree

2 files changed

+86
-50
lines changed

2 files changed

+86
-50
lines changed

src/millrun/cli.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
import typer
77
from .millrun import execute_batch
88

9-
DATA_STORE_FILE = pathlib.Path(__file__).parents[0] / "DATA_STORE.json"
10-
ROW_SELECTIONS: list[int] = []
11-
DATA_STORE: dict = {}
129

1310
def _parse_json(filepath: str) -> dict:
1411
with open(filepath, 'r') as file:
@@ -26,12 +23,13 @@ def _parse_json(filepath: str) -> dict:
2623
add_completion=False,
2724
no_args_is_help=True,
2825
help=APP_INTRO,
26+
# pretty_exceptions_enable=False,
27+
pretty_exceptions_show_locals=False
2928
)
3029

31-
3230
@app.command(
3331
name='run',
34-
help="PROOFExecutes a notebook or directory of notebooks using the provided bulk parameters JSON file",
32+
help="Executes a notebook or directory of notebooks using the provided bulk parameters JSON file",
3533
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
3634
)
3735
def run(
@@ -71,20 +69,18 @@ def run(
7169
output_dir = pathlib.Path(output_dir)
7270
else:
7371
output_dir = pathlib.Path.cwd()
74-
try:
75-
execute_batch(
76-
notebook_dir_or_file,
77-
params,
78-
output_dir,
79-
prepend,
80-
append,
81-
# recursive,
82-
# exclude_glob_pattern,
83-
# include_glob_pattern,
84-
# **kwargs
85-
)
86-
except Exception as e:
87-
print(f"Error! {e.msg}")
72+
execute_batch(
73+
notebook_dir_or_file,
74+
params,
75+
output_dir,
76+
prepend,
77+
append,
78+
recursive,
79+
exclude_glob_pattern,
80+
include_glob_pattern,
81+
multiprocessing=True
82+
# **kwargs
83+
)
8884

8985

9086
if __name__ == "__main__":

src/millrun/millrun.py

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
from typing import Optional, Any
44
import papermill as pm
5+
import functools as ft
6+
from concurrent.futures import ProcessPoolExecutor
57

68

79
def execute_batch(
@@ -13,6 +15,7 @@ def execute_batch(
1315
recursive: bool = False,
1416
exclude_glob_pattern: Optional[str] = None,
1517
include_glob_pattern: Optional[str] = None,
18+
multiprocessing: bool = False,
1619
**kwargs,
1720
) -> list[pathlib.Path] | None:
1821
"""
@@ -86,20 +89,14 @@ def execute_batch(
8689
output_dir.mkdir(exist_ok=True, parents=True)
8790

8891
if notebook_filename is not None:
89-
for notebook_params in bulk_params_list:
90-
output_name = get_output_name(
91-
notebook_filename,
92-
output_prepend_components,
93-
output_append_components,
94-
notebook_params
95-
)
96-
notebook_file = notebook_dir_or_file
97-
pm.execute_notebook(
98-
notebook_dir / notebook_file,
99-
output_path=output_dir / output_name,
100-
parameters=notebook_params,
101-
**kwargs
102-
)
92+
execute_notebooks(
93+
notebook_dir / notebook_filename,
94+
bulk_params_list,
95+
output_prepend_components,
96+
output_append_components,
97+
output_dir,
98+
multiprocessing
99+
)
103100
else:
104101
glob_method = notebook_dir.glob
105102
if recursive:
@@ -117,22 +114,14 @@ def execute_batch(
117114
notebook_paths = included_paths - excluded_paths
118115

119116
for notebook_path in notebook_paths:
120-
print(f"{notebook_path=}")
121-
for notebook_params in bulk_params_list:
122-
output_name_template = get_output_name(
123-
notebook_path.name,
124-
output_prepend_components,
125-
output_append_components,
126-
notebook_params
127-
)
128-
output_name = output_name_template.format(nb=notebook_params)
129-
notebook_file = notebook_dir_or_file
130-
pm.execute_notebook(
131-
notebook_path,
132-
output_path=output_dir / output_name,
133-
parameters=notebook_params,
134-
**kwargs
135-
)
117+
execute_notebooks(
118+
notebook_path,
119+
bulk_params_list,
120+
output_prepend_components,
121+
output_append_components,
122+
output_dir,
123+
multiprocessing
124+
)
136125

137126
def check_unequal_value_lengths(bulk_params: dict[str, list]) -> bool | dict:
138127
"""
@@ -184,4 +173,55 @@ def get_output_name(
184173
prepend_str = "-".join(prepends)
185174
append_str = "-".join(appends)
186175
notebook_filename = pathlib.Path(notebook_filename)
187-
return "-".join([elem for elem in [prepend_str, notebook_filename.stem, append_str] if elem]) + notebook_filename.suffix
176+
return "-".join([elem for elem in [prepend_str, notebook_filename.stem, append_str] if elem]) + notebook_filename.suffix
177+
178+
179+
def execute_notebooks(
180+
notebook_filename: pathlib.Path,
181+
bulk_params_list: dict[str, Any],
182+
output_prepend_components: list[str],
183+
output_append_components: list[str],
184+
output_dir: pathlib.Path,
185+
multiprocessing: bool = False,
186+
**kwargs,
187+
):
188+
mp_execute_notebook = ft.partial(
189+
execute_notebook,
190+
notebook_filename=notebook_filename,
191+
output_prepend_components=output_prepend_components,
192+
output_append_components=output_append_components,
193+
output_dir=output_dir,
194+
)
195+
# print(mp_execute_notebook(notebook_params=bulk_params_list))
196+
if multiprocessing:
197+
with ProcessPoolExecutor() as executor:
198+
for result in executor.map(mp_execute_notebook, bulk_params_list):
199+
pass
200+
else:
201+
for result in map(mp_execute_notebook, bulk_params_list):
202+
pass
203+
204+
205+
206+
def execute_notebook(
207+
notebook_params: dict,
208+
notebook_filename: pathlib.Path,
209+
output_prepend_components: list[str],
210+
output_append_components: list[str],
211+
output_dir: pathlib.Path,
212+
**kwargs,
213+
):
214+
print(notebook_filename)
215+
output_name = get_output_name(
216+
notebook_filename,
217+
output_prepend_components,
218+
output_append_components,
219+
notebook_params
220+
)
221+
pm.execute_notebook(
222+
notebook_filename,
223+
output_path=output_dir / output_name,
224+
parameters=notebook_params,
225+
progress_bar=True,
226+
**kwargs
227+
)

0 commit comments

Comments
 (0)