@@ -12,6 +12,7 @@ def compare_tree_values(
1212 leaf_b : Hashable ,
1313 compare_func : Union [str , callable , None ],
1414 comparison_key : Optional [Hashable ]= None ,
15+ comparison_label : Optional [Hashable ] = None ,
1516 * args ,
1617 ** kwargs ,
1718) -> dict :
@@ -39,6 +40,10 @@ def compare_tree_values(
3940 'comparison_key' is ignored.
4041 'comparison_key': If provided, will serve as the key for the comparison value.
4142 If None, then the name of the comparison operator will used instead.
43+ 'comparison_label': If provided, will add an extra nested layer to the resulting
44+ dictionary keyed with 'comparison_label'. This is useful if you are going
45+ to be merging comparison trees and you wish to uniquely identify each
46+ comparison.
4247 """
4348 ops = {
4449 "div" : operator .truediv ,
@@ -59,19 +64,24 @@ def compare_tree_values(
5964
6065 branch_a = trim_branches (subtree_a , levels_a )
6166 branch_b = trim_branches (subtree_b , levels_b )
62-
6367 for trunk in branch_a .keys ():
6468 value_a = branch_a [trunk ]
69+ if trunk not in branch_b : continue
6570 value_b = branch_b [trunk ]
6671 env_acc .setdefault (trunk , {})
67- env_acc [trunk ].setdefault (leaf_a , value_a )
68- env_acc [trunk ].setdefault (leaf_b , value_b )
72+ if comparison_label is not None :
73+ env_acc [trunk ].setdefault (comparison_label , {})
74+ compare_acc = env_acc [trunk ][comparison_label ]
75+ else :
76+ compare_acc = env_acc [trunk ]
77+ compare_acc .setdefault (leaf_a , value_a )
78+ compare_acc .setdefault (leaf_b , value_b )
6979 comparison_operator = ops .get (compare_func , compare_func )
7080 if comparison_operator is not None :
7181 compared_value = comparison_operator (value_a , value_b )
7282 if comparison_key is None :
7383 comparison_key = comparison_operator .__name__
74- env_acc [ trunk ] .setdefault (comparison_key , compared_value )
84+ compare_acc .setdefault (comparison_key , compared_value )
7585 return env_acc
7686
7787
@@ -185,7 +195,7 @@ def filter_keys(
185195 tree : dict ,
186196 include_keys : Optional [list [str ]] = None ,
187197 exclude_keys : Optional [list [str ]] = None ,
188- include_keys_startswith : Optional [str ] = None
198+ include_keys_startswith : Optional [str | list [ str ] ] = None
189199 ) -> dict :
190200 """
191201 Returns a copy of 'tree' that has had some of its top-level
@@ -197,7 +207,9 @@ def filter_keys(
197207 - exclude_keys: Provide a list of keys to exclude
198208 - include_keys_startswith: Provide a substring that
199209 occurs at the start of the keys. All matches will
200- be included.
210+ be included. If a list of substrings are provided,
211+ then all substrings will be checked for a match
212+ and included.
201213
202214 These filters are additive and are applied in the following
203215 order:
@@ -208,14 +220,18 @@ def filter_keys(
208220 """
209221 include_keys = include_keys or []
210222 exclude_keys = exclude_keys or []
223+ if include_keys_startswith is not None and isinstance (include_keys_startswith , str ):
224+ include_keys_startswith = [include_keys_startswith ]
211225 filtered_keys = []
212226 for key in tree .keys ():
213227 if key in exclude_keys : continue
214- if key .startswith (include_keys_startswith ):
215- filtered_keys .append (key )
228+ if include_keys_startswith is not None :
229+ for include_startswith in include_keys_startswith :
230+ if key .startswith (include_startswith ):
231+ filtered_keys .append (key )
216232 elif key in include_keys :
217233 filtered_keys .append (key )
218- filtered_tree = {k : v for k , v in filtered_tree .items () if k in filtered_keys }
234+ filtered_tree = {k : v for k , v in tree .items () if k in filtered_keys }
219235 return filtered_tree
220236
221237
0 commit comments