diff --git a/.gitignore b/.gitignore index f6bd932..969c17b 100644 --- a/.gitignore +++ b/.gitignore @@ -146,4 +146,5 @@ config_files/wa/test_webarena.json config_files/wa/test_webarena/* cache/* -agents/prompts/jsons/* \ No newline at end of file +agents/prompts/jsons/* +log.txt \ No newline at end of file diff --git a/agent/__init__.py b/agent/__init__.py deleted file mode 100644 index 9028d30..0000000 --- a/agent/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .agent import ( - Agent, - PromptAgent, - TeacherForcingAgent, - construct_agent, -) - -__all__ = ["Agent", "TeacherForcingAgent", "PromptAgent", "construct_agent"] diff --git a/agent/agent.py b/agent/agents.py similarity index 69% rename from agent/agent.py rename to agent/agents.py index 5fbcba9..9d79c85 100644 --- a/agent/agent.py +++ b/agent/agents.py @@ -1,8 +1,16 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +up_dir = parent_dir +for i in range(3): + sys.path.append(up_dir) + up_dir = os.path.dirname(up_dir) +from kutils import DEBUG, INFO, WARN, ERROR +import utils as u import argparse import json +import importlib from typing import Any, Optional - -import tiktoken from beartype import beartype from PIL import Image @@ -14,36 +22,23 @@ create_id_based_action, create_none_action, create_playwright_action, + create_vision_action, + create_mas_action, ) from browser_env.utils import Observation, StateInfo from llms import ( - call_llm, - generate_from_huggingface_completion, - generate_from_openai_chat_completion, - generate_from_openai_completion, + # call_llm, + # generate_from_huggingface_completion, + # generate_from_openai_chat_completion, + # generate_from_openai_completion, lm_config, ) from llms.tokenizers import Tokenizer - - -class Agent: - """Base class for the agent""" - - def __init__(self, *args: Any) -> None: - pass - - def next_action( - self, trajectory: Trajectory, intent: str, meta_data: Any - ) -> Action: - """Predict the next action given the observation""" - raise NotImplementedError - - def reset( - self, - test_config_file: str, - ) -> None: - raise NotImplementedError - +from prompts.prompt_constructor import MultimodalCoTPromptConstructor, MultimodalCoTPromptConstructor +from llms.ais_requestor import get_lm_requestor +from base_agent import Agent +from multi_agent import MultiAgent +from vision_agent import VisionAgent class TeacherForcingAgent(Agent): """Agent that follows a pre-defined action sequence""" @@ -106,7 +101,7 @@ def __init__( self, action_set_tag: str, lm_config: lm_config.LMConfig, - prompt_constructor: PromptConstructor, + prompt_constructor: PromptConstructor | MultimodalCoTPromptConstructor | MultimodalCoTPromptConstructor, captioning_fn = None, ) -> None: super().__init__() @@ -114,19 +109,19 @@ def __init__( self.prompt_constructor = prompt_constructor self.action_set_tag = action_set_tag self.captioning_fn = captioning_fn - - # Check if the model is multimodal. - if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model) and type(prompt_constructor) == MultimodalCoTPromptConstructor: - self.multimodal_inputs = True - else: - self.multimodal_inputs = False + self.multimodal_inputs = True + self.kq = get_lm_requestor(lm_config.model) def set_action_set_tag(self, tag: str) -> None: self.action_set_tag = tag @beartype def next_action( - self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], images: Optional[list[Image.Image]] = None, + self, + trajectory: Trajectory, + intent: str, + meta_data: dict[str, Any], + images: Optional[list[Image.Image]] = None, output_response: bool = False ) -> Action: # Create page screenshot image for multimodal models. @@ -155,17 +150,29 @@ def next_action( ) if self.multimodal_inputs: - prompt = self.prompt_constructor.construct( + messages = self.prompt_constructor.construct( trajectory, intent, page_screenshot_img, images, meta_data ) else: - prompt = self.prompt_constructor.construct( + messages = self.prompt_constructor.construct( trajectory, intent, meta_data ) lm_config = self.lm_config n = 0 while True: - response = call_llm(lm_config, prompt) + # response = call_llm(lm_config, prompt) + # u.write_json(f'{u.get_time()}.json', messages) + + try: + model = lm_config.model + if 'qwen' in model: + response = self.kq.infer_messages(messages) + else: + raise ValueError(lm_config) + except Exception as e: + ERROR(e) + response = 'stop []' + force_prefix = self.prompt_constructor.instruction[ "meta_data" ].get("force_prefix", "") @@ -204,14 +211,13 @@ def reset(self, test_config_file: str) -> None: def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent: llm_config = lm_config.construct_llm_config(args) + default_provider = 'openai' + default_model = 'gpt-3.5-turbo-1106' + tokenizer = Tokenizer(default_provider, default_model) + agent: Agent - if args.agent_type == "teacher_forcing": - agent = TeacherForcingAgent() - elif args.agent_type == "prompt": - with open(args.instruction_path) as f: - constructor_type = json.load(f)["meta_data"]["prompt_constructor"] - tokenizer = Tokenizer(args.provider, args.model) - prompt_constructor = eval(constructor_type)( + if args.mode == "som": + prompt_constructor = MultimodalCoTPromptConstructor( args.instruction_path, lm_config=llm_config, tokenizer=tokenizer ) agent = PromptAgent( @@ -220,8 +226,29 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent: prompt_constructor=prompt_constructor, captioning_fn=captioning_fn ) + elif args.mode == "mas": + prompt_constructor = MultimodalCoTPromptConstructor( + args.instruction_path, lm_config=llm_config, tokenizer=tokenizer + ) + agent = MultiAgent( + action_set_tag=args.action_set_tag, + lm_config=llm_config, + prompt_constructor=prompt_constructor, + captioning_fn=captioning_fn + ) + elif args.mode == 'url_mas': + from url_infer.url_multi_agent import URLMultiAgent + prompt_constructor = MultimodalCoTPromptConstructor( + args.instruction_path, lm_config=llm_config, tokenizer=tokenizer + ) + agent = URLMultiAgent( + action_set_tag=args.action_set_tag, + lm_config=llm_config, + prompt_constructor=prompt_constructor, + captioning_fn=captioning_fn + ) else: raise NotImplementedError( - f"agent type {args.agent_type} not implemented" + f"agent type {args.mode} not implemented" ) return agent diff --git a/agent/base_agent.py b/agent/base_agent.py new file mode 100644 index 0000000..ef56009 --- /dev/null +++ b/agent/base_agent.py @@ -0,0 +1,22 @@ +from typing import Any, Optional +from browser_env import Trajectory +from browser_env.actions import Action + +class Agent: + """Base class for the agent""" + + def __init__(self, *args: Any) -> None: + pass + + def next_action( + self, trajectory: Trajectory, intent: str, meta_data: Any + ) -> Action: + """Predict the next action given the observation""" + raise NotImplementedError + + def reset( + self, + test_config_file: str, + ) -> None: + raise NotImplementedError + diff --git a/agent/multi_agent.py b/agent/multi_agent.py new file mode 100644 index 0000000..dece622 --- /dev/null +++ b/agent/multi_agent.py @@ -0,0 +1,130 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +up_dir = parent_dir +for i in range(3): + sys.path.append(up_dir) + up_dir = os.path.dirname(up_dir) +from kutils import DEBUG, INFO, WARN, ERROR +import utils as u +from typing import Any, Optional +from beartype import beartype +from PIL import Image +from agent.prompts import * +from browser_env import Trajectory +from browser_env.actions import ( + Action, + ActionParsingError, + create_id_based_action, + create_none_action, + create_playwright_action, + create_vision_action, + create_mas_action, +) +from browser_env.utils import Observation, StateInfo +from llms import lm_config +from llms.tokenizers import Tokenizer +from prompts.prompt_constructor import MultimodalCoTPromptConstructor, MultimodalCoTPromptConstructor +from base_agent import Agent +from AWorld.examples.visualwebarena.action_team import ActionTeam + +class MultiAgent(Agent): + @beartype + def __init__( + self, + action_set_tag: str, + lm_config: lm_config.LMConfig, + prompt_constructor: PromptConstructor | + MultimodalCoTPromptConstructor | + MultimodalCoTPromptConstructor, + captioning_fn = None, + ) -> None: + super().__init__() + self.lm_config = lm_config + self.prompt_constructor = prompt_constructor + self.action_set_tag = action_set_tag + self.captioning_fn = captioning_fn + self.multimodal_inputs = True + self.bu_agent = ActionTeam(self.lm_config.model) + + def set_action_set_tag(self, tag: str) -> None: + self.action_set_tag = tag + + @beartype + def next_action( + self, + trajectory: Trajectory, + intent: str, + meta_data: dict[str, Any], + images: Optional[list[Image.Image]] = None, + output_response: bool = False + ) -> Action: + # Create page screenshot image for multimodal models. + ori_page_screenshot_arr = trajectory[-1]["observation"]["ori_image"] + ori_page_screenshot_img = Image.fromarray(ori_page_screenshot_arr) + som_page_screenshot_arr = trajectory[-1]["observation"]["image"] + som_page_screenshot_img = Image.fromarray(som_page_screenshot_arr) + + last_som_img = None + last_ori_img = None + for i in range(-2, -len(trajectory), -1): + if 'observation' in trajectory[i].keys(): + last_som_img = trajectory[i]['observation']['image'] + last_som_img = Image.fromarray(last_som_img) + last_ori_img = trajectory[i]['observation']['ori_image'] + last_ori_img = Image.fromarray(last_ori_img) + break + + # Caption the input image, if provided. + image_input_caption = '' + input_img = None + try: + if images is not None and len(images) > 0: + if self.captioning_fn is not None: + for image_i, image in enumerate(images): + if image_i == 0: + image_input_caption += f'Input image {image_i+1}: "{self.captioning_fn([image])[0]}"' + else: + image_input_caption += f'input image {image_i+1}: "{self.captioning_fn([image])[0]}"' + if len(images) > 1: + image_input_caption += ", " + # Update intent to include captions of input images. + # intent = f"{image_input_caption}\nTask: {intent}" + elif not self.multimodal_inputs: + print("WARNING: Input image provided but no image captioner available.") + input_img = images[0] + except Exception as e: + ERROR(f'caption function {self.lm_config.caption_model} {e}') + + page_text = trajectory[-1]['observation']['text'] + page_texts = page_text.split('\n') + page_texts = [a.replace(' ', '') for a in page_texts if not a.startswith('[]')] + page_text = '\n'.join(page_texts) + + state_info: StateInfo = trajectory[-1] # type: ignore[assignment] + obs = state_info["observation"]['text'] + page = meta_data["page"] + url = page.url + action_history = meta_data["action_history"] + tabs = meta_data['tabs'] + site_name = self.lm_config.domain + + action_info, response = self.bu_agent.next_action(output_response, + intent, action_history, site_name, url, obs, tabs, input_img, som_page_screenshot_img, ori_page_screenshot_img, page) + force_prefix = self.prompt_constructor.instruction["meta_data"].get("force_prefix", "") + response = f"{force_prefix}{response}" + + try: + parsed_response = self.prompt_constructor.extract_action(response) + action = create_mas_action(parsed_response, obs) + action["raw_prediction"] = response + except Exception as e: + action = create_none_action() + action["raw_prediction"] = response + + action['action_info'] = action_info + action['domain'] = self.lm_config.domain + return action + + def reset(self, test_config_file: str) -> None: + self.called_closing_agents = [] \ No newline at end of file diff --git a/agent/prompts/jsons/mas.json b/agent/prompts/jsons/mas.json new file mode 100644 index 0000000..0fe6f35 --- /dev/null +++ b/agent/prompts/jsons/mas.json @@ -0,0 +1,34 @@ +{ + "intro": "You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.\n\nHere's the information you'll have:\nThe user's objective: This is the task you're trying to complete.\nThe current web page screenshot: This is a screenshot of the webpage, with each interactable element assigned a unique numerical id. Each bounding box and its respective id shares the same color.\nThe observation, which lists the IDs of all interactable elements on the current web page with their text content if any, in the format [id] [tagType] [text content]. tagType is the type of the element, such as button, link, or textbox. text content is the text content of the element. For example, [1234] [button] ['Add to Cart'] means that there is a button with id 1234 and text content 'Add to Cart' on the current web page. [] [StaticText] [text] means that the element is of some text that is not interactable.\nThe current web page's URL: This is the page you're currently navigating.\nThe open tabs: These are the tabs you have open.\nThe previous action: This is the action you just performed. It may be helpful to track your progress.\n\nThe actions you can perform fall into several categories:\n\nPage Operation Actions:\n```click [id]```: This action clicks on an element with a specific id on the webpage.\n```type [id] [content]```: Use this to type the content into the field with id. By default, the \"Enter\" key is pressed after typing unless press_enter_after is set to 0, i.e., ```type [id] [content] [0]```.\n```hover [id]```: Hover over an element with id.\n```press [key_comb]```: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).\n```scroll [down]``` or ```scroll [up]```: Scroll the page up or down.\n\nTab Management Actions:\n```new_tab```: Open a new, empty browser tab.\n```tab_focus [tab_index]```: Switch the browser's focus to a specific tab using its index.\n```close_tab```: Close the currently active tab.\n\nURL Navigation Actions:\n```goto [url]```: Navigate to a specific URL.\n```go_back```: Navigate to the previously viewed page.\n```go_forward```: Navigate to the next page (if a previous 'go_back' action was performed).\n\nCompletion Action:\n```stop [answer]```: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket.\n\nHomepage:\nIf you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.\nhttp://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.\n\nTo be successful, it is very important to follow the following rules:\n1. You should only issue an action that is valid given the current observation\n2. You should only issue one action at a time.\n3. You should follow the examples to reason step by step and then issue the next action.\n4. Generate the action in the correct format. Start with a \"In summary, the next action I will perform is\" phrase, followed by action inside ``````. For example, \"In summary, the next action I will perform is ```click [1234]```\".\n5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.", + "examples": [ + [ + "OBSERVATION:\n[31] [IMG] [Image, description: hp fx-7010dn fax machine, url: http://ec2-3-13-232-171.us-east-2.compute.amazonaws.com:7770/media/catalog/product/cache/89ff578b9cd87e0600daac45c9e1ea98/B/0/B08GKZ3ZKD.0.jpg]\n[32] [A] [HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)]\n[] [StaticText] [$279.49]\n[33] [BUTTON] [Add to Cart]\n[34] [A] [Add to Wish List]\n[35] [A] [Add to Compare]\nURL: http://onestopmarket.com/office-products/office-electronics.html\nOBJECTIVE: What is the price of HP Inkjet Fax Machine?\nPREVIOUS ACTION: None", + "Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```", + "agent/prompts/som_examples/som_example1.png" + ], + [ + "OBSERVATION:\n[] [StaticText] [/f/food]\n[] [StaticText] [[homemade] Obligatory Halloween Pumpkin Loaf!\tSubmitted by\tkneechalice\tt3_yid9lu\t1 year ago]\n[9] [IMG] []\n[] [StaticText] [Submitted by\tkneechalice\tt3_yid9lu\t1 year ago]\n[10] [A] [kneechalice]\n[11] [A] [45 comments]\n[] [StaticText] [[I ate] Maple Pecan Croissant\tSubmitted by\tAccordingtoJP\tt3_y3hrpn\t1 year ago]\n[14] [IMG] []\n[] [StaticText] [Submitted by\tAccordingtoJP\tt3_y3hrpn\t1 year ago]\n[15] [A] [AccordingtoJP]\n[16] [A] [204 comments]\nURL: http://reddit.com\nOBJECTIVE: Tell me what the top comment on the croissant post says.\nPREVIOUS ACTION: None", + "Let's think step-by-step. This page has a post titled '[I ate] Maple Pecan Croissant', which is the post mentioned in the objective. In order to find the top comment, I will navigate into the comments section of the post. In summary, the next action I will perform is ```click [11]```", + "agent/prompts/som_examples/som_example2.png" + ], + [ + "OBSERVATION:\n[] [StaticText] [What are you looking for today?]\n[5] [INPUT] []\n[6] [SELECT] [Select a category]\n[7] [BUTTON] [Search]\n[] [StaticText] [Latest Listings]\n[] [StaticText] [Atlas Powered Audio System w/ Tripod\t150.00 $\tMusic instruments\tBorough of Red Lion (Pennsylvania)\t2023/11/16]\n[8] [IMG] [Atlas Powered Audio System w/ Tripod]\n[9] [A] [Atlas Powered Audio System w/ Tripod]\n[] [StaticText] [150.00 $]\n[] [StaticText] [Neptune Gaming Console\t350.00 $\tVideo gaming\tPennwyn (Pennsylvania)\t2023/11/16]\n[10] [IMG] [Neptune Gaming Console]\n[11] [A] [Neptune Gaming Console]\n[] [StaticText] [350.00 $]\nURL: http://classifieds.com\nOBJECTIVE: Help me find the cheapest dark colored guitar.\nPREVIOUS ACTION: None", + "Let's think step-by-step. The objective is to find the cheapest dark colored guitar on the site. The site has a search box whose ID is [5]. I can search for guitars by entering \"guitar\". I can submit this by pressing the Enter afterwards. In summary, the next action I will perform is ```type [5] [guitar] [1]```", + "agent/prompts/som_examples/som_example3.png" + ] + ], + "template": "OBSERVATION: {observation}\nURL: {url}\nOBJECTIVE: {objective}\nPREVIOUS ACTION: {previous_action}", + "meta_data": { + "observation": "image_som", + "action_type": "som", + "keywords": [ + "url", + "objective", + "observation", + "previous_action" + ], + "prompt_constructor": "MultimodalCoTPromptConstructor", + "answer_phrase": "In summary, the next action I will perform is", + "action_splitter": "```" + } +} \ No newline at end of file diff --git a/agent/prompts/jsons/pure_prompt.md b/agent/prompts/jsons/pure_prompt.md new file mode 100644 index 0000000..fad5d3e --- /dev/null +++ b/agent/prompts/jsons/pure_prompt.md @@ -0,0 +1,41 @@ +"You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue. + +Here's the information you'll have: +The user's objective: This is the task you're trying to complete. +The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information. +The current web page's URL: This is the page you're currently navigating. +The open tabs: These are the tabs you have open. +The previous action: This is the action you just performed. It may be helpful to track your progress. + +The actions you can perform fall into several categories: + +Page Operation Actions: +```click [id]```: This action clicks on an element with a specific id on the webpage. +```type [id] [content]```: Use this to type the content into the field with id. By default, the \"Enter\" key is pressed after typing unless press_enter_after is set to 0, i.e., ```type [id] [content] [0]```. +```hover [id]```: Hover over an element with id. +```press [key_comb]```: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v). +```scroll [down]``` or ```scroll [up]```: Scroll the page up or down. + +Tab Management Actions: +```new_tab```: Open a new, empty browser tab. +```tab_focus [tab_index]```: Switch the browser's focus to a specific tab using its index. The current tab is index 0, and the wiki tab index is 1. +```close_tab```: Close the currently active tab. + +URL Navigation Actions: +```goto [url]```: Navigate to a specific URL. +```go_back```: Navigate to the previously viewed page. +```go_forward```: Navigate to the next page (if a previous 'go_back' action was performed). + +Completion Action: +```stop [answer]```: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. + +Homepage: +If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit. +http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites. + +To be successful, it is very important to follow the following rules: +1. You should only issue an action that is valid given the current observation +2. You should only issue one action at a time. +3. You should follow the examples to reason step by step and then issue the next action. +4. Generate the action in the correct format. Start with a \"In summary, the next action I will perform is\" phrase, followed by action inside ``````. For example, \"In summary, the next action I will perform is ```click [1234]```\". +5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.", \ No newline at end of file diff --git a/agent/prompts/jsons/som.json b/agent/prompts/jsons/som.json new file mode 100644 index 0000000..0788788 --- /dev/null +++ b/agent/prompts/jsons/som.json @@ -0,0 +1,34 @@ +{ + "intro": "You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.\n\nHere's the information you'll have:\nThe user's objective: This is the task you're trying to complete.\nThe current web page screenshot: This is a screenshot of the webpage, with each interactable element assigned a unique numerical id. Each bounding box and its respective id shares the same color.\nThe observation, which lists the IDs of all interactable elements on the current web page with their text content if any, in the format [id] [tagType] [text content]. tagType is the type of the element, such as button, link, or textbox. text content is the text content of the element. For example, [1234] [button] ['Add to Cart'] means that there is a button with id 1234 and text content 'Add to Cart' on the current web page. [] [StaticText] [text] means that the element is of some text that is not interactable.\nThe current web page's URL: This is the page you're currently navigating.\nThe open tabs: These are the tabs you have open.\nThe previous action: This is the action you just performed. It may be helpful to track your progress.\n\nThe actions you can perform fall into several categories:\n\nPage Operation Actions:\n```click [id]```: This action clicks on an element with a specific id on the webpage.\n```type [id] [content]```: Use this to type the content into the field with id. By default, the \"Enter\" key is pressed after typing unless press_enter_after is set to 0, i.e., ```type [id] [content] [0]```. When you wanna type something in to search bar, you don't need to click it first.\n```hover [id]```: Hover over an element with id.\n```press [key_comb]```: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).\n```scroll [down]``` or ```scroll [up]```: Scroll the page up or down.\n\nTab Management Actions:\n```new_tab```: Open a new, empty browser tab.\n```tab_focus [tab_index]```: Switch the browser's focus to a specific tab using its index. The current tab is index 0, and the wiki tab index is 1.\n```close_tab```: Close the currently active tab.\n\nURL Navigation Actions:\n```goto [url]```: Navigate to a specific URL.\n```go_back```: Navigate to the previously viewed page.\n```go_forward```: Navigate to the next page (if a previous 'go_back' action was performed).\n\nCompletion Action:\n```stop [answer]```: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket.\n\nHomepage:\nIf you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.\nhttp://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.\n\nTo be successful, it is very important to follow the following rules:\n1. You should only issue an action that is valid given the current observation\n2. You should only issue one action at a time.\n3. The search box only accepts input of less than three keywords.\n4. Generate the action in the correct format. Start with a \"In summary, the next action I will perform is\" phrase, followed by action inside ``````. For example, \"In summary, the next action I will perform is ```click [1234]```\".\n5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.", + "examples": [ + [ + "OBSERVATION:\n[31] [IMG] [Image, description: hp fx-7010dn fax machine, url: http://ec2-3-13-232-171.us-east-2.compute.amazonaws.com:7770/media/catalog/product/cache/89ff578b9cd87e0600daac45c9e1ea98/B/0/B08GKZ3ZKD.0.jpg]\n[32] [A] [HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)]\n[] [StaticText] [$279.49]\n[33] [BUTTON] [Add to Cart]\n[34] [A] [Add to Wish List]\n[35] [A] [Add to Compare]\nURL: http://onestopmarket.com/office-products/office-electronics.html\nOBJECTIVE: What is the price of HP Inkjet Fax Machine?\nPREVIOUS ACTION: None", + "Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```", + "agent/prompts/som_examples/som_example1.png" + ], + [ + "OBSERVATION:\n[] [StaticText] [/f/food]\n[] [StaticText] [[homemade] Obligatory Halloween Pumpkin Loaf!\tSubmitted by\tkneechalice\tt3_yid9lu\t1 year ago]\n[9] [IMG] []\n[] [StaticText] [Submitted by\tkneechalice\tt3_yid9lu\t1 year ago]\n[10] [A] [kneechalice]\n[11] [A] [45 comments]\n[] [StaticText] [[I ate] Maple Pecan Croissant\tSubmitted by\tAccordingtoJP\tt3_y3hrpn\t1 year ago]\n[14] [IMG] []\n[] [StaticText] [Submitted by\tAccordingtoJP\tt3_y3hrpn\t1 year ago]\n[15] [A] [AccordingtoJP]\n[16] [A] [204 comments]\nURL: http://reddit.com\nOBJECTIVE: Tell me what the top comment on the croissant post says.\nPREVIOUS ACTION: None", + "Let's think step-by-step. This page has a post titled '[I ate] Maple Pecan Croissant', which is the post mentioned in the objective. In order to find the top comment, I will navigate into the comments section of the post. In summary, the next action I will perform is ```click [11]```", + "agent/prompts/som_examples/som_example2.png" + ], + [ + "OBSERVATION:\n[] [StaticText] [What are you looking for today?]\n[5] [INPUT] []\n[6] [SELECT] [Select a category]\n[7] [BUTTON] [Search]\n[] [StaticText] [Latest Listings]\n[] [StaticText] [Atlas Powered Audio System w/ Tripod\t150.00 $\tMusic instruments\tBorough of Red Lion (Pennsylvania)\t2023/11/16]\n[8] [IMG] [Atlas Powered Audio System w/ Tripod]\n[9] [A] [Atlas Powered Audio System w/ Tripod]\n[] [StaticText] [150.00 $]\n[] [StaticText] [Neptune Gaming Console\t350.00 $\tVideo gaming\tPennwyn (Pennsylvania)\t2023/11/16]\n[10] [IMG] [Neptune Gaming Console]\n[11] [A] [Neptune Gaming Console]\n[] [StaticText] [350.00 $]\nURL: http://classifieds.com\nOBJECTIVE: Help me find the cheapest dark colored guitar.\nPREVIOUS ACTION: None", + "Let's think step-by-step. The objective is to find the cheapest dark colored guitar on the site. The site has a search box whose ID is [5]. I can search for guitars by entering \"guitar\". I can submit this by pressing the Enter afterwards. In summary, the next action I will perform is ```type [5] [guitar] [1]```", + "agent/prompts/som_examples/som_example3.png" + ] + ], + "template": "OBSERVATION: {observation}\nURL: {url}\nOBJECTIVE: {objective}\nPREVIOUS ACTION: {previous_action}", + "meta_data": { + "observation": "image_som", + "action_type": "som", + "keywords": [ + "url", + "objective", + "observation", + "previous_action" + ], + "prompt_constructor": "MultimodalCoTPromptConstructor", + "answer_phrase": "In summary, the next action I will perform is", + "action_splitter": "```" + } +} \ No newline at end of file diff --git a/agent/prompts/jsons/vision.json b/agent/prompts/jsons/vision.json new file mode 100644 index 0000000..f65c66e --- /dev/null +++ b/agent/prompts/jsons/vision.json @@ -0,0 +1,34 @@ +{ + "intro": "You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.\n\nHere's the information you'll have:\nThe user's objective: This is the task you're trying to complete.\nThe current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.\nThe current web page's URL: This is the page you're currently navigating.\nThe open tabs: These are the tabs you have open.\nThe previous action: This is the action you just performed. It may be helpful to track your progress.\n\nThe actions you can perform fall into several categories:\n\nPage Operation Actions:\n```click [id]```: This action clicks on an element with a specific id on the webpage.\n```type [id] [content]```: Use this to type the content into the field with id. By default, the \"Enter\" key is pressed after typing unless press_enter_after is set to 0, i.e., ```type [id] [content] [0]```.\n```hover [id]```: Hover over an element with id.\n```press [key_comb]```: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).\n```scroll [down]``` or ```scroll [up]```: Scroll the page up or down.\n\nTab Management Actions:\n```new_tab```: Open a new, empty browser tab.\n```tab_focus [tab_index]```: Switch the browser's focus to a specific tab using its index.\n```close_tab```: Close the currently active tab.\n\nURL Navigation Actions:\n```goto [url]```: Navigate to a specific URL.\n```go_back```: Navigate to the previously viewed page.\n```go_forward```: Navigate to the next page (if a previous 'go_back' action was performed).\n\nCompletion Action:\n```stop [answer]```: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket.\n\nHomepage:\nIf you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.\nhttp://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.\n\nTo be successful, it is very important to follow the following rules:\n1. You should only issue an action that is valid given the current observation\n2. You should only issue one action at a time.\n3. You should follow the examples to reason step by step and then issue the next action.\n4. Generate the action in the correct format. Start with a \"In summary, the next action I will perform is\" phrase, followed by action inside ``````. For example, \"In summary, the next action I will perform is ```click [1234]```\".\n5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.", + "examples": [ + [ + "OBSERVATION:\n[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'\n[1749] StaticText '$279.49'\n[1757] button 'Add to Cart'\n[1760] button 'Add to Wish List'\n[1761] button 'Add to Compare'\nURL: http://onestopmarket.com/office-products/office-electronics.html\nOBJECTIVE: What is the price of HP Inkjet Fax Machine?\nPREVIOUS ACTION: None", + "Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```", + "agent/prompts/multimodal_examples/multimodal_example1.png" + ], + [ + "OBSERVATION:\n[204] heading '/f/food'\n[593] heading '[homemade] Obligatory Halloween Pumpkin Loaf!'\n\t[942] link '[homemade] Obligatory Halloween Pumpkin Loaf!'\n[945] StaticText 'Submitted by '\n[30] link 'kneechalice' expanded: False\n[1484] StaticText 't3_yid9lu'\n[949] time 'October 31, 2022 at 10:10:03 AM EDT'\n\t[1488] StaticText '1 year ago'\n[1489] link '45 comments'\n[605] heading '[I ate] Maple Pecan Croissant'\n\t[963] link '[I ate] Maple Pecan Croissant'\n[966] StaticText 'Submitted by '\n[37] link 'AccordingtoJP' expanded: False\n[1494] StaticText 't3_y3hrpn'\n[970] time 'October 13, 2022 at 10:41:09 PM EDT'\n\t[1498] StaticText '1 year ago'\n[1499] link '204 comments'\nURL: http://reddit.com\nOBJECTIVE: Tell me what the top comment on the croissant post says.\nPREVIOUS ACTION: None", + "Let's think step-by-step. This page has a post titled '[I ate] Maple Pecan Croissant', which is the post mentioned in the objective. In order to find the top comment, I will navigate into the comments section of the post. In summary, the next action I will perform is ```click [1499]```", + "agent/prompts/multimodal_examples/multimodal_example2.png" + ], + [ + "OBSERVATION:\n[42] link 'My account'\n[43] link 'Logout'\n[44] link 'Publish Ad'\n[25] heading 'What are you looking for today?'\n[143] StaticText 'Keyword'\n[81] textbox 'e.g., a blue used car' required: False\n[146] StaticText 'Category'\n[28] heading 'Latest Listings'\n[86] link 'Atlas Powered Audio System w/ Tripod'\n\t[176] img 'Atlas Powered Audio System w/ Tripod'\n[511] StaticText '150.00 $'\n[88] link 'Neptune Gaming Console'\n\t[178] img 'Neptune Gaming Console'\n[515] StaticText '350.00 $'\nURL: http://classifieds.com\nOBJECTIVE: Help me find the cheapest dark colored guitar.\nPREVIOUS ACTION: None", + "Let's think step-by-step. The objective is to find the cheapest dark colored guitar on the site. The site has a search box whose ID is [81]. I can search for guitars by entering \"guitar\". I can submit this by pressing the Enter afterwards. In summary, the next action I will perform is ```type [81] [guitar] [1]```", + "agent/prompts/multimodal_examples/multimodal_example3.png" + ] + ], + "template": "OBSERVATION:\n{observation}\nURL: {url}\nOBJECTIVE: {objective}\nPREVIOUS ACTION: {previous_action}", + "meta_data": { + "observation": "accessibility_tree", + "action_type": "id_accessibility_tree", + "keywords": [ + "url", + "objective", + "observation", + "previous_action" + ], + "prompt_constructor": "MultimodalCoTPromptConstructor", + "answer_phrase": "In summary, the next action I will perform is", + "action_splitter": "```" + } +} \ No newline at end of file diff --git a/agent/prompts/prompt_constructor.py b/agent/prompts/prompt_constructor.py index 5c50c9d..b446cbb 100644 --- a/agent/prompts/prompt_constructor.py +++ b/agent/prompts/prompt_constructor.py @@ -1,3 +1,4 @@ +import os import json import re from pathlib import Path @@ -32,7 +33,13 @@ def __init__( self.obs_modality = "text" self.lm_config = lm_config instruction = json.load(open(self.instruction_path)) - instruction["examples"] = [tuple(e) for e in instruction["examples"]] + instruction["examples"] = [list(e) for e in instruction["examples"]] + + self.example_path = os.path.dirname(os.path.abspath(__file__)) + '/json/' + self.example_path = self.example_path[:self.example_path.find('visualwebarena')] + '/visualwebarena/' + for i in range(len(instruction["examples"])): + instruction["examples"][i][-1] = self.example_path + instruction["examples"][i][-1] + self.instruction: Instruction = instruction self.tokenizer = tokenizer @@ -306,7 +313,10 @@ def construct( page = state_info["info"]["page"] url = page.url - previous_action_str = meta_data["action_history"][-1] + if meta_data["action_history"]: + previous_action_str = '\n'.join(meta_data["action_history"]) + else: + previous_action_str = "None" current = template.format( objective=intent, url=self.map_url_to_real(url), @@ -331,116 +341,68 @@ def get_lm_api_input( ) -> APIInput: """Return the require format for an API""" message: list[dict[str, str]] | str | list[str | Image.Image] - if "openai" in self.lm_config.provider: - if self.lm_config.mode == "chat": - message = [ - { - "role": "system", - "content": [{"type": "text", "text": intro}], - } - ] - for (x, y, z) in examples: - example_img = Image.open(z) - message.append( + # if "openai" in self.lm_config.provider: + # if self.lm_config.mode == "chat": + message = [ + { + "role": "system", + "content": [{"type": "text", "text": intro}], + } + ] + for (x, y, z) in examples: + example_img = Image.open(z) + message.append( + { + "role": "user", + "name": "example_user", + "content": [ + {"type": "text", "text": x}, { - "role": "system", - "name": "example_user", - "content": [ - {"type": "text", "text": x}, - { - "type": "text", - "text": "IMAGES: (1) current page screenshot", - }, - { - "type": "image_url", - "image_url": { - "url": pil_to_b64(example_img) - }, - }, - ], - } - ) - message.append( + "type": "text", + "text": "IMAGES: (1) current page screenshot", + }, { - "role": "system", - "name": "example_assistant", - "content": [{"type": "text", "text": y}], - } - ) + "type": "image_url", + "image_url": { + "url": pil_to_b64(example_img) + }, + }, + ], + } + ) + message.append( + { + "role": "system", + "name": "example_assistant", + "content": [{"type": "text", "text": y}], + } + ) - # Encode images and page_screenshot_img as base64 strings. - current_prompt = current - content = [ + # Encode images and page_screenshot_img as base64 strings. + current_prompt = current + content = [ + { + "type": "text", + "text": "IMAGES: (1) current page screenshot", + }, + { + "type": "image_url", + "image_url": {"url": pil_to_b64(page_screenshot_img)}, + }, + ] + for image_i, image in enumerate(images): + content.extend( + [ { "type": "text", - "text": "IMAGES: (1) current page screenshot", + "text": f"({image_i+2}) input image {image_i+1}", }, { "type": "image_url", - "image_url": {"url": pil_to_b64(page_screenshot_img)}, + "image_url": {"url": pil_to_b64(image)}, }, ] - for image_i, image in enumerate(images): - content.extend( - [ - { - "type": "text", - "text": f"({image_i+2}) input image {image_i+1}", - }, - { - "type": "image_url", - "image_url": {"url": pil_to_b64(image)}, - }, - ] - ) - content = [{"type": "text", "text": current_prompt}] + content - - message.append({"role": "user", "content": content}) - return message - else: - raise ValueError( - f"GPT-4V models do not support mode {self.lm_config.mode}" - ) - elif "google" in self.lm_config.provider: - if self.lm_config.mode == "completion": - message = [ - intro, - "Here are a few examples:", - ] - for (x, y, z) in examples: - example_img = Image.open(z) - message.append(f"Observation\n:{x}\n") - message.extend( - [ - "IMAGES:", - "(1) current page screenshot:", - pil_to_vertex(example_img), - ] - ) - message.append(f"Action: {y}") - message.append("Now make prediction given the observation") - message.append(f"Observation\n:{current}\n") - message.extend( - [ - "IMAGES:", - "(1) current page screenshot:", - pil_to_vertex(page_screenshot_img), - ] - ) - for image_i, image in enumerate(images): - message.extend( - [ - f"({image_i+2}) input image {image_i+1}", - pil_to_vertex(image), - ] - ) - message.append("Action:") - return message - else: - raise ValueError( - f"Gemini models do not support mode {self.lm_config.mode}" - ) - else: - raise NotImplementedError( - f"Provider {self.lm_config.provider} not implemented" ) + content = [{"type": "text", "text": current_prompt}] + content + message.append({"role": "user", "content": content}) + return message diff --git a/agent/vision_agent.py b/agent/vision_agent.py new file mode 100644 index 0000000..c2b3bf4 --- /dev/null +++ b/agent/vision_agent.py @@ -0,0 +1,138 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +up_dir = parent_dir +for i in range(3): + sys.path.append(up_dir) + up_dir = os.path.dirname(up_dir) +from kutils import DEBUG, INFO, WARN, ERROR +import utils as u +import json +import importlib +from typing import Any, Optional +from beartype import beartype +from PIL import Image + +from agent.prompts import * +from browser_env import Trajectory +from browser_env.actions import ( + Action, + ActionParsingError, + create_id_based_action, + create_none_action, + create_playwright_action, + create_vision_action, + create_mas_action, +) +from llms import ( + # call_llm, + # generate_from_huggingface_completion, + # generate_from_openai_chat_completion, + # generate_from_openai_completion, + lm_config, +) +from llms.tokenizers import Tokenizer +from prompts.prompt_constructor import MultimodalCoTPromptConstructor, MultimodalCoTPromptConstructor +from llms.ais_requestor import get_lm_requestor +from base_agent import Agent + +class VisionAgent(Agent): + """prompt-based agent that emits action given the history""" + + @beartype + def __init__( + self, + action_set_tag: str, + lm_config: lm_config.LMConfig, + prompt_constructor: PromptConstructor | MultimodalCoTPromptConstructor | MultimodalCoTPromptConstructor, + captioning_fn = None, + ) -> None: + super().__init__() + self.lm_config = lm_config + self.prompt_constructor = prompt_constructor + self.action_set_tag = action_set_tag + self.captioning_fn = captioning_fn + self.multimodal_inputs = True + self.kq = get_lm_requestor(self.lm_config.model) + + def set_action_set_tag(self, tag: str) -> None: + self.action_set_tag = tag + + def map_to_vwa_actions(self, action_info): + pred_action_history = action_info['pred_action_history'] + pred_action_description = action_info['pred_action_description'] + pred_action = action_info['pred_action'] + pred_action_type = action_info['pred_action_type'] + pred_bbox = action_info['pred_bbox'] + pred_type_value = action_info['pred_type_value'] + pred_click_point = action_info['pred_click_point'] + parse_error_msg = action_info['parse_error_msg'] + return pred_action + + def filter_bboxes(self, bboxes, w, h): + to_be_del = [] + for idx, box in bboxes.items(): + if box['x'] < 0 or box['x'] > w or box['y'] < 0 or box['y'] > h: + to_be_del.append(idx) + for idx in to_be_del: + del bboxes[idx] + return bboxes + + @beartype + def next_action( + self, + trajectory: Trajectory, + intent: str, + meta_data: dict[str, Any], + images: Optional[list[Image.Image]] = None, + output_response: bool = False + ) -> Action: + # Create page screenshot image for multimodal models. + if self.multimodal_inputs: + page_screenshot_arr = trajectory[-1]["observation"]["ori_image"] + page_screenshot_img = Image.fromarray(page_screenshot_arr) # size = (viewport_width, viewport_width) + + bboxes = meta_data['bbox'] + + n = 0 + while True: + # task, action history, memo, tabs, hint + prompt = prompt_f.format(intent, meta_data["action_history"], '', '', meta_data['hint']) + if images: + input_img = images[0] + response = self.kq.infer_with_input_img( + sys_prompt, + prompt, + input_img, + page_screenshot_img) + else: + response = self.kq.infer( + sys_prompt, + prompt, + page_screenshot_img) + + action_info = parse_response(response, 'bbox') + force_prefix = self.prompt_constructor.instruction["meta_data"].get("force_prefix", "") + response = f"{force_prefix}{response}" + if output_response: print(f'{response}', flush=True) + + n += 1 + try: + action = create_vision_action(action_info, bboxes) + action["raw_prediction"] = response + action['action_info'] = action_info + break + except ActionParsingError as e: + INFO(e) + if n >= self.lm_config.gen_config["max_retry"]: + action = create_none_action() + action["raw_prediction"] = response + break + + action['prompt'] = prompt + action['action_info'] = action_info + return action + + def reset(self, test_config_file: str) -> None: + pass + diff --git a/browser_env/actions.py b/browser_env/actions.py index 13aaf1b..69c7de6 100644 --- a/browser_env/actions.py +++ b/browser_env/actions.py @@ -6,6 +6,7 @@ import random import re import string +import time from enum import IntEnum from itertools import chain from typing import Any, TypedDict, Union, cast @@ -55,6 +56,20 @@ class ParsedPlaywrightCode(TypedDict): ) +def extract_text(text, from_text, to_text = None) -> list: + if to_text == None: + return [text[text.find(from_text):]] + + if from_text == None: + return [text[:text.find(to_text)]] + + pattern = re.escape(from_text) + r'(.*?)' + re.escape(to_text) + matches = re.findall(pattern, text, re.DOTALL) + if matches: + return [match.strip() for match in matches] + + return [] + @beartype def is_in_viewport( element: Locator, viewport: ViewportSize, threshold: float = 0.3 @@ -99,7 +114,7 @@ class Action(TypedDict): coords: npt.NDArray[np.float32] element_role: int element_name: str - text: list[int] + text: list page_number: int url: str nth: int @@ -154,6 +169,10 @@ def action2str( action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}" case ActionTypes.UPLOAD: action_str = f"upload [{action['text']}] to [{element_id}]" + case ActionTypes.WAIT: + action_str = f"wait" + case ActionTypes.DRAG: + action_str = f"drag [{element_id}] [{action['text'][0]}]" case ActionTypes.STOP: action_str = f"stop [{action['answer']}]" case ActionTypes.NONE: @@ -197,6 +216,10 @@ def action2str( action_str = f"stop [{action['answer']}]" case ActionTypes.UPLOAD: action_str = f"upload [{action['text']}] to [{element_id}]" + case ActionTypes.WAIT: + action_str = "wait" + case ActionTypes.DRAG: + action_str = f"drag [{element_id}] [{action['text'][0]}]" case ActionTypes.NONE: action_str = "none" case _: @@ -341,6 +364,8 @@ class ActionTypes(IntEnum): STOP = 17 CLEAR = 18 UPLOAD = 19 + WAIT = 20 + DRAG = 21 def __str__(self) -> str: return f"ACTION_TYPES.{self.name}" @@ -391,6 +416,10 @@ def is_equivalent(a: Action, b: Action) -> bool: return a["pw_code"] == b["pw_code"] case ActionTypes.STOP: return a["answer"] == b["answer"] + case ActionTypes.WAIT: + return True + case ActionTypes.DRAG: + return a['element_id'] == b['element_id'] and a['text'][0] == b['text'][0] case _: raise ValueError(f"Unknown action type: {a['action_type']}") @@ -519,6 +548,24 @@ def create_none_action() -> Action: def create_stop_action(answer: str) -> Action: action = create_none_action() action.update({"action_type": ActionTypes.STOP, "answer": answer}) + pred_action_history = [] # TODO + pred_action_description = '' + pred_action = '' + pred_action_type = 'UNKNOWN' + pred_bbox = [0, 0, 0, 0] + pred_type_value = '' + pred_click_point = [0, 0] + res = { + 'pred_action_history': pred_action_history, + 'pred_action_description': pred_action_description, + 'pred_action': pred_action, + 'pred_action_type': pred_action_type, + 'pred_bbox': pred_bbox, + 'pred_type_value': pred_type_value, + 'pred_click_point': pred_click_point, + 'parse_error_msg': '', + } + action['action_info'] = res return action @@ -635,6 +682,31 @@ def create_goto_url_action(url: str) -> Action: ) return action +@beartype +def create_wait_action() -> Action: + """Return a valid action object with type WAIT.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.WAIT, + } + ) + return action + +@beartype +def create_drag_action(element_id: int, value: str, element_name) -> Action: + """Return a valid action object with type WAIT.""" + action = create_none_action() + action["action_type"] = ActionTypes.DRAG + action['element_id'] = str(element_id) + def to_int_if_possible(s: str): + s = s.strip() + if re.fullmatch(r'[+-]?\d+', s): + return int(s) # Python 的 int 无溢出问题 + return None # 不是整数 + action['text'].append(to_int_if_possible(value)) + action['element_name'] = element_name + return action @beartype def create_page_close_action() -> Action: @@ -1052,11 +1124,13 @@ async def aexecute_type(keys: list[int], page: APage) -> None: @beartype def execute_focus( element_role: int, element_name: str, nth: int, page: Page -) -> None: +) -> bool: """Click the specified DOM element.""" element_role_str = _id2role[element_role] if page.viewport_size is None: - raise ValueError("Viewport size is not set for the current page") + # raise ValueError("Viewport size is not set for the current page") + print("Viewport size is not set for the current page") + return False element_location_list: list[tuple[Locator, float, float]] = [] for frame in page.frames: match element_role_str: @@ -1079,11 +1153,16 @@ def execute_focus( (locator, bounding_box["x"], bounding_box["y"]) ) if len(element_location_list) <= nth: - raise ValueError( - f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested" - ) + # raise ValueError( + # f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested" + # ) + print(f'ExecuteFocus: There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested') + return False + + if len(element_location_list) > 5: return False element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order element_location_list[nth][0].focus() + return True @beartype @@ -1262,6 +1341,21 @@ async def aexecute_playwright_check( # perform the action await locator.check() +def execute_drag(start_x: float, start_y: float, end_x: float, end_y: float, page: Page): + page.mouse.move(start_x, start_y) + page.mouse.down() + # steps 越大越平滑,某些 UI 框架更稳定 + page.mouse.move(end_x, end_y, steps=20) + page.mouse.up() + +def set_range_value(locator_code: str, value: float, page: Page): + locator = eval(locator_code) + locator.scroll_into_view_if_needed() + locator.evaluate("""(el, val) => { + el.value = val; + el.dispatchEvent(new Event('input', { bubbles: true })); + el.dispatchEvent(new Event('change', { bubbles: true })); + }""", value) @beartype def execute_action( @@ -1272,7 +1366,19 @@ def execute_action( sleep_after_execution: float = 0.0, ) -> Page: """Execute the action on the ChromeDriver.""" - action_type = action["action_type"] + + action_type = action['action_type'] + element_id = action['element_id'] + element_role = action['element_role'] + element_name = action['element_name'] + pw_code = action['pw_code'] + nth = action['nth'] + text = action['text'] + print(100 * '-') + print(f'execute_action in actions.py: \naction_type = {action_type}\nelement_id = {element_id}\nelement_role = {element_role}\nelement_name = {element_name}\npw_code = {pw_code}\nnth = {nth}\ntext = {text}') + + action_method = 'normal' + num_tabs_before = len(browser_ctx.pages) match action_type: case ActionTypes.NONE: @@ -1300,22 +1406,41 @@ def execute_action( case ActionTypes.CLICK: # check each kind of locator in order # TODO[shuyanzh]: order is temp now - if action["element_id"]: - element_id = action["element_id"] - element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] - execute_mouse_click(element_center[0], element_center[1], page) - elif action["element_role"] and action["element_name"]: + + flag = False + + if action['domain'] != 'reddit' and action['domain'] != 'shopping' and action["element_role"] and action["element_name"]: element_role = int(action["element_role"]) element_name = action["element_name"] nth = action["nth"] - execute_focus(element_role, element_name, nth, page) - execute_click_current(page) - elif action["pw_code"]: + try: + flag = execute_focus(element_role, element_name, nth, page) + except Exception as e: + flag = False + print(f'execute_action in actions.py: {e}') + if flag: + try: + execute_click_current(page) + action_method = 'role, name, focus, click' + except Exception as e: + flag = False + print(f'execute_action in actions.py: {e}') + + if action["element_id"] and not flag: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + flag = True + action_method = 'element id mouse click' + + if action["pw_code"] and not flag: parsed_code = parse_playwright_code(action["pw_code"]) locator_code = parsed_code[:-1] # [shuyanzh], don't support action args and kwargs now execute_playwright_click(locator_code=locator_code, page=page) - else: + flag = True + + if not flag: raise ValueError("No proper locator found for click action") case ActionTypes.HOVER: if action["element_id"]: @@ -1337,25 +1462,38 @@ def execute_action( "No proper locator found for hover action" ) case ActionTypes.TYPE: - if action["element_id"]: - element_id = action["element_id"] - element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] - execute_mouse_click(element_center[0], element_center[1], page) - execute_type(action["text"], page) - elif action["element_role"] and action["element_name"]: + flag = False + + if action["element_role"] and action["element_name"]: element_role = int(action["element_role"]) element_name = action["element_name"] nth = action["nth"] - execute_focus(element_role, element_name, nth, page) + try: + flag = execute_focus(element_role, element_name, nth, page) + except Exception as e: + flag = False + print(f'execute_action in actions.py: {e}') + if flag: + try: + execute_type(action["text"], page) + action_method = 'role, name, focus, type' + except Exception as e: + flag = False + print(f'execute_action in actions.py: {e}') + + if action["element_id"] and not flag: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) execute_type(action["text"], page) - elif action["pw_code"]: + action_method = 'element id mouse click type' + + elif action["pw_code"] and not flag: parsed_code = parse_playwright_code(action["pw_code"]) locator_code = parsed_code[:-1] text = parsed_code[-1]["arguments"][0] # [shuyanzh], don't support action args and kwargs now - execute_playwright_type( - text=text, locator_code=locator_code, page=page - ) + execute_playwright_type(text=text, locator_code=locator_code, page=page) else: raise NotImplementedError( "No proper locator found for type action" @@ -1401,6 +1539,18 @@ def execute_action( element_id = action["element_id"] element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] execute_upload(element_center[0], element_center[1], action["text"], page) + case ActionTypes.WAIT: + time.sleep(1) + case ActionTypes.DRAG: + value = action['text'][0] + element_name = action["element_name"] + slider = page.locator(f'#{element_name}') + slider.evaluate( +f"""(el) => {{ + el.value = {value}; + el.dispatchEvent(new Event('input', {{ bubbles: true }})); + el.dispatchEvent(new Event('change', {{ bubbles: true }})); +}}""") case _: raise ValueError(f"Unknown action type: {action_type}") @@ -1411,6 +1561,8 @@ def execute_action( page = browser_ctx.pages[-1] page.bring_to_front() + print(f'execute_action in actions.py: action method = {action_method}') + print(100 * '-') return page @@ -1698,6 +1850,157 @@ def create_playwright_action(playwright_code: str) -> Action: raise ActionParsingError(f"Unknown playwright action {action}") +@beartype +def create_mas_action(action_str: str, obs) -> Action: + """Main function to return individual id based action""" + obs_list = obs.split('\n') + obs_dict = {} + for content in obs_list: + try: + idx = int(content[1:content.find(']')]) + matches = re.findall(r'\[(.*?)\]', content) + if len(matches) == 3: + content = matches[2] + obs_dict[idx] = content + except Exception as e: + print(f'create_mas_action in actions.py: {e}') + + action_str = action_str.strip() + if "[" in action_str: + action = action_str.split("[")[0].strip() + else: + actions = action_str.split() + if actions: + action = actions[0].strip() + else: + raise ActionParsingError(f"No action specified: {action_str}") + + match action: + case "click": + match = re.search(r"click ?\[(\d+)\]", action_str) + if not match: + print(f'create_mas_action in actions.py: invalid click action {action_str}') + return create_none_action() + element_id = match.group(1) + if int(element_id) > len(obs_list): + last_element_id = obs_list[-1] + last_element_id = last_element_id[last_element_id.find('[')+1:last_element_id.find(']')] + last_element_id = int(last_element_id) + if int(element_id) > last_element_id: + return create_none_action() + + element_name = obs_dict[int(element_id)] + element_name = extract_text(element_name, 'element_name = ', ']')[0] + element_name = f'#{element_name}' + return create_click_action(element_id=element_id, element_name=element_name) + case "clear": + match = re.search(r"clear ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid clear action {action_str}") + element_id = match.group(1) + return create_clear_action(element_id=element_id) + case "upload": + # add default enter flag + if not (action_str.endswith("[0]") or action_str.endswith("[1]")): + action_str += " [1]" + + match = re.search( + r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str + ) + if not match: + raise ActionParsingError(f"Invalid type action {action_str}") + element_id, text, enter_flag = ( + match.group(1), + match.group(2), + match.group(3), + ) + if enter_flag == "1": + text += "\n" + return create_upload_action(text=text, element_id=element_id) + case "hover": + match = re.search(r"hover ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid hover action {action_str}") + element_id = match.group(1) + return create_hover_action(element_id=element_id) + case "type": + # add default enter flag + if not (action_str.endswith("[0]") or action_str.endswith("[1]")): + action_str += " [1]" + + match = re.search( + r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str + ) + if not match: + raise ActionParsingError(f"Invalid type action {action_str}") + element_id, text, enter_flag = ( + match.group(1), + match.group(2), + match.group(3), + ) + if enter_flag == "1": + text += "\n" + return create_type_action(text=text, element_id=element_id) + case "press": + match = re.search(r"press ?\[(.+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid press action {action_str}") + key_comb = match.group(1) + return create_key_press_action(key_comb=key_comb) + case "scroll": + # up or down + match = re.search(r"scroll ?\[?(up|down)\]?", action_str) + if not match: + raise ActionParsingError(f"Invalid scroll action {action_str}") + direction = match.group(1) + return create_scroll_action(direction=direction) + case "goto": + match = re.search(r"goto ?\[(.+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid goto action {action_str}") + url = match.group(1) + return create_goto_url_action(url=url) + case "new_tab": + return create_new_tab_action() + case "go_back": + return create_go_back_action() + case "go_forward": + return create_go_forward_action() + case "tab_focus": + match = re.search(r"tab_focus ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError( + f"Invalid tab_focus action {action_str}" + ) + page_number = int(match.group(1)) + return create_page_focus_action(page_number) + case "close_tab": + return create_page_close_action() + case "wait": + return create_wait_action() + case "drag": + match = re.search(r"drag ?\[(\d+)\] ?\[(\d+)\]", action_str) + if not match: + print(f'create_mas_action in actions.py: invalid drag action {action_str}') + return create_none_action() + element_id, text = ( + int(match.group(1)), + match.group(2), + ) + if element_id not in obs_dict.keys(): + return create_none_action() + element_name = obs_dict[element_id] + element_name = extract_text(element_name, 'element_name = ', ']')[0] + return create_drag_action(element_id, text, element_name) + case "stop": # stop answer + match = re.search(r"stop ?\[(.+)\]", action_str) + if not match: # some tasks don't require an answer + answer = "" + else: + answer = match.group(1) + return create_stop_action(answer) + + raise ActionParsingError(f"Invalid action {action_str}") @beartype def create_id_based_action(action_str: str) -> Action: @@ -1810,3 +2113,74 @@ def create_id_based_action(action_str: str) -> Action: return create_stop_action(answer) raise ActionParsingError(f"Invalid action {action_str}") + +def get_element_id(click_point, bboxes, th = 5): + for box_idx, box in bboxes.items(): + x, y = click_point + if isinstance(box, dict): + x1 = box['left'] + y1 = box['top'] + x2 = box['right'] + y2 = box['bottom'] + width = box['width'] + height = box['height'] + left, right = min(x1, x2), max(x1, x2) + top, bottom = min(y1, y2), max(y1, y2) + # print(box_idx, x, y, left, right, top, bottom, width, height) + if left <= x <= right and top <= y <= bottom: + return str(box_idx) + for box_idx, box in bboxes.items(): + dx = max(left - x, 0, x - right) + dy = max(top - y, 0, y - bottom) + distance = (dx ** 2 + dy ** 2) ** 0.5 + if distance < th: return str(box_idx) + else: + x1 = box[0] + y1 = box[1] + x2 = box[0] + box[2] + y2 = box[1] + box[3] + left, right = min(x1, x2), max(x1, x2) + top, bottom = min(y1, y2), max(y1, y2) + if left <= x <= right and top <= y <= bottom: + return str(int(box_idx) + 1) + for box_idx, box in bboxes.items(): + dx = max(left - x, 0, x - right) + dy = max(top - y, 0, y - bottom) + distance = (dx ** 2 + dy ** 2) ** 0.5 + if distance < th: return str(int(box_idx) + 1) + + return "-1" + +@beartype +def create_vision_action(action_info: dict, bboxes: dict) -> Action: + action_str = action_info['pred_action'] + action_type = action_info['pred_action_type'] + type_value = action_info['pred_type_value'] + click_point = action_info['pred_click_point'] + + match action_type: + case "CLICK": + element_id = get_element_id(click_point, bboxes) + if element_id == '-1': return create_none_action() + return create_click_action(element_id=element_id) + case "SELECT": + element_id = get_element_id(click_point, bboxes) + if element_id == '-1': return create_none_action() + return create_click_action(element_id=element_id) + case "TYPE": + element_id = get_element_id(click_point, bboxes) + if element_id == '-1': return create_none_action() + type_value += "\n" + return create_type_action(text=type_value, element_id=element_id) + case "SCROLL": + return create_scroll_action(direction=type_value) + case "GOTO": + return create_goto_url_action(url=type_value) + case "BACK": + return create_go_back_action() + case "FINISH": # stop answer + return create_stop_action(type_value) + case "WAIT": + return create_none_action() + + raise ActionParsingError(f"Invalid action {action_str}") \ No newline at end of file diff --git a/browser_env/auto_login.py b/browser_env/auto_login.py index 67a22c9..3828b26 100644 --- a/browser_env/auto_login.py +++ b/browser_env/auto_login.py @@ -1,7 +1,10 @@ """Script to automatically login each website""" +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +if parent_dir not in sys.path: sys.path.insert(0, parent_dir) import argparse import glob -import os import time from concurrent.futures import ThreadPoolExecutor from itertools import combinations diff --git a/browser_env/env_config.py b/browser_env/env_config.py index 3ec81a1..bda3ab4 100644 --- a/browser_env/env_config.py +++ b/browser_env/env_config.py @@ -82,16 +82,16 @@ ACCOUNTS = { - "reddit": {"username": "MarvelsGrantMan136", "password": "test1234"}, + "reddit": {"username": "MarvelsGrantMan136", "password": os.getenv('REDDIT_PSW', '')}, "shopping": { "username": "emma.lopez@gmail.com", - "password": "Password.123", + "password": os.getenv('SHOPPING_PSW', ''), }, "classifieds": { "username": "blake.sullivan@gmail.com", - "password": "Password.123", + "password": os.getenv('CLASSIFIEDS_PSW', ''), }, - "shopping_site_admin": {"username": "admin", "password": "admin1234"}, - "shopping_admin": {"username": "admin", "password": "admin1234"}, - "gitlab": {"username": "byteblaze", "password": "hello1234"}, + "shopping_site_admin": {"username": "admin", "password": os.getenv('SHOPPING_SITE_ADMIN_PSW', '')}, + "shopping_admin": {"username": "admin", "password": os.getenv('SHOPPING_ADMIN_PSW', '')}, + "gitlab": {"username": "byteblaze", "password": os.getenv('GITLAB_PSW', '')}, } \ No newline at end of file diff --git a/browser_env/envs.py b/browser_env/envs.py index ef326bb..ff470bb 100644 --- a/browser_env/envs.py +++ b/browser_env/envs.py @@ -7,7 +7,9 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Union +from io import BytesIO, StringIO +import pandas as pd import numpy as np import numpy.typing as npt import requests @@ -202,7 +204,7 @@ def setup(self, config_file: Path | None = None) -> None: client = page.context.new_cdp_session(page) client.send("Accessibility.enable") client.detach() - page.goto(url) + page.goto(url, timeout=120000) # set the first page as the current page self.page = self.context.pages[0] self.page.bring_to_front() @@ -218,12 +220,95 @@ def setup(self, config_file: Path | None = None) -> None: def _get_obs(self) -> dict[str, Observation]: obs = self.observation_handler.get_observation(self.page) + tabs = self.context.pages + tabs_dict = {} + for i, tab in enumerate(tabs): + tabs_dict[f'Tab {i}: {tab.title()}'] = tab.url + obs['tabs'] = tabs_dict return obs def _get_obs_metadata(self) -> dict[str, ObservationMetadata]: metadata = self.observation_handler.get_observation_metadata() return metadata + def get_som_img(self): + return self.observation_handler.action_processor.get_som_img(self.page) + + def get_bboxes(self) -> dict: + browser_info = self.observation_handler.action_processor.fetch_browser_info(self.page) + browser_config = browser_info["config"] + data_string = self.observation_handler.action_processor.get_page_bboxes(self.page) + df = pd.read_csv(StringIO(data_string), delimiter=",", quotechar='"') + df["Area"] = df["Width"] * df["Height"] + # Remove bounding boxes that are clipped. + b_x, b_y = (browser_config["win_left_bound"], browser_config["win_upper_bound"]) + df = df[ + (df["Bottom"] - b_y >= 0) + & (df["Top"] - b_y <= self.viewport_size["height"]) + & (df["Right"] - b_x >= 0) + & (df["Left"] - b_x <= self.viewport_size["width"]) + ] + viewport_area = self.viewport_size["width"] * self.viewport_size["height"] + # Filter out bounding boxes that too large (more than 80% of the viewport) + df = df[df["Area"] <= 0.8 * viewport_area] + + bboxes = {} + index = 0 + for _, row in df.iterrows(): + if not row["Interactable"]: + content = "" + # Add image alt-text to the text representation. + if row["Element"] == "IMG" and pd.notna(row["Alt"]): + content += row["Alt"] + # Add HTML textContent (if any) to the text representation. + if pd.notna(row["TextContent"]): + content += ( + row["TextContent"].strip().replace("\n", "").replace("\t", "") + )[ + :200 + ] # Limit to 200 characters to avoid having too much text + continue + + unique_id = str(index + 1) + top, right, bottom, left, width, height = ( + row["Top"], + row["Right"], + row["Bottom"], + row["Left"], + row["Width"], + row["Height"], + ) + left, right, top, bottom = left - b_x, right - b_x, top - b_y, bottom - b_y + + ori_idx = row['ID'] + element = row['Element'] + alt = row['Alt'] + cls = row['Class'] + idx2 = row['Id'] + text_content = row['TextContent'] + interactable = row['Interactable'] + area = row['Area'] + + bboxes[unique_id] = { + 'idx': ori_idx, + 'element': element, + 'top': top, + 'right': right, + 'bottom': bottom, + 'left': left, + 'width': width, + 'height': height, + 'alt': alt, + 'cls': cls, + 'idx2': idx2, + 'text_content': text_content, + 'interactable': interactable, + 'area': area, + } + + index += 1 + return bboxes + @beartype def reset( self, diff --git a/browser_env/helper_functions.py b/browser_env/helper_functions.py index 54dce12..99f4978 100644 --- a/browser_env/helper_functions.py +++ b/browser_env/helper_functions.py @@ -34,6 +34,131 @@ """ +import re +from PIL import Image, ImageDraw, ImageFont + +def add_text_top( + img, + text, + font_path=None, + font_size=36, + text_color=(0, 0, 0), + background_color="white", # None 表示在 RGBA 模式下使用透明背景 + padding=(24, 24, 24, 24), # (left, top, right, bottom) + align="center", # "left" | "center" | "right" + line_spacing=1.3 # 行距系数 +): + # 统一模式,便于透明和非透明处理 + base_mode = "RGBA" if img.mode in ("RGBA", "LA") else "RGB" + img_conv = img.convert(base_mode) + W = img_conv.width + + # 字体 + if font_path: + font = ImageFont.truetype(font_path, font_size) + else: + # 默认字体不一定支持中文,建议传入中文字体路径,如 NotoSansSC、SourceHanSans、微软雅黑等 + font = ImageFont.load_default() + + # 用于测量文字宽度 + draw_tmp = ImageDraw.Draw(img_conv) + def measure(s): + try: + return draw_tmp.textlength(s, font=font) + except Exception: + bbox = font.getbbox(s) + return bbox[2] - bbox[0] + + # 最大文本宽度 + left, top, right, bottom = padding + max_text_width = W - left - right + if max_text_width <= 0: + raise ValueError("Padding 太大,已超过图片宽度。") + + # 自动换行:兼顾中英文,优先按词(空白分隔)换行,不可行时按字符切分 + def wrap(text): + lines = [] + for para in text.split('\n'): + if para == '': + lines.append('') + continue + tokens = re.findall(r'\S+|\s+', para) # 保留空白作为可能的断点 + current = '' + for tok in tokens: + test = current + tok + if measure(test) <= max_text_width: + current = test + else: + if current.strip() != '' or current != '': + lines.append(current.rstrip()) + current = tok + # 若单个 token 仍超宽,按字符拆分 + if measure(current) > max_text_width: + buf = '' + for ch in current: + t2 = buf + ch + if measure(t2) <= max_text_width: + buf = t2 + else: + lines.append(buf.rstrip()) + buf = ch + current = buf + else: + # 极端情况:一开始就超宽,按字符拆分 + buf = '' + for ch in tok: + t2 = buf + ch + if measure(t2) <= max_text_width: + buf = t2 + else: + lines.append(buf.rstrip()) + buf = ch + current = buf + if current != '': + lines.append(current.rstrip()) + return lines + + lines = wrap(text) + + # 行高计算 + try: + ascent, descent = font.getmetrics() + base_line_height = ascent + descent + except Exception: + b = font.getbbox("Ay") + base_line_height = b[3] - b[1] + line_h = int(base_line_height * line_spacing) + + # 计算需要的顶部空白高度 + text_block_height = top + bottom + (line_h * len(lines) if lines else 0) + + # 背景色处理 + if background_color is None: + bg = (255, 255, 255, 0) if base_mode == "RGBA" else (255, 255, 255) + else: + bg = background_color + + # 生成新图(上方空白 + 原图) + new_img = Image.new(base_mode, (W, img_conv.height + text_block_height), bg) + new_img.paste(img_conv, (0, text_block_height)) + + # 绘制文字 + draw = ImageDraw.Draw(new_img) + y = top + for line in lines: + lw = measure(line) + if align == "center": + x = (W - lw) // 2 + elif align == "right": + x = W - right - lw + else: + x = left + draw.text((x, y), line, font=font, fill=text_color) + y += line_h + + return new_img + + def get_render_action( action: Action, observation_metadata: dict[str, ObservationMetadata], @@ -94,6 +219,8 @@ def get_action_description( ActionTypes.TYPE, ]: action_name = str(action["action_type"]).split(".")[1].lower() + print(f'get_action_description in helper_functions.py: {action["element_id"]}') + print(f'get_action_description in helper_functions.py: {text_meta_data["obs_nodes_info"]}') if action["element_id"] in text_meta_data["obs_nodes_info"]: node_content = text_meta_data["obs_nodes_info"][ action["element_id"] @@ -158,26 +285,137 @@ class RenderHelper(object): """Helper class to render text and image observations and meta data in the trajectory""" def __init__( - self, config_file: str, result_dir: str, action_set_tag: str + self, + _config, + result_file: str, + action_set_tag: str, + input_images = [], ) -> None: - with open(config_file, "r") as f: - _config = json.load(f) - _config_str = "" - for k, v in _config.items(): - _config_str += f"{k}: {v}\n" - _config_str = f"
{_config_str}\n"
- task_id = _config["task_id"]
+ _config_str = ""
+ for k, v in _config.items():
+ _config_str += f"{k}: {v}\n"
+ _config_str = f"{_config_str}\n"
+ if input_images:
+ for input_image in input_images:
+ byte_io = io.BytesIO()
+ input_image.save(byte_io, format="PNG")
+ byte_io.seek(0)
+ image_bytes = base64.b64encode(byte_io.read())
+ image_str = image_bytes.decode("utf-8")
+ _config_str += f"{text_obs}{text_obs}{text_obs}
-
+
A classifieds website for people to sell and buy things
@@ -82,7 +82,7 @@
-
+
An online shopping site
@@ -90,7 +90,7 @@
-
+
A social news aggregation and discussion website
@@ -114,7 +114,7 @@
-
+
An online encyclopedia
diff --git a/evaluation_harness/__init__.py b/evaluation_harness/__init__.py deleted file mode 100644 index fd0b27d..0000000 --- a/evaluation_harness/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from .evaluators import * -from .helper_functions import ( - get_query_text, - get_query_text_lowercase, - reddit_get_latest_comment_content_by_username, - reddit_get_latest_comment_obj_by_username, - reddit_get_parent_comment_username_of_latest_comment_by_username, - shopping_get_latest_order_url, - shopping_get_num_reviews, - shopping_get_order_product_name_list, - shopping_get_order_product_option, - shopping_get_order_product_quantity, - shopping_get_product_attributes, - shopping_get_product_price, - shopping_get_rating_as_percentage, - shopping_get_sku_latest_review_author, - shopping_get_sku_latest_review_rating, - shopping_get_sku_latest_review_text, -) diff --git a/evaluation_harness/evaluators.py b/evaluation_harness/evaluators.py index 03224fd..e972250 100644 --- a/evaluation_harness/evaluators.py +++ b/evaluation_harness/evaluators.py @@ -18,7 +18,7 @@ from playwright.sync_api import CDPSession, Page from browser_env.actions import Action -from browser_env.utils import StateInfo +from browser_env.utils import StateInfo, DetachedPage from evaluation_harness import image_utils from evaluation_harness.helper_functions import ( PseudoPage, @@ -45,6 +45,7 @@ ) Trajectory = list[Union[Action, StateInfo]] +from llms.ais_requestor import AISRequestor @beartype @@ -56,7 +57,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage + page: Page | PseudoPage | None ) -> float: raise NotImplementedError @@ -143,6 +144,8 @@ class StringEvaluator(Evaluator): must include: each phrase in the reference answer must be included in the answer fuzzy match: the answer is similar to the reference answer, using LLM judge """ + def __init__(self, kq_config): + self.kq = AISRequestor(kq_config) @staticmethod @beartype @@ -172,6 +175,7 @@ def must_include(ref: str, pred: str) -> float: # prevent false positive (e.g, 0) if len(word_tokenize(clean_ref)) == 1: tok_pred = word_tokenize(clean_pred) + tok_pred = ''.join(tok_pred) return float(clean_ref in tok_pred) else: return float(clean_ref in clean_pred) @@ -192,19 +196,19 @@ def must_exclude(ref: str, pred: str) -> float: @staticmethod @beartype - def fuzzy_match(ref: str, pred: str, intent: str) -> float: - return llm_fuzzy_match(pred, ref, intent) + def fuzzy_match(ref: str, pred: str, intent: str, llm) -> float: + return llm_fuzzy_match(pred, ref, intent, llm) @staticmethod @beartype - def ua_match(ref: str, pred: str, intent: str) -> float: - return llm_ua_match(pred, ref, intent) + def ua_match(ref: str, pred: str, intent: str, llm) -> float: + return llm_ua_match(pred, ref, intent, llm) def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage | None = None + page: Page | PseudoPage | DetachedPage | None = None ) -> float: with open(config_file, "r") as f: configs = json.load(f) @@ -238,7 +242,8 @@ def __call__( assert isinstance(value, list) for must_value in value: value_or = must_value.split(" |OR| ") - score *= any([self.must_include(ref=v, pred=pred) for v in value_or]) + lst = [self.must_include(ref=v, pred=pred) for v in value_or] + score *= any(lst) case "must_exclude": assert isinstance(value, list) for must_excl_value in value: @@ -267,12 +272,14 @@ def __call__( intent=configs["intent"], ref=configs["eval"]["string_note"], pred=pred, + llm = self.kq.infer_messages ) else: assert isinstance(value, list) for reference in value: score *= self.fuzzy_match( - ref=reference, pred=pred, intent=intent + ref=reference, pred=pred, intent=intent, + llm=self.kq.infer_messages ) return score @@ -307,7 +314,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage + page: Page | PseudoPage | DetachedPage ) -> float: with open(config_file, "r") as f: configs = json.load(f) @@ -342,11 +349,14 @@ def clean_url(url: str) -> str: class HTMLContentExactEvaluator(Evaluator): """Check whether the contents appear in the page""" + def __init__(self, kq_config): + self.kq = AISRequestor(kq_config) + def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage + page: Page | PseudoPage | None ) -> float: with open(config_file, "r") as f: configs = json.load(f) @@ -365,6 +375,7 @@ def __call__( # navigate to that url if target_url != "last": + print(target_url) page.goto(target_url) time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep @@ -469,6 +480,7 @@ def __call__( ref=target, pred=selected_element, intent="NOT USED", + llm = self.kq.infer_messages ) ] ) @@ -612,7 +624,7 @@ def __call__( self, trajectory: Trajectory, config_file: Path | str, - page: Page | PseudoPage + page: Page | PseudoPage | None ) -> float: score = 1.0 @@ -624,23 +636,16 @@ def __call__( @beartype -def evaluator_router( - config_file: Path | str, captioning_fn=None -) -> EvaluatorComb: - """Router to get the evaluator class""" - with open(config_file, "r") as f: - configs = json.load(f) - - eval_types = configs["eval"]["eval_types"] - evaluators: list[Evaluator | EvaluatorPartial] = [] +def evaluator_router(eval_types, kq_config, captioning_fn=None) -> EvaluatorComb: + evaluators: list[Evaluator] = [] for eval_type in eval_types: match eval_type: case "string_match": - evaluators.append(StringEvaluator()) + evaluators.append(StringEvaluator(kq_config)) case "url_match": evaluators.append(URLExactEvaluator()) case "program_html": - evaluators.append(HTMLContentExactEvaluator()) + evaluators.append(HTMLContentExactEvaluator(kq_config)) case "page_image_query": evaluators.append(PageImageEvaluator(captioning_fn)) case _: diff --git a/evaluation_harness/helper_functions.py b/evaluation_harness/helper_functions.py index c934c16..4b3737f 100644 --- a/evaluation_harness/helper_functions.py +++ b/evaluation_harness/helper_functions.py @@ -15,10 +15,6 @@ SHOPPING, WIKIPEDIA, ) -from llms.providers.openai_utils import ( - generate_from_openai_chat_completion, -) - class PseudoPage: def __init__(self, original_page: Page, url: str): @@ -424,28 +420,27 @@ def reddit_get_post_url(url: str) -> str: @beartype def reddit_get_post_comment_tree(page: Page | PseudoPage) -> Dict[str, Any]: try: - comment_tree = page.evaluate( - f"""(function buildCommentTree(node, data_level) {{ - let tree = {{ - "username": node.querySelector(".fg-inherit").outerText, - "net_score": parseInt(node.querySelector(".vote__net-score").outerText), - "content": node.querySelector(".comment__content").outerText, - "time": new Date(node.querySelector('.comment__main > header > h1 > span > time').dateTime), - "children": [] - }}; - node.querySelectorAll(".comment").forEach((child) => {{ - if (parseInt(child.getAttribute('data-level')) === data_level+1) {{ - tree['children'].push(buildCommentTree(child, data_level+1)); - }} - }}) - + js_code='''(() => Array.from(document.querySelectorAll('#main > .comment[data-level="1"]')).map(n => (function buildCommentTree(node, data_level) { + let tree = { + username: node.querySelector(".fg-inherit").outerText, + net_score: parseInt(node.querySelector(".vote__net-score").outerText), + content: node.querySelector(".comment__content").outerText, + time: new Date(node.querySelector('.comment__main > header > h1 > span > time').dateTime), + children: [] + }; + node.querySelectorAll('.comment').forEach(child => { + if (parseInt(child.getAttribute('data-level')) === data_level+1) { + tree.children.push(buildCommentTree(child, data_level + 1)); + } + }); return tree; -}})(document.querySelector("#main"), 0)""" - ) - except Exception: + })(n, 1)))()''' + comment_tree = page.evaluate(js_code) + except Exception as e: + print(e) comment_tree = {} - return comment_tree + return {"comment_trees":comment_tree} @beartype @@ -453,8 +448,8 @@ def reddit_get_latest_comment_obj_by_username( page: Page | PseudoPage, username: str ) -> Dict[str, Any]: try: - comment_tree = reddit_get_post_comment_tree(page) - latest_time = datetime.min.replace(tzinfo=timezone.utc) + comment_trees = reddit_get_post_comment_tree(page).get("comment_trees") + latest_time = datetime.min comment = {} def dfs(node): @@ -472,10 +467,11 @@ def dfs(node): for child in node["children"]: dfs(child) - - dfs(comment_tree) + for comment_tree in comment_trees: + dfs(comment_tree) except Exception as e: + print(e) comment = {} return comment @@ -499,7 +495,7 @@ def reddit_get_parent_comment_obj_of_latest_comment_by_username( page: Page | PseudoPage, username: str ) -> Dict[str, Any]: try: - comment_tree = reddit_get_post_comment_tree(page) + comment_trees = reddit_get_post_comment_tree(page).get("comment_trees") latest_time = datetime.min.replace(tzinfo=timezone.utc) comment = {} @@ -519,9 +515,11 @@ def dfs(node): else: dfs(child) - dfs(comment_tree) + for comment_tree in comment_trees: + dfs(comment_tree) - except Exception: + except Exception as e: + print(e) comment = {} return comment @@ -577,7 +575,7 @@ def gitlab_get_project_memeber_role( @beartype -def llm_fuzzy_match(pred: str, reference: str, question: str) -> float: +def llm_fuzzy_match(pred: str, reference: str, question: str, llm) -> float: """Check whether the prediction matches the reference with GPT-4-turbo""" messages: list[dict[str, Any]] = [] # construct the question to ask @@ -592,22 +590,26 @@ def llm_fuzzy_match(pred: str, reference: str, question: str) -> float: {"role": "user", "content": message}, ] - response = generate_from_openai_chat_completion( - model="gpt-4-1106-preview", - messages=messages, - temperature=0, - max_tokens=768, - top_p=1.0, - context_length=0, - ).lower() + # response = generate_from_openai_chat_completion( + # model="gpt-4-1106-preview", + # messages=messages, + # temperature=0, + # max_tokens=768, + # top_p=1.0, + # context_length=0, + # ).lower() + + response = llm(messages).lower() + if "partially correct" in response or "incorrect" in response: return 0.0 - else: - assert "correct" in response, response + elif "correct" in response: return 1.0 + else: + return 0.0 -def llm_ua_match(pred: str, reference: str, question: str) -> float: +def llm_ua_match(pred: str, reference: str, question: str, llm) -> float: """Check whether the prediction matches the reference with GPT-4-turbo""" messages: list[dict[str, Any]] = [] # construct the question to ask @@ -627,14 +629,17 @@ def llm_ua_match(pred: str, reference: str, question: str) -> float: {"role": "user", "content": message}, ] - response = generate_from_openai_chat_completion( - model="gpt-4-1106-preview", - messages=messages, - temperature=0, - max_tokens=768, - top_p=1.0, - context_length=0, - ).lower() + # response = generate_from_openai_chat_completion( + # model="gpt-4-1106-preview", + # messages=messages, + # temperature=0, + # max_tokens=768, + # top_p=1.0, + # context_length=0, + # ).lower() + + response = llm(messages).lower() + if "different" in response: return 0.0 else: diff --git a/evaluation_harness/image_utils.py b/evaluation_harness/image_utils.py index da6782e..383e855 100644 --- a/evaluation_harness/image_utils.py +++ b/evaluation_harness/image_utils.py @@ -3,58 +3,36 @@ import numpy as np from PIL import Image from skimage.metrics import structural_similarity as ssim -from transformers import ( - Blip2ForConditionalGeneration, - Blip2Processor, -) +# from transformers import ( +# Blip2ForConditionalGeneration, +# Blip2Processor, +# ) +from llms.ais_requestor import get_lm_requestor -def get_captioning_fn( - device, dtype, model_name: str = "Salesforce/blip2-flan-t5-xl" -) -> callable: - if "blip2" in model_name: - captioning_processor = Blip2Processor.from_pretrained(model_name) - captioning_model = Blip2ForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype - ) - else: - raise NotImplementedError( - "Only BLIP-2 models are currently supported" - ) - captioning_model.to(device) +def get_captioning_fn(caption_model) -> callable: + if caption_model == "": + return None + kq = get_lm_requestor(caption_model) def caption_images( images: List[Image.Image], prompt: List[str] = None, max_new_tokens: int = 32, ) -> List[str]: - if prompt is None: - # Perform VQA - inputs = captioning_processor( - images=images, return_tensors="pt" - ).to(device, dtype) - generated_ids = captioning_model.generate( - **inputs, max_new_tokens=max_new_tokens - ) - captions = captioning_processor.batch_decode( - generated_ids, skip_special_tokens=True - ) - else: - # Regular captioning. Prompt is a list of strings, one for each image - assert len(images) == len( - prompt - ), "Number of images and prompts must match, got {} and {}".format( - len(images), len(prompt) - ) - inputs = captioning_processor( - images=images, text=prompt, return_tensors="pt" - ).to(device, dtype) - generated_ids = captioning_model.generate( - **inputs, max_new_tokens=max_new_tokens - ) - captions = captioning_processor.batch_decode( - generated_ids, skip_special_tokens=True - ) + + prompt = len(images) * [''] + + assert len(images) == len(prompt), "Number of images and prompts must match, got {} and {}".format(len(images), len(prompt)) + + captions = [] + for question, img in zip(prompt, images): + try: + caption = kq.infer('', question, img) + except Exception as e: + print("Error during captioning:", e) + caption = "a picture" + captions.append(caption) return captions diff --git a/llms/ais_requestor.py b/llms/ais_requestor.py new file mode 100644 index 0000000..0bd7e07 --- /dev/null +++ b/llms/ais_requestor.py @@ -0,0 +1,182 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +up_dir = parent_dir +for i in range(3): + sys.path.append(up_dir) + up_dir = os.path.dirname(up_dir) +from kutils import DEBUG, INFO, WARN, ERROR +from tqdm import tqdm +from PIL import Image +import pandas as pd +import json +import io +import requests +import base64 +from io import BytesIO +from openai import OpenAI + +def get_git(): + current_file = os.path.abspath(__file__) + current_folder = os.path.dirname(current_file) + return current_folder + '/' + +def pil_image_to_base64(img, format='png'): + import PIL, PIL.JpegImagePlugin + if isinstance(img, PIL.JpegImagePlugin.JpegImageFile): format = 'jpg' + if format == 'jpg': format = 'JPEG' + elif format == 'png': format = 'PNG' + output_buffer = BytesIO() + img.save(output_buffer, format=format) + byte_data = output_buffer.getvalue() + base64_str = base64.b64encode(byte_data).decode('utf-8') + return base64_str + +class AISRequestor(): + def __init__(self, config): + if 'models' in config.keys(): + self.model = config['model'] + else: + self.model = config['model'] + self.temperature = config['temperature'] + self.max_tokens = config['max_tokens'] + self.client = OpenAI( + api_key = config['api_key'], + base_url = config['base_url'], + timeout = 600000, + ) + + def infer(self, system, prompt, pil_img): + messages = \ + [ + { + 'role': 'system', + 'content': system + }, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + ] + } + ] + + if pil_img: + image_msg = { + 'type': 'image_url', + 'image_url': { + 'url': 'data:image/png;base64,' + } + } + messages[1]['content'].append(image_msg) + image_base64 = pil_image_to_base64(pil_img) + messages[1]['content'][1]['image_url']['url'] += image_base64 + + return self.infer_messages(messages) + + def infer_with_input_img(self, system, prompt, input_img, screenshot): + messages = \ + [ + { + 'role': 'system', + 'content': system + }, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': prompt + }, + { + "type": "text", + "text": "IMAGES: (0) user input image" + }, + { + 'type': 'image_url', + 'image_url': { + 'url': 'data:image/png;base64,' + } + }, + { + "type": "text", + "text": "(1) current page screenshot" + }, + { + 'type': 'image_url', + 'image_url': { + 'url': 'data:image/png;base64,' + } + } + ] + } + ] + input_image_base64 = pil_image_to_base64(input_img) + messages[1]['content'][2]['image_url']['url'] += input_image_base64 + ss_image_base64 = pil_image_to_base64(screenshot) + messages[1]['content'][4]['image_url']['url'] += ss_image_base64 + return self.infer_messages(messages) + + def infer_messages(self, messages) -> str: + response = self.client.chat.completions.create( + model = self.model, + messages = messages, + temperature = self.temperature, + max_tokens = self.max_tokens, + ) + response = response.choices[0].message.content + return response + +def check_models(models): + kq_config = { + 'api_key': os.getenv("AGI_API_KEY"), + 'model': '', + 'base_url': "https://agi-pre.alipay.com/api", + 'temperature': 0.0, + 'max_tokens': 4096, + } + for model in models: + kq_config['model'] = model + kq = AISRequestor(kq_config) + prompt = "" + image_file = f'{get_git()}/asset/unireco_bird_example.jpg' + img = Image.open(image_file) + try: + response = kq.infer('', prompt, img) + if isinstance(response, str): INFO(f'model {model} good') + except Exception as e: + DEBUG(f'model {model} {e}') + return False + return True + +def get_lm_requestor(model): + kq_config = {} + if 'gpt' in model: + kq_config = { + 'api_key': os.getenv("OPENROUTER_KEY"), + 'model': f'openai/{model}', + 'base_url': os.getenv("BASE_URL"), + 'temperature': 0.0, + 'max_tokens': 4096, + } + else: + kq_config = { + 'api_key': os.getenv('AGI_API_KEY'), + 'model': model, + 'base_url': os.getenv("AGI_URL"), + 'temperature': 0.0, + 'max_tokens': 4096, + } + kq = AISRequestor(kq_config) + return kq + +if __name__ == "__main__": + models = [ + 'KevinBlip', + 'KevinBlip2' + ] + + check_models(models) \ No newline at end of file diff --git a/llms/lm_config.py b/llms/lm_config.py index c02d60b..00b195f 100644 --- a/llms/lm_config.py +++ b/llms/lm_config.py @@ -24,34 +24,40 @@ class LMConfig: provider: str model: str + caption_model: str model_cls: type | None = None tokenizer_cls: type | None = None mode: str | None = None gen_config: dict[str, Any] = dataclasses.field(default_factory=dict) + domain: str = "reddit" def construct_llm_config(args: argparse.Namespace) -> LMConfig: llm_config = LMConfig( - provider=args.provider, model=args.model, mode=args.mode + provider=args.provider, + model=args.model, + mode=args.mode, + caption_model=args.caption_model, + domain=args.domain, ) - if args.provider in ["openai", "google"]: - llm_config.gen_config["temperature"] = args.temperature - llm_config.gen_config["top_p"] = args.top_p - llm_config.gen_config["context_length"] = args.context_length - llm_config.gen_config["max_tokens"] = args.max_tokens - llm_config.gen_config["stop_token"] = args.stop_token - llm_config.gen_config["max_obs_length"] = args.max_obs_length - llm_config.gen_config["max_retry"] = args.max_retry - elif args.provider == "huggingface": - llm_config.gen_config["temperature"] = args.temperature - llm_config.gen_config["top_p"] = args.top_p - llm_config.gen_config["max_new_tokens"] = args.max_tokens - llm_config.gen_config["stop_sequences"] = ( - [args.stop_token] if args.stop_token else None - ) - llm_config.gen_config["max_obs_length"] = args.max_obs_length - llm_config.gen_config["model_endpoint"] = args.model_endpoint - llm_config.gen_config["max_retry"] = args.max_retry - else: - raise NotImplementedError(f"provider {args.provider} not implemented") + # if args.provider in ["openai", "google"]: + llm_config.gen_config["temperature"] = args.temperature + llm_config.gen_config["top_p"] = args.top_p + llm_config.gen_config["context_length"] = args.context_length + llm_config.gen_config["max_tokens"] = args.max_tokens + llm_config.gen_config["stop_token"] = args.stop_token + llm_config.gen_config["max_obs_length"] = args.max_obs_length + llm_config.gen_config["max_retry"] = args.max_retry + # elif args.provider == "huggingface": + # llm_config.gen_config["temperature"] = args.temperature + # llm_config.gen_config["top_p"] = args.top_p + # llm_config.gen_config["max_new_tokens"] = args.max_tokens + # llm_config.gen_config["stop_sequences"] = ( + # [args.stop_token] if args.stop_token else None + # ) + # llm_config.gen_config["max_obs_length"] = args.max_obs_length + # llm_config.gen_config["model_endpoint"] = args.model_endpoint + # llm_config.gen_config["max_retry"] = args.max_retry + # else: + # raise NotImplementedError(f"provider {args.provider} not implemented") return llm_config diff --git a/prepare.sh b/prepare.sh index 09885ad..9308afb 100644 --- a/prepare.sh +++ b/prepare.sh @@ -1,4 +1,4 @@ #!/bin/bash # re-validate login information -mkdir -p ./.auth +# mkdir -p ./.auth python browser_env/auto_login.py \ No newline at end of file diff --git a/run.py b/run.py deleted file mode 100644 index 7a48a2e..0000000 --- a/run.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Script to run end-to-end evaluation on the benchmark. - -Modified from https://github.com/web-arena-x/webarena/blob/main/run.py. -""" -import argparse -import glob -import json -import logging -import os -import random -import subprocess -import tempfile -import time -from pathlib import Path -from typing import List - -import openai -import requests -import torch -from PIL import Image - -from agent import ( - PromptAgent, - construct_agent, -) -from agent.prompts import * -from browser_env import ( - Action, - ActionTypes, - ScriptBrowserEnv, - StateInfo, - Trajectory, - create_stop_action, -) -from browser_env.actions import is_equivalent -from browser_env.auto_login import get_site_comb_from_filepath -from browser_env.helper_functions import ( - RenderHelper, - get_action_description, -) -from evaluation_harness import evaluator_router, image_utils - -DATASET = os.environ["DATASET"] - -LOG_FOLDER = "log_files" -Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True) -LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log" - -logger = logging.getLogger("logger") -logger.setLevel(logging.INFO) - -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.DEBUG) -logger.addHandler(console_handler) - -file_handler = logging.FileHandler(LOG_FILE_NAME) -file_handler.setLevel(logging.DEBUG) -logger.addHandler(file_handler) - -# Set the log format -formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -console_handler.setFormatter(formatter) -file_handler.setFormatter(formatter) - - -def config() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Run end-to-end evaluation on the benchmark" - ) - parser.add_argument( - "--render", action="store_true", help="Render the browser" - ) - - parser.add_argument( - "--slow_mo", - type=int, - default=0, - help="Slow down the browser by the specified amount", - ) - parser.add_argument( - "--action_set_tag", default="id_accessibility_tree", help="Action type" - ) - parser.add_argument( - "--observation_type", - choices=[ - "accessibility_tree", - "accessibility_tree_with_captioner", - "html", - "image", - "image_som", - ], - default="accessibility_tree", - help="Observation type", - ) - parser.add_argument( - "--current_viewport_only", - action="store_true", - help="Only use the current viewport for the observation", - ) - parser.add_argument("--viewport_width", type=int, default=1280) - parser.add_argument("--viewport_height", type=int, default=2048) - parser.add_argument("--save_trace_enabled", action="store_true") - parser.add_argument("--sleep_after_execution", type=float, default=0.0) - - parser.add_argument("--max_steps", type=int, default=30) - - # agent config - parser.add_argument("--agent_type", type=str, default="prompt") - parser.add_argument( - "--instruction_path", - type=str, - default="agents/prompts/state_action_agent.json", - ) - parser.add_argument( - "--parsing_failure_th", - help="When consecutive parsing failures exceed this threshold, the agent will terminate early.", - type=int, - default=3, - ) - parser.add_argument( - "--repeating_action_failure_th", - help="When consecutive repeated actions exceed this threshold, the agent will terminate early.", - type=int, - default=5, - ) - - parser.add_argument("--test_config_base_dir", type=str) - - parser.add_argument( - "--eval_captioning_model_device", - type=str, - default="cpu", - choices=["cpu", "cuda"], - help="Device to run eval captioning model on. By default, runs it on CPU.", - ) - parser.add_argument( - "--eval_captioning_model", - type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl"], - help="Captioning backbone for VQA-type evals.", - ) - parser.add_argument( - "--captioning_model", - type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"], - help="Captioning backbone for accessibility tree alt text.", - ) - - # lm config - parser.add_argument("--provider", type=str, default="openai") - parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613") - parser.add_argument("--mode", type=str, default="chat") - parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--context_length", type=int, default=0) - parser.add_argument("--max_tokens", type=int, default=384) - parser.add_argument("--stop_token", type=str, default=None) - parser.add_argument( - "--max_retry", - type=int, - help="max retry times to perform generations when parsing fails", - default=1, - ) - parser.add_argument( - "--max_obs_length", - type=int, - help="when not zero, will truncate the observation to this length before feeding to the model", - default=3840, - ) - - # example config - parser.add_argument("--test_start_idx", type=int, default=0) - parser.add_argument("--test_end_idx", type=int, default=910) - - # logging related - parser.add_argument("--result_dir", type=str, default="") - args = parser.parse_args() - - # check the whether the action space is compatible with the observation space - if ( - args.action_set_tag == "id_accessibility_tree" - and args.observation_type - not in [ - "accessibility_tree", - "accessibility_tree_with_captioner", - "image_som", - ] - ): - raise ValueError( - f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}" - ) - - return args - - -def early_stop( - trajectory: Trajectory, max_steps: int, thresholds: dict[str, int] -) -> tuple[bool, str]: - """Check whether need to stop early""" - - # reach the max step - num_steps = (len(trajectory) - 1) / 2 - if num_steps >= max_steps: - return True, f"Reach max steps {max_steps}" - - last_k_actions: list[Action] - action_seq: list[Action] - - # Case: parsing failure for k times - k = thresholds["parsing_failure"] - last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] - if len(last_k_actions) >= k: - if all( - [ - action["action_type"] == ActionTypes.NONE - for action in last_k_actions - ] - ): - return True, f"Failed to parse actions for {k} times" - - # Case: same action for k times - k = thresholds["repeating_action"] - last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] - action_seq = trajectory[1::2] # type: ignore[assignment] - - if len(action_seq) == 0: - return False, "" - - last_action: Action = action_seq[-1] - - if last_action["action_type"] != ActionTypes.TYPE: - if len(last_k_actions) >= k: - if all( - [ - is_equivalent(action, last_action) - for action in last_k_actions - ] - ): - return True, f"Same action for {k} times" - - else: - # check the action sequence - if ( - sum([is_equivalent(action, last_action) for action in action_seq]) - >= k - ): - return True, f"Same typing action for {k} times" - - return False, "" - - -def test( - args: argparse.Namespace, - config_file_list: list[str] -) -> None: - scores = [] - max_steps = args.max_steps - - early_stop_thresholds = { - "parsing_failure": args.parsing_failure_th, - "repeating_action": args.repeating_action_failure_th, - } - - if args.observation_type in [ - "accessibility_tree_with_captioner", - "image_som", - ]: - device = torch.device("cuda") if torch.cuda.is_available() else "cpu" - dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - caption_image_fn = image_utils.get_captioning_fn( - device, dtype, args.captioning_model - ) - else: - caption_image_fn = None - - # Load a (possibly different) captioning model for running VQA evals. - if DATASET == 'visualwebarena': - if ( - caption_image_fn - and args.eval_captioning_model == args.captioning_model - ): - eval_caption_image_fn = caption_image_fn - else: - eval_caption_image_fn = image_utils.get_captioning_fn( - args.eval_captioning_model_device, - torch.float16 - if ( - torch.cuda.is_available() - and args.eval_captioning_model_device == "cuda" - ) - else torch.float32, - args.eval_captioning_model, - ) - else: - caption_image_fn = None - eval_caption_image_fn = None - - agent = construct_agent( - args, - captioning_fn=caption_image_fn - if args.observation_type == "accessibility_tree_with_captioner" - else None, - ) # NOTE: captioning_fn here is used for captioning input images. - - env = ScriptBrowserEnv( - headless=not args.render, - slow_mo=args.slow_mo, - observation_type=args.observation_type, - current_viewport_only=args.current_viewport_only, - viewport_size={ - "width": args.viewport_width, - "height": args.viewport_height, - }, - save_trace_enabled=args.save_trace_enabled, - sleep_after_execution=args.sleep_after_execution, - # NOTE: captioning_fn here is used for LLM + captioning baselines. - # This can be different from the captioning model used for evals. - captioning_fn=caption_image_fn, - ) - - for config_file in config_file_list: - try: - render_helper = RenderHelper( - config_file, args.result_dir, args.action_set_tag - ) - - # Load task. - with open(config_file) as f: - _c = json.load(f) - intent = _c["intent"] - task_id = _c["task_id"] - image_paths = _c.get("image", None) - images = [] - - # automatically login - if _c["storage_state"]: - cookie_file_name = os.path.basename(_c["storage_state"]) - comb = get_site_comb_from_filepath(cookie_file_name) - temp_dir = tempfile.mkdtemp() - # subprocess to renew the cookie - subprocess.run( - [ - "python", - "browser_env/auto_login.py", - "--auth_folder", - temp_dir, - "--site_list", - *comb, - ] - ) - _c["storage_state"] = f"{temp_dir}/{cookie_file_name}" - assert os.path.exists(_c["storage_state"]) - # update the config file - config_file = f"{temp_dir}/{os.path.basename(config_file)}" - with open(config_file, "w") as f: - json.dump(_c, f) - - # Load input images for the task, if any. - if image_paths is not None: - if isinstance(image_paths, str): - image_paths = [image_paths] - for image_path in image_paths: - # Load image either from the web or from a local path. - if image_path.startswith("http"): - headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'} - input_image = Image.open(requests.get(image_path, stream=True, headers = headers).raw) - else: - input_image = Image.open(image_path) - - images.append(input_image) - - logger.info(f"[Config file]: {config_file}") - logger.info(f"[Intent]: {intent}") - - agent.reset(config_file) - trajectory: Trajectory = [] - obs, info = env.reset(options={"config_file": config_file}) - state_info: StateInfo = {"observation": obs, "info": info} - trajectory.append(state_info) - - meta_data = {"action_history": ["None"]} - while True: - early_stop_flag, stop_info = early_stop( - trajectory, max_steps, early_stop_thresholds - ) - - if early_stop_flag: - action = create_stop_action(f"Early stop: {stop_info}") - else: - try: - action = agent.next_action( - trajectory, - intent, - images=images, - meta_data=meta_data, - ) - except ValueError as e: - # get the error message - action = create_stop_action(f"ERROR: {str(e)}") - - trajectory.append(action) - - action_str = get_action_description( - action, - state_info["info"]["observation_metadata"], - action_set_tag=args.action_set_tag, - prompt_constructor=agent.prompt_constructor - if isinstance(agent, PromptAgent) - else None, - ) - render_helper.render( - action, state_info, meta_data, args.render_screenshot - ) - meta_data["action_history"].append(action_str) - - if action["action_type"] == ActionTypes.STOP: - break - - obs, _, terminated, _, info = env.step(action) - state_info = {"observation": obs, "info": info} - trajectory.append(state_info) - - if terminated: - # add a action place holder - trajectory.append(create_stop_action("")) - break - - # NOTE: eval_caption_image_fn is used for running eval_vqa functions. - evaluator = evaluator_router( - config_file, captioning_fn=eval_caption_image_fn - ) - score = evaluator( - trajectory=trajectory, - config_file=config_file, - page=env.page - ) - - scores.append(score) - - if score == 1: - logger.info(f"[Result] (PASS) {config_file}") - else: - logger.info(f"[Result] (FAIL) {config_file}") - - if args.save_trace_enabled: - env.save_trace( - Path(args.result_dir) / "traces" / f"{task_id}.zip" - ) - except openai.OpenAIError as e: - logger.info(f"[OpenAI Error] {repr(e)}") - except Exception as e: - logger.info(f"[Unhandled Error] {repr(e)}]") - import traceback - - # write to error file - with open(Path(args.result_dir) / "error.txt", "a") as f: - f.write(f"[Config file]: {config_file}\n") - f.write(f"[Unhandled Error] {repr(e)}\n") - f.write(traceback.format_exc()) # write stack trace to file - - render_helper.close() - - env.close() - if len(scores): - logger.info(f"Average score: {sum(scores) / len(scores)}") - - -def prepare(args: argparse.Namespace) -> None: - # convert prompt python files to json - from agent.prompts import to_json - - to_json.run() - - # prepare result dir - result_dir = args.result_dir - if not result_dir: - result_dir = ( - f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}" - ) - if not Path(result_dir).exists(): - Path(result_dir).mkdir(parents=True, exist_ok=True) - args.result_dir = result_dir - logger.info(f"Create result dir: {result_dir}") - - if not (Path(result_dir) / "traces").exists(): - (Path(result_dir) / "traces").mkdir(parents=True) - - # log the log file - with open(os.path.join(result_dir, "log_files.txt"), "a+") as f: - f.write(f"{LOG_FILE_NAME}\n") - - -def get_unfinished(config_files: list[str], result_dir: str) -> list[str]: - result_files = glob.glob(f"{result_dir}/*.html") - task_ids = [ - os.path.basename(f).split(".")[0].split("_")[1] for f in result_files - ] - unfinished_configs = [] - for config_file in config_files: - task_id = os.path.basename(config_file).split(".")[0] - if task_id not in task_ids: - unfinished_configs.append(config_file) - return unfinished_configs - - -def dump_config(args: argparse.Namespace) -> None: - config_file = Path(args.result_dir) / "config.json" - if not config_file.exists(): - with open(config_file, "w") as f: - json.dump(vars(args), f, indent=4) - logger.info(f"Dump config to {config_file}") - - -if __name__ == "__main__": - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - args = config() - args.sleep_after_execution = 2.5 - prepare(args) - - test_config_base_dir = args.test_config_base_dir - - test_file_list = [] - st_idx = args.test_start_idx - ed_idx = args.test_end_idx - for i in range(st_idx, ed_idx): - test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json")) - test_file_list = get_unfinished(test_file_list, args.result_dir) - print(f"Total {len(test_file_list)} tasks left") - args.render = False - args.render_screenshot = True - args.save_trace_enabled = True - - args.current_viewport_only = True - dump_config(args) - - test(args, test_file_list) diff --git a/run_demo.py b/run_demo.py deleted file mode 100644 index 4c0ea95..0000000 --- a/run_demo.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Script to run end-to-end evaluation on the benchmark. - -Modified from https://github.com/web-arena-x/webarena/blob/main/run.py. -""" -import argparse -import json -import logging -import os -import random -import time -import tempfile -from pathlib import Path - -import openai -import requests -import torch -from beartype import beartype -from PIL import Image - -from agent import ( - PromptAgent, - construct_agent, -) -from agent.prompts import * -from browser_env import ( - Action, - ActionTypes, - ScriptBrowserEnv, - StateInfo, - Trajectory, - create_stop_action, -) -from browser_env.actions import is_equivalent -from browser_env.helper_functions import ( - RenderHelper, - get_action_description, -) -from evaluation_harness import image_utils - -LOG_FOLDER = "log_files" -Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True) -LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log" - -logger = logging.getLogger("logger") -logger.setLevel(logging.INFO) - -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.DEBUG) -logger.addHandler(console_handler) - -file_handler = logging.FileHandler(LOG_FILE_NAME) -file_handler.setLevel(logging.DEBUG) -logger.addHandler(file_handler) - -# Set the log format -formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -console_handler.setFormatter(formatter) -file_handler.setFormatter(formatter) - - -def config() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Run end-to-end evaluation on the benchmark" - ) - parser.add_argument( - "--render", action="store_true", help="Render the browser" - ) - - parser.add_argument( - "--slow_mo", - type=int, - default=0, - help="Slow down the browser by the specified amount", - ) - parser.add_argument( - "--action_set_tag", default="som", help="Action type" - ) - parser.add_argument( - "--observation_type", - choices=[ - "accessibility_tree", - "accessibility_tree_with_captioner", - "html", - "image", - "image_som", - ], - default="image_som", - help="Observation type", - ) - parser.add_argument( - "--current_viewport_only", - action="store_true", - help="Only use the current viewport for the observation", - ) - parser.add_argument("--viewport_width", type=int, default=1280) - parser.add_argument("--viewport_height", type=int, default=2048) - parser.add_argument("--save_trace_enabled", action="store_true") - parser.add_argument("--sleep_after_execution", type=float, default=0.0) - - parser.add_argument("--max_steps", type=int, default=30) - - # agent config - parser.add_argument("--agent_type", type=str, default="prompt") - parser.add_argument( - "--instruction_path", - type=str, - default="agent/prompts/jsons/p_som_cot_id_actree_3s.json", - ) - parser.add_argument( - "--parsing_failure_th", - help="When consecutive parsing failures exceed this threshold, the agent will terminate early.", - type=int, - default=3, - ) - parser.add_argument( - "--repeating_action_failure_th", - help="When consecutive repeated actions exceed this threshold, the agent will terminate early.", - type=int, - default=5, - ) - - parser.add_argument( - "--eval_captioning_model_device", - type=str, - default="cpu", - choices=["cpu", "cuda"], - help="Device to run eval captioning model on. By default, runs it on CPU.", - ) - parser.add_argument( - "--eval_captioning_model", - type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl"], - help="Captioning backbone for VQA-type evals.", - ) - parser.add_argument( - "--captioning_model", - type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"], - help="Captioning backbone for accessibility tree alt text.", - ) - - # lm config - parser.add_argument("--provider", type=str, default="openai") - parser.add_argument("--model", type=str, default="gpt-4-vision-preview") - parser.add_argument("--mode", type=str, default="chat") - parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--context_length", type=int, default=0) - parser.add_argument("--max_tokens", type=int, default=384) - parser.add_argument("--stop_token", type=str, default=None) - parser.add_argument( - "--max_retry", - type=int, - help="max retry times to perform generations when parsing fails", - default=1, - ) - parser.add_argument( - "--max_obs_length", - type=int, - help="when not zero, will truncate the observation to this length before feeding to the model", - default=3840, - ) - - - # example config - parser.add_argument("--start_url", type=str, default="https://google.com") - parser.add_argument("--intent", type=str, required=True) - parser.add_argument("--image", type=str, default="", help="url of images, seperated by |AND|") - - # logging related - parser.add_argument("--result_dir", type=str, default="") - args = parser.parse_args() - - # check the whether the action space is compatible with the observation space - if ( - args.action_set_tag == "id_accessibility_tree" - and args.observation_type - not in [ - "accessibility_tree", - "accessibility_tree_with_captioner", - "image_som", - ] - ): - raise ValueError( - f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}" - ) - - return args - - -@beartype -def early_stop( - trajectory: Trajectory, max_steps: int, thresholds: dict[str, int] -) -> tuple[bool, str]: - """Check whether need to stop early""" - - # reach the max step - num_steps = (len(trajectory) - 1) / 2 - if num_steps >= max_steps: - return True, f"Reach max steps {max_steps}" - - last_k_actions: list[Action] - action_seq: list[Action] - - # Case: parsing failure for k times - k = thresholds["parsing_failure"] - last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] - if len(last_k_actions) >= k: - if all( - [ - action["action_type"] == ActionTypes.NONE - for action in last_k_actions - ] - ): - return True, f"Failed to parse actions for {k} times" - - # Case: same action for k times - k = thresholds["repeating_action"] - last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] - action_seq = trajectory[1::2] # type: ignore[assignment] - - if len(action_seq) == 0: - return False, "" - - last_action: Action = action_seq[-1] - - if last_action["action_type"] != ActionTypes.TYPE: - if len(last_k_actions) >= k: - if all( - [ - is_equivalent(action, last_action) - for action in last_k_actions - ] - ): - return True, f"Same action for {k} times" - - else: - # check the action sequence - if ( - sum([is_equivalent(action, last_action) for action in action_seq]) - >= k - ): - return True, f"Same typing action for {k} times" - - return False, "" - - -@beartype -def test( - args: argparse.Namespace, - config_file: str -) -> None: - scores = [] - max_steps = args.max_steps - - early_stop_thresholds = { - "parsing_failure": args.parsing_failure_th, - "repeating_action": args.repeating_action_failure_th, - } - - caption_image_fn = None # Don't use captioning for the demo, due to extra resources required to run BLIP-2. - - - agent = construct_agent( - args, - captioning_fn=caption_image_fn - if args.observation_type == "accessibility_tree_with_captioner" - else None, - ) # NOTE: captioning_fn here is used for captioning input images. - - assert args.render, "Rendering is required for end-to-end evaluation" - - env = ScriptBrowserEnv( - headless=not args.render, - slow_mo=args.slow_mo, - observation_type=args.observation_type, - current_viewport_only=args.current_viewport_only, - viewport_size={ - "width": args.viewport_width, - "height": args.viewport_height, - }, - save_trace_enabled=args.save_trace_enabled, - sleep_after_execution=args.sleep_after_execution, - # NOTE: captioning_fn here is used for LLM + captioning baselines. - # This can be different from the captioning model used for evals. - captioning_fn=caption_image_fn, - ) - - try: - render_helper = RenderHelper( - config_file, args.result_dir, args.action_set_tag - ) - - # Load task. - with open(config_file, 'r') as f: - _c = json.load(f) - intent = _c["intent"] - image_paths = _c.get("image", None) - images = [] - - # Load input images for the task, if any. - if image_paths is not None: - if isinstance(image_paths, str): - image_paths = [image_paths] - for image_path in image_paths: - # Load image either from the web or from a local path. - if image_path.startswith("http"): - input_image = Image.open(requests.get(image_path, stream=True).raw) - else: - input_image = Image.open(image_path) - - images.append(input_image) - - logger.info(f"[Config file]: {config_file}") - logger.info(f"[Intent]: {intent}") - - agent.reset(config_file) - trajectory: Trajectory = [] - obs, info = env.reset(options={"config_file": config_file}) - state_info: StateInfo = {"observation": obs, "info": info} - trajectory.append(state_info) - - meta_data = {"action_history": ["None"]} - while True: - early_stop_flag, stop_info = early_stop( - trajectory, max_steps, early_stop_thresholds - ) - - if early_stop_flag: - action = create_stop_action(f"Early stop: {stop_info}") - else: - try: - print('=' * 30) - print('Agent: Thinking...') - action = agent.next_action( - trajectory, - intent, - images=images, - meta_data=meta_data, - output_response=True - ) - except ValueError as e: - # get the error message - action = create_stop_action(f"ERROR: {str(e)}") - - trajectory.append(action) - - action_str = get_action_description( - action, - state_info["info"]["observation_metadata"], - action_set_tag=args.action_set_tag, - prompt_constructor=agent.prompt_constructor - if isinstance(agent, PromptAgent) - else None, - ) - render_helper.render( - action, state_info, meta_data, args.render_screenshot - ) - meta_data["action_history"].append(action_str) - - if action["action_type"] == ActionTypes.STOP: - break - - obs, _, terminated, _, info = env.step(action) - state_info = {"observation": obs, "info": info} - trajectory.append(state_info) - - if terminated: - # add a action place holder - trajectory.append(create_stop_action("")) - break - - if args.save_trace_enabled: - env.save_trace( - Path(args.result_dir) / "trace.zip" - ) - except openai.OpenAIError as e: - logger.info(f"[OpenAI Error] {repr(e)}") - except Exception as e: - logger.info(f"[Unhandled Error] {repr(e)}]") - import traceback - - # write to error file - with open(Path(args.result_dir) / "error.txt", "a") as f: - f.write(f"[Config file]: {config_file}\n") - f.write(f"[Unhandled Error] {repr(e)}\n") - f.write(traceback.format_exc()) # write stack trace to file - - render_helper.close() - - env.close() - - -def prepare(args: argparse.Namespace) -> None: - # convert prompt python files to json - from agent.prompts import to_json - - to_json.run() - - # prepare result dir - result_dir = args.result_dir - if not result_dir: - result_dir = ( - f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}" - ) - if not Path(result_dir).exists(): - Path(result_dir).mkdir(parents=True, exist_ok=True) - args.result_dir = result_dir - logger.info(f"Create result dir: {result_dir}") - - if not (Path(result_dir) / "traces").exists(): - (Path(result_dir) / "traces").mkdir(parents=True) - - # log the log file - with open(os.path.join(result_dir, "log_files.txt"), "a+") as f: - f.write(f"{LOG_FILE_NAME}\n") - - -@beartype -def dump_config(args: argparse.Namespace) -> None: - config_file = Path(args.result_dir) / "config.json" - if not config_file.exists(): - with open(config_file, "w") as f: - json.dump(vars(args), f, indent=4) - logger.info(f"Dump config to {config_file}") - - -if __name__ == "__main__": - os.environ["TOKENIZERS_PARALLELISM"] = "false" - args = config() - args.sleep_after_execution = 2.5 - prepare(args) - - _, tmp_config_file = tempfile.mkstemp(text=True) - images_url = None - if args.image: - images_url = args.image.split('|AND|') - with open(tmp_config_file, 'w') as f: - json.dump({ - "task_id": 0, - "start_url": args.start_url, - "intent": args.intent, - "image": images_url - }, f) - - args.render_screenshot = True - args.save_trace_enabled = True - - args.current_viewport_only = True - dump_config(args) - - test(args, tmp_config_file) - - os.remove(tmp_config_file) diff --git a/run_recon_act_infer.py b/run_recon_act_infer.py new file mode 100644 index 0000000..eb38d25 --- /dev/null +++ b/run_recon_act_infer.py @@ -0,0 +1,56 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +if parent_dir not in sys.path: sys.path.insert(0, parent_dir) +import AWorld.examples.visualwebarena.utils as u +from vwa_tester import VWATester, VWAPathHandler, VWAConfig + +def get_default_config(domain, start_index = 0, end_index = 466) -> VWAConfig: + args = VWAConfig() + args.action_set_tag = "som" # action type + args.vwa_code_path = f'{parent_dir}/' + args.vwa_data_path = f'{u.get_nas()}/gui_dataset/visualwebarena/' + + args.mode = 'mas' # som, vision, mas + if args.mode == 'vision': + args.instruction_path = f'{args.vwa_code_path}/agent/prompts/jsons/vision.json' + args.observation_type = "image" + elif args.mode == 'som': + args.instruction_path = f'{args.vwa_code_path}/agent/prompts/jsons/som.json' + args.observation_type = "image_som" + elif args.mode == 'mas': + args.instruction_path = f'{args.vwa_code_path}/agent/prompts/jsons/mas.json' + args.observation_type = "image_som" + + args.model = os.getenv('LLM_MODEL_NAME') + args.eval_provider = os.getenv('LLM_MODEL_NAME') + + args.caption_model = 'KevinBlip' + + args.domain = domain + args.print_time = True + args.output_response = True + args.render = True + args.flush = True + args.save_trace_enabled = False + args.render_fail_only = False + args.test_start_idx = start_index + args.test_end_idx = end_index + + return args + +if __name__ == "__main__": + domain = sys.argv[1] + os.environ["TOKENIZERS_PARALLELISM"] = "false" + args = get_default_config(domain) + ph = VWAPathHandler(args) + metrics_files = u.list_files(ph.metrics_path) + test_config_base_dir = f'{args.vwa_data_path}/config_files/vwa/test_{args.domain}' + test_config_files = u.list_files(test_config_base_dir) + test_config_files = sorted(test_config_files, key=lambda x: int(x.split('.')[0])) + if not args.flush: + test_config_files = [a for a in test_config_files if a not in metrics_files] + test_config_files = [test_config_base_dir + '/' + a for a in test_config_files] + args.test_config_files = test_config_files + tester = VWATester(args) + tester.test() \ No newline at end of file diff --git a/scripts/collect_obs.py b/scripts/collect_obs.py index 49317bc..a50a2c0 100644 --- a/scripts/collect_obs.py +++ b/scripts/collect_obs.py @@ -41,7 +41,6 @@ def get_observation( for action in action_seq: action = action.strip() obs, success, _, _, info = env.step(create_playwright_action(action)) - print(obs["text"]) _ = input("Press enter to continue") diff --git a/scripts/generate_test_data.py b/scripts/generate_test_data.py index 06ecbf1..8a62d35 100644 --- a/scripts/generate_test_data.py +++ b/scripts/generate_test_data.py @@ -5,8 +5,16 @@ from browser_env.env_config import * +def get_name(file_path, pure = False): + filename_with_ext = os.path.basename(file_path) + if pure: + last_dot_index = filename_with_ext.rfind('.') + name = filename_with_ext[:last_dot_index] + ext = filename_with_ext[last_dot_index+1:] + return name + return filename_with_ext -def main() -> None: +if __name__ == "__main__": DATASET = os.environ["DATASET"] if DATASET == "webarena": print("DATASET: webarena") @@ -45,8 +53,13 @@ def main() -> None: } else: raise ValueError(f"Dataset not implemented: {DATASET}") - + + data_path = f'{os.path.expanduser("~")}/data/gui_dataset/visualwebarena/' + auth_path = f'{data_path}/auth/' + os.makedirs(auth_path, exist_ok=True) + for inp_path in inp_paths: + inp_path = data_path + inp_path output_dir = inp_path.replace('.raw.json', '') os.makedirs(output_dir, exist_ok=True) with open(inp_path, "r") as f: @@ -59,8 +72,16 @@ def main() -> None: data = json.loads(raw) for idx, item in enumerate(data): with open(os.path.join(output_dir, f"{idx}.json"), "w") as f: + if 'image' in item.keys(): + image_path = item['image'] + if isinstance(image_path, str): + item['image'] = data_path + image_path[image_path.find('static'):] + elif isinstance(image_path, list): + if image_path: + for i in range(len(image_path)): + item['image'][i] = data_path + image_path[i][image_path[i].find('static'):] + storage_state = item['storage_state'] + storage_name = get_name(storage_state) + storage_state = f'{auth_path}/{storage_name}' + item['storage_state'] = storage_state json.dump(item, f, indent=2) - - -if __name__ == "__main__": - main() diff --git a/test_utils.py b/test_utils.py new file mode 100644 index 0000000..76140f6 --- /dev/null +++ b/test_utils.py @@ -0,0 +1,137 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +up_dir = parent_dir +for i in range(3): + sys.path.append(up_dir) + up_dir = os.path.dirname(up_dir) +from kutils import DEBUG, INFO, WARN, ERROR +import utils as u +from PIL import Image, ImageChops +import numpy as np +import argparse +import time +import glob +import json +from pathlib import Path + +from browser_env.actions import is_equivalent +from browser_env import ( + Action, + ActionTypes, + ScriptBrowserEnv, + StateInfo, + Trajectory, + create_stop_action, +) + +def early_stop( + trajectory: Trajectory, max_steps: int, thresholds: dict[str, int] +) -> tuple[bool, str]: + """Check whether need to stop early""" + + # reach the max step + num_steps = (len(trajectory) - 1) / 2 + if num_steps >= max_steps: + return True, f"Reach max steps {max_steps}" + + last_k_actions: list[Action] + action_seq: list[Action] + + # Case: parsing failure for k times + k = thresholds["parsing_failure"] + last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] + if len(last_k_actions) >= k: + if all( + [ + action["action_type"] == ActionTypes.NONE + for action in last_k_actions + ] + ): + return True, f"Failed to parse actions for {k} times" + + # Case: same action for k times + k = thresholds["repeating_action"] + last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment] + action_seq = trajectory[1::2] # type: ignore[assignment] + + if len(action_seq) == 0: + return False, "" + + last_action: Action = action_seq[-1] + + if last_action["action_type"] != ActionTypes.TYPE: + if len(last_k_actions) >= k: + if all( + [ + is_equivalent(action, last_action) + for action in last_k_actions + ] + ): + return True, f"Same action for {k} times" + + else: + # check the action sequence + if ( + sum([is_equivalent(action, last_action) for action in action_seq]) + >= k + ): + return True, f"Same typing action for {k} times" + + return False, "" + +def resize_image_proportional(image, max_width, max_height): + original_width, original_height = image.size + width_ratio = max_width / original_width + height_ratio = max_height / original_height + min_ratio = min(width_ratio, height_ratio) + new_width = int(original_width * min_ratio) + new_height = int(original_height * min_ratio) + resized_image = image.resize((new_width, new_height)) + return resized_image + +def is_images_same(img1, img2, ratio = 0.05): + wh = 100 + img1 = img1.resize((wh, wh)) + img2 = img2.resize((wh, wh)) + t = u.get_time() + if img1.size != img2.size: + raise ValueError("Should be same size") + diff = ImageChops.difference(img1, img2).convert('L') # 转为灰度图 + diff_array = np.array(diff) + diff_pixels = np.count_nonzero(diff_array) + total_pixels = diff_array.size + diff_ratio = diff_pixels / total_pixels + return diff_ratio < ratio + +def prepare(args: argparse.Namespace) -> None: + # convert prompt python files to json + from agent.prompts import to_json + + to_json.run() + + # prepare result dir + result_dir = args.result_dir + if not result_dir: + result_dir = ( + f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}" + ) + if not Path(result_dir).exists(): + Path(result_dir).mkdir(parents=True, exist_ok=True) + args.result_dir = result_dir + INFO(f"Create result dir: {result_dir}") + + if not (Path(result_dir) / "traces").exists(): + (Path(result_dir) / "traces").mkdir(parents=True) + +def get_unfinished(config_files: list[str], result_dir: str) -> list[str]: + result_files = glob.glob(f"{result_dir}/*.html") + task_ids = [ + os.path.basename(f).split(".")[0].split("_")[1] for f in result_files + ] + unfinished_configs = [] + for config_file in config_files: + task_id = os.path.basename(config_file).split(".")[0] + if task_id not in task_ids: + unfinished_configs.append(config_file) + return unfinished_configs diff --git a/tests/test_evaluation_harness/test_exact_evaluators.py b/tests/test_evaluation_harness/test_exact_evaluators.py index 0fbf735..0623875 100644 --- a/tests/test_evaluation_harness/test_exact_evaluators.py +++ b/tests/test_evaluation_harness/test_exact_evaluators.py @@ -127,7 +127,6 @@ def test_url_exact_match_fail(script_browser_env: ScriptBrowserEnv) -> None: score = evalutor( trajectory, config_file, env.page ) - print(env.page.url) assert score == 0.0 diff --git a/vwa_tester.py b/vwa_tester.py new file mode 100644 index 0000000..e9e5ce3 --- /dev/null +++ b/vwa_tester.py @@ -0,0 +1,494 @@ +import os +import sys +parent_dir = os.path.dirname(os.path.abspath(__file__)) +up_dir = parent_dir +for i in range(3): + sys.path.append(up_dir) + up_dir = os.path.dirname(up_dir) +from kutils import DEBUG, INFO, WARN, ERROR +import utils as u +import argparse +import requests +import threading +import torch +import subprocess +import multiprocessing +import queue +import time +from tqdm import tqdm +from PIL import Image, ImageChops +from pathlib import Path +from browser_env import ( + Action, + ActionTypes, + ScriptBrowserEnv, + StateInfo, + Trajectory, + create_stop_action, + create_none_action, +) +from browser_env.utils import DetachedPage +from agent.agents import ( + PromptAgent, + construct_agent, +) +from evaluation_harness.evaluators import evaluator_router +from evaluation_harness import image_utils +from browser_env.auto_login import get_site_comb_from_filepath +from browser_env.helper_functions import ( + RenderHelper, + get_action_description, +) +import test_utils as tu +from typing import Optional, Literal + +class VWAConfig: + # Environment + render: bool = False + render_screenshot: bool = True + render_fail_only: bool = True + slow_mo: int = 0 + action_set_tag: str = "id_accessibility_tree" + observation_type: Literal[ + "accessibility_tree", + "accessibility_tree_with_captioner", + "html", + "image", + "image_som", + ] = "accessibility_tree" + current_viewport_only: bool = True + viewport_width: int = 1280 + viewport_height: int = 2100 + sleep_after_execution: float = 0.0 + output_response: bool = False + save_trace_enabled: bool = False + + # Task + max_steps: int = 20 + single_site_mode = False + flush: bool = False + + # Agent + instruction_path: str = "" + parsing_failure_th: int = 3 + repeating_action_failure_th: int = 4 + test_config_base_dir: Optional[str] = None + + # Captioning + caption_model: str = '' + + # Language Model + provider: str = "openai" + eval_provider: str = '' + model: str = "qwen25vl72b" + mode: str = "som" + temperature: float = 1.0 + top_p: float = 0.9 + context_length: int = 0 + max_tokens: int = 32768 + stop_token: Optional[str] = None + vwa_code_path: Optional[str] = None + vwa_data_path: Optional[str] = None + domain: Optional[str] = None + print_time: bool = False + max_retry: int = 30 + max_obs_length: int = 3840 + + # Example range + test_start_idx: int = 0 + test_end_idx: int = 910 + + test_config_files = [] + +class VWAPathHandler(): + args: VWAConfig + dataset_result_path : str + out_model_path : str + output_path_model : str + result_path : str + traj_path : str + render_dir : str + cache_dir : str + auth_dir: str + config_dir : str + metrics_path : str + domain: str + + def __init__(self, args): + self.args = args + self.dataset_result_path = f'{args.vwa_data_path}/results/' + u.mkdir(self.dataset_result_path) + self.out_model_path = f'{args.model}_{args.mode}' + self.output_path_model = f'{self.dataset_result_path}/{self.out_model_path}' + u.mkdir(self.output_path_model) + self.result_path = f'{self.output_path_model}/{args.domain}/' + u.mkdir(self.result_path) + self.traj_path = f'{self.result_path}/traj/' + u.mkdir(self.traj_path) + self.render_dir = f'{self.result_path}/render/' + u.mkdir(self.render_dir) + self.cache_dir = f'{self.result_path}/cache/' + u.mkdir(self.cache_dir) + self.auth_dir = f'{self.result_path}/auth/' + u.mkdir(self.auth_dir) + self.config_dir = f'{self.result_path}/config/' + u.mkdir(self.config_dir) + self.metrics_path = f'{self.result_path}/results/' + u.mkdir(self.metrics_path) + + def reset_output_folder(self, folder_name, result_folder = None): + if result_folder: + self.dataset_result_path = f'{self.args.vwa_data_path}/{result_folder}/' + self.output_path_model = f'{self.dataset_result_path}/{folder_name}' + self.result_path = f'{self.output_path_model}/{self.args.domain}/' + self.traj_path = f'{self.result_path}/traj/' + self.render_dir = f'{self.result_path}/render/' + self.cache_dir = f'{self.result_path}/cache/' + self.auth_dir = f'{self.result_path}/auth/' + self.config_dir = f'{self.result_path}/config/' + self.metrics_path = f'{self.result_path}/results/' + +class VWATester(): + def __init__(self, args: VWAConfig): + self.args = args + self.max_steps = args.max_steps + self.print_time = self.args.print_time + # self.domains = ['reddit', 'classifieds', 'shopping'] + # end_idxs = [210, 234, 466] + if self.args.domain == None or self.args.domain == 'None': + exit() + + self.kq_config = { + 'api_key': os.getenv("OPENROUTER_KEY"), + 'model': self.args.eval_provider, + 'base_url': os.getenv("BASE_URL"), + 'temperature': 0.0, + 'max_tokens': 4096, + } + + self.early_stop_thresholds = { + "parsing_failure": args.parsing_failure_th, + "repeating_action": args.repeating_action_failure_th, + } + + caption_image_fn = image_utils.get_captioning_fn(self.args.caption_model) + self.eval_caption_image_fn = caption_image_fn + + self.agent = construct_agent( + args, + captioning_fn=caption_image_fn + ) # NOTE: captioning_fn here is used for captioning input images. + + self.env = ScriptBrowserEnv( + headless=True, + slow_mo=args.slow_mo, + observation_type=args.observation_type, + current_viewport_only=args.current_viewport_only, + viewport_size={ + "width": args.viewport_width, + "height": args.viewport_height, + }, + save_trace_enabled=True, + sleep_after_execution=args.sleep_after_execution, + # NOTE: captioning_fn here is used for LLM + captioning baselines. + # This can be different from the captioning model used for evals. + captioning_fn=caption_image_fn, + ) + + self.ph = VWAPathHandler(args) + + def handle_meta_bf(self, meta_data, trajectory): + if self.args.mode == 'vision': + try: + bboxes = self.env.get_bboxes() + except Exception as e: + ERROR(e) + meta_data['bbox'] = {} + return meta_data + + meta_data['bbox'] = bboxes + if len(trajectory) > 3: + last_img_str = trajectory[-3]["observation"]["ori_image"] + last_img = Image.fromarray(last_img_str) # size = (viewport_width, viewport_width) + curr_img_str = trajectory[-1]["observation"]["ori_image"] + curr_img = Image.fromarray(curr_img_str) # size = (viewport_width, viewport_width) + f_same = tu.is_images_same(curr_img, last_img) + if f_same and trajectory[-2]['action_info']['pred_action_type'] == 'SCROLL': + meta_data['hint'] += 'You have scrolled to the end of this page.' + + return meta_data + elif self.args.mode == 'som': + return meta_data + else: + return meta_data + + def handle_meta_af(self, meta_data, action, action_str): + if self.args.mode == 'vision': + last_action = action['action_info']['pred_action_description'] + ' ' + action['action_info']['pred_action'] + if action_str == "None" or action_str == 'none': + meta_data['hint'] = 'Last step you clicked an uninteractable area, you should try to change the element you click this time.' + else: + meta_data['hint'] = '' + meta_data["action_history"].append(last_action) + elif self.args.mode == 'som': + raw_response = action['raw_prediction'] + try: + if 'Let\'s think step-by-step. ' in raw_response: + key_content = u.extract_text(raw_response, 'Let\'s think step-by-step. ', ' In summary, the next action I will perform is')[0] + else: + key_content = u.extract_text(raw_response, None, ' In summary, the next action I will perform is')[0] + meta_data["action_history"].append(key_content) + except Exception as e: + # ERROR(f'{e}, response format error') + meta_data["action_history"].append(raw_response) + else: + last_action = action['action_info']['pred_action_description'] + meta_data["action_history"].append(last_action) + return meta_data + + def auto_loging(self, _c, config_file): + # automatically login + if _c["storage_state"]: + cookie_file_name = os.path.basename(_c["storage_state"]) + comb = get_site_comb_from_filepath(cookie_file_name) + # temp_dir = tempfile.mkdtemp() + # subprocess to renew the cookie + print(f'auto login for {comb} ...') + subprocess.run( + [ + "python", + f"{self.args.vwa_code_path}browser_env/auto_login.py", + "--auth_folder", + self.ph.auth_dir, + "--site_list", + *comb, + ] + ) + print('auto login done') + _c["storage_state"] = f"{self.ph.auth_dir}/{cookie_file_name}" + assert os.path.exists(_c["storage_state"]) + # update the config file + config_file = f"{self.ph.config_dir}/{os.path.basename(config_file)}" + u.write_json(config_file, _c) + return config_file + + def load_input_image(self, image_paths): + # Load input images for the task, if any. + images = [] + if image_paths is not None: + if isinstance(image_paths, str): + image_paths = [image_paths] + for image_path in image_paths: + # Load image either from the web or from a local path. + if image_path.startswith("http"): + headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'} + input_image = Image.open(requests.get(image_path, stream=True, headers = headers).raw) + else: + input_image = Image.open(image_path) + input_image = tu.resize_image_proportional(input_image, 500, 500) # TODO + images.append(input_image) + return images + + def rollout(self, config_file, intent, images, render_helper = RenderHelper): + self.agent.reset(config_file) + trajectory: Trajectory = [] + try: + obs, info = self.env.reset(options={"config_file": config_file}) + except Exception as e: + ERROR(f'reset error {e}, waiting for re-setup') + self.__reset() + obs, info = self.env.reset(options={"config_file": config_file}) + + state_info: StateInfo = {"observation": obs, "info": info} + + trajectory.append(state_info) + meta_data = {"action_history": [], 'hint': '', 'tabs': obs['tabs']} + + # for step in range(self.max_steps): + i_step = -1 + while 1: + step_start_time = time.time() + i_step += 1 + early_stop_flag, stop_info = tu.early_stop( + trajectory, self.max_steps, self.early_stop_thresholds + ) + + if early_stop_flag: + WARN('early stop') + action = create_stop_action(f"Early stop: {stop_info}") + else: + meta_data = self.handle_meta_bf(meta_data, trajectory) + meta_data["page"] = self.env.page + # try: + start_time = time.time() + action = self.agent.next_action( + trajectory, + intent, + images=images, + meta_data=meta_data, + output_response=self.args.output_response) + end_time = time.time() + if self.print_time: + INFO(f'step {i_step} model infer time = {round(end_time - start_time, 2)}') + # except Exception as e: + # # get the error message + # ERROR(e) + # action = create_stop_action(f"ERROR: {str(e)}") + + trajectory.append(action) + observation_metadata = state_info["info"]["observation_metadata"] + + start_time = time.time() + action_str = get_action_description( + action, + observation_metadata, + action_set_tag = self.args.action_set_tag, + prompt_constructor=self.agent.prompt_constructor + if isinstance(self.agent, PromptAgent) + else None, + ) + end_time = time.time() + if self.print_time: + INFO(f'step {i_step} get action str time = {round(end_time - start_time, 2)}') + + meta_data = self.handle_meta_af(meta_data, action, action_str) + + if render_helper: + start_time = time.time() + render_helper.render(action, state_info, meta_data, self.args.render_screenshot) + end_time = time.time() + if self.print_time: + INFO(f'step {i_step} render time = {round(end_time - start_time, 2)}') + + if action["action_type"] == ActionTypes.STOP: break + + start_time = time.time() + try: + obs, _, terminated, _, info = self.env.step(action) + state_info = {"observation": obs, "info": info} + trajectory.append(state_info) + u.wait(1) + except Exception as e: + ERROR(f'execution error {e}') + trajectory.append(create_stop_action("")) + break + + end_time = time.time() + if self.print_time: + INFO(f'step {i_step} action exe time = {round(end_time - start_time, 2)}') + + if terminated: + DEBUG(terminated) + # add a action place holder + trajectory.append(create_stop_action("")) + break + + step_end_time = time.time() + if self.print_time: + INFO(f'step {i_step} total time = {round(step_end_time - step_start_time, 2)}') + + return trajectory + + def __find_last_ob(self, lst): + result = None + i = len(lst) + for d in reversed(lst): + i -= 1 + for key in d.keys(): + if "observation" in key: + result = d + break + if result is not None: + break + return i + + def __reset(self): + if self.args.domain == 'reddit': + u.execute(f'bash {u.get_git()}/dataset/visualwebarena/scripts/reset_reddit.sh') + elif self.args.domain == 'classifieds': + u.execute(f'curl -X POST http://localhost:9980/index.php?page=reset -d "token=4b61655535e7ed388f0d40a93600254c"') + elif self.args.domain == 'shopping': + u.execute(f'bash {u.get_git()}/dataset/visualwebarena/scripts/reset_shopping.sh') + + def test(self): + for config_file in tqdm(self.args.test_config_files, multiprocessing.current_process().name): + _c = u.read_json(config_file) + sites = _c['sites'] + if self.args.single_site_mode and len(sites) != 1: + WARN(f'{u.get_name(config_file)} is multi sites task: {sites}') + continue + intent = _c["intent"] + task_id = _c["task_id"] + result_file = f'{self.ph.metrics_path}/{task_id}.json' + + if u.is_file_exist(result_file) and not self.args.flush: + INFO('skip') + continue + if task_id < self.args.test_start_idx or task_id > self.args.test_end_idx: continue + + require_reset = _c['require_reset'] + if require_reset: self.__reset() + + u.write_json(f'{self.ph.config_dir}/{task_id}.json', vars(self.args)) + render_file = f'{self.ph.render_dir}/{task_id}.html' + cache_config_file = self.auto_loging(_c, config_file) + image_paths = _c.get("image", None) + images = self.load_input_image(image_paths) + + traj_file = f'{self.ph.traj_path}/{task_id}.json' + trajectory = [] + + render_helper = None + if self.args.render: + render_helper = RenderHelper(_c, render_file, self.args.action_set_tag, images) + + i_retry = 0 + while i_retry < self.args.max_retry: + trajectory = self.rollout(cache_config_file, intent, images, render_helper) + if trajectory: break + i_retry += 1 + if i_retry == self.args.max_retry: + return + + last_ob_idx = self.__find_last_ob(trajectory) + for i in range(len(trajectory)): + step = trajectory[i] + if i == last_ob_idx: + last_page = trajectory[i]['info']['page'] + trajectory[i]['info']['page'] = {'url': last_page.url, 'content': last_page.content} + if 'observation' in step.keys() and i != last_ob_idx: + del trajectory[i]['info']['page'] + if 'observation' in step.keys(): + del trajectory[i]['observation'] + if 'coords' in step.keys(): + trajectory[i]['coords'] = trajectory[i]['coords'].tolist() + u.write_json(traj_file, trajectory) + + last_page = self.env.page + if render_helper: render_helper.close() + + eval_types = _c["eval"]["eval_types"] + evaluator = evaluator_router( + eval_types, + self.kq_config, + captioning_fn = self.eval_caption_image_fn + ) + start_time = time.time() + score = evaluator( + trajectory=trajectory, + config_file=config_file, + page=last_page + ) + end_time = time.time() + + if self.args.save_trace_enabled: + traj_file = f'{self.ph.traj_path}/{task_id}.zip' + self.env.save_trace(traj_file) + + if self.args.render_fail_only and score: u.execute(f'rm {render_file}') + if self.print_time: + INFO(f'eval time = {round(end_time - start_time, 2)}, result = {score}') + + u.write_json(result_file, {task_id: score}) + + self.env.close()