22import json
33from typing import Optional , Any
44import papermill as pm
5+ import functools as ft
6+ from concurrent .futures import ProcessPoolExecutor
57
68
79def execute_batch (
@@ -13,6 +15,7 @@ def execute_batch(
1315 recursive : bool = False ,
1416 exclude_glob_pattern : Optional [str ] = None ,
1517 include_glob_pattern : Optional [str ] = None ,
18+ multiprocessing : bool = False ,
1619 ** kwargs ,
1720) -> list [pathlib .Path ] | None :
1821 """
@@ -86,20 +89,14 @@ def execute_batch(
8689 output_dir .mkdir (exist_ok = True , parents = True )
8790
8891 if notebook_filename is not None :
89- for notebook_params in bulk_params_list :
90- output_name = get_output_name (
91- notebook_filename ,
92- output_prepend_components ,
93- output_append_components ,
94- notebook_params
95- )
96- notebook_file = notebook_dir_or_file
97- pm .execute_notebook (
98- notebook_dir / notebook_file ,
99- output_path = output_dir / output_name ,
100- parameters = notebook_params ,
101- ** kwargs
102- )
92+ execute_notebooks (
93+ notebook_dir / notebook_filename ,
94+ bulk_params_list ,
95+ output_prepend_components ,
96+ output_append_components ,
97+ output_dir ,
98+ multiprocessing
99+ )
103100 else :
104101 glob_method = notebook_dir .glob
105102 if recursive :
@@ -117,22 +114,14 @@ def execute_batch(
117114 notebook_paths = included_paths - excluded_paths
118115
119116 for notebook_path in notebook_paths :
120- print (f"{ notebook_path = } " )
121- for notebook_params in bulk_params_list :
122- output_name_template = get_output_name (
123- notebook_path .name ,
124- output_prepend_components ,
125- output_append_components ,
126- notebook_params
127- )
128- output_name = output_name_template .format (nb = notebook_params )
129- notebook_file = notebook_dir_or_file
130- pm .execute_notebook (
131- notebook_path ,
132- output_path = output_dir / output_name ,
133- parameters = notebook_params ,
134- ** kwargs
135- )
117+ execute_notebooks (
118+ notebook_path ,
119+ bulk_params_list ,
120+ output_prepend_components ,
121+ output_append_components ,
122+ output_dir ,
123+ multiprocessing
124+ )
136125
137126def check_unequal_value_lengths (bulk_params : dict [str , list ]) -> bool | dict :
138127 """
@@ -184,4 +173,55 @@ def get_output_name(
184173 prepend_str = "-" .join (prepends )
185174 append_str = "-" .join (appends )
186175 notebook_filename = pathlib .Path (notebook_filename )
187- return "-" .join ([elem for elem in [prepend_str , notebook_filename .stem , append_str ] if elem ]) + notebook_filename .suffix
176+ return "-" .join ([elem for elem in [prepend_str , notebook_filename .stem , append_str ] if elem ]) + notebook_filename .suffix
177+
178+
179+ def execute_notebooks (
180+ notebook_filename : pathlib .Path ,
181+ bulk_params_list : dict [str , Any ],
182+ output_prepend_components : list [str ],
183+ output_append_components : list [str ],
184+ output_dir : pathlib .Path ,
185+ multiprocessing : bool = False ,
186+ ** kwargs ,
187+ ):
188+ mp_execute_notebook = ft .partial (
189+ execute_notebook ,
190+ notebook_filename = notebook_filename ,
191+ output_prepend_components = output_prepend_components ,
192+ output_append_components = output_append_components ,
193+ output_dir = output_dir ,
194+ )
195+ # print(mp_execute_notebook(notebook_params=bulk_params_list))
196+ if multiprocessing :
197+ with ProcessPoolExecutor () as executor :
198+ for result in executor .map (mp_execute_notebook , bulk_params_list ):
199+ pass
200+ else :
201+ for result in map (mp_execute_notebook , bulk_params_list ):
202+ pass
203+
204+
205+
206+ def execute_notebook (
207+ notebook_params : dict ,
208+ notebook_filename : pathlib .Path ,
209+ output_prepend_components : list [str ],
210+ output_append_components : list [str ],
211+ output_dir : pathlib .Path ,
212+ ** kwargs ,
213+ ):
214+ print (notebook_filename )
215+ output_name = get_output_name (
216+ notebook_filename ,
217+ output_prepend_components ,
218+ output_append_components ,
219+ notebook_params
220+ )
221+ pm .execute_notebook (
222+ notebook_filename ,
223+ output_path = output_dir / output_name ,
224+ parameters = notebook_params ,
225+ progress_bar = True ,
226+ ** kwargs
227+ )
0 commit comments