diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 0749384a4..9cfb0ca34 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -13,6 +13,7 @@ from codegen.extensions.linear.linear_client import LinearClient from codegen.extensions.tools.bash import run_bash_command from codegen.extensions.tools.github.checkout_pr import checkout_pr +from codegen.extensions.tools.github.edit_pr import edit_pr from codegen.extensions.tools.github.view_pr_checks import view_pr_checks from codegen.extensions.tools.global_replacement_edit import replacement_edit_global from codegen.extensions.tools.linear.linear import ( @@ -514,6 +515,46 @@ def _run(self, title: str, body: str) -> str: return result.render() +class GithubEditPRInput(BaseModel): + """Input for editing a PR.""" + + pr_number: int = Field(..., description="The PR number to edit") + title: Optional[str] = Field(None, description="The new title for the PR (optional)") + body: Optional[str] = Field(None, description="The new body/description for the PR (optional)") + state: Optional[str] = Field( + None, + description="The new state for the PR (optional, can be 'open', 'closed', 'draft', or 'ready_for_review')" + ) + + +class GithubEditPRTool(BaseTool): + """Tool for editing a PR's title, body, and/or state.""" + + name: ClassVar[str] = "edit_pr" + description: ClassVar[str] = "Edit a PR's title and/or body and/or state. The (optional) state parameter can be 'open', 'closed', 'draft', or 'ready_for_review'." + args_schema: ClassVar[type[BaseModel]] = GithubEditPRInput + codebase: Codebase = Field(exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) + + def _run( + self, + pr_number: int, + title: Optional[str] = None, + body: Optional[str] = None, + state: Optional[str] = None, + ) -> str: + result = edit_pr( + self.codebase, + pr_number=pr_number, + title=title, + body=body, + state=state, + ) + return result.render() + + class GithubSearchIssuesInput(BaseModel): """Input for searching GitHub issues.""" @@ -656,7 +697,7 @@ class GithubViewPRCheckTool(BaseTool): name: ClassVar[str] = "view_pr_checks" description: ClassVar[str] = "View the check suites for a PR" - args_schema: ClassVar[type[BaseModel]] = GithubCreatePRReviewCommentInput + args_schema: ClassVar[type[BaseModel]] = GithubViewPRCheckInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: @@ -875,9 +916,12 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: ReflectionTool(codebase), # Github GithubCreatePRTool(codebase), + GithubEditPRTool(codebase), GithubCreatePRCommentTool(codebase), GithubCreatePRReviewCommentTool(codebase), GithubViewPRTool(codebase), + GithubViewPRCheckTool(codebase), + GithubCheckoutPRTool(codebase), GithubSearchIssuesTool(codebase), # Linear LinearGetIssueTool(codebase), diff --git a/src/codegen/extensions/tools/github/edit_pr.py b/src/codegen/extensions/tools/github/edit_pr.py new file mode 100644 index 000000000..03c90505c --- /dev/null +++ b/src/codegen/extensions/tools/github/edit_pr.py @@ -0,0 +1,94 @@ +"""Tool for editing a PR's title, body, and/or state.""" + +from typing import TYPE_CHECKING, Optional + +from codegen.extensions.tools.observation import Observation +from codegen.sdk.core.codebase import Codebase + +if TYPE_CHECKING: + from github.PullRequest import PullRequest + + +def edit_pr( + codebase: Codebase, + pr_number: int, + title: Optional[str] = None, + body: Optional[str] = None, + state: Optional[str] = None, +) -> Observation: + """Edit a PR's title, body, and/or state. + + Args: + codebase: The codebase to operate on + pr_number: The PR number to edit + title: The new title for the PR (optional) + body: The new body/description for the PR (optional) + state: The new state for the PR (optional, can be 'open', 'closed', 'draft', or 'ready_for_review') + + Returns: + Observation with the result of the operation + """ + repo = codebase.git_client.get_repo() + if not repo: + return Observation( + success=False, + message=f"Failed to get repository for PR #{pr_number}", + ) + + try: + pr: PullRequest = repo.get_pull(pr_number) + except Exception as e: + return Observation( + success=False, + message=f"Failed to get PR #{pr_number}: {e}", + ) + + # Track what was updated + updates = [] + + # Update title if provided + if title is not None: + pr.edit(title=title) + updates.append("title") + + # Update body if provided + if body is not None: + pr.edit(body=body) + updates.append("body") + + # Update state if provided + if state is not None: + state = state.lower() + if state == "closed": + pr.edit(state="closed") + updates.append("state (closed)") + elif state == "open": + pr.edit(state="open") + updates.append("state (opened)") + elif state == "draft": + pr.as_draft() + updates.append("state (converted to draft)") + elif state == "ready_for_review": + pr.ready_for_review() + updates.append("state (marked ready for review)") + else: + return Observation( + success=False, + message=f"Invalid state '{state}'. Must be one of: 'open', 'closed', 'draft', or 'ready_for_review'", + ) + + if not updates: + return Observation( + success=True, + message=f"No changes were made to PR #{pr_number}. Please provide at least one of: title, body, or state.", + ) + + return Observation( + success=True, + message=( + f"Successfully updated PR #{pr_number} ({', '.join(updates)}). " + "Note that this tool only updates PR metadata and does not push code changes to the PR branch. " + "To add code changes to a PR, make your edits and then use the `create_pr` tool while on the PR branch." + ), + url=pr.html_url, + )