diff --git a/tools/patch/patchv3.py b/tools/patch/patchv3.py new file mode 100644 index 000000000..0ac8e2a1a --- /dev/null +++ b/tools/patch/patchv3.py @@ -0,0 +1,589 @@ +# -*- coding: utf-8 -*- +import argparse +import copy +import os +import re +import shutil +import tempfile + +import git +import yaml + +from encryption_utils import encrypt_file +from git.repo import Repo +from git_utils import ( + check_git_user_info, + get_diff_between_commit_and_now, + get_file_statuses_for_staged_or_unstaged, + get_file_statuses_for_untracked, + get_submodule_commit, +) +from logger_utils import get_patch_logger + +FLAGSCALE_BACKEND = "FlagScale" +logger = get_patch_logger() + + +def generate_and_save_patch(sub_repo, base_commit, file_path, status, src_dir): + patch_content = "" + try: + if status in ['A', 'UT']: + patch_content = sub_repo.git.diff('--no-index', '/dev/null', file_path) + + elif status in ['M', 'T', 'D']: + patch_content = sub_repo.git.diff(base_commit, '--', file_path) + except git.exc.GitCommandError as e: + if e.status == 1: + raw_output = str(e.stdout) + start_marker = "diff --git" + start_index = raw_output.find(start_marker) + end_index = raw_output.rfind("'") + patch_content = raw_output[start_index:end_index] + else: + raise e + + if patch_content: + target_patch_path = os.path.join(src_dir, file_path + ".patch") + os.makedirs(os.path.dirname(target_patch_path), exist_ok=True) + + with open(target_patch_path, 'w', encoding='utf-8') as f: + content = patch_content if patch_content else "" + if content and not content.endswith('\n'): + content += '\n' + f.write(content) + logger.info(f"Generated patch for '{file_path}' (Status: {status})") + else: + logger.warning(f"No patch content generated for '{file_path}' (Status: {status})") + + +def patch(main_path, submodule_name, src, dst): + """ + Sync the submodule modifications to the corresponding backend in FlagScale. + Args: + main_path (str): The path to the repository. + submodule_name (str): The name of the submodule to be patched, e.g., "Mgeatron-LM". + src (str): The source directory of the submodule, e.g., "flagscale/backends/Megatron-LM". + dst (str): The destination directory of the submodule, e.g., "third_party/Megatron-LM". + """ + + submodule_path = os.path.join("third_party", submodule_name) + logger.info(f"Patching backend {submodule_path}...") + + # Get the submodule repo and the commit in FlagScale. + main_repo = Repo(main_path) + submodule = main_repo.submodule(submodule_path) + sub_repo = submodule.module() + # Get the submodule commit in FlagScale by FlagScale HEAD instead of the submodule HEAD. + submodule_commit_in_fs = main_repo.head.commit.tree[submodule_path].hexsha + logger.info(f"Base commit hash of submodule {submodule_path} is {submodule_commit_in_fs}.") + + # Get all differences between the submodule specified commit and now. + # The differences include staged, working directory, and untracked files. + staged_diff, unstaged_diff, untracked_files = get_diff_between_commit_and_now( + sub_repo, submodule_commit_in_fs + ) + + file_statuses = {} + # Process the diff between the staged and the base commit + staged_file_statuses = get_file_statuses_for_staged_or_unstaged(staged_diff) + file_statuses.update(staged_file_statuses) + # Process the diff between the working directory and the staged + unstaged_file_statuses = get_file_statuses_for_staged_or_unstaged(unstaged_diff) + file_statuses.update(unstaged_file_statuses) + # Process the untracked files + untracked_file_statuses = get_file_statuses_for_untracked(untracked_files) + file_statuses.update(untracked_file_statuses) + + try: + # create temporary path + if os.path.exists(src): + temp_path = tempfile.mkdtemp() + shutil.copytree(src, temp_path, dirs_exist_ok=True) + logger.info(f"Created a temporary backup of '{src}' at '{temp_path}'") + + logger.info(f"Cleaning up old patch directory: {src}") + shutil.rmtree(src, ignore_errors=True) + os.makedirs(src) + + if not file_statuses: + logger.info("No file changes detected. Nothing to patch.") + + else: + logger.info(f"Found {len(file_statuses)} file change(s). Generating patches...") + for file_path, status_info in file_statuses.items(): + status = status_info[0] + generate_and_save_patch(sub_repo, submodule_commit_in_fs, file_path, status, src) + logger.info("Patch generation completed successfully!") + + except Exception as e: + logger.error(f"An error occurred during patch generation: {e}", exc_info=True) + shutil.rmtree(src, ignore_errors=True) + shutil.copytree(temp_path, src, dirs_exist_ok=True) + + finally: + if "temp_path" in locals() and os.path.exists(temp_path): + logger.info(f"Cleaning up temp path: {temp_path}") + shutil.rmtree(temp_path, ignore_errors=True) + + +def patch_hardware(main_path, commit, backends, device_type, tasks, key_path=None): + assert commit is not None, "The commit hash must be specified for hardware patch." + assert backends is not None, "The backends must be specified for hardware patch." + assert device_type is not None, "The device type must be specified for hardware patch." + assert tasks is not None, "The tasks must be specified for hardware patch." + + # ============================ MODIFICATION 1 ================================= + patch_info = prompt_info(main_path, backends, device_type, tasks, flagscale_base_commit=commit) + generate_patch_file(main_path, commit, patch_info, key_path=key_path) + + +def prompt_info(main_path, backends, device_type, tasks, flagscale_base_commit=None): + logger.info("Prompting for patch information: ") + + backends_version = {} + logger.info("1. Please enter backends version: ") + for backend in backends: + version = input(f" {backend} version: ").strip() + while not version: + logger.info(f"Version for {backend} cannot be empty.") + version = input(f" {backend} version: ").strip() + backends_version[backend] = version + + backends_commit = {} + + # ============================== MODIFICATION 2 ================================= + logger.info(f"2. Resolving backend commits from FlagScale base commit: {flagscale_base_commit}...") + + main_repo = Repo(main_path) + if not flagscale_base_commit: + raise ValueError("FlagScale base commit is required to resolve submodule versions.") + + try: + base_commit_obj = main_repo.commit(flagscale_base_commit) + except Exception: + raise ValueError(f"Commit {flagscale_base_commit} not found in FlagScale repository.") + + for backend in backends: + if backend == FLAGSCALE_BACKEND: + continue + + submodule_path = os.path.join("third_party", backend) + try: + original_submodule_commit = base_commit_obj.tree[submodule_path].hexsha + backends_commit[backend] = original_submodule_commit + logger.info(f" -> {backend}: Auto-detected official commit {original_submodule_commit}") + except KeyError: + logger.error(f" -> {backend}: Submodule path not found in commit {flagscale_base_commit}") + backends_commit[backend] = None + except Exception as e: + logger.error(f" -> {backend}: Failed to resolve commit. Error: {e}") + backends_commit[backend] = None + # ============================ MODIFICATION 2 END =============================== + + # FlagScale-compatible models + model_input = input( + "3. Please enter FlagScale-compatible models (separated with commas): " + ).strip() + models = [m.strip() for m in model_input.split(",") if m.strip()] + while not models: + logger.info("At least one FlagScale-compatible model must be provided.") + model_input = input( + "3. Please enter FlagScale-compatible models (separated with commas): " + ).strip() + models = [m.strip() for m in model_input.split(",") if m.strip()] + + # 3. Commit message + commit_msg = input("4. Please enter commit message: ").strip() + while not commit_msg: + logger.info("Commit message cannot be empty.") + commit_msg = input("4. Please enter commit message: ").strip() + + # 4. Contact (optional) + contact_prompt = "5. Please enter email (optional): " + contact = input(contact_prompt).strip() + + return { + "task": tasks, + "backends_version": backends_version, + "device_type": device_type, + "models": models, + "contact": contact, + "commit_msg": commit_msg, + "backends_commit": backends_commit, + } + + +def _generate_patch_file_for_backend( + main_path: str, + commit: str, + backend: str, + patch_info: dict, + key_path=None, + flagscale_commit=None, +): + repo = Repo(main_path) + assert not repo.bare + try: + patch_dir = os.path.join(main_path, "hardware", patch_info["device_type"], backend) + repo_path = ( + os.path.join(main_path, "third_party", backend) + if backend != FLAGSCALE_BACKEND + else main_path + ) + repo = Repo(repo_path) + current_branch = repo.active_branch.name + if repo.bare: + raise Exception("Repository is bare. Cannot proceed.") + + logger.info(f"Generating the patch file for {repo_path}:") + # Step 1: Stash all, including untracked, and create temp_branch_for_hardware_patch + logger.info( + "Step 1: Stashing current changes (including untracked) and create the temp_branch_for_hardware_patch..." + ) + temp_branch = "temp_branch_for_hardware_patch" + repo.git.stash("push", "--include-untracked") + stash_pop = False + if temp_branch in repo.heads: + logger.info( + "Temporary branch 'temp_branch_for_hardware_patch' already exists, deleting..." + ) + repo.git.branch("-D", temp_branch) + repo.git.checkout("-b", temp_branch) + + # Step2: Apply stash on the temp_branch_for_hardware_patch and add + logger.info( + "Step 2: Applying stashed changes and add changes on the temp_branch_for_hardware_patch..." + ) + repo.git.stash("apply") + ( + repo.git.add(all=True) + if backend != FLAGSCALE_BACKEND + else repo.git.add('--all', f':!third_party') + ) + + # Step3: Commit with message "Patch for {commit}". + logger.info("Step 3: Committing changes on the temp_branch_for_hardware_patch...") + repo.git.commit("--no-verify", "-m", f"Patch for {commit}") + + # Step4: Generate patch diff between commit and HEAD, writing into temp_file. + logger.info(f"Step 4: Generating diff patch from {commit} to HEAD...") + + # Diff excludes the submodules + flagscale_diff_args = None + if backend == FLAGSCALE_BACKEND: + backends = copy.deepcopy(list(patch_info["backends_version"].keys())) + flagscale_diff_args = [commit, "HEAD", "--binary", "--ignore-submodules=all", "--"] + for item in backends: + if item != FLAGSCALE_BACKEND: + backend_dir = os.path.join(main_path, "flagscale", "backends", item) + flagscale_diff_args.append(f':(exclude){backend_dir}') + + diff_data = ( + repo.git.diff(commit, "HEAD", "--binary") + if backend != FLAGSCALE_BACKEND + else repo.git.diff(*flagscale_diff_args) + ) + if not diff_data: + raise ValueError(f"No changes in backend {backend}.") + + splits = re.split(r'(?=^diff --git a/)', diff_data, flags=re.MULTILINE) + + temp_patch_files_with_relpath = [] + for file_diff in splits: + if not file_diff.strip(): + continue + + m = re.match(r'diff --git a/.+ b/(.+)', file_diff.splitlines()[0]) + if not m: + continue + filepath = m.group(1) # relative path like megatron/train.py + + temp_file = tempfile.NamedTemporaryFile( + delete=False, mode="w", encoding="utf-8", suffix=".patch" + ) + temp_file.write(file_diff) + + temp_file.write("\n") + temp_file.flush() + temp_file.close() + + temp_patch_files_with_relpath.append((temp_file.name, filepath)) + + if key_path is not None: + encrypted_files_with_relpath = [] + for patch_path, rel_path in temp_patch_files_with_relpath: + encrypted_path = encrypt_file(patch_path, key_path) + encrypted_files_with_relpath.append((encrypted_path, rel_path)) + temp_patch_files_with_relpath = encrypted_files_with_relpath + + temp_yaml_file = tempfile.NamedTemporaryFile( + delete=False, mode="w", encoding="utf-8", suffix=".yaml" + ) + temp_yaml_path = temp_yaml_file.name + data = copy.deepcopy(patch_info) + assert flagscale_commit is not None, "FlagScale commit must be specified." + data["commit"] = flagscale_commit + if "commit_msg" in data: + del data["commit_msg"] + yaml.dump(data, temp_yaml_file, sort_keys=True, allow_unicode=True) + temp_yaml_file.flush() + + # Step5: Checkout to current branch and pop the stash. + logger.info(f"Step 5: Checkouting to current branch and pop the stash...") + repo.git.checkout(current_branch) + repo.git.stash("pop") + stash_pop = True + + except Exception as e: + logger.error(f"{e}", exc_info=True) + logger.info(f"Rolling back to current branch...") + repo.git.checkout(current_branch) + if "stash_pop" in locals() and not stash_pop: + repo.git.stash("pop") + stash_pop = True + + finally: + try: + # Clean up the temporary branch + if "temp_branch" in locals() and temp_branch in repo.heads: + repo.git.branch("-D", temp_branch) + + except Exception as cleanup_error: + logger.error(f"Failed to roll back: {cleanup_error}", exc_info=True) + raise cleanup_error + + return patch_dir, temp_patch_files_with_relpath, temp_yaml_path + + +def generate_patch_file(main_path: str, commit: str, patch_info: dict, key_path=None): + repo = Repo(main_path) + assert not repo.bare + temp_patch_files = [] + + """ + This function performs the following steps. + """ + try: + backends = copy.deepcopy(list(patch_info["backends_version"].keys())) + patches = {} + for backend in backends: + # Generate patch file for each backend + if backend != FLAGSCALE_BACKEND: + # Get the submodule repo and the commit in FlagScale. + main_repo = Repo(main_path) + submodule_path = os.path.join("third_party", backend) + submodule = main_repo.submodule(submodule_path) + + # ============================ MODIFICATION 3: ========================== + submodule_commit_in_fs = submodule.module().head.commit.hexsha + patch_dir, patch_files_with_relpath, temp_yaml_path = ( + _generate_patch_file_for_backend( + main_path, + submodule_commit_in_fs, + backend, + patch_info, + key_path=key_path, + flagscale_commit=commit, + ) + ) + # ============================= MODIFICATION 3 END ======================= + else: + patch_dir, patch_files_with_relpath, temp_yaml_path = ( + _generate_patch_file_for_backend( + main_path, + commit, + backend, + patch_info, + key_path=key_path, + flagscale_commit=commit, + ) + ) + patches[backend] = [patch_dir, patch_files_with_relpath, temp_yaml_path] + + logger.info(f"Checking out {commit}...") + repo.git.checkout(commit) + + logger.info("Staging the generated patch files...") + patch_dir_need_to_clean = [] + for backend in patches: + patch_dir, patch_files_with_relpath, temp_yaml_path = patches[backend] + patch_dir_exist = os.path.exists(patch_dir) + + if not patch_dir_exist: + patch_dir_need_to_clean.append(patch_dir) + + # Remove the backend + if os.path.exists(patch_dir): + shutil.rmtree(patch_dir) + repo.git.rm('-r', patch_dir, ignore_unmatch=True) + os.makedirs(patch_dir, exist_ok=True) + + # Copy each patch file preserving relative path inside patch_dir + for temp_patch_path, rel_path in patch_files_with_relpath: + dst_patch_path = ( + os.path.join(patch_dir, f"{rel_path}.patch") + if key_path is None + else os.path.join(patch_dir, f"{rel_path}.patch.encrypted") + ) + os.makedirs(os.path.dirname(dst_patch_path), exist_ok=True) + shutil.copy(temp_patch_path, dst_patch_path) + repo.git.add(dst_patch_path) + temp_patch_files.append(temp_patch_path) + + # Copy yaml file + yaml_file_name = "diff.yaml" + dst_yaml_path = os.path.join(patch_dir, yaml_file_name) + shutil.copy(temp_yaml_path, dst_yaml_path) + repo.git.add(dst_yaml_path) + temp_patch_files.append(temp_yaml_path) + + # Commit the patch file with the same message. + logger.info("Committing the patch file...") + commit_msg = patch_info["commit_msg"] + repo.git.commit("--no-verify", "-m", commit_msg) + + logger.info( + "Commit successfully! If you want to push, try 'git push origin HEAD:refs/heads/(your branch)' or 'git push --force origin HEAD:refs/heads/(your branch)'" + ) + + except Exception as e: + logger.error(f"{e}", exc_info=True) + + finally: + try: + # Clean up the temp files + if "temp_patch_files" in locals(): + for temp_patch_path in temp_patch_files: + if os.path.exists(temp_patch_path): + os.remove(temp_patch_path) + logger.debug(f"Temporary patch file {temp_patch_path} deleted.") + + # Clean up the patch dir + if "patch_dir_need_to_clean" in locals(): + for patch_dir in patch_dir_need_to_clean: + if os.path.exists(patch_dir): + shutil.rmtree(patch_dir) + logger.debug(f"Temporary patch dir {patch_dir} deleted.") + + except Exception as cleanup_error: + logger.error(f"Failed to delete temporary: {cleanup_error}", exc_info=True) + raise cleanup_error + + +def validate_patch_args(device_type, task, commit, main_path): + main_repo = Repo(main_path) + if commit: + # Check if the commit exists in the FlagScale + try: + main_repo.commit(commit) + except ValueError: + raise ValueError(f"Commit {commit} does not exist in the FlagScale.") + + if device_type: + if ( + device_type.count("_") != 1 + or len(device_type.split("_")) != 2 + or not device_type.split("_")[0] + or not device_type.split("_")[0][0].isupper() + ): + raise ValueError("Device type is invalid!") + + if commit or device_type or task: + assert ( + commit and device_type and task + ), "The args commit, device_type, task must not be None." + + +def normalize_backend(backend): + """ + Normalize backend to standard backend names + + Args: + backend (str): Backend name provided by the user. + + Returns: + str: Standardized backend name. + """ + + input_lower = backend.lower() + + if input_lower in ["megatron", "megatron-lm"]: + return "Megatron-LM" + elif input_lower in ["energon", "megatron-energon"]: + return "Megatron-Energon" + elif input_lower in ["fs", "flagscale"]: + return "FlagScale" + elif input_lower == "vllm": + return "vllm" + elif input_lower == "sglang": + return "sglang" + elif input_lower in ["llama.cpp", "llama_cpp"]: + return "llama.cpp" + elif input_lower in ["omniinfer", "omni_infer", "OmniInfer"]: + return "omniinfer" + elif input_lower in ["verl"]: + return "verl" + + raise ValueError(f'Unsupported backend {backend}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Sync submodule modifications to the corresponding backend in FlagScale." + ) + parser.add_argument( + "--backend", + nargs="+", + type=normalize_backend, + default=["Megatron-LM"], + help="Backend to patch (default: Megatron-LM)", + ) + parser.add_argument( + "--commit", type=str, default=None, help="Patch based on this commit. Default is None." + ) + parser.add_argument( + "--device-type", type=str, default=None, help="Device type. Default is None." + ) + parser.add_argument("--task", nargs="+", default=None, help="Task. Default is None") + parser.add_argument( + "--key-path", + type=str, + default=None, + help="The path for storing public and private keys. Be careful not to upload to the Git repository.", + ) + + args = parser.parse_args() + backends = args.backend + commit = args.commit + tasks = args.task + device_type = args.device_type + key_path = args.key_path + + if not isinstance(backends, list): + backends = [backends] + + if tasks is not None and not isinstance(tasks, list): + tasks = [tasks] + + # FlagScale/tools/patch + script_dir = os.path.dirname(os.path.realpath(__file__)) + # FlagScale/tools + script_dir = os.path.dirname(script_dir) + # FlagScale + main_path = os.path.dirname(script_dir) + + check_git_user_info(main_path) + + validate_patch_args(device_type, tasks, commit, main_path) + + if FLAGSCALE_BACKEND in backends: + assert commit is not None, "FlagScale patch only can be generated with hardware." + + if commit: + # hardware patch + patch_hardware(main_path, commit, backends, device_type, tasks, key_path=key_path) + else: + for backend in backends: + dst = os.path.join(main_path, "third_party", backend) + src = os.path.join(main_path, "flagscale", "backends", backend) + patch(main_path, backend, src, dst) diff --git a/tools/patch/unpatchv3.py b/tools/patch/unpatchv3.py new file mode 100644 index 000000000..e35e90f50 --- /dev/null +++ b/tools/patch/unpatchv3.py @@ -0,0 +1,499 @@ +import argparse +import os +import shutil +import tempfile + +import git +import yaml + +from encryption_utils import decrypt_file +from git.repo import Repo +from logger_utils import get_unpatch_logger +from patch import normalize_backend + +FLAGSCALE_BACKEND = "FlagScale" +logger = get_unpatch_logger() + + +def apply_patches_from_directory(src_dir, dst_dir): + if not os.path.isdir(src_dir): + logger.warning(f"Patch directory '{src_dir}' does not exist. Nothing to apply.") + return + + try: + repo = Repo(dst_dir) + for root, _, files in os.walk(src_dir): + for file in sorted(files): + if file.endswith(".patch"): + patch_file_path = os.path.join(root, file) + logger.info(f"Applying patch: {patch_file_path}") + try: + repo.git.apply(patch_file_path, check=True) + repo.git.apply(patch_file_path) + except git.exc.GitCommandError as e: + logger.error( + f"Failed to apply patch '{patch_file_path}'. Error: {e.stderr}" + ) + raise e + + except Exception as e: + logger.error(f"An error occurred while setting up patch application for '{dst_dir}': {e}") + raise e + + +def unpatch(main_path, src, dst, submodule_name, force=False, backend_commit={}, fs_extension=True): + """Unpatch the backend with patches.""" + if submodule_name != FLAGSCALE_BACKEND: + logger.info(f"Unpatching backend {submodule_name}...") + submodule_commit = None + if backend_commit and backend_commit[submodule_name] is not None: + submodule_commit = backend_commit[submodule_name] + init_submodule(main_path, dst, submodule_name, force=force, commit=submodule_commit) + if fs_extension: + apply_patches_from_directory(src, dst) + logger.info(f"Successfully applied all patches to backend {submodule_name}") + + # ============================ MODIFICATION 1 ====================================== + try: + sub_repo = Repo(dst) + if sub_repo.is_dirty(untracked_files=True): + logger.info(f"Committing FlagScale patches to submodule {submodule_name}...") + + sub_repo.git.add(all=True) + sub_repo.git.commit('-m', f'Apply FlagScale base patches for {submodule_name}') + new_sub_commit = sub_repo.head.commit.hexsha + logger.info(f"Committed FlagScale patches in submodule. New commit: {new_sub_commit}") + + logger.info(f"Updating main FlagScale repo to point to new submodule commit...") + main_repo = Repo(main_path) + + try: + main_repo.config_reader().get_value('user', 'name') + main_repo.config_reader().get_value('user', 'email') + except Exception: + logger.warning("Git user info not set. Using default bot for this local commit.") + with main_repo.config_writer() as writer: + writer.set_value('user', 'name', 'flagscale-bot') + writer.set_value('user', 'email', 'bot@flagscale.com') + + main_repo.git.add(dst) + main_repo.git.commit('-m', f'Update {submodule_name} with FlagScale base patches') + logger.info(f"Main repo commit created. HEAD is now at {main_repo.head.commit.hexsha}") + + else: + logger.info(f"No FlagScale patches to commit in submodule {submodule_name} (clean).") + + except Exception as e: + logger.error(f"Failed to auto-commit submodule changes: {e}", exc_info=True) + raise e + # ================================================================== + + else: + logger.info( + f"FlagScale extension for {submodule_name} is disabled, skipping unpatching..." + ) + + +def init_submodule(main_path, dst, submodule_name, force=False, commit=None): + if os.path.lexists(dst) and len(os.listdir(dst)) > 0 and not force: + logger.info(f"Skipping {submodule_name} initialization, as it already lexists.") + return + logger.info(f"Initializing submodule {submodule_name}...") + logger.warning( + "When you perform unpatch, the specified submodule will be fully restored to its initial state, regardless of any modifications you may have made within the submodule." + ) + repo = Repo(main_path) + submodule_name = os.path.join("third_party", submodule_name) + submodule = repo.submodule(submodule_name) + retry_times = 2 + for _ in range(retry_times): + try: + git_modules_path = os.path.join(main_path, ".git", "modules", submodule_name) + if os.path.exists(git_modules_path): + shutil.rmtree(git_modules_path) + submodule_worktree_path = os.path.join(main_path, submodule_name) + if os.path.exists(submodule_worktree_path): + shutil.rmtree(submodule_worktree_path) + submodule.update(init=True, force=force) + if commit: + sub_repo = submodule.module() + sub_repo.git.reset('--hard', commit) + logger.info(f"Reset {submodule_name} to commit {commit}.") + logger.info(f"Initialized {submodule_name} submodule.") + break + + except Exception as e: + logger.error(f"Exception occurred: {e}", exc_info=True) + logger.info(f"Retrying to initialize submodule {submodule_name}...") + + +def commit_to_checkout(main_path, device_type=None, tasks=None, backends=None, commit=None): + if commit: + return commit + + newest_flagscale_commit = None + main_repo = Repo(main_path) + if device_type and tasks: + # Check if device_type is in the format xxx_yyy + if device_type.count("_") != 1 or len(device_type.split("_")) != 2: + raise ValueError("Invalid format. Device type must be in the format xxx_yyy.") + + assert backends + history_yaml = os.path.join(main_path, "hardware", "patch_history.yaml") + if not os.path.exists(history_yaml): + logger.warning( + f"Yaml {history_yaml} does not exist. Please check the hardware/patch_history.yaml." + ) + logger.warning("Try to use the current commit to unpatch.") + return main_repo.head.commit.hexsha + + # Backend key + backends_key = "+".join(sorted(backends)) + # Newest flagscale commit to checkout and unpatch + newest_flagscale_commit = None + # Find newest flagscale commit + with open(history_yaml, 'r') as f: + history = yaml.safe_load(f) + if device_type not in history: + logger.warning(f"Device type {device_type} not found in {history_yaml}.") + logger.warning("Try to use the current commit to unpatch.") + return main_repo.head.commit.hexsha + + # Find the newest flagscale commit in the history + for task in tasks: + if task not in history[device_type]: + continue + if backends_key not in history[device_type][task]: + continue + if ( + not isinstance(history[device_type][task][backends_key], list) + or not history[device_type][task][backends_key] + ): + continue + newest_flagscale_commit = history[device_type][task][backends_key][-1] + try: + main_repo.commit(newest_flagscale_commit) + break + except ValueError: + raise ValueError( + f"The commit ID {newest_flagscale_commit} does not exist in the FlagScale. Please check the {history_yaml}" + ) + newest_flagscale_commit = None + if not newest_flagscale_commit: + logger.warning( + f"No valid commit found for device type {device_type}, task {task} in {history_yaml}. Try to use the current commit to unpatch." + ) + return main_repo.head.commit.hexsha + return newest_flagscale_commit + + +def apply_hardware_patch( + device_type, backends, commit, main_path, need_init_submodule, key_path=None +): + build_path = os.path.join(main_path, "build", device_type) + final_path = os.path.join(build_path, os.path.basename(main_path)) + + try: + # Remove existing build directory if present. + if os.path.exists(build_path): + logger.info(f"Removing existing build path: {build_path}") + shutil.rmtree(build_path) + + temp_path = tempfile.mkdtemp() + logger.info(f"Step 1: Copying {main_path} to temp path {temp_path}") + shutil.copytree(main_path, temp_path, dirs_exist_ok=True) + + repo = Repo(temp_path) + # Stash firstly to prevent checkout failed + repo.git.stash("push", "--include-untracked") + logger.info(f"Step 2: Checking out {commit} in temp path {temp_path}") + repo.git.checkout(commit) + + # Check device path + device_path = os.path.join(temp_path, "hardware", device_type) + if not os.path.exists(device_path): + raise ValueError(f"{device_path} is not found.") + + # Check backend path and patch file path + all_base_commit_id = set() + patch_files = [] + patch_backends = [] + backends_commit = {} + + for backend in backends: + backend_path = os.path.join(device_path, backend) + if not os.path.exists(backend_path): + raise ValueError(f"{backend_path} is not found.") + + yaml_file = os.path.join(backend_path, "diff.yaml") + if not os.path.isfile(yaml_file): + raise ValueError(f"Missing diff.yaml in {backend_path}") + + with open(yaml_file, "r") as f: + info = yaml.safe_load(f) + base_commit_id = info["commit"] + if "backends_commit" in info and backend in info["backends_commit"]: + backends_commit[backend] = info["backends_commit"][backend] + assert base_commit_id + all_base_commit_id.add(base_commit_id) + + backend_patch_files = [] + for root, _, files in os.walk(backend_path): + for file in files: + if file.endswith(".patch") or file.endswith(".patch.encrypted"): + backend_patch_files.append(os.path.join(root, file)) + if not backend_patch_files: + raise ValueError(f"No patch files found in {backend_path}") + patch_files.extend(backend_patch_files) + patch_backends.extend([backend] * len(backend_patch_files)) + + all_base_commit_id = list(all_base_commit_id) + + # Sort the commit by appearance order + position = {} + rev_list = repo.git.rev_list('--topo-order', 'HEAD').splitlines() + for idx, commit in enumerate(rev_list): + if commit in all_base_commit_id: + position[commit] = idx + + # Check if all commits were found + missing = set(all_base_commit_id) - set(position.keys()) + if missing: + raise ValueError(f"The following commits were not found in rev-list: {missing}") + + sorted_commits = sorted(all_base_commit_id, key=lambda x: position[x]) + # Get the neweset base_commit_id + base_commit_id = sorted_commits[0] + logger.info(f"Step 3: Finding the newset base commit {base_commit_id} to checkout.") + + temp_unpatch_path = tempfile.mkdtemp() + logger.info(f"Step 4: Copying {temp_path} to temp unpatch path {temp_unpatch_path}") + shutil.copytree(temp_path, temp_unpatch_path, dirs_exist_ok=True) + repo = Repo(temp_unpatch_path) + repo.git.checkout(base_commit_id) + + logger.info(f"Step 5: Applying patch:") + initialized_backends = set() + for idx, patch_file in enumerate(patch_files): + # Check if the patch file is encrypted + new_patch_file = patch_file + if patch_file.endswith(".encrypted"): + if key_path is not None: + private_key_path = os.path.join(key_path, "private_key.pem") + new_patch_file = decrypt_file(patch_file, private_key_path) + else: + raise ValueError( + f"Patch file {patch_file} is encrypted, but no key path provided." + ) + backend = patch_backends[idx] + if backend != FLAGSCALE_BACKEND: + # init submodule + if need_init_submodule: + if backend not in initialized_backends: + logger.info( + f" Initializing submodule {backend} in temp unpatch path {temp_unpatch_path}..." + ) + dst = os.path.join(temp_unpatch_path, "third_party", backend) + src = os.path.join(temp_unpatch_path, "flagscale", "backends", backend) + # Initialize the submodule + + submodule_commit = None + if backends_commit and backend in backends_commit: + submodule_commit = backends_commit[backend] + init_submodule( + temp_unpatch_path, dst, backend, force=True, commit=submodule_commit + ) + initialized_backends.add(backend) + submodule_path = ( + os.path.join(temp_unpatch_path, "third_party", backend) + if backend != FLAGSCALE_BACKEND + else temp_unpatch_path + ) + + repo = Repo(submodule_path) + try: + repo.git.apply("--whitespace", "warn", new_patch_file) + except Exception as e: + logger.warning( + f"Failed to apply patch cleanly, and error is {e.stderr}. Retrying with --whitespace=fix." + ) + + repo.git.apply("--whitespace", "fix", new_patch_file) + logger.info(f"Patch {new_patch_file} has been applied.") + + logger.info(f"Step 6: Moving patched temp path {temp_unpatch_path} to {final_path}") + os.makedirs(build_path, exist_ok=True) + shutil.move(temp_unpatch_path, final_path) + logger.info(f"Unpatch Ended.") + + except Exception as e: + logger.error(f"Exception occurred: {e}", exc_info=True) + + # Clean up temp directory + if "temp_path" in locals() and os.path.exists(temp_path): + logger.info(f"Cleaning up temp path: {temp_path}") + shutil.rmtree(temp_path, ignore_errors=True) + + # Clean up temp directory + if "temp_unpatch_path" in locals() and os.path.exists(temp_unpatch_path): + logger.info(f"Cleaning up temp path: {temp_unpatch_path}") + shutil.rmtree(temp_unpatch_path, ignore_errors=True) + + # Clean up build directory + if os.path.exists(build_path): + logger.info(f"Cleaning up build path: {build_path}") + shutil.rmtree(build_path, ignore_errors=True) + + raise ValueError("Error occurred during unpatching.") + + finally: + # Clean up temp directory + if "temp_path" in locals() and os.path.exists(temp_path): + logger.info(f"Cleaning up temp path: {temp_path}") + shutil.rmtree(temp_path, ignore_errors=True) + + # Clean up temp directory + if "temp_unpatch_path" in locals() and os.path.exists(temp_unpatch_path): + logger.info(f"Cleaning up temp path: {temp_unpatch_path}") + shutil.rmtree(temp_unpatch_path, ignore_errors=True) + + return final_path + + +def validate_unpatch_args(device_type, tasks, commit, main_path): + main_repo = Repo(main_path) + if commit: + # Check if the commit exists in the FlagScale + try: + main_repo.commit(commit) + except ValueError: + raise ValueError(f"Commit {commit} does not exist in the FlagScale.") + if device_type: + if ( + device_type.count("_") != 1 + or len(device_type.split("_")) != 2 + or not device_type.split("_")[0] + or not device_type.split("_")[0][0].isupper() + ): + raise ValueError("Device type is invalid!") + + if device_type or tasks: + assert device_type and tasks, "The args device_type, task must not be None." + + +def backend_commit_mapping(backends, backends_commit): + backend_commit = {} + for idx, backend in enumerate(backends): + if backend == FLAGSCALE_BACKEND: + assert backends_commit == [None], "FlagScale backend commit must be None." + else: + if idx >= len(backends_commit): + backend_commit[backend] = None + else: + backend_commit[backend] = backends_commit[idx] + + return backend_commit + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Patch or unpatch backend with patch files.") + parser.add_argument( + "--backend", + nargs="+", + type=normalize_backend, + default=["Megatron-LM"], + help="Backend to unpatch (default: Megatron-LM)", + ) + + parser.add_argument( + "--device-type", type=str, default=None, help="Device type. Default is None." + ) + parser.add_argument( + "--task", + nargs="+", + default=None, + choices=["train", "inference", "post_train"], + help="Task. Default is None", + ) + parser.add_argument( + "--commit", type=str, default=None, help="Unpatch based on this commit. Default is None." + ) + parser.add_argument( + '--no-force', dest='force', action='store_false', help='Do not force update the backend.' + ) + parser.add_argument( + "--no-init-submodule", + action="store_false", + dest="init_submodule", + help="Do not initialize and update submodules. Default is True.", + ) + parser.add_argument( + "--key-path", + type=str, + default=None, + help="The path for storing public and private keys. Be careful not to upload to the Git repository.", + ) + parser.add_argument( + "--no-fs-extension", + action="store_false", + dest="fs_extension", + help="Disable fs extension. Default is True.", + ) + parser.add_argument( + "--backend-commit", nargs="+", default=[None], help="The backend commit to checkout." + ) + + args = parser.parse_args() + backends = args.backend + device_type = args.device_type + tasks = args.task + commit = args.commit + key_path = args.key_path + backends_commit = args.backend_commit + fs_extension = args.fs_extension + + if not isinstance(backends, list): + backends = [backends] + + if not isinstance(backends_commit, list): + backends_commit = [backends_commit] + + if tasks is not None and not isinstance(tasks, list): + tasks = [tasks] + + # FlagScale/tools/patch + script_dir = os.path.dirname(os.path.realpath(__file__)) + # FlagScale/tools + script_dir = os.path.dirname(script_dir) + # FlagScale + main_path = os.path.dirname(script_dir) + + validate_unpatch_args(device_type, tasks, commit, main_path) + backend_commit = backend_commit_mapping(backends, backends_commit) + + if FLAGSCALE_BACKEND in backends: + assert ( + device_type is not None + ), "FlagScale unpatch only can be applied with hardware unpatch." + + # Check patch exist + commit = commit_to_checkout(main_path, device_type, tasks, backends, commit) + if commit is not None: + # Checkout to the commit and apply the patch to build FlagScale + apply_hardware_patch( + device_type, backends, commit, main_path, args.init_submodule, key_path=key_path + ) + + else: + for backend in backends: + dst = os.path.join(main_path, "third_party", backend) + src = os.path.join(main_path, "flagscale", "backends", backend) + unpatch( + main_path, + src, + dst, + backend, + force=args.force, + backend_commit=backend_commit, + fs_extension=fs_extension, + )