1+ from copy import copy
2+ from typing import Hashable , Union , Optional , Any
3+ import operator
4+ import deepmerge
5+
6+ def compare_tree_values (
7+ tree_a : dict | list ,
8+ tree_b : dict | list ,
9+ levels_a : list [Hashable | None ],
10+ levels_b : list [Hashable | None ],
11+ leaf_a : Union [Hashable , list [Hashable ]],
12+ leaf_b : Union [Hashable , list [Hashable ]],
13+ compare_func : Union [str , callable ],
14+ compared_key : Optional [Hashable ]= None ,
15+ * args ,
16+ ** kwargs ,
17+ ) -> dict :
18+ """
19+ Returns a dictionary tree keyed according to
20+ 'tree_a': the first tree to compare
21+ 'tree_b': the second tree to compare
22+ 'levels_a': The levels to iterate through in order to access the leaf keys in
23+ 'leaves_a'. If a level is listed is None, then all keys at that level will
24+ be iterated over.
25+ 'levels_b': The levels to iterate through in order to access the leaf keys in
26+ 'leaves_b'. If a level is listed is None, then all keys at that level will
27+ be iterated over.
28+ 'leaves_a': a list of leaf keys to compare. Must be same length as 'leaves_b'.
29+ 'leaves_b': a list of leaf keys to compare. Must be same length as 'leaves_a'.
30+ 'compare_func': Either one of
31+ {'div', 'sub', 'add', 'mult', 'ge', 'le', 'lt', 'gt', 'eq', 'ne'} or a
32+ user-supplied callable whos call signature takes the values of the individul
33+ elements of 'leaves_a' as the first param, the individual elements of 'leaves_b'
34+ as the second param. Optionally, args and kwargs can be passed and they
35+ will be passed on to the callable.
36+ """
37+ ops = {
38+ "div" : operator .truediv ,
39+ "sub" : operator .sub ,
40+ "add" : operator .add ,
41+ "mul" : operator .mul ,
42+ "ge" : operator .ge ,
43+ "le" : operator .le ,
44+ "lt" : operator .lt ,
45+ "gt" : operator .gt ,
46+ "eq" : operator .eq ,
47+ "ne" : operator .ne ,
48+ }
49+ env_acc = {}
50+ # If we are at the last branch...
51+ subtree_a = retrieve_leaves (tree_a , levels_a , leaf_a )
52+ subtree_b = retrieve_leaves (tree_b , levels_b , leaf_b )
53+
54+ branch_a = trim_branches (subtree_a , levels_a )
55+ branch_b = trim_branches (subtree_b , levels_b )
56+
57+ for trunk in branch_a .keys ():
58+ value_a = branch_a [trunk ]
59+ value_b = branch_b [trunk ]
60+ comparison_operator = ops .get (compare_func , compare_func )
61+ compared_value = comparison_operator (value_a , value_b )
62+ env_acc .setdefault (trunk , {})
63+ env_acc [trunk ].setdefault (leaf_a , value_a )
64+ env_acc [trunk ].setdefault (leaf_b , value_b )
65+ if compared_key is None :
66+ compared_key = str (compare_func )
67+ env_acc [trunk ].setdefault (compared_key , compared_value )
68+ return env_acc
69+
70+
71+ def trim_branches (
72+ tree : dict | list ,
73+ levels : list [Hashable | None ],
74+ ):
75+ """
76+ Returns a copy of the 'tree' but with the branches in
77+ 'levels' trimmed off.
78+ """
79+ trimmed = tree .copy ()
80+ for i in range (len (levels )):
81+ leaf = levels .pop ()
82+ trimmed = retrieve_leaves (trimmed , levels , leaf = leaf )
83+ return trimmed
84+
85+
86+ def retrieve_leaves (
87+ tree : dict | list ,
88+ levels : list [Hashable | None ],
89+ leaf : list [Hashable ] | Hashable | None ,
90+ ) -> dict :
91+ """
92+ Envelopes the tree at the leaf node with 'agg_func'.
93+ """
94+ env_acc = {}
95+ key_error_msg = (
96+ "Key '{level}' does not exist at this level. Available keys: {keys}. "
97+ "Perhaps not all of your tree elements have the same keys. Try enveloping over trees "
98+ "that have the same branch structure and leaf names."
99+ )
100+ # If we are at the last branch...
101+ if not levels :
102+ if leaf is None :
103+ return tree
104+ if isinstance (leaf , list ):
105+ leaf_values = {}
106+ for leaf_elem in leaf :
107+ try :
108+ tree [leaf_elem ]
109+ except KeyError :
110+ raise KeyError (key_error_msg .format (level = leaf_elem , keys = list (tree .keys ())))
111+ leaf_values .update ({leaf_elem : tree [leaf_elem ]})
112+ else :
113+ try :
114+ tree [leaf ]
115+ except KeyError :
116+ raise KeyError (key_error_msg .format (level = leaf , keys = list (tree .keys ())))
117+ leaf_values = tree [leaf ]
118+ return leaf_values
119+ else :
120+ # Otherwise, pop the next level and dive into the tree on that branch
121+ level = levels [0 ]
122+ if level is not None :
123+ try :
124+ tree [level ]
125+ except KeyError :
126+ raise KeyError (key_error_msg .format (level = level , keys = list (tree .keys ())))
127+ env_acc .update ({level : retrieve_leaves (tree [level ], levels [1 :], leaf )})
128+ return env_acc
129+ else :
130+ # If None, then walk all branches of this node of the tree
131+ if isinstance (tree , list ):
132+ tree = {idx : leaf for idx , leaf in enumerate (tree )}
133+ for k , v in tree .items ():
134+ env_acc .update ({k : retrieve_leaves (v , levels [1 :], leaf )})
135+ return env_acc
136+
137+
138+ def extract_keys (
139+ object : dict [str , Any ],
140+ key_name : str ,
141+ include_startswith : Optional [str ] = None ,
142+ exclude_startswith : Optional [str ] = None ,
143+ ) -> list [dict [str , Any ]]:
144+ """
145+ Returns a list of dicts where each dict has a key of 'key_name'
146+ and a value of one of the keys of 'object'.
147+
148+ e.g.
149+ object = {"key1": value, "key2": value, "key3": value}
150+ key_name = "label"
151+
152+ extract_keys(object, key_name) # [{"label": "key1"}, {"label": "key2"}, {"label": "key3"}]
153+
154+ 'include_startswith': If provided, will only include keys that start with this string.
155+ 'exclude_startswith': If provided, will exclude all keys that start with this string.
156+
157+ If both 'include_startswith' and 'exclude_startswith' are provided, exclude is executed
158+ first.
159+ """
160+ shortlist = []
161+ for key in object .keys ():
162+ if exclude_startswith is not None and key .startswith (exclude_startswith ):
163+ continue
164+ else :
165+ shortlist .append (key )
166+
167+ acc = []
168+ for key in shortlist :
169+ if include_startswith is not None and key .startswith (include_startswith ):
170+ acc .append ({key_name : key })
171+ elif include_startswith is None :
172+ acc .append ({key_name : key })
173+
174+ return acc
175+
176+
177+
178+ def merge_trees (trees : list [dict [str , dict ]]) -> dict [str , dict ]:
179+ """
180+ Merges all of the tress (dictionaries) in 'result_trees'.
181+
182+ This is different than a typical dictionary merge (e.g. a | b)
183+ which will merge dictionaries with different keys but will over-
184+ write values if two keys are the same.
185+
186+ Instead, it crawls each branch of the tree and merges the data
187+ within each branch, no matter how deep the branches go.
188+ """
189+ acc = {}
190+ for result_tree in trees :
191+ acc = deepmerge .always_merger .merge (acc , result_tree )
192+ return acc
0 commit comments