Skip to content

Commit ad51995

Browse files
bhackpytorchmergebot
authored andcommitted
Add a nightly hotpatch utils for python only PR (pytorch#136535)
I think this could help many teams, especially compile/export teams (/cc @ezyang), to let end user/bug reporters to quickly test WIP PR when reporting a related bug. This could quickly run in an official nightly Docker container or in a nightly venv/coda env. Let me know what do you think. Pull Request resolved: pytorch#136535 Approved by: https://github.com/ezyang
1 parent 9d72f74 commit ad51995

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

tools/nightly_hotpatch.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import os
5+
import shutil
6+
import subprocess
7+
import sys
8+
import tempfile
9+
import urllib.request
10+
from typing import cast, List, NoReturn, Optional
11+
12+
13+
def parse_arguments() -> argparse.Namespace:
14+
"""
15+
Parses command-line arguments using argparse.
16+
17+
Returns:
18+
argparse.Namespace: The parsed arguments containing the PR number, optional target directory, and strip count.
19+
"""
20+
parser = argparse.ArgumentParser(
21+
description=(
22+
"Download and apply a Pull Request (PR) patch from the PyTorch GitHub repository "
23+
"to your local PyTorch installation.\n\n"
24+
"Best Practice: Since this script involves hot-patching PyTorch, it's recommended to use "
25+
"a disposable environment like a Docker container or a dedicated Python virtual environment (venv). "
26+
"This ensures that if the patching fails, you can easily recover by resetting the environment."
27+
),
28+
epilog=(
29+
"Example:\n"
30+
" python nightly_hotpatch.py 12345\n"
31+
" python nightly_hotpatch.py 12345 --directory /path/to/pytorch --strip 1\n\n"
32+
"These commands will download the patch for PR #12345 and apply it to your local "
33+
"PyTorch installation."
34+
),
35+
formatter_class=argparse.RawDescriptionHelpFormatter,
36+
)
37+
38+
parser.add_argument(
39+
"PR_NUMBER",
40+
type=int,
41+
help="The number of the Pull Request (PR) from the PyTorch GitHub repository to download and apply as a patch.",
42+
)
43+
44+
parser.add_argument(
45+
"--directory",
46+
"-d",
47+
type=str,
48+
default=None,
49+
help="Optional. Specify the target directory to apply the patch. "
50+
"If not provided, the script will use the PyTorch installation path.",
51+
)
52+
53+
parser.add_argument(
54+
"--strip",
55+
"-p",
56+
type=int,
57+
default=1,
58+
help="Optional. Specify the strip count to remove leading directories from file paths in the patch. Default is 1.",
59+
)
60+
61+
return parser.parse_args()
62+
63+
64+
def get_pytorch_path() -> str:
65+
"""
66+
Retrieves the installation path of PyTorch in the current environment.
67+
68+
Returns:
69+
str: The directory of the PyTorch installation.
70+
71+
Exits:
72+
If PyTorch is not installed in the current Python environment, the script will exit.
73+
"""
74+
try:
75+
import torch
76+
77+
torch_paths: List[str] = cast(List[str], torch.__path__)
78+
torch_path: str = torch_paths[0]
79+
parent_path: str = os.path.dirname(torch_path)
80+
print(f"PyTorch is installed at: {torch_path}")
81+
print(f"Parent directory for patching: {parent_path}")
82+
return parent_path
83+
except ImportError:
84+
handle_import_error()
85+
86+
87+
def handle_import_error() -> NoReturn:
88+
"""
89+
Handle the case where PyTorch is not installed and exit the program.
90+
91+
Exits:
92+
NoReturn: This function will terminate the program.
93+
"""
94+
print("Error: PyTorch is not installed in the current Python environment.")
95+
sys.exit(1)
96+
97+
98+
def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
99+
"""
100+
Downloads the patch file for a given PR from the specified GitHub repository.
101+
102+
Args:
103+
pr_number (int): The pull request number.
104+
repo_url (str): The URL of the repository where the PR is hosted.
105+
download_dir (str): The directory to store the downloaded patch.
106+
107+
Returns:
108+
str: The path to the downloaded patch file.
109+
110+
Exits:
111+
If the download fails, the script will exit.
112+
"""
113+
patch_url = f"{repo_url}/pull/{pr_number}.diff"
114+
patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch")
115+
print(f"Downloading PR #{pr_number} patch from {patch_url}...")
116+
try:
117+
with urllib.request.urlopen(patch_url) as response, open(
118+
patch_file, "wb"
119+
) as out_file:
120+
shutil.copyfileobj(response, out_file)
121+
if not os.path.isfile(patch_file):
122+
print(f"Failed to download patch for PR #{pr_number}")
123+
sys.exit(1)
124+
print(f"Patch downloaded to {patch_file}")
125+
return patch_file
126+
except urllib.error.HTTPError as e:
127+
print(f"HTTP Error: {e.code} when downloading patch for PR #{pr_number}")
128+
sys.exit(1)
129+
except Exception as e:
130+
print(f"An error occurred while downloading the patch: {e}")
131+
sys.exit(1)
132+
133+
134+
def apply_patch(patch_file: str, target_dir: Optional[str], strip_count: int) -> None:
135+
"""
136+
Applies the downloaded patch to the specified directory using the given strip count.
137+
138+
Args:
139+
patch_file (str): The path to the patch file.
140+
target_dir (Optional[str]): The directory to apply the patch to. If None, uses PyTorch installation path.
141+
strip_count (int): The number of leading directories to strip from file paths in the patch.
142+
143+
Exits:
144+
If the patch command fails or the 'patch' utility is not available, the script will exit.
145+
"""
146+
if target_dir:
147+
print(f"Applying patch in directory: {target_dir}")
148+
else:
149+
print("No target directory specified. Using PyTorch installation path.")
150+
151+
print(f"Applying patch with strip count: {strip_count}")
152+
try:
153+
# Construct the patch command with -d and -p options
154+
patch_command = ["patch", f"-p{strip_count}", "-i", patch_file]
155+
156+
if target_dir:
157+
patch_command.insert(
158+
1, f"-d{target_dir}"
159+
) # Insert -d option right after 'patch'
160+
print(f"Running command: {' '.join(patch_command)}")
161+
result = subprocess.run(patch_command, capture_output=True, text=True)
162+
else:
163+
patch_command.insert(1, f"-d{target_dir}")
164+
print(f"Running command: {' '.join(patch_command)}")
165+
result = subprocess.run(patch_command, capture_output=True, text=True)
166+
167+
# Check if the patch was applied successfully
168+
if result.returncode != 0:
169+
print("Failed to apply patch.")
170+
print("Patch output:")
171+
print(result.stdout)
172+
print(result.stderr)
173+
sys.exit(1)
174+
else:
175+
print("Patch applied successfully.")
176+
except FileNotFoundError:
177+
print("Error: The 'patch' utility is not installed or not found in PATH.")
178+
sys.exit(1)
179+
except Exception as e:
180+
print(f"An error occurred while applying the patch: {e}")
181+
sys.exit(1)
182+
183+
184+
def main() -> None:
185+
"""
186+
Main function to orchestrate the patch download and application process.
187+
188+
Steps:
189+
1. Parse command-line arguments to get the PR number, optional target directory, and strip count.
190+
2. Retrieve the local PyTorch installation path or use the provided target directory.
191+
3. Download the patch for the provided PR number.
192+
4. Apply the patch to the specified directory with the given strip count.
193+
"""
194+
args = parse_arguments()
195+
pr_number = args.PR_NUMBER
196+
custom_target_dir = args.directory
197+
strip_count = args.strip
198+
199+
if custom_target_dir:
200+
if not os.path.isdir(custom_target_dir):
201+
print(
202+
f"Error: The specified target directory '{custom_target_dir}' does not exist."
203+
)
204+
sys.exit(1)
205+
target_dir = custom_target_dir
206+
print(f"Using custom target directory: {target_dir}")
207+
else:
208+
target_dir = get_pytorch_path()
209+
210+
repo_url = "https://github.com/pytorch/pytorch"
211+
212+
with tempfile.TemporaryDirectory() as tmpdirname:
213+
patch_file = download_patch(pr_number, repo_url, tmpdirname)
214+
apply_patch(patch_file, target_dir, strip_count)
215+
216+
217+
if __name__ == "__main__":
218+
main()

0 commit comments

Comments
 (0)