diff --git a/taskweaver/role/role.py b/taskweaver/role/role.py index 21872edeb..d91b4967a 100644 --- a/taskweaver/role/role.py +++ b/taskweaver/role/role.py @@ -2,7 +2,7 @@ import os.path from dataclasses import dataclass from datetime import timedelta -from typing import List, Optional, Set, Tuple, Union +from typing import List, Literal, Optional, Set, Tuple, Union from injector import Module, inject, provider @@ -153,13 +153,62 @@ def format_experience( else "" ) + def prepare_loading( + self, + use_flag: bool, + dynamic_sub_path: bool, + base_path: str, + memory: Optional[Memory], + loaded_from_attr: str, + item_type: Literal["experience", "example"], + ) -> Optional[str]: + """Prepare for loading by checking configurations and memory, and return load_from path if applicable.""" + if not use_flag: + setattr(self, f"{item_type}s", []) + return None + + if not os.path.exists(base_path): + raise FileNotFoundError( + f"The default {item_type} base path {base_path} does not exist." + f"The original {item_type} base paths have been changed to `{item_type}s` folder." + f"Please migrate the {item_type}s to the new base path.", + ) + + sub_path = "" + if dynamic_sub_path: + assert memory is not None, f"Memory should be provided when dynamic_{item_type}_sub_path is True" + sub_paths = memory.get_shared_memory_entries(entry_type=f"{item_type}_sub_path") + if sub_paths: + self.tracing.set_span_attribute(f"{item_type}_sub_path", str(sub_paths)) + # todo: handle multiple sub paths + sub_path = sub_paths[0].content + else: + self.logger.info(f"No {item_type} sub path found in memory.") + setattr(self, f"{item_type}s", []) + return None + + load_from = os.path.join(base_path, sub_path) + if getattr(self, loaded_from_attr) is not None and getattr(self, loaded_from_attr) == load_from: + self.logger.info(f"{item_type.capitalize()} already loaded from {load_from}.") + return None + + setattr(self, loaded_from_attr, load_from) + return sub_path + def role_load_experience( self, query: str, memory: Optional[Memory] = None, ) -> None: - if not self.config.use_experience: - self.experiences = [] + sub_path = self.prepare_loading( + self.config.use_experience, + self.config.dynamic_experience_sub_path, + self.config.experience_dir, + memory, + "experience_loaded_from", + "experience", + ) + if sub_path is None: return if self.experience_generator is None: @@ -167,87 +216,50 @@ def role_load_experience( "Experience generator is not initialized. Each role instance should have its own generator.", ) - experience_sub_path = "" - if self.config.dynamic_experience_sub_path: - assert memory is not None, "Memory should be provided when dynamic_experience_sub_path is True" - experience_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path") - if experience_sub_paths: - self.tracing.set_span_attribute("experience_sub_path", str(experience_sub_paths)) - # todo: handle multiple experience sub paths - experience_sub_path = experience_sub_paths[0].content - else: - self.logger.info("No experience sub path found in memory.") - self.experiences = [] - return - - load_from = os.path.join(self.config.experience_dir, experience_sub_path) - if self.experience_loaded_from is None or self.experience_loaded_from != load_from: - self.experience_loaded_from = load_from - self.experience_generator.set_experience_dir(self.config.experience_dir) - self.experience_generator.set_sub_path(experience_sub_path) - self.experience_generator.refresh() - self.experience_generator.load_experience() - self.logger.info( - "Experience loaded successfully for {}, there are {} experiences with filter [{}]".format( - self.alias, - len(self.experience_generator.experience_list), - experience_sub_path, - ), - ) - else: - self.logger.info(f"Experience already loaded from {load_from}.") + self.experience_generator.set_experience_dir(self.config.experience_dir) + self.experience_generator.set_sub_path(sub_path) + self.experience_generator.refresh() + self.experience_generator.load_experience() + self.logger.info( + "Experience loaded successfully for {}, there are {} experiences with filter [{}]".format( + self.alias, + len(self.experience_generator.experience_list), + sub_path, + ), + ) experiences = self.experience_generator.retrieve_experience(query) self.logger.info(f"Retrieved {len(experiences)} experiences for query [{query}]") self.experiences = [exp for exp, _ in experiences] - # todo: `role_load_example` is similar to `role_load_experience`, consider refactoring def role_load_example( self, role_set: Set[str], memory: Optional[Memory] = None, ) -> None: - if not self.config.use_example: - self.examples = [] + sub_path = self.prepare_loading( + self.config.use_example, + self.config.dynamic_example_sub_path, + self.config.example_base_path, + memory, + "example_loaded_from", + "example", + ) + if sub_path is None: return - if not os.path.exists(self.config.example_base_path): - raise FileNotFoundError( - f"The default example base path {self.config.example_base_path} does not exist." - "The original example base paths have been changed to `examples` folder." - "Please migrate the examples to the new base path.", - ) - - example_sub_path = "" - if self.config.dynamic_example_sub_path: - assert memory is not None, "Memory should be provided when dynamic_example_sub_path is True" - example_sub_paths = memory.get_shared_memory_entries(entry_type="example_sub_path") - if example_sub_paths: - self.tracing.set_span_attribute("example_sub_path", str(example_sub_paths)) - # todo: handle multiple sub paths - example_sub_path = example_sub_paths[0].content - else: - self.logger.info("No example sub path found in memory.") - self.examples = [] - return - - load_from = os.path.join(self.config.example_base_path, example_sub_path) - if self.example_loaded_from is None or self.example_loaded_from != load_from: - self.example_loaded_from = load_from - self.examples = load_examples( - folder=self.config.example_base_path, - sub_path=example_sub_path, - role_set=role_set, - ) - self.logger.info( - "Example loaded successfully for {}, there are {} examples with filter [{}]".format( - self.alias, - len(self.examples), - example_sub_path, - ), - ) - else: - self.logger.info(f"Example already loaded from {load_from}.") + self.examples = load_examples( + folder=self.config.example_base_path, + sub_path=sub_path, + role_set=role_set, + ) + self.logger.info( + "Example loaded successfully for {}, there are {} examples with filter [{}]".format( + self.alias, + len(self.examples), + sub_path, + ), + ) class RoleModuleConfig(ModuleConfig):