1+ import pathlib
2+ import json
3+ from typing import Optional , Any
4+ import papermill as pm
5+
6+
7+ def execute_batch (
8+ notebook_dir_or_file : pathlib .Path | str ,
9+ bulk_params : list | dict ,
10+ output_dir : Optional [pathlib .Path | str ] = None ,
11+ output_prepend_components : Optional [list [str ]] = None ,
12+ output_append_components : Optional [list [str ]] = None ,
13+ recursive : bool = False ,
14+ exclude_glob_pattern : Optional [str ] = None ,
15+ include_glob_pattern : Optional [str ] = None ,
16+ ** kwargs ,
17+ ) -> list [pathlib .Path ] | None :
18+ """
19+ Executes the notebooks contained in the notebook_dir_or_file using the parameters in
20+ 'bulk_params'.
21+
22+ 'notebook_dir_or_file': If a directory, then will execute all of the notebooks
23+ within the directory
24+ 'bulk_params':
25+ Either a dict in the following format:
26+
27+ {
28+ "key1": ["list", "of", "values"], # All lists of values must be same length
29+ "key2": ["list", "of", "values"],
30+ "key3": ["list", "of", "values"],
31+ ...
32+ },
33+
34+ -or-
35+
36+ A list in the following format:
37+
38+ [
39+ {"key1": "value", "key2": "value"}, # All keys
40+ {"key1": "value", "key2": "value"},
41+ {"key1": "value", "key2": "value"},
42+ ...
43+ ]
44+ 'output_dir': The directory for all output files. If None,
45+ files will be output in the same directory as the source file.
46+ 'output_prepend_components': A list of str representing the keys used
47+ in 'bulk_params'. These keys will be used to retrieve the value
48+ for the key in each iteration of 'bulk_params' and they will
49+ be used to name the output file be prepending them to the
50+ original filename. If a key is not found then the key will
51+ be interpreted as a str literal and will be added asis.
52+ 'output_append_components': Same as the prepend components but
53+ these components will be used at the end of the original filename.
54+ 'recursive': If True, and if 'notebook_dir_or_file' is a directory,
55+ then will execute notebooks within all sub-directories.
56+ 'exclude_glob_pattern': A glob-style pattern of files to exclude.
57+ If None, then all files willbe included.
58+ 'include_glob_pattern': A glob-style pattern of files to include. If
59+ None then all files will be included.
60+ '**kwargs': Passed on to papermill.execute_notebook
61+ """
62+ notebook_dir_or_file = pathlib .Path (notebook_dir_or_file )
63+ output_dir = pathlib .Path (output_dir )
64+ if isinstance (bulk_params , dict ):
65+ unequal_lengths = check_unequal_value_lengths (bulk_params )
66+ if unequal_lengths :
67+ raise ValueError (
68+ f"All lists in the bulk_params dict must be of equal length.\n "
69+ f"The following keys have unequal length: { unequal_lengths } "
70+ )
71+ bulk_params_list = convert_bulk_params_to_list (bulk_params )
72+ else :
73+ bulk_params_list = bulk_params
74+
75+ notebook_dir = None
76+ notebook_filename = None
77+ if notebook_dir_or_file .is_dir ():
78+ notebook_dir = notebook_dir_or_file
79+ else :
80+ notebook_dir = notebook_dir_or_file .parent .resolve ()
81+ notebook_filename = notebook_dir_or_file .name
82+
83+ if output_dir is None :
84+ output_dir = notebook_dir
85+ if not output_dir .exists ():
86+ output_dir .mkdir (exist_ok = True , parents = True )
87+
88+ if notebook_filename is not None :
89+ for notebook_params in bulk_params_list :
90+ output_name = get_output_name (
91+ notebook_filename ,
92+ output_prepend_components ,
93+ output_append_components ,
94+ notebook_params
95+ )
96+ notebook_file = notebook_dir_or_file
97+ pm .execute_notebook (
98+ notebook_dir / notebook_file ,
99+ output_path = output_dir / output_name ,
100+ parameters = notebook_params ,
101+ ** kwargs
102+ )
103+ else :
104+ glob_method = notebook_dir .glob
105+ if recursive :
106+ glob_method = notebook_dir .rglob
107+
108+ excluded_paths = set ()
109+ if exclude_glob_pattern is not None :
110+ excluded_paths = set (glob_method (exclude_glob_pattern ))
111+
112+ glob_pattern = include_glob_pattern
113+ if include_glob_pattern is None :
114+ glob_pattern = "*.ipynb"
115+ included_paths = set (glob_method (glob_pattern ))
116+
117+ notebook_paths = included_paths - excluded_paths
118+
119+ for notebook_path in notebook_paths :
120+ print (f"{ notebook_path = } " )
121+ for notebook_params in bulk_params_list :
122+ output_name_template = get_output_name (
123+ notebook_path .name ,
124+ output_prepend_components ,
125+ output_append_components ,
126+ notebook_params
127+ )
128+ output_name = output_name_template .format (nb = notebook_params )
129+ notebook_file = notebook_dir_or_file
130+ pm .execute_notebook (
131+ notebook_path ,
132+ output_path = output_dir / output_name ,
133+ parameters = notebook_params ,
134+ ** kwargs
135+ )
136+
137+ def check_unequal_value_lengths (bulk_params : dict [str , list ]) -> bool | dict :
138+ """
139+ Returns False if all list values are equal length. Returns a dict
140+ of value lengths otherwise.
141+ """
142+ acc = {}
143+ for k , v in bulk_params .items ():
144+ try :
145+ acc .update ({k : len (v )})
146+ except TypeError :
147+ raise ValueError (f"The values of the bulk_param keys must be lists, not: '{ k : v} '" )
148+ all_values = set (acc .values ())
149+ if len (all_values ) == 1 :
150+ return False
151+ else :
152+ return acc
153+
154+
155+ def convert_bulk_params_to_list (bulk_params : dict [str , list ]):
156+ """
157+ Converts a dict of lists into a list of dicts.
158+ """
159+ iter_length = len (list (bulk_params .values ())[0 ])
160+ bulk_params_list = []
161+ for idx in range (iter_length ):
162+ inner_acc = {}
163+ for parameter_name , parameter_values in bulk_params .items ():
164+ inner_acc .update ({parameter_name : parameter_values [idx ]})
165+ bulk_params_list .append (inner_acc )
166+ return bulk_params_list
167+
168+
169+ def get_output_name (
170+ notebook_filename : str ,
171+ output_prepend_components : list [str ] | None ,
172+ output_append_components : list [str ] | None ,
173+ notebook_params : dict [str , Any ]
174+ ) -> str :
175+ """
176+ Returns the output name given the included components.
177+ """
178+ if output_prepend_components is None :
179+ output_prepend_components = []
180+ if output_append_components is None :
181+ output_append_components = []
182+ prepends = [notebook_params [comp ] for comp in output_prepend_components ]
183+ appends = [notebook_params [comp ] for comp in output_append_components ]
184+ prepend_str = "-" .join (prepends )
185+ append_str = "-" .join (appends )
186+ notebook_filename = pathlib .Path (notebook_filename )
187+ return "-" .join ([elem for elem in [prepend_str , notebook_filename .stem , append_str ] if elem ]) + notebook_filename .suffix
0 commit comments