diff --git a/sqlite/util/graph_net_sample_groups_util.py b/sqlite/util/graph_net_sample_groups_util.py new file mode 100644 index 000000000..39297fe48 --- /dev/null +++ b/sqlite/util/graph_net_sample_groups_util.py @@ -0,0 +1,49 @@ +# graph_net_sample_groups_util.py +from typing import List, Set, Dict +from collections import defaultdict + +from sqlite.orm_models import get_session, GraphNetSampleGroup + + +def get_all_graph_net_sample_groups( + db_path: str, + group_types: List[str], + group_policies: List[str], + versions: List[str], +) -> List[Set[str]]: + """ + Get all graph_net sample groups from database. + + Viba: + get_all_graph_net_sample_groups := + list[set[$sample_uid str]] + <- $group_net_db_file_path str + <- $group_type list[str] + <- $group_policy list[str] + <- $version list[str] + + Args: + db_path: Path to the SQLite database file. + group_types: List of group types to filter (e.g., ["shape_diversity", "dtype_diversity"]). + group_policies: List of group policies to filter (e.g., ["by_bucket"]). + versions: List of policy versions to filter (e.g., ["v0.1"]). + + Returns: + List of sets, each set contains sample UIDs belonging to one group. + """ + session = get_session(db_path) + + query = session.query(GraphNetSampleGroup).filter( + GraphNetSampleGroup.deleted.is_(False), + GraphNetSampleGroup.group_type.in_(group_types), + GraphNetSampleGroup.group_policy.in_(group_policies), + GraphNetSampleGroup.policy_version.in_(versions), + ) + + groups_dict: Dict[str, List[str]] = defaultdict(list) + for row in query.all(): + groups_dict[row.group_uid].append(row.sample_uid) + + session.close() + + return [set(uids) for uids in groups_dict.values()] diff --git a/sqlite/util/graph_net_sample_util.py b/sqlite/util/graph_net_sample_util.py new file mode 100644 index 000000000..e2d2c30ff --- /dev/null +++ b/sqlite/util/graph_net_sample_util.py @@ -0,0 +1,127 @@ +# graph_net_sample_util.py +import json +from typing import Dict, List + +from sqlite.orm_models import get_session, GraphSample, SampleOpNameList + + +class GraphNetSampleTypeGetter: + """ + Get sample_type for a given sample_uid. + + Viba: + GraphNetSampleTypeGetter := + # __call__ + $sample_type str + <- $sample_uid str + # __init__ + <- $group_net_db_file_path str + <- $fetch_cache dict[$sample_uid str, $sample_type str] + """ + + def __init__(self, db_path: str): + self.db_path = db_path + self._cache: Dict[str, str] = {} + + def __call__(self, sample_uid: str) -> str: + """Get sample_type for the given sample_uid.""" + if sample_uid in self._cache: + return self._cache[sample_uid] + + session = get_session(self.db_path) + sample = ( + session.query(GraphSample).filter(GraphSample.uuid == sample_uid).first() + ) + session.close() + + sample_type = sample.sample_type if sample else "" + self._cache[sample_uid] = sample_type + return sample_type + + def bulk_get(self, sample_uids: List[str]) -> Dict[str, str]: + """Bulk get sample_types for multiple sample UIDs.""" + session = get_session(self.db_path) + + samples = ( + session.query(GraphSample).filter(GraphSample.uuid.in_(sample_uids)).all() + ) + + result = {} + for s in samples: + result[s.uuid] = s.sample_type + self._cache[s.uuid] = s.sample_type + + for uid in sample_uids: + if uid not in result: + result[uid] = "" + + session.close() + return result + + +class GraphNetSampleOpSeqGetter: + """ + Get op_seq for a given sample_uid. + + Viba: + GraphNetSampleOpSeqGetter := + # __call__ + $sample_op_seq list[str] + <- $sample_uid str + # __init__ + <- $group_net_db_file_path str + <- $fetch_cache dict[$sample_uid str, $sample_op_seq list[str]] + """ + + def __init__(self, db_path: str): + self.db_path = db_path + self._cache: Dict[str, List[str]] = {} + + def __call__(self, sample_uid: str) -> List[str]: + """Get op_seq for the given sample_uid.""" + if sample_uid in self._cache: + return self._cache[sample_uid] + + session = get_session(self.db_path) + op_list = ( + session.query(SampleOpNameList) + .filter(SampleOpNameList.sample_uuid == sample_uid) + .first() + ) + session.close() + + if op_list and op_list.op_names_json: + op_data = json.loads(op_list.op_names_json) + op_seq = [op["op_name"] for op in op_data] + else: + op_seq = [] + + self._cache[sample_uid] = op_seq + return op_seq + + def bulk_get(self, sample_uids: List[str]) -> Dict[str, List[str]]: + """Bulk get op_seqs for multiple sample UIDs.""" + session = get_session(self.db_path) + + op_lists = ( + session.query(SampleOpNameList) + .filter(SampleOpNameList.sample_uuid.in_(sample_uids)) + .all() + ) + + result = {} + for op_list in op_lists: + if op_list.op_names_json: + op_data = json.loads(op_list.op_names_json) + op_seq = [op["op_name"] for op in op_data] + else: + op_seq = [] + result[op_list.sample_uuid] = op_seq + self._cache[op_list.sample_uuid] = op_seq + + for uid in sample_uids: + if uid not in result: + result[uid] = [] + + session.close() + return result