From 6df7eaeab3d454cf9f36d4f7652c323546a1ab0b Mon Sep 17 00:00:00 2001 From: bclavie Date: Mon, 17 Mar 2025 03:59:04 +0000 Subject: [PATCH 1/8] live evals --- checkpoint_eval_monitor.py | 146 +++++++++ create_random_init_model.py | 127 ++++++++ generate_eval_config.py | 36 ++- run_evals.py | 624 ++++++++++++++++++------------------ 4 files changed, 615 insertions(+), 318 deletions(-) create mode 100644 checkpoint_eval_monitor.py create mode 100755 create_random_init_model.py diff --git a/checkpoint_eval_monitor.py b/checkpoint_eval_monitor.py new file mode 100644 index 00000000..8e7e03a2 --- /dev/null +++ b/checkpoint_eval_monitor.py @@ -0,0 +1,146 @@ +import os +import time +import json +import sys +import logging +import argparse +from pathlib import Path +from typing import Set +from huggingface_hub import HfApi, list_repo_files +from run_evals import _main as eval_main + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)] +) +logger = logging.getLogger("poller") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Poll a Hugging Face repo for new .pt checkpoints (with 'rank' in filename); call run_evals." + ) + parser.add_argument("--repo_id", type=str, default="PLACEHOLDER", + help="Hugging Face repo ID to monitor for new checkpoints") + parser.add_argument("--token", type=str, default=None, + help="Optional HF API token for private repos") + parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", + help="Local directory to store or download checkpoints; " + "this is passed to run_evals._main(..., checkpoints=...)") + parser.add_argument("--poll_interval", type=int, default=60, + help="How many seconds to wait between polls") + + parser.add_argument("--wandb_project", type=str, default=None, + help="Optional W&B project to pass to _main") + parser.add_argument("--wandb_entity", type=str, default=None, + help="Optional W&B entity to pass to _main") + parser.add_argument("--tasks", nargs="+", default=["mnli"], + help="Which tasks to evaluate. Will pass as a list of strings to _main.") + parser.add_argument("--seeds", nargs="+", type=int, default=[42, 314, 1234], + help="Random seeds to pass to _main") + parser.add_argument("--gpu_ids", nargs="+", type=int, default=[3,4,5], + help="Optional list of GPU IDs to use for evaluation") + parser.add_argument("--skip_generation", action="store_true", + help="If set, pass skip_generation=True to _main") + return parser.parse_args() + + +def load_processed(file_path: str = "processed_checkpoints.json") -> Set[str]: + """ + Load a set of checkpoint filenames we've already processed, so we don't re‐process them. + """ + if os.path.exists(file_path): + try: + with open(file_path, "r") as f: + return set(json.load(f)) + except Exception as e: + logger.warning(f"Could not parse {file_path}: {e}") + return set() + + +def save_processed(processed: Set[str], file_path: str = "processed_checkpoints.json"): + """ + Save a set of checkpoint filenames, so next time we skip them. + """ + try: + with open(file_path, "w") as f: + json.dump(list(processed), f) + except Exception as e: + logger.warning(f"Could not write to {file_path}: {e}") + + +def find_new_checkpoints(files_in_repo: list[str], processed: Set[str]) -> Set[str]: + """ + Return any .pt filenames containing 'rank' that are not yet in 'processed'. + E.g. 'my_run/epoch3-rank0.pt' + """ + new_ckpts = set() + for f in files_in_repo: + if f.endswith(".pt") and "rank" in f and f not in processed: + new_ckpts.add(f) + return new_ckpts + + +def poll_loop(args): + """ + Main polling loop: + - check the HF repo for new .pt files + - pass them to run_evals.programmatic_main + - record them in JSON + - sleep + """ + hf_api = HfApi(token=args.token) + processed = load_processed() + + logger.info(f"Starting poller for {args.repo_id}") + logger.info(f"Polling every {args.poll_interval} seconds.\n") + + while True: + try: + logger.info(f"Checking for new checkpoints in {args.repo_id}...") + repo_files = list_repo_files(args.repo_id, token=args.token) + new_ckpts = find_new_checkpoints(repo_files, processed) + + if not new_ckpts: + logger.info("No new checkpoints found.") + else: + for ckpt in new_ckpts: + logger.info(f"Found new checkpoint: {ckpt}") + logger.info("Calling run_evals.programmatic_main(...) on that checkpoint...") + + try: + eval_main( + checkpoints=args.checkpoint_dir, + hub_repo=args.repo_id, + hub_files=[ckpt], + hub_token=args.token, + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + tasks=args.tasks, + seeds=args.seeds, + skip_generation=args.skip_generation, + gpu_ids=args.gpu_ids, + verbose=True, + parallel=True, + ) + # Mark it processed + processed.add(ckpt) + save_processed(processed) + except Exception as e: + logger.error(f"Error running eval on {ckpt}: {e}", exc_info=True) + + except Exception as e: + logger.error(f"Error in poll loop: {e}", exc_info=True) + + logger.info(f"Sleeping {args.poll_interval} seconds...\n") + time.sleep(args.poll_interval) + + +def main(): + args = parse_args() + poll_loop(args) + + +if __name__ == "__main__": + main() diff --git a/create_random_init_model.py b/create_random_init_model.py new file mode 100755 index 00000000..ab709d29 --- /dev/null +++ b/create_random_init_model.py @@ -0,0 +1,127 @@ +import os +import torch +import yaml +import argparse +from pathlib import Path +from huggingface_hub import HfApi +from composer import Trainer +from composer.models import HuggingFaceModel +from src.flex_bert import create_flex_bert_mlm + +def parse_args(): + parser = argparse.ArgumentParser(description='Create a random init Composer model and upload to HF') + parser.add_argument('--config_path', type=str, required=True, + help='Path to the training config YAML file') + parser.add_argument('--output_dir', type=str, default='./checkpoints/random_init', + help='Directory to save the model checkpoints') + parser.add_argument('--repo_id', type=str, default='PLACEHOLDER', + help='HuggingFace repository ID to upload the model') + parser.add_argument('--token', type=str, default=None, + help='HuggingFace API token for private repos') + return parser.parse_args() + +def main(): + args = parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.config_path, 'r') as f: + config = yaml.safe_load(f) + + print(f"Creating model with config from {args.config_path}") + + model_config = config['model']['model_config'] + + valid_attention_types = ['base', 'parallel', 'rope', 'rope_parallel'] + if 'attention_layer' in model_config and model_config['attention_layer'] not in valid_attention_types: + print(f"Warning: Invalid attention_layer '{model_config['attention_layer']}', falling back to 'rope'") + model_config['attention_layer'] = 'rope' + + try: + model = create_flex_bert_mlm( + pretrained_model_name=config['model']['pretrained_model_name'], + tokenizer_name=config['tokenizer_name'], + model_config=model_config + ) + print("HF model created successfully.") + except Exception as e: + print(f"Error creating model: {e}") + print("Attempting with simplified config...") + + for key in list(model_config.keys()): + if key not in ['vocab_size', 'hidden_size', 'num_hidden_layers', + 'num_attention_heads', 'attention_layer', 'padding']: + model_config.pop(key, None) + + model_config['attention_layer'] = 'rope' + model_config['padding'] = 'unpadded' + + model = create_flex_bert_mlm( + pretrained_model_name=config['model']['pretrained_model_name'], + tokenizer_name=config['tokenizer_name'], + model_config=model_config + ) + print("HF model created with simplified config.") + + + composer_model = HuggingFaceModel( + model=model, + tokenizer=None, + use_logits=True + ) + print("Composer model created.") + + checkpoint_path = os.path.join(args.output_dir, "latest-rank0.pt") + + trainer = Trainer( + model=composer_model, + max_duration="1ba", + device="cpu" + ) + + print(f"Saving Composer checkpoint to {checkpoint_path}...") + trainer.save_checkpoint(checkpoint_path) + + config_path = os.path.join(args.output_dir, f"{Path(args.output_dir).name}.yaml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + print(f"Config saved at: {config_path}") + + if args.token: + print(f"Uploading to HuggingFace repo: {args.repo_id}") + api = HfApi(token=args.token) + + try: + api.repo_info(repo_id=args.repo_id) + print(f"Repository {args.repo_id} already exists") + except Exception: + print(f"Creating new repository: {args.repo_id}") + api.create_repo( + repo_id=args.repo_id, + private=True, + repo_type="model", + exist_ok=True + ) + print(f"Repository {args.repo_id} created successfully") + + api.upload_file( + path_or_fileobj=checkpoint_path, + path_in_repo=f"{Path(args.output_dir).name}/latest-rank0.pt", + repo_id=args.repo_id, + token=args.token + ) + + api.upload_file( + path_or_fileobj=config_path, + path_in_repo=f"{Path(args.output_dir).name}/{Path(args.output_dir).name}.yaml", + repo_id=args.repo_id, + token=args.token + ) + + print("Upload complete!") + else: + print("No HuggingFace token provided. Skipping upload.") + +if __name__ == "__main__": + main() diff --git a/generate_eval_config.py b/generate_eval_config.py index 81e480f5..4a2c80d7 100644 --- a/generate_eval_config.py +++ b/generate_eval_config.py @@ -41,6 +41,7 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option class ModelSize(str, Enum): BASE = "base" LARGE = "large" + HUGE = "huge" def get_model_defaults(model_size: ModelSize): @@ -58,6 +59,12 @@ def get_model_defaults(model_size: ModelSize): "intermediate_size": 2624, "num_attention_heads": 16, }, + "huge": { + "num_hidden_layers": 32, + "hidden_size": 1536, + "intermediate_size": 4096, + "num_attention_heads": 24, + }, } # Select the default model config based on the model_size argument @@ -225,6 +232,7 @@ def main( head_class_dropout: Annotated[float, Option(help="Classification head dropout rate", rich_help_panel="Model Options")] = 0.0, fast_ultrafeedback: Annotated[bool, Option("--fast-ultrafeedback", help="Use a shorter sequence length (1536) for the UltraFeedback eval", rich_help_panel="Task Settings")] = False, seeds: Annotated[List[int], Option(help="List of seeds to use for the eval", rich_help_panel="Task Settings")] = [1618, 42, 6033, 3145], + gpu_ids: Annotated[List[int], Option(help="List of GPU IDs to use for the eval", rich_help_panel="Task Settings")] = [0], parallel: Annotated[bool, Option("--parallel/--single", help="Run the evals in parallel on multiple GPUs or one GPU. Only use if evaluating a single checkpoint on multiple GPUs.", rich_help_panel="Task Settings")] = False, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip @@ -236,14 +244,25 @@ def main( ckpt = checkpoint.name # checkpoint ckpt_path = checkpoint.parent elif checkpoint.is_dir(): - ckpts = list(checkpoint.glob("*.pt")) + # Search recursively for checkpoint files + ckpts = list(checkpoint.glob("**/*.pt")) if len(ckpts) == 1: ckpt = ckpts[0].name + ckpt_path = ckpts[0].parent elif len(ckpts) > 1: - ckpt = "latest-rank0.pt" + # Look for latest-rank0.pt in any subfolder + latest_ckpts = list(checkpoint.glob("**/latest-rank0.pt")) + if latest_ckpts: + ckpt = latest_ckpts[0].name + ckpt_path = latest_ckpts[0].parent + else: + # Default to first checkpoint found + ckpt = ckpts[0].name + ckpt_path = ckpts[0].parent elif len(ckpts) == 0: - raise ValueError(f"No checkpoint found in the provided directory: {checkpoint}") - ckpt_path = checkpoint + raise ValueError(f"No checkpoint found in the provided directory or its subdirectories: {checkpoint}") + else: + ckpt_path = checkpoint else: raise ValueError(f"Invalid checkpoint path provided: {checkpoint}") @@ -388,7 +407,12 @@ def main( elif task_name == "mnli": task_config["seeds"] = seeds[:3] - task_config["trainer_kwargs"] = {"save_num_checkpoints_to_keep": 1, "max_duration": "2ep"} + task_config["trainer_kwargs"] = { + "save_num_checkpoints_to_keep": 1, + "max_duration": "2ep", + "batch_size": 2, + "device_train_microbatch_size": "auto", + } elif task_name == "boolq": task_config["seeds"] = seeds[:3] @@ -417,6 +441,8 @@ def main( task_config["seeds"] = seeds[:3] tasks_dict[task_name] = task_config + task_config["gpu_ids"] = gpu_ids + new_config["tasks"] = tasks_dict # Write the new configuration to a YAML file diff --git a/run_evals.py b/run_evals.py index 0585baf6..3621844a 100644 --- a/run_evals.py +++ b/run_evals.py @@ -13,7 +13,7 @@ from enum import Enum from multiprocessing import Process, Queue from pathlib import Path -from typing import Annotated, List, Optional +from typing import Annotated, List, Optional, Union import datasets import psutil @@ -30,18 +30,16 @@ warnings.simplefilter("ignore", category=FutureWarning) from eval import GLUE_TASKS, SUPERGLUE_TASKS, TASK_NAME_TO_CLASS - # Create TaskName enum dynamically from TASK_NAME_TO_CLASS keys TaskName = Enum("TaskName", {name: name for name in TASK_NAME_TO_CLASS.keys()}, type=str) - app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}, pretty_exceptions_show_locals=False) class ModelSize(str, Enum): BASE = "base" LARGE = "large" - + HUGE = "huge" # from maxb2: https://github.com/tiangolo/typer/issues/86#issuecomment-996374166 def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Optional[str] = None): @@ -57,19 +55,14 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option return config -# Global dictionary to keep track of GPUs with running jobs -# Changed to store more information per GPU gpus_in_use = {} -# Queue to keep track of GPUs that might be free potentially_free_gpus = deque() -# Global list to keep track of all running processes all_processes = [] +allowed_gpus = None -# Global list to specify which GPUs to use -allowed_gpus = None # Will be set to list of GPU IDs or None - +console = Console() -def kill_process_tree(pid): +def kill_process_tree(pid: int): try: parent = psutil.Process(pid) children = parent.children(recursive=True) @@ -86,49 +79,48 @@ def kill_process_tree(pid): def signal_handler(signum, frame): print("\nReceived termination signal. Cleaning up subprocesses...") - for process in all_processes: - if process.poll() is None: # If the process is still running - kill_process_tree(process.pid) - + for proc in all_processes: + if proc.poll() is None: + kill_process_tree(proc.pid) print("Cleanup completed. Exiting.") - os._exit(0) # Force exit without running cleanup handlers + os._exit(0) -def get_gpu_memory_usage(gpu_id): - """Get memory usage for a specific GPU.""" +def get_gpu_memory_usage(gpu_id: int) -> Optional[int]: try: - output = ( + out = ( subprocess.check_output( - f"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader -i {gpu_id}", shell=True + f"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader -i {gpu_id}", + shell=True ) .decode("utf-8") .strip() ) - return int(output) + return int(out) except subprocess.CalledProcessError: print(f"Failed to get memory usage for GPU {gpu_id}") return None -def get_free_gpu(): - """Check for free GPUs, prioritizing potentially free GPUs.""" +def get_free_gpu() -> Optional[int]: global allowed_gpus while potentially_free_gpus: gpu_id = potentially_free_gpus.popleft() if (allowed_gpus is None or gpu_id in allowed_gpus) and gpu_id not in gpus_in_use: - memory_used = get_gpu_memory_usage(gpu_id) - if memory_used is not None and memory_used < 100: + used = get_gpu_memory_usage(gpu_id) + if used is not None and used < 100: return gpu_id - # If no potentially free GPUs, check allowed GPUs try: gpu_output = subprocess.check_output( "nvidia-smi --query-gpu=index,memory.used --format=csv,nounits,noheader", shell=True ).decode("utf-8") for line in gpu_output.strip().split("\n"): - gpu_id, memory_used = map(int, line.split(",")) - if (allowed_gpus is None or gpu_id in allowed_gpus) and memory_used < 100 and gpu_id not in gpus_in_use: - return gpu_id + g_id_str, mem_str = line.split(",") + g_id_int = int(g_id_str) + mem_used = int(mem_str) + if (allowed_gpus is None or g_id_int in allowed_gpus) and mem_used < 100 and g_id_int not in gpus_in_use: + return g_id_int return None except subprocess.CalledProcessError: print("Failed to execute nvidia-smi") @@ -137,46 +129,37 @@ def get_free_gpu(): def run_subprocess(cmd: List[str], verbose: bool = False, show_errors: bool = False): stdout = None if verbose else subprocess.DEVNULL - stderr = None if verbose or show_errors else subprocess.DEVNULL - process = subprocess.Popen(cmd, stdout=stdout, stderr=stderr) - all_processes.append(process) # Add the process to the global list - process.wait() + stderr = None if (verbose or show_errors) else subprocess.DEVNULL + proc = subprocess.Popen(cmd, stdout=stdout, stderr=stderr) + all_processes.append(proc) + proc.wait() def handle_process_completion(process, stderr_file, config_path: Path, verbose: bool, gpu_id: Optional[int] = None): - """Handles the completion of a process, checks for errors, cleans up stderr_file, and logs messages.""" - returncode = process.returncode - - # Read and clean up stderr output + code = process.returncode if stderr_file is not None: stderr_file.seek(0) - error_output = stderr_file.read() + error_out = stderr_file.read() stderr_file.close() - os.unlink(stderr_file.name) # Delete the temp file + os.unlink(stderr_file.name) else: - error_output = "Error output was displayed above." + error_out = "Error output was displayed above." - # Construct job identifier - if gpu_id is not None: - job_identifier = f"Job on GPU {gpu_id} for {config_path.name}" - else: - job_identifier = f"Job for {config_path.name}" + job_label = f"Job for {config_path.name}" if gpu_id is None else f"Job on GPU {gpu_id} for {config_path.name}" - if returncode != 0: - # The process exited with an error + if code != 0: if verbose: - print(f"{job_identifier} failed with return code {returncode}") + print(f"{job_label} failed with return code {code}") print("Error Output:") - print(error_output) + print(error_out) else: - console.print(f"[red]{job_identifier} failed with return code {returncode}[/red]") - console.print(f"[red]Error Output:[/red]\n{error_output}") + console.print(f"[red]{job_label} failed with return code {code}[/red]") + console.print(f"[red]Error Output:[/red]\n{error_out}") else: - # The process completed successfully if verbose: - print(f"{job_identifier} has finished successfully.") + print(f"{job_label} finished successfully.") else: - console.log(f"{job_identifier} has finished successfully.") + console.log(f"{job_label} has finished successfully.") def run_job( @@ -186,19 +169,17 @@ def run_job( gpu_id: Optional[int] = None, gpu_ids: Optional[List[int]] = None, ): - """Run a job with optional GPU management.""" if gpu_id is not None: - # GPU management is required env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) elif gpu_ids is not None: env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids)) else: - env = None # Use default environment + env = None if verbose: - stdout = None # Output will be shown directly + stdout = None stderr = None stderr_file = None else: @@ -206,69 +187,59 @@ def run_job( stderr_file = tempfile.NamedTemporaryFile(mode="w+", delete=False) stderr = stderr_file - process = subprocess.Popen(["python", "eval.py", str(config_path)], env=env, stdout=stdout, stderr=stderr) - all_processes.append(process) # Add the process to the global list + proc = subprocess.Popen(["python", "eval.py", str(config_path)], env=env, stdout=stdout, stderr=stderr) + all_processes.append(proc) if gpu_id is not None: - # Store process info for GPU management - gpus_in_use[gpu_id] = {"process": process, "stderr_file": stderr_file, "config": config_path} - - if gpu_id is None: - process.wait() - handle_process_completion(process, stderr_file, config_path, verbose, gpu_id=None) + gpus_in_use[gpu_id] = {"process": proc, "stderr_file": stderr_file, "config": config_path} + else: + proc.wait() + handle_process_completion(proc, stderr_file, config_path, verbose, gpu_id=None) if delete_eval_yamls: config_path.unlink() - return process + return proc def check_finished_jobs(verbose: bool = False): - """Check for finished jobs and free up their GPUs.""" - finished_gpus = [] + done_gpus = [] for gpu_id, info in gpus_in_use.items(): process = info["process"] stderr_file = info["stderr_file"] config = info["config"] - if process.poll() is not None: # Job has finished - # Handle process completion + if process.poll() is not None: handle_process_completion(process, stderr_file, config, verbose, gpu_id=gpu_id) - finished_gpus.append(gpu_id) + done_gpus.append(gpu_id) - for gpu_id in finished_gpus: - del gpus_in_use[gpu_id] - potentially_free_gpus.append(gpu_id) + for g in done_gpus: + del gpus_in_use[g] + potentially_free_gpus.append(g) def manage_jobs(configs: List[Path], verbose: bool = False, delete_eval_yamls: bool = True): - """Manage the launching of jobs for each configuration file in the directory.""" - if verbose: - for config in configs: + for cfg in configs: while True: check_finished_jobs(verbose) - gpu_id = get_free_gpu() - if gpu_id is not None: + free = get_free_gpu() + if free is not None: time.sleep(random.randint(0, 5)) - print(f"\nLaunching job for {config} on GPU {gpu_id}\n") - run_job(config, gpu_id=gpu_id, verbose=verbose, delete_eval_yamls=delete_eval_yamls) + print(f"\nLaunching job for {cfg} on GPU {free}\n") + run_job(cfg, gpu_id=free, verbose=verbose, delete_eval_yamls=delete_eval_yamls) break else: time.sleep(10) - - # Wait for all remaining jobs to finish while gpus_in_use: check_finished_jobs(verbose) time.sleep(10) else: - def update_progress_for_finished_jobs(): - """Update progress bars for any finished GPU jobs.""" - for gpu_id, info in list(gpus_in_use.items()): - process = info["process"] - if process.poll() is not None: # Job finished - if gpu_id in gpu_tasks: - gpu_progress.update(gpu_tasks[gpu_id], completed=1, visible=False) + for gpuid, info in list(gpus_in_use.items()): + prc = info["process"] + if prc.poll() is not None: + if gpuid in gpu_tasks: + gpu_progress.update(gpu_tasks[gpuid], completed=1, visible=False) completed_configs.add(info["config"]) overall_progress.update(overall_task, completed=len(completed_configs)) @@ -279,41 +250,40 @@ def update_progress_for_finished_jobs(): TextColumn("[progress.percentage]{task.completed}/{task.total}"), TimeElapsedColumn(), ) - gpu_progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), TimeElapsedColumn() ) progress_group = Group( Panel(overall_progress, title="Overall Progress", border_style="blue", padding=(1, 1)), - Panel(gpu_progress, title="GPU Jobs", border_style="green", padding=(1, 1)), + Panel(gpu_progress, title="GPU Jobs", border_style="green", padding=(1, 1)) ) with Live(progress_group, console=console, refresh_per_second=4): overall_task = overall_progress.add_task("[cyan]Overall Progress", total=len(configs)) gpu_tasks = {} - completed_configs = set() # Track completed configs + completed_configs = set() - for config in configs: + for cfg in configs: while True: check_finished_jobs(verbose) update_progress_for_finished_jobs() - gpu_id = get_free_gpu() - if gpu_id is not None: + free = get_free_gpu() + if free is not None: time.sleep(random.randint(0, 5)) - if gpu_id not in gpu_tasks: - gpu_tasks[gpu_id] = gpu_progress.add_task(f"[green]GPU {gpu_id}", total=1) + if free not in gpu_tasks: + gpu_tasks[free] = gpu_progress.add_task(f"[green]GPU {free}", total=1) else: - gpu_progress.update(gpu_tasks[gpu_id], completed=1, visible=False) - gpu_tasks[gpu_id] = gpu_progress.add_task(f"[green]GPU {gpu_id}", total=1) - gpu_progress.update(gpu_tasks[gpu_id], description=f"[green]GPU {gpu_id}: {config.name}") - run_job(config, gpu_id=gpu_id, verbose=verbose, delete_eval_yamls=delete_eval_yamls) + gpu_progress.update(gpu_tasks[free], completed=1, visible=False) + gpu_tasks[free] = gpu_progress.add_task(f"[green]GPU {free}", total=1) + + gpu_progress.update(gpu_tasks[free], description=f"[green]GPU {free}: {cfg.name}") + run_job(cfg, gpu_id=free, verbose=verbose, delete_eval_yamls=delete_eval_yamls) break else: time.sleep(10) - # Wait for all remaining jobs to finish while gpus_in_use: check_finished_jobs(verbose) update_progress_for_finished_jobs() @@ -322,62 +292,56 @@ def update_progress_for_finished_jobs(): overall_progress.update(overall_task, completed=len(configs)) if delete_eval_yamls: - for config in configs: + for c in configs: try: - config.unlink() + c.unlink() except FileNotFoundError: pass def create_symlink_for_newest_checkpoint(folder: Path, override_existing: bool = False): - """Create a symlink to the newest checkpoint file if 'latest-rank0.pt' does not exist.""" - if folder.is_dir(): - pt_files = list(folder.glob("*.pt")) - if not pt_files: - print(f" Warning: No .pt file found in {folder}.") + if not folder.is_dir(): + return + + pt_files = list(folder.glob("*.pt")) + if not pt_files: + print(f" Warning: No .pt file found in {folder}, skipping symlink creation.") + return + + if len(pt_files) == 1 and pt_files[0].name == "latest-rank0.pt" and not pt_files[0].is_symlink(): + print(f" Only found one .pt in {folder.name}, named 'latest-rank0.pt' (real file). Skipping symlink creation.") + return + + def extract_nums(fp: Path): + m = re.search(r"ep(\d+)-ba(\d+)", fp.stem) + if m: + ep, ba = map(int, m.groups()) + return (ep, ba) + return (0, 0) + + newest_file = max(pt_files, key=extract_nums) + symlink_path = folder / "latest-rank0.pt" + + if symlink_path.is_symlink(): + if symlink_path.resolve() == newest_file.resolve(): + print(f" Existing symlink in {folder.name} already points to {newest_file.name}") return - - # Sort files based on epoch and batch numbers extracted from filenames - def extract_numbers(filename: Path): - if filename.is_symlink(): - return (0, 0) - if filename.name == "latest-rank0.pt": - return (0, 0) - - try: - # Using regex to find patterns of 'ep' followed by digits and 'ba' followed by digits - match = re.search(r"ep(\d+)-ba(\d+)", filename.stem) - if match: - epoch, batch = map(int, match.groups()) - return (epoch, batch) - else: - raise ValueError(f"Filename does not match expected pattern: {filename}") - except Exception as e: - print(f" Error extracting numbers from filename {filename}: {e}") - return (0, 0) - - newest_file = max(pt_files, key=extract_numbers) - - symlink_path = folder / "latest-rank0.pt" - if symlink_path.exists() and symlink_path.is_symlink(): - if symlink_path.resolve() == newest_file.resolve(): - print(f" Existing symlink points to latest checkpoint: {newest_file.parent.name}/{newest_file.name}") + else: + print(f" Warning: symlink in {folder.name} points to {symlink_path.resolve().name}, but newest is {newest_file.name}") + if not override_existing: return - else: - print( - f" Warning: Existing symlink points to {symlink_path.parent.name}/{symlink_path.name}, " - f"but latest checkpoint is {newest_file.parent.name}/{newest_file.name}" - ) - if not override_existing: - return + symlink_path.unlink(missing_ok=True) + elif symlink_path.exists(): + if not override_existing: + print(f" {symlink_path.name} is a real file in {folder.name}. Use override to remove it.") + return + symlink_path.unlink(missing_ok=True) - symlink_path.symlink_to(newest_file.name) - if override_existing: - print( - f" Overwriting existing symlink with {symlink_path.parent.name}/{symlink_path.name} -> {newest_file.name}" - ) - else: - print(f" Created new symlink {symlink_path.parent.name}/{symlink_path.name} -> {newest_file.name}") + symlink_path.symlink_to(newest_file.name) + if override_existing: + print(f" Overwrote symlink {symlink_path.name} -> {newest_file.name}") + else: + print(f" Created new symlink {symlink_path.name} -> {newest_file.name}") def generate_eval_configs( @@ -392,38 +356,30 @@ def generate_eval_configs( head_class_act: Optional[str], head_class_norm: Optional[str], head_class_dropout: float, - tasks: Optional[List[TaskName]], # type: ignore + tasks: Optional[List[Union[TaskName, str]]], fast_ultrafeedback: bool, seeds: List[int], parallel: bool, use_dir_names: Optional[bool], model_size: ModelSize, rope_theta: Optional[float], + gpu_ids: Optional[List[int]] = None, ): - """Generate evaluation configs for each checkpoint.""" - folders = [ - folder - for folder in checkpoints.glob("*") - if folder.is_dir() - and not folder.name.startswith(".") - and any(file.suffix == ".pt" for file in folder.glob("*.pt")) + f + for f in checkpoints.glob("*") + if f.is_dir() and not f.name.startswith(".") and any(x.suffix == ".pt" for x in f.glob("*.pt")) ] if use_dir_names is None and len(folders) > 1: use_dir_names = True - print("Using folder names as run names since multiple `checkpoints` were provided with one `train_config`.") + print("Using folder names as run names since multiple checkpoint folders found.") for folder in folders: cmd = [ - "python", - "generate_eval_config.py", - "--checkpoint", - str(folder), - "--output-dir", - str(checkpoints), + "python", "generate_eval_config.py", + "--checkpoint", str(folder), + "--output-dir", str(checkpoints), ] - - # Add optional arguments if they're provided if use_dir_names: cmd.append("--use-dir-name") if model_size: @@ -443,7 +399,6 @@ def generate_eval_configs( if track_run_project: cmd.extend(["--track-run-project", track_run_project]) - # Classification head options if pooling_type: cmd.extend(["--pooling-type", pooling_type]) if head_class_act: @@ -453,10 +408,13 @@ def generate_eval_configs( if head_class_dropout > 0: cmd.extend(["--head-class-dropout", str(head_class_dropout)]) - # Add tasks + # Handle tasks as either TaskName or str if tasks: for task in tasks: - cmd.extend(["--tasks", task.value]) + if hasattr(task, "value"): + cmd.extend(["--tasks", task.value]) + else: + cmd.extend(["--tasks", str(task)]) if fast_ultrafeedback: cmd.append("--fast-ultrafeedback") @@ -464,9 +422,15 @@ def generate_eval_configs( for seed in seeds: cmd.extend(["--seeds", str(seed)]) - cmd.append("--parallel") if parallel else cmd.append("--single") + if parallel: + cmd.append("--parallel") + else: + cmd.append("--single") + + if gpu_ids: + if isinstance(gpu_ids, int): gpu_ids = [gpu_ids] + for g in gpu_ids: cmd.extend(["--gpu-ids", str(g)]) - # Run the config generation process without suppressing output run_subprocess(cmd, show_errors=True) if not train_config: time.sleep(1) @@ -480,62 +444,55 @@ def download_dataset(dataset_name: str, subset: Optional[str] = None): return f"Error in processing {dataset_name}: {e}" -def download_datasets(tasks: List[TaskName], msg_queue): # type: ignore +def download_datasets(tasks: List[Union[TaskName, str]], msg_queue): try: required_datasets = [] - task_to_datasets = { "mlmmlu_amateur_semipro": [["answerdotai/MLMMLU", "Amateur"], ["answerdotai/MLMMLU", "Semipro"]], "mlmmlu_rookie_reserve": [["answerdotai/MLMMLU", "Rookie"], ["answerdotai/MLMMLU", "Reserve"]], "eurlex": [["coastalcph/lex_glue", "eurlex"]], "ultrafeedback": [["rbiswasfc/ultrafeedback-binary-classification"]], } + for t in tasks: + if hasattr(t, "value"): + task_val = t.value + else: + task_val = str(t) - for task in tasks: - if task.value in GLUE_TASKS: - datasets_info = [["glue", task.value]] - elif task.value in SUPERGLUE_TASKS: - datasets_info = [["aps/super_glue", task.value]] + if task_val in GLUE_TASKS: + required_datasets.append(["glue", task_val]) + elif task_val in SUPERGLUE_TASKS: + required_datasets.append(["aps/super_glue", task_val]) else: - datasets_info = task_to_datasets.get(task.value, []) - required_datasets.extend(datasets_info) + extras = task_to_datasets.get(task_val, []) + required_datasets.extend(extras) - # Suppress output globally in this process import sys - sys.stdout = open(os.devnull, "w") sys.stderr = open(os.devnull, "w") msgs = [] - for dataset_name, subset in required_datasets: - datasets.load_dataset(dataset_name, subset) - msgs.append(f"Successfully downloaded {dataset_name} {subset}") - msg_queue.put(" " + "\n ".join(msgs) + "\n") + for ds_name, subset in required_datasets: + datasets.load_dataset(ds_name, subset, trust_remote_code=True) + msgs.append(f"Successfully downloaded {ds_name} {subset}") + msg_queue.put("\n ".join([""] + msgs)) except Exception as e: msg_queue.put(f"Error in downloading datasets: {e}") def find_checkpoint_file(file_path: str, repo_files: List[str]) -> Optional[str]: import re + valid = [f for f in repo_files if f.startswith(file_path) and f.endswith((".pt", ".yaml"))] + if len(valid) == 1: + return valid[0] - # Filter files in the specified file_path that end with .pt or .yaml - valid_files = [file for file in repo_files if file.startswith(file_path) and file.endswith((".pt", ".yaml"))] + def extract_nums(fn: str): + m = re.search(r"ep(\d+)-ba(\d+)", fn) + if m: + return tuple(map(int, m.groups())) + return (-1, -1) - if len(valid_files) == 1: - return valid_files[0] - - # Function to extract epoch and batch numbers from the filename - def extract_numbers(filename: str): - match = re.search(r"ep(\d+)-ba(\d+)", filename) - if match: - epoch, batch = map(int, match.groups()) - return epoch, batch - return -1, -1 # Return a default value for files that don't match the pattern - - # Find the newest file based on epoch and batch numbers - newest_file = max(valid_files, key=extract_numbers, default=None) - - return newest_file + return max(valid, key=extract_nums, default=None) def download_hub_files( @@ -545,41 +502,29 @@ def download_hub_files( repo_type: str = "model", token: Optional[str] = None, ) -> List[Path]: - """Download specific files or the entire repository from a Hugging Face Hub repository.""" output_dir.mkdir(parents=True, exist_ok=True) - downloaded_files = [] + downloaded_files: List[Path] = [] def move_and_flatten_files(local_dir: Path): - for file_path in local_dir.rglob("*"): - if file_path.is_file() and file_path.name.endswith((".pt", ".yaml")): - # Determine the target directory - target_dir = output_dir / file_path.parent.name - - # Check if the file is already in the correct location - if file_path.parent.resolve() in [target_dir.resolve(), output_dir.resolve()]: - downloaded_files.append(file_path) + for fp in local_dir.rglob("*"): + if fp.is_file() and fp.name.endswith((".pt", ".yaml")): + target_dir = output_dir / fp.parent.name + if fp.parent.resolve() in [target_dir.resolve(), output_dir.resolve()]: + downloaded_files.append(fp) continue - - # Create the target directory if it doesn't exist target_dir.mkdir(parents=True, exist_ok=True) - # Move the file to the target directory - new_path = target_dir / file_path.name - file_path.rename(new_path) + new_path = target_dir / fp.name + fp.rename(new_path) downloaded_files.append(new_path) - # List all files in the repository api = HfApi() repo_files = api.list_repo_files(repo_id=repo_id, repo_type=repo_type, token=token) - try: if not filenames: - # Check if files already exist before downloading entire repository existing_files = list(output_dir.glob("**/*.pt")) + list(output_dir.glob("**/*.yaml")) if existing_files: - print(f"Found existing files in '{output_dir}', skipping download.") + print(f"Found existing files in '{output_dir}', skipping snapshot_download.") return existing_files - - # Download the entire repository local_dir = snapshot_download( repo_id=repo_id, repo_type=repo_type, @@ -589,25 +534,22 @@ def move_and_flatten_files(local_dir: Path): use_auth_token=token, ) move_and_flatten_files(Path(local_dir)) - print(f"Successfully downloaded and flattened the repository '{repo_id}' to '{output_dir}'.") + print(f"Successfully downloaded entire repo '{repo_id}' -> '{output_dir}'.") else: - for filename in filenames: - resolved_filename = find_checkpoint_file(filename, repo_files) + for fn in filenames: + resolved_filename = find_checkpoint_file(fn, repo_files) if not resolved_filename: - print(f"Warning: Could not find matching file for '{filename}' in repository.") + print(f"Warning: No match for '{fn}' in {repo_id}.") continue - - # Check if file exists in output_dir or any immediate subdirectory - filename = Path(resolved_filename).name + just_name = Path(resolved_filename).name parent_dir = Path(resolved_filename).parent.name - existing_files = list(output_dir.glob(f"**/{parent_dir}/{filename}")) - if existing_files: - existing_file = existing_files[0] - print(f"File '{parent_dir}/{filename}' already exists at '{existing_file}', skipping download.") + existing_fs = list(output_dir.glob(f"**/{parent_dir}/{just_name}")) + if existing_fs: + existing_file = existing_fs[0] + print(f"File '{parent_dir}/{just_name}' already exists at '{existing_file}', skip download.") downloaded_files.append(existing_file) continue - # Download the file _ = hf_hub_download( repo_id=repo_id, filename=resolved_filename, @@ -616,7 +558,7 @@ def move_and_flatten_files(local_dir: Path): local_dir=output_dir, cache_dir=None, ) - print(f"Successfully downloaded '{resolved_filename}' from '{repo_id}'.") + print(f"Downloaded '{resolved_filename}' from '{repo_id}'.") move_and_flatten_files(output_dir) except Exception as e: print(f"Error downloading from '{repo_id}': {e}") @@ -624,81 +566,77 @@ def move_and_flatten_files(local_dir: Path): return downloaded_files -console = Console() - +def _main( + checkpoints: Union[str, Path], + train_config: Optional[Union[str, Path]] = None, + model_size: ModelSize = ModelSize.BASE, + rope_theta: Optional[float] = None, + skip_generation: bool = False, + run_all_yamls: bool = False, + tasks: Optional[List[Union[str, TaskName]]] = None, + hub_repo: Optional[str] = None, + hub_files: Optional[List[str]] = None, + hub_token: Optional[str] = None, + wandb_run: Optional[str] = None, + wandb_project: Optional[str] = None, + wandb_entity: Optional[str] = None, + track_run: bool = False, + track_run_project: Optional[str] = None, + pooling_type: Optional[str] = None, + head_class_act: Optional[str] = None, + head_class_norm: Optional[str] = None, + head_class_dropout: float = 0.0, + fast_ultrafeedback: bool = False, + seeds: List[int] = [1618, 42, 6033, 3145], + verbose: bool = False, + overwrite_existing_symlinks: bool = False, + parallel: bool = False, + delete_eval_yamls: bool = False, + use_dir_names: Optional[bool] = None, + gpu_ids: Optional[List[int]] = None, + config: Optional[Union[str, Path]] = None, +): + if isinstance(checkpoints, str): + checkpoints = Path(checkpoints) + if isinstance(train_config, str): + train_config = Path(train_config) -@app.command() -def main( - checkpoints: Annotated[Path, Option(help="Path to the directory containing FlexBert checkpoints or location to download checkpoints from Hugging Face Hub to", rich_help_panel="Checkpoint & Config Paths", show_default=False)], - train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, - model_size: Annotated[ModelSize, Option("--model-size", help="Model to use for default model config", rich_help_panel="Checkpoint & Config Paths")] = ModelSize.BASE, - rope_theta: Annotated[Optional[float], Option("--rope-theta", help="Value for `rotary_emb_base` in the model configuration. If not provided, defaults to pretraining value of 10000.0", rich_help_panel="Checkpoint & Config Paths")] = None, - skip_generation: Annotated[bool, Option("--skip-generation", help="Skip generation of evaluation configs. If not true, assumes all existing eval yamls have been already ran.", rich_help_panel="Checkpoint & Config Paths")] = False, - run_all_yamls: Annotated[bool, Option("--run-all-yamls", help="Run all evaluation yamls in the `checkpoints` directory, even if some have already been run.", rich_help_panel="Checkpoint & Config Paths")] = False, - tasks: Annotated[Optional[List[TaskName]], Option(help="List of tasks to include in the evaluation. Default is all tasks.", rich_help_panel="Eval Tasks", case_sensitive=False, show_default=False)] = None, # type: ignore - hub_repo: Annotated[Optional[str], Option(help="Hugging Face Hub repository ID to download FlexBert weights. Downloads to `checkpoints` directory.", rich_help_panel="Hugging Face Download")] = None, - hub_files: Annotated[Optional[List[str]], Option(help="List of files to download from the `hub_repo`. If not provided, will download all files in the repo.", rich_help_panel="Hugging Face Download")] = None, - hub_token: Annotated[Optional[str], Option(help="Authentication token for private Hugging Face Hub repositories if not already logged in via `huggingface-cli login`", rich_help_panel="Hugging Face Download")] = None, - wandb_run: Annotated[Optional[str], Option(help="wandb run containing the training configuration", rich_help_panel="Weights & Biases")] = None, - wandb_project: Annotated[Optional[str], Option(help="wandb project for the run", rich_help_panel="Weights & Biases")] = None, - wandb_entity: Annotated[Optional[str], Option(help="wandb entity for the project", rich_help_panel="Weights & Biases")] = None, - track_run: Annotated[bool, Option("--track-run", help="Track the eval run with wandb", rich_help_panel="Weights & Biases")] = False, - track_run_project: Annotated[Optional[str], Option(help="wandb project for tracking the run", rich_help_panel="Weights & Biases")] = None, - pooling_type: Annotated[Optional[str], Option(help="Pooling type for the classification head", show_default=False, rich_help_panel="Model Options")] = None, - head_class_act: Annotated[Optional[str], Option(help="Classification head activation function", show_default=False, rich_help_panel="Model Options")] = None, - head_class_norm: Annotated[Optional[str], Option(help="Classification head normalization function", show_default=False, rich_help_panel="Model Options")] = None, - head_class_dropout: Annotated[float, Option(help="Classification head dropout rate", rich_help_panel="Model Options")] = 0.0, - fast_ultrafeedback: Annotated[bool, Option("--fast-ultrafeedback", help="Use a shorter sequence length (1536) for the UltraFeedback eval", rich_help_panel="Task Settings")] = False, - seeds: Annotated[List[int], Option(help="List of seeds to use for the eval", rich_help_panel="Task Settings")] = [1618, 42, 6033, 3145], - verbose: Annotated[bool, Option("-v", "--verbose", help="Show detailed output from evaluation jobs", rich_help_panel="Config Options")] = False, - overwrite_existing_symlinks: Annotated[bool, Option("--override-existing-symlinks", help="Overwrite existing symlinks to point to latest checkpoint", rich_help_panel="Config Options")] = False, - parallel: Annotated[bool, Option("--parallel/--single", help="Run the evals in parallel on multiple GPUs or one GPU. Use `parallel` if passing to `config`. Only use if evaluating a single checkpoint on multiple GPUs.", rich_help_panel="Task Settings")] = False, - delete_eval_yamls: Annotated[bool, Option("--delete/--keep", help="Delete all evaluation YAML files after running the evals. Use `delete_eval_yamls` if passing to `config`", rich_help_panel="Config Options")] = False, - use_dir_names: Annotated[Optional[bool], Option("--use-dir-names", help="Use the folder names as the wandb run names. Defaults to true if multiple `checkpoints` are provided with one `train_config`", rich_help_panel="Config Options")] = None, - gpu_ids: Annotated[Optional[List[int]], Option(help="List of GPU IDs to use", rich_help_panel="GPU Options")] = None, - config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, -): # fmt: skip - """Run evaluations on model checkpoints.""" - - # Set the allowed_gpus global variable global allowed_gpus - if gpu_ids is not None: - allowed_gpus = gpu_ids - else: - allowed_gpus = None # Use all GPUs + allowed_gpus = gpu_ids if hub_repo: - print(f"\nDownloading files from {hub_repo}...") + print(f"\nDownloading from {hub_repo} to {checkpoints} ...") downloaded_files = download_hub_files( - repo_id=hub_repo, filenames=hub_files, output_dir=checkpoints, token=hub_token + repo_id=hub_repo, + filenames=hub_files, + output_dir=checkpoints, + token=hub_token ) if not downloaded_files: print("No files were downloaded successfully. Exiting.") raise Exit(code=1) print(f"Successfully downloaded {len(downloaded_files)} files to {checkpoints}") - # Set default tasks to all tasks if not provided - all_tasks = [task for task in TaskName] - tasks = tasks or all_tasks + if not tasks or len(tasks) == 0: + tasks = [t for t in TaskName] print("\nAsynchronously downloading required datasets...") msg_queue = Queue() download_process = Process(target=download_datasets, args=(tasks, msg_queue)) download_process.start() - print("\nCreating symlinks for latest checkpoints...") + print("\nCreating symlinks for newest checkpoints...") for folder in checkpoints.glob("*"): if folder.is_dir() and not folder.name.startswith("."): create_symlink_for_newest_checkpoint(folder, overwrite_existing_symlinks) if not skip_generation: print("\nGenerating evaluation configs...\n") - if not run_all_yamls: config_files_completed = list(checkpoints.glob("*_evaluation.yaml")) - print("Skipping Completed Jobs (delete yamls to run):") - for config in config_files_completed: - print(f" {config.name}\n") + print("Skipping Completed Jobs (delete yamls to re-run):") + for c in config_files_completed: + print(f" {c.name}\n") else: config_files_completed = [] @@ -721,38 +659,37 @@ def main( use_dir_names=use_dir_names, model_size=model_size, rope_theta=rope_theta, + gpu_ids=gpu_ids, ) - config_files = list(checkpoints.glob("*_evaluation.yaml")) - config_files = sorted(list(set(config_files) - set(config_files_completed))) + config_files = list(set(checkpoints.glob("*_evaluation.yaml")) - set(config_files_completed)) + config_files = sorted(config_files) else: config_files = list(checkpoints.glob("*_evaluation.yaml")) print("Jobs to be run:") - for config in config_files: - print(f" {config.name}\n") + for cfg in config_files: + print(f" {cfg.name}\n") - # Wait for the dataset download to complete print("Waiting for dataset downloads to complete...") download_process.join() print("\nDataset downloading complete.") while not msg_queue.empty(): print(msg_queue.get()) - if len(config_files) >= 1 and parallel is False: - manage_jobs(configs=config_files, verbose=verbose, delete_eval_yamls=delete_eval_yamls) - elif len(config_files) > 1 and parallel is True: - raise ValueError(f"{parallel=} is only supported for running one config at a time.") - elif len(config_files) == 1 and parallel is True: + if len(config_files) >= 1 and not parallel: + manage_jobs(config_files, verbose=verbose, delete_eval_yamls=delete_eval_yamls) + elif len(config_files) > 1 and parallel: + raise ValueError("Parallel runs only supported for a single config at a time.") + elif len(config_files) == 1 and parallel: if not verbose: - console.print(f"[bold green]Running {config_files[0].name} in parallel on GPUs {', '.join(map(str, gpu_ids))}") # fmt: skip + console.print(f"[bold green]Running {config_files[0].name} in parallel on GPUs: {gpu_ids}") run_job(config_files[0], verbose=verbose, delete_eval_yamls=delete_eval_yamls, gpu_ids=gpu_ids) else: - message = "No configuration files found in the specified directory." + msg = "No evaluation config (.yaml) files found." if verbose: - print(message) + print(msg) else: - console.print(f"[bold red]{message}") - + console.print(f"[bold red]{msg}") raise Exit(code=1) if verbose: @@ -761,7 +698,69 @@ def main( console.print("[bold green]All jobs completed.") -# Register the signal handler +@app.command() +def main( + checkpoints: Annotated[Path, Option("--checkpoints", help="Directory for model checkpoints.")], + train_config: Annotated[Optional[Path], Option(help="Path to .yaml config")] = None, + model_size: Annotated[ModelSize, Option("--model-size")] = ModelSize.BASE, + rope_theta: Annotated[Optional[float], Option("--rope-theta")] = None, + skip_generation: Annotated[bool, Option("--skip-generation")] = False, + run_all_yamls: Annotated[bool, Option("--run-all-yamls")] = False, + tasks: Annotated[Optional[List[TaskName]], Option(help="Tasks")] = None, + hub_repo: Annotated[Optional[str], Option("--hub-repo")] = None, + hub_files: Annotated[Optional[List[str]], Option("--hub-files")] = None, + hub_token: Annotated[Optional[str], Option("--hub-token")] = None, + wandb_run: Annotated[Optional[str], Option("--wandb-run")] = None, + wandb_project: Annotated[Optional[str], Option("--wandb-project")] = None, + wandb_entity: Annotated[Optional[str], Option("--wandb-entity")] = None, + track_run: Annotated[bool, Option("--track-run")] = False, + track_run_project: Annotated[Optional[str], Option("--track-run-project")] = None, + pooling_type: Annotated[Optional[str], Option("--pooling-type")] = None, + head_class_act: Annotated[Optional[str], Option("--head-class-act")] = None, + head_class_norm: Annotated[Optional[str], Option("--head-class-norm")] = None, + head_class_dropout: Annotated[float, Option("--head-class-dropout")] = 0.0, + fast_ultrafeedback: Annotated[bool, Option("--fast-ultrafeedback")] = False, + seeds: Annotated[List[int], Option("--seeds")] = [1618, 42, 6033, 3145], + verbose: Annotated[bool, Option("-v", "--verbose")] = False, + overwrite_existing_symlinks: Annotated[bool, Option("--override-existing-symlinks")] = False, + parallel: Annotated[bool, Option("--parallel/--single")] = False, + delete_eval_yamls: Annotated[bool, Option("--delete/--keep")] = False, + use_dir_names: Annotated[Optional[bool], Option("--use-dir-names")] = None, + gpu_ids: Annotated[Optional[List[int]], Option("--gpu-ids")] = None, + config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="YAML config file")] = None, +): + _main( + checkpoints=checkpoints, + train_config=train_config, + model_size=model_size, + rope_theta=rope_theta, + skip_generation=skip_generation, + run_all_yamls=run_all_yamls, + tasks=tasks, + hub_repo=hub_repo, + hub_files=hub_files, + hub_token=hub_token, + wandb_run=wandb_run, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + track_run=track_run, + track_run_project=track_run_project, + pooling_type=pooling_type, + head_class_act=head_class_act, + head_class_norm=head_class_norm, + head_class_dropout=head_class_dropout, + fast_ultrafeedback=fast_ultrafeedback, + seeds=seeds, + verbose=verbose, + overwrite_existing_symlinks=overwrite_existing_symlinks, + parallel=parallel, + delete_eval_yamls=delete_eval_yamls, + use_dir_names=use_dir_names, + gpu_ids=gpu_ids, + config=config, + ) + + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) @@ -769,12 +768,11 @@ def main( try: app() finally: - # Ensure all subprocesses are terminated when the script exits - for process in all_processes: - if process.poll() is None: - process.terminate() - for process in all_processes: + for p in all_processes: + if p.poll() is None: + p.terminate() + for p in all_processes: try: - process.wait(timeout=5) + p.wait(timeout=5) except subprocess.TimeoutExpired: - process.kill() + p.kill() From 6de99c06e4806b2d866aa6897239bcd365ef02e3 Mon Sep 17 00:00:00 2001 From: bclavie Date: Tue, 18 Mar 2025 01:28:50 +0000 Subject: [PATCH 2/8] code cleanup (redo) --- run_evals.py | 420 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 252 insertions(+), 168 deletions(-) diff --git a/run_evals.py b/run_evals.py index 3621844a..df9935ed 100644 --- a/run_evals.py +++ b/run_evals.py @@ -30,9 +30,11 @@ warnings.simplefilter("ignore", category=FutureWarning) from eval import GLUE_TASKS, SUPERGLUE_TASKS, TASK_NAME_TO_CLASS + # Create TaskName enum dynamically from TASK_NAME_TO_CLASS keys TaskName = Enum("TaskName", {name: name for name in TASK_NAME_TO_CLASS.keys()}, type=str) + app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}, pretty_exceptions_show_locals=False) @@ -55,10 +57,16 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option return config +# Global dictionary to keep track of GPUs with running jobs +# Changed to store more information per GPU gpus_in_use = {} +# Queue to keep track of GPUs that might be free potentially_free_gpus = deque() +# Global list to keep track of all running processes all_processes = [] -allowed_gpus = None + +# Global list to specify which GPUs to use +allowed_gpus = None # Will be set to list of GPU IDs or None console = Console() @@ -79,48 +87,49 @@ def kill_process_tree(pid: int): def signal_handler(signum, frame): print("\nReceived termination signal. Cleaning up subprocesses...") - for proc in all_processes: - if proc.poll() is None: - kill_process_tree(proc.pid) + for process in all_processes: + if process.poll() is None: # If the process is still running + kill_process_tree(process.pid) + print("Cleanup completed. Exiting.") - os._exit(0) + os._exit(0) # Force exit without running cleanup handlers -def get_gpu_memory_usage(gpu_id: int) -> Optional[int]: +def get_gpu_memory_usage(gpu_id): + """Get memory usage for a specific GPU.""" try: - out = ( + output = ( subprocess.check_output( - f"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader -i {gpu_id}", - shell=True + f"nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader -i {gpu_id}", shell=True ) .decode("utf-8") .strip() ) - return int(out) + return int(output) except subprocess.CalledProcessError: print(f"Failed to get memory usage for GPU {gpu_id}") return None -def get_free_gpu() -> Optional[int]: +def get_free_gpu(): + """Check for free GPUs, prioritizing potentially free GPUs.""" global allowed_gpus while potentially_free_gpus: gpu_id = potentially_free_gpus.popleft() if (allowed_gpus is None or gpu_id in allowed_gpus) and gpu_id not in gpus_in_use: - used = get_gpu_memory_usage(gpu_id) - if used is not None and used < 100: + memory_used = get_gpu_memory_usage(gpu_id) + if memory_used is not None and memory_used < 100: return gpu_id + # If no potentially free GPUs, check allowed GPUs try: gpu_output = subprocess.check_output( "nvidia-smi --query-gpu=index,memory.used --format=csv,nounits,noheader", shell=True ).decode("utf-8") for line in gpu_output.strip().split("\n"): - g_id_str, mem_str = line.split(",") - g_id_int = int(g_id_str) - mem_used = int(mem_str) - if (allowed_gpus is None or g_id_int in allowed_gpus) and mem_used < 100 and g_id_int not in gpus_in_use: - return g_id_int + gpu_id, memory_used = map(int, line.split(",")) + if (allowed_gpus is None or gpu_id in allowed_gpus) and memory_used < 100 and gpu_id not in gpus_in_use: + return gpu_id return None except subprocess.CalledProcessError: print("Failed to execute nvidia-smi") @@ -129,37 +138,46 @@ def get_free_gpu() -> Optional[int]: def run_subprocess(cmd: List[str], verbose: bool = False, show_errors: bool = False): stdout = None if verbose else subprocess.DEVNULL - stderr = None if (verbose or show_errors) else subprocess.DEVNULL - proc = subprocess.Popen(cmd, stdout=stdout, stderr=stderr) - all_processes.append(proc) - proc.wait() + stderr = None if verbose or show_errors else subprocess.DEVNULL + process = subprocess.Popen(cmd, stdout=stdout, stderr=stderr) + all_processes.append(process) # Add the process to the global list + process.wait() def handle_process_completion(process, stderr_file, config_path: Path, verbose: bool, gpu_id: Optional[int] = None): - code = process.returncode + """Handles the completion of a process, checks for errors, cleans up stderr_file, and logs messages.""" + returncode = process.returncode + + # Read and clean up stderr output if stderr_file is not None: stderr_file.seek(0) - error_out = stderr_file.read() + error_output = stderr_file.read() stderr_file.close() - os.unlink(stderr_file.name) + os.unlink(stderr_file.name) # Delete the temp file else: - error_out = "Error output was displayed above." + error_output = "Error output was displayed above." - job_label = f"Job for {config_path.name}" if gpu_id is None else f"Job on GPU {gpu_id} for {config_path.name}" + # Construct job identifier + if gpu_id is not None: + job_identifier = f"Job on GPU {gpu_id} for {config_path.name}" + else: + job_identifier = f"Job for {config_path.name}" - if code != 0: + if returncode != 0: + # The process exited with an error if verbose: - print(f"{job_label} failed with return code {code}") + print(f"{job_identifier} failed with return code {returncode}") print("Error Output:") - print(error_out) + print(error_output) else: - console.print(f"[red]{job_label} failed with return code {code}[/red]") - console.print(f"[red]Error Output:[/red]\n{error_out}") + console.print(f"[red]{job_identifier} failed with return code {returncode}[/red]") + console.print(f"[red]Error Output:[/red]\n{error_output}") else: + # The process completed successfully if verbose: - print(f"{job_label} finished successfully.") + print(f"{job_identifier} has finished successfully.") else: - console.log(f"{job_label} has finished successfully.") + console.log(f"{job_identifier} has finished successfully.") def run_job( @@ -169,17 +187,19 @@ def run_job( gpu_id: Optional[int] = None, gpu_ids: Optional[List[int]] = None, ): + """Run a job with optional GPU management.""" if gpu_id is not None: + # GPU management is required env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) elif gpu_ids is not None: env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_ids)) else: - env = None + env = None # Use default environment if verbose: - stdout = None + stdout = None # Output will be shown directly stderr = None stderr_file = None else: @@ -187,59 +207,69 @@ def run_job( stderr_file = tempfile.NamedTemporaryFile(mode="w+", delete=False) stderr = stderr_file - proc = subprocess.Popen(["python", "eval.py", str(config_path)], env=env, stdout=stdout, stderr=stderr) - all_processes.append(proc) + process = subprocess.Popen(["python", "eval.py", str(config_path)], env=env, stdout=stdout, stderr=stderr) + all_processes.append(process) # Add the process to the global list if gpu_id is not None: - gpus_in_use[gpu_id] = {"process": proc, "stderr_file": stderr_file, "config": config_path} + # Store process info for GPU management + gpus_in_use[gpu_id] = {"process": process, "stderr_file": stderr_file, "config": config_path} + else: - proc.wait() - handle_process_completion(proc, stderr_file, config_path, verbose, gpu_id=None) + process.wait() + handle_process_completion(process, stderr_file, config_path, verbose, gpu_id=None) if delete_eval_yamls: config_path.unlink() - return proc + return process def check_finished_jobs(verbose: bool = False): - done_gpus = [] + """Check for finished jobs and free up their GPUs.""" + finished_gpus = [] for gpu_id, info in gpus_in_use.items(): process = info["process"] stderr_file = info["stderr_file"] config = info["config"] - if process.poll() is not None: + if process.poll() is not None: # Job has finished + # Handle process completion handle_process_completion(process, stderr_file, config, verbose, gpu_id=gpu_id) - done_gpus.append(gpu_id) + finished_gpus.append(gpu_id) - for g in done_gpus: - del gpus_in_use[g] - potentially_free_gpus.append(g) + for gpu_id in finished_gpus: + del gpus_in_use[gpu_id] + potentially_free_gpus.append(gpu_id) def manage_jobs(configs: List[Path], verbose: bool = False, delete_eval_yamls: bool = True): + """Manage the launching of jobs for each configuration file in the directory.""" + if verbose: - for cfg in configs: + for config in configs: while True: check_finished_jobs(verbose) - free = get_free_gpu() - if free is not None: + gpu_id = get_free_gpu() + if gpu_id is not None: time.sleep(random.randint(0, 5)) - print(f"\nLaunching job for {cfg} on GPU {free}\n") - run_job(cfg, gpu_id=free, verbose=verbose, delete_eval_yamls=delete_eval_yamls) + print(f"\nLaunching job for {config} on GPU {gpu_id}\n") + run_job(config, gpu_id=gpu_id, verbose=verbose, delete_eval_yamls=delete_eval_yamls) break else: time.sleep(10) + + # Wait for all remaining jobs to finish while gpus_in_use: check_finished_jobs(verbose) time.sleep(10) else: + def update_progress_for_finished_jobs(): - for gpuid, info in list(gpus_in_use.items()): - prc = info["process"] - if prc.poll() is not None: - if gpuid in gpu_tasks: - gpu_progress.update(gpu_tasks[gpuid], completed=1, visible=False) + """Update progress bars for any finished GPU jobs.""" + for gpu_id, info in list(gpus_in_use.items()): + process = info["process"] + if process.poll() is not None: # Job finished + if gpu_id in gpu_tasks: + gpu_progress.update(gpu_tasks[gpu_id], completed=1, visible=False) completed_configs.add(info["config"]) overall_progress.update(overall_task, completed=len(completed_configs)) @@ -250,40 +280,41 @@ def update_progress_for_finished_jobs(): TextColumn("[progress.percentage]{task.completed}/{task.total}"), TimeElapsedColumn(), ) + gpu_progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), TimeElapsedColumn() ) progress_group = Group( Panel(overall_progress, title="Overall Progress", border_style="blue", padding=(1, 1)), - Panel(gpu_progress, title="GPU Jobs", border_style="green", padding=(1, 1)) + Panel(gpu_progress, title="GPU Jobs", border_style="green", padding=(1, 1)), ) with Live(progress_group, console=console, refresh_per_second=4): overall_task = overall_progress.add_task("[cyan]Overall Progress", total=len(configs)) gpu_tasks = {} - completed_configs = set() + completed_configs = set() # Track completed configs - for cfg in configs: + for config in configs: while True: check_finished_jobs(verbose) update_progress_for_finished_jobs() - free = get_free_gpu() - if free is not None: + gpu_id = get_free_gpu() + if gpu_id is not None: time.sleep(random.randint(0, 5)) - if free not in gpu_tasks: - gpu_tasks[free] = gpu_progress.add_task(f"[green]GPU {free}", total=1) + if gpu_id not in gpu_tasks: + gpu_tasks[gpu_id] = gpu_progress.add_task(f"[green]GPU {gpu_id}", total=1) else: - gpu_progress.update(gpu_tasks[free], completed=1, visible=False) - gpu_tasks[free] = gpu_progress.add_task(f"[green]GPU {free}", total=1) - - gpu_progress.update(gpu_tasks[free], description=f"[green]GPU {free}: {cfg.name}") - run_job(cfg, gpu_id=free, verbose=verbose, delete_eval_yamls=delete_eval_yamls) + gpu_progress.update(gpu_tasks[gpu_id], completed=1, visible=False) + gpu_tasks[gpu_id] = gpu_progress.add_task(f"[green]GPU {gpu_id}", total=1) + gpu_progress.update(gpu_tasks[gpu_id], description=f"[green]GPU {gpu_id}: {config.name}") + run_job(config, gpu_id=gpu_id, verbose=verbose, delete_eval_yamls=delete_eval_yamls) break else: time.sleep(10) + # Wait for all remaining jobs to finish while gpus_in_use: check_finished_jobs(verbose) update_progress_for_finished_jobs() @@ -292,56 +323,70 @@ def update_progress_for_finished_jobs(): overall_progress.update(overall_task, completed=len(configs)) if delete_eval_yamls: - for c in configs: + for config in configs: try: - c.unlink() + config.unlink() except FileNotFoundError: pass def create_symlink_for_newest_checkpoint(folder: Path, override_existing: bool = False): - if not folder.is_dir(): - return - - pt_files = list(folder.glob("*.pt")) - if not pt_files: - print(f" Warning: No .pt file found in {folder}, skipping symlink creation.") - return - - if len(pt_files) == 1 and pt_files[0].name == "latest-rank0.pt" and not pt_files[0].is_symlink(): - print(f" Only found one .pt in {folder.name}, named 'latest-rank0.pt' (real file). Skipping symlink creation.") - return - - def extract_nums(fp: Path): - m = re.search(r"ep(\d+)-ba(\d+)", fp.stem) - if m: - ep, ba = map(int, m.groups()) - return (ep, ba) - return (0, 0) - - newest_file = max(pt_files, key=extract_nums) - symlink_path = folder / "latest-rank0.pt" - - if symlink_path.is_symlink(): - if symlink_path.resolve() == newest_file.resolve(): - print(f" Existing symlink in {folder.name} already points to {newest_file.name}") + """Create a symlink to the newest checkpoint file if 'latest-rank0.pt' does not exist.""" + if folder.is_dir(): + pt_files = list(folder.glob("*.pt")) + if not pt_files: + print(f" Warning: No .pt file found in {folder}, skipping symlink creation.") return - else: - print(f" Warning: symlink in {folder.name} points to {symlink_path.resolve().name}, but newest is {newest_file.name}") + + if len(pt_files) == 1 and pt_files[0].name == "latest-rank0.pt" and not pt_files[0].is_symlink(): + print(f" Only found one .pt in {folder.name}, named 'latest-rank0.pt' (real file). Skipping symlink creation.") + return + + # Sort files based on epoch and batch numbers extracted from filenames + def extract_numbers(filename: Path): + if filename.is_symlink(): + return (0, 0) + if filename.name == "latest-rank0.pt": + return (0, 0) + + try: + # Using regex to find patterns of 'ep' followed by digits and 'ba' followed by digits + match = re.search(r"ep(\d+)-ba(\d+)", filename.stem) + if match: + epoch, batch = map(int, match.groups()) + return (epoch, batch) + else: + raise ValueError(f"Filename does not match expected pattern: {filename}") + except Exception as e: + print(f" Error extracting numbers from filename {filename}: {e}") + return (0, 0) + + newest_file = max(pt_files, key=extract_numbers) + + symlink_path = folder / "latest-rank0.pt" + if symlink_path.is_symlink(): + if symlink_path.resolve() == newest_file.resolve(): + print(f" Existing symlink in {folder.name} already points to {newest_file.name}") + return + else: + print( + f" Warning: symlink in {folder.name} points to {symlink_path.resolve().name}, " + f"but newest is {newest_file.name}" + ) + if not override_existing: + return + symlink_path.unlink(missing_ok=True) + elif symlink_path.exists(): if not override_existing: + print(f" {symlink_path.name} is a real file in {folder.name}. Use override to remove it.") return symlink_path.unlink(missing_ok=True) - elif symlink_path.exists(): - if not override_existing: - print(f" {symlink_path.name} is a real file in {folder.name}. Use override to remove it.") - return - symlink_path.unlink(missing_ok=True) - symlink_path.symlink_to(newest_file.name) - if override_existing: - print(f" Overwrote symlink {symlink_path.name} -> {newest_file.name}") - else: - print(f" Created new symlink {symlink_path.name} -> {newest_file.name}") + symlink_path.symlink_to(newest_file.name) + if override_existing: + print(f" Overwrote symlink {symlink_path.name} -> {newest_file.name}") + else: + print(f" Created new symlink {symlink_path.name} -> {newest_file.name}") def generate_eval_configs( @@ -356,7 +401,7 @@ def generate_eval_configs( head_class_act: Optional[str], head_class_norm: Optional[str], head_class_dropout: float, - tasks: Optional[List[Union[TaskName, str]]], + tasks: Optional[List[Union[TaskName, str]]], # type: ignore fast_ultrafeedback: bool, seeds: List[int], parallel: bool, @@ -365,21 +410,30 @@ def generate_eval_configs( rope_theta: Optional[float], gpu_ids: Optional[List[int]] = None, ): + """Generate evaluation configs for each checkpoint.""" + folders = [ - f - for f in checkpoints.glob("*") - if f.is_dir() and not f.name.startswith(".") and any(x.suffix == ".pt" for x in f.glob("*.pt")) + folder + for folder in checkpoints.glob("*") + if folder.is_dir() + and not folder.name.startswith(".") + and any(file.suffix == ".pt" for file in folder.glob("*.pt")) ] if use_dir_names is None and len(folders) > 1: use_dir_names = True - print("Using folder names as run names since multiple checkpoint folders found.") + print("Using folder names as run names since multiple `checkpoints` were provided with one `train_config`.") for folder in folders: cmd = [ - "python", "generate_eval_config.py", - "--checkpoint", str(folder), - "--output-dir", str(checkpoints), + "python", + "generate_eval_config.py", + "--checkpoint", + str(folder), + "--output-dir", + str(checkpoints), ] + + # Add optional arguments if they're provided if use_dir_names: cmd.append("--use-dir-name") if model_size: @@ -399,6 +453,7 @@ def generate_eval_configs( if track_run_project: cmd.extend(["--track-run-project", track_run_project]) + # Classification head options if pooling_type: cmd.extend(["--pooling-type", pooling_type]) if head_class_act: @@ -408,7 +463,7 @@ def generate_eval_configs( if head_class_dropout > 0: cmd.extend(["--head-class-dropout", str(head_class_dropout)]) - # Handle tasks as either TaskName or str + # Add tasks if tasks: for task in tasks: if hasattr(task, "value"): @@ -424,13 +479,13 @@ def generate_eval_configs( if parallel: cmd.append("--parallel") - else: - cmd.append("--single") if gpu_ids: if isinstance(gpu_ids, int): gpu_ids = [gpu_ids] for g in gpu_ids: cmd.extend(["--gpu-ids", str(g)]) + # Run the config generation process without suppressing output + run_subprocess(cmd, show_errors=True) if not train_config: time.sleep(1) @@ -444,7 +499,7 @@ def download_dataset(dataset_name: str, subset: Optional[str] = None): return f"Error in processing {dataset_name}: {e}" -def download_datasets(tasks: List[Union[TaskName, str]], msg_queue): +def download_datasets(tasks: List[Union[TaskName, str]], msg_queue): # type: ignore try: required_datasets = [] task_to_datasets = { @@ -467,32 +522,42 @@ def download_datasets(tasks: List[Union[TaskName, str]], msg_queue): extras = task_to_datasets.get(task_val, []) required_datasets.extend(extras) + # Suppress output globally in this process import sys + sys.stdout = open(os.devnull, "w") sys.stderr = open(os.devnull, "w") msgs = [] - for ds_name, subset in required_datasets: - datasets.load_dataset(ds_name, subset, trust_remote_code=True) - msgs.append(f"Successfully downloaded {ds_name} {subset}") - msg_queue.put("\n ".join([""] + msgs)) + for dataset_name, subset in required_datasets: + datasets.load_dataset(dataset_name, subset, trust_remote_code=True) + msgs.append(f"Successfully downloaded {dataset_name} {subset}") + msg_queue.put(" " + "\n ".join(msgs) + "\n") except Exception as e: msg_queue.put(f"Error in downloading datasets: {e}") def find_checkpoint_file(file_path: str, repo_files: List[str]) -> Optional[str]: import re - valid = [f for f in repo_files if f.startswith(file_path) and f.endswith((".pt", ".yaml"))] - if len(valid) == 1: - return valid[0] - def extract_nums(fn: str): - m = re.search(r"ep(\d+)-ba(\d+)", fn) - if m: - return tuple(map(int, m.groups())) - return (-1, -1) + # Filter files in the specified file_path that end with .pt or .yaml + valid_files = [file for file in repo_files if file.startswith(file_path) and file.endswith((".pt", ".yaml"))] + + if len(valid_files) == 1: + return valid_files[0] + + # Function to extract epoch and batch numbers from the filename + def extract_numbers(filename: str): + match = re.search(r"ep(\d+)-ba(\d+)", filename) + if match: + epoch, batch = map(int, match.groups()) + return epoch, batch + return -1, -1 # Return a default value for files that don't match the pattern - return max(valid, key=extract_nums, default=None) + # Find the newest file based on epoch and batch numbers + newest_file = max(valid_files, key=extract_numbers, default=None) + + return newest_file def download_hub_files( @@ -502,29 +567,41 @@ def download_hub_files( repo_type: str = "model", token: Optional[str] = None, ) -> List[Path]: + """Download specific files or the entire repository from a Hugging Face Hub repository.""" output_dir.mkdir(parents=True, exist_ok=True) - downloaded_files: List[Path] = [] + downloaded_files = [] def move_and_flatten_files(local_dir: Path): - for fp in local_dir.rglob("*"): - if fp.is_file() and fp.name.endswith((".pt", ".yaml")): - target_dir = output_dir / fp.parent.name - if fp.parent.resolve() in [target_dir.resolve(), output_dir.resolve()]: - downloaded_files.append(fp) + for file_path in local_dir.rglob("*"): + if file_path.is_file() and file_path.name.endswith((".pt", ".yaml")): + # Determine the target directory + target_dir = output_dir / file_path.parent.name + + # Check if the file is already in the correct location + if file_path.parent.resolve() in [target_dir.resolve(), output_dir.resolve()]: + downloaded_files.append(file_path) continue + + # Create the target directory if it doesn't exist target_dir.mkdir(parents=True, exist_ok=True) - new_path = target_dir / fp.name - fp.rename(new_path) + # Move the file to the target directory + new_path = target_dir / file_path.name + file_path.rename(new_path) downloaded_files.append(new_path) + # List all files in the repository api = HfApi() repo_files = api.list_repo_files(repo_id=repo_id, repo_type=repo_type, token=token) + try: if not filenames: + # Check if files already exist before downloading entire repository existing_files = list(output_dir.glob("**/*.pt")) + list(output_dir.glob("**/*.yaml")) if existing_files: - print(f"Found existing files in '{output_dir}', skipping snapshot_download.") + print(f"Found existing files in '{output_dir}', skipping download.") return existing_files + + # Download the entire repository local_dir = snapshot_download( repo_id=repo_id, repo_type=repo_type, @@ -534,22 +611,25 @@ def move_and_flatten_files(local_dir: Path): use_auth_token=token, ) move_and_flatten_files(Path(local_dir)) - print(f"Successfully downloaded entire repo '{repo_id}' -> '{output_dir}'.") + print(f"Successfully downloaded and flattened the repository '{repo_id}' to '{output_dir}'.") else: - for fn in filenames: - resolved_filename = find_checkpoint_file(fn, repo_files) + for filename in filenames: + resolved_filename = find_checkpoint_file(filename, repo_files) if not resolved_filename: - print(f"Warning: No match for '{fn}' in {repo_id}.") + print(f"Warning: Could not find matching file for '{filename}' in repository.") continue - just_name = Path(resolved_filename).name + + # Check if file exists in output_dir or any immediate subdirectory + filename = Path(resolved_filename).name parent_dir = Path(resolved_filename).parent.name - existing_fs = list(output_dir.glob(f"**/{parent_dir}/{just_name}")) - if existing_fs: - existing_file = existing_fs[0] - print(f"File '{parent_dir}/{just_name}' already exists at '{existing_file}', skip download.") + existing_files = list(output_dir.glob(f"**/{parent_dir}/{filename}")) + if existing_files: + existing_file = existing_files[0] + print(f"File '{parent_dir}/{filename}' already exists at '{existing_file}', skipping download.") downloaded_files.append(existing_file) continue + # Download the file _ = hf_hub_download( repo_id=repo_id, filename=resolved_filename, @@ -558,7 +638,7 @@ def move_and_flatten_files(local_dir: Path): local_dir=output_dir, cache_dir=None, ) - print(f"Downloaded '{resolved_filename}' from '{repo_id}'.") + print(f"Successfully downloaded '{resolved_filename}' from '{repo_id}'.") move_and_flatten_files(output_dir) except Exception as e: print(f"Error downloading from '{repo_id}': {e}") @@ -632,11 +712,12 @@ def _main( if not skip_generation: print("\nGenerating evaluation configs...\n") + if not run_all_yamls: config_files_completed = list(checkpoints.glob("*_evaluation.yaml")) - print("Skipping Completed Jobs (delete yamls to re-run):") - for c in config_files_completed: - print(f" {c.name}\n") + print("Skipping Completed Jobs (delete yamls to run):") + for config in config_files_completed: + print(f" {config.name}\n") else: config_files_completed = [] @@ -667,9 +748,10 @@ def _main( config_files = list(checkpoints.glob("*_evaluation.yaml")) print("Jobs to be run:") - for cfg in config_files: - print(f" {cfg.name}\n") + for config in config_files: + print(f" {config.name}\n") + # Wait for the dataset download to complete print("Waiting for dataset downloads to complete...") download_process.join() print("\nDataset downloading complete.") @@ -685,11 +767,12 @@ def _main( console.print(f"[bold green]Running {config_files[0].name} in parallel on GPUs: {gpu_ids}") run_job(config_files[0], verbose=verbose, delete_eval_yamls=delete_eval_yamls, gpu_ids=gpu_ids) else: - msg = "No evaluation config (.yaml) files found." + message = "No configuration files found in the specified directory." if verbose: - print(msg) + print(message) else: - console.print(f"[bold red]{msg}") + console.print(f"[bold red]{message}") + raise Exit(code=1) if verbose: @@ -723,7 +806,7 @@ def main( seeds: Annotated[List[int], Option("--seeds")] = [1618, 42, 6033, 3145], verbose: Annotated[bool, Option("-v", "--verbose")] = False, overwrite_existing_symlinks: Annotated[bool, Option("--override-existing-symlinks")] = False, - parallel: Annotated[bool, Option("--parallel/--single")] = False, + parallel: Annotated[bool, Option("--parallel")] = False, delete_eval_yamls: Annotated[bool, Option("--delete/--keep")] = False, use_dir_names: Annotated[Optional[bool], Option("--use-dir-names")] = None, gpu_ids: Annotated[Optional[List[int]], Option("--gpu-ids")] = None, @@ -768,11 +851,12 @@ def main( try: app() finally: - for p in all_processes: - if p.poll() is None: - p.terminate() - for p in all_processes: + # Ensure all subprocesses are terminated when the script exits + for process in all_processes: + if process.poll() is None: + process.terminate() + for process in all_processes: try: - p.wait(timeout=5) + process.wait(timeout=5) except subprocess.TimeoutExpired: - p.kill() + process.kill() From 9b3e6457d6dc8bba1b933fa2b7b5e6096474be0f Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 25 Mar 2025 12:29:30 -0500 Subject: [PATCH 3/8] fix run_evals --- run_evals.py | 162 ++++++++++++++++----------------------------------- 1 file changed, 51 insertions(+), 111 deletions(-) diff --git a/run_evals.py b/run_evals.py index df9935ed..421f0772 100644 --- a/run_evals.py +++ b/run_evals.py @@ -13,7 +13,7 @@ from enum import Enum from multiprocessing import Process, Queue from pathlib import Path -from typing import Annotated, List, Optional, Union +from typing import Annotated, List, Optional import datasets import psutil @@ -43,6 +43,7 @@ class ModelSize(str, Enum): LARGE = "large" HUGE = "huge" + # from maxb2: https://github.com/tiangolo/typer/issues/86#issuecomment-996374166 def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Optional[str] = None): if config is not None: @@ -70,6 +71,7 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option console = Console() + def kill_process_tree(pid: int): try: parent = psutil.Process(pid) @@ -339,7 +341,7 @@ def create_symlink_for_newest_checkpoint(folder: Path, override_existing: bool = return if len(pt_files) == 1 and pt_files[0].name == "latest-rank0.pt" and not pt_files[0].is_symlink(): - print(f" Only found one .pt in {folder.name}, named 'latest-rank0.pt' (real file). Skipping symlink creation.") + print(f" Only found one .pt in {folder.name}, named 'latest-rank0.pt' (real file). Skipping symlink creation.") # fmt: skip return # Sort files based on epoch and batch numbers extracted from filenames @@ -401,7 +403,7 @@ def generate_eval_configs( head_class_act: Optional[str], head_class_norm: Optional[str], head_class_dropout: float, - tasks: Optional[List[Union[TaskName, str]]], # type: ignore + tasks: Optional[List[TaskName]], # type: ignore fast_ultrafeedback: bool, seeds: List[int], parallel: bool, @@ -481,8 +483,10 @@ def generate_eval_configs( cmd.append("--parallel") if gpu_ids: - if isinstance(gpu_ids, int): gpu_ids = [gpu_ids] - for g in gpu_ids: cmd.extend(["--gpu-ids", str(g)]) + if isinstance(gpu_ids, int): + gpu_ids = [gpu_ids] + for g in gpu_ids: + cmd.extend(["--gpu-ids", str(g)]) # Run the config generation process without suppressing output @@ -499,7 +503,7 @@ def download_dataset(dataset_name: str, subset: Optional[str] = None): return f"Error in processing {dataset_name}: {e}" -def download_datasets(tasks: List[Union[TaskName, str]], msg_queue): # type: ignore +def download_datasets(tasks: List[TaskName], msg_queue): # type: ignore try: required_datasets = [] task_to_datasets = { @@ -508,19 +512,15 @@ def download_datasets(tasks: List[Union[TaskName, str]], msg_queue): # type: ig "eurlex": [["coastalcph/lex_glue", "eurlex"]], "ultrafeedback": [["rbiswasfc/ultrafeedback-binary-classification"]], } - for t in tasks: - if hasattr(t, "value"): - task_val = t.value - else: - task_val = str(t) - if task_val in GLUE_TASKS: - required_datasets.append(["glue", task_val]) - elif task_val in SUPERGLUE_TASKS: - required_datasets.append(["aps/super_glue", task_val]) + for task in tasks: + if task.value in GLUE_TASKS: + datasets_info = [["glue", task.value]] + elif task.value in SUPERGLUE_TASKS: + datasets_info = [["aps/super_glue", task.value]] else: - extras = task_to_datasets.get(task_val, []) - required_datasets.extend(extras) + datasets_info = task_to_datasets.get(task.value, []) + required_datasets.extend(datasets_info) # Suppress output globally in this process import sys @@ -646,36 +646,38 @@ def move_and_flatten_files(local_dir: Path): return downloaded_files -def _main( - checkpoints: Union[str, Path], - train_config: Optional[Union[str, Path]] = None, - model_size: ModelSize = ModelSize.BASE, - rope_theta: Optional[float] = None, - skip_generation: bool = False, - run_all_yamls: bool = False, - tasks: Optional[List[Union[str, TaskName]]] = None, - hub_repo: Optional[str] = None, - hub_files: Optional[List[str]] = None, - hub_token: Optional[str] = None, - wandb_run: Optional[str] = None, - wandb_project: Optional[str] = None, - wandb_entity: Optional[str] = None, - track_run: bool = False, - track_run_project: Optional[str] = None, - pooling_type: Optional[str] = None, - head_class_act: Optional[str] = None, - head_class_norm: Optional[str] = None, - head_class_dropout: float = 0.0, - fast_ultrafeedback: bool = False, - seeds: List[int] = [1618, 42, 6033, 3145], - verbose: bool = False, - overwrite_existing_symlinks: bool = False, - parallel: bool = False, - delete_eval_yamls: bool = False, - use_dir_names: Optional[bool] = None, - gpu_ids: Optional[List[int]] = None, - config: Optional[Union[str, Path]] = None, -): +@app.command() +def main( + checkpoints: Annotated[Path, Option(help="Path to the directory containing FlexBert checkpoints or location to download checkpoints from Hugging Face Hub to", rich_help_panel="Checkpoint & Config Paths", show_default=False)], + train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, + model_size: Annotated[ModelSize, Option("--model-size", help="Model to use for default model config", rich_help_panel="Checkpoint & Config Paths")] = ModelSize.BASE, + rope_theta: Annotated[Optional[float], Option("--rope-theta", help="Value for `rotary_emb_base` in the model configuration. If not provided, defaults to pretraining value of 10000.0", rich_help_panel="Checkpoint & Config Paths")] = None, + skip_generation: Annotated[bool, Option("--skip-generation", help="Skip generation of evaluation configs. If not true, assumes all existing eval yamls have been already ran.", rich_help_panel="Checkpoint & Config Paths")] = False, + run_all_yamls: Annotated[bool, Option("--run-all-yamls", help="Run all evaluation yamls in the `checkpoints` directory, even if some have already been run.", rich_help_panel="Checkpoint & Config Paths")] = False, + tasks: Annotated[Optional[List[TaskName]], Option(help="List of tasks to include in the evaluation. Default is all tasks.", rich_help_panel="Eval Tasks", case_sensitive=False, show_default=False)] = None, # type: ignore + hub_repo: Annotated[Optional[str], Option(help="Hugging Face Hub repository ID to download FlexBert weights. Downloads to `checkpoints` directory.", rich_help_panel="Hugging Face Download")] = None, + hub_files: Annotated[Optional[List[str]], Option(help="List of files to download from the `hub_repo`. If not provided, will download all files in the repo.", rich_help_panel="Hugging Face Download")] = None, + hub_token: Annotated[Optional[str], Option(help="Authentication token for private Hugging Face Hub repositories if not already logged in via `huggingface-cli login`", rich_help_panel="Hugging Face Download")] = None, + wandb_run: Annotated[Optional[str], Option(help="wandb run containing the training configuration", rich_help_panel="Weights & Biases")] = None, + wandb_project: Annotated[Optional[str], Option(help="wandb project for the run", rich_help_panel="Weights & Biases")] = None, + wandb_entity: Annotated[Optional[str], Option(help="wandb entity for the project", rich_help_panel="Weights & Biases")] = None, + track_run: Annotated[bool, Option("--track-run", help="Track the eval run with wandb", rich_help_panel="Weights & Biases")] = False, + track_run_project: Annotated[Optional[str], Option(help="wandb project for tracking the run", rich_help_panel="Weights & Biases")] = None, + pooling_type: Annotated[Optional[str], Option(help="Pooling type for the classification head", show_default=False, rich_help_panel="Model Options")] = None, + head_class_act: Annotated[Optional[str], Option(help="Classification head activation function", show_default=False, rich_help_panel="Model Options")] = None, + head_class_norm: Annotated[Optional[str], Option(help="Classification head normalization function", show_default=False, rich_help_panel="Model Options")] = None, + head_class_dropout: Annotated[float, Option(help="Classification head dropout rate", rich_help_panel="Model Options")] = 0.0, + fast_ultrafeedback: Annotated[bool, Option("--fast-ultrafeedback", help="Use a shorter sequence length (1536) for the UltraFeedback eval", rich_help_panel="Task Settings")] = False, + seeds: Annotated[List[int], Option(help="List of seeds to use for the eval", rich_help_panel="Task Settings")] = [1618, 42, 6033, 3145], + verbose: Annotated[bool, Option("-v", "--verbose", help="Show detailed output from evaluation jobs", rich_help_panel="Config Options")] = False, + overwrite_existing_symlinks: Annotated[bool, Option("--override-existing-symlinks", help="Overwrite existing symlinks to point to latest checkpoint", rich_help_panel="Config Options")] = False, + parallel: Annotated[bool, Option("--parallel/--single", help="Run the evals in parallel on multiple GPUs or one GPU. Use `parallel` if passing to `config`. Only use if evaluating a single checkpoint on multiple GPUs.", rich_help_panel="Task Settings")] = False, + delete_eval_yamls: Annotated[bool, Option("--delete/--keep", help="Delete all evaluation YAML files after running the evals. Use `delete_eval_yamls` if passing to `config`", rich_help_panel="Config Options")] = False, + use_dir_names: Annotated[Optional[bool], Option("--use-dir-names", help="Use the folder names as the wandb run names. Defaults to true if multiple `checkpoints` are provided with one `train_config`", rich_help_panel="Config Options")] = None, + gpu_ids: Annotated[Optional[List[int]], Option(help="List of GPU IDs to use", rich_help_panel="GPU Options")] = None, + config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, +): # fmt: skip + """Run evaluations on model checkpoints.""" if isinstance(checkpoints, str): checkpoints = Path(checkpoints) if isinstance(train_config, str): @@ -690,7 +692,7 @@ def _main( repo_id=hub_repo, filenames=hub_files, output_dir=checkpoints, - token=hub_token + token=hub_token, ) if not downloaded_files: print("No files were downloaded successfully. Exiting.") @@ -781,69 +783,7 @@ def _main( console.print("[bold green]All jobs completed.") -@app.command() -def main( - checkpoints: Annotated[Path, Option("--checkpoints", help="Directory for model checkpoints.")], - train_config: Annotated[Optional[Path], Option(help="Path to .yaml config")] = None, - model_size: Annotated[ModelSize, Option("--model-size")] = ModelSize.BASE, - rope_theta: Annotated[Optional[float], Option("--rope-theta")] = None, - skip_generation: Annotated[bool, Option("--skip-generation")] = False, - run_all_yamls: Annotated[bool, Option("--run-all-yamls")] = False, - tasks: Annotated[Optional[List[TaskName]], Option(help="Tasks")] = None, - hub_repo: Annotated[Optional[str], Option("--hub-repo")] = None, - hub_files: Annotated[Optional[List[str]], Option("--hub-files")] = None, - hub_token: Annotated[Optional[str], Option("--hub-token")] = None, - wandb_run: Annotated[Optional[str], Option("--wandb-run")] = None, - wandb_project: Annotated[Optional[str], Option("--wandb-project")] = None, - wandb_entity: Annotated[Optional[str], Option("--wandb-entity")] = None, - track_run: Annotated[bool, Option("--track-run")] = False, - track_run_project: Annotated[Optional[str], Option("--track-run-project")] = None, - pooling_type: Annotated[Optional[str], Option("--pooling-type")] = None, - head_class_act: Annotated[Optional[str], Option("--head-class-act")] = None, - head_class_norm: Annotated[Optional[str], Option("--head-class-norm")] = None, - head_class_dropout: Annotated[float, Option("--head-class-dropout")] = 0.0, - fast_ultrafeedback: Annotated[bool, Option("--fast-ultrafeedback")] = False, - seeds: Annotated[List[int], Option("--seeds")] = [1618, 42, 6033, 3145], - verbose: Annotated[bool, Option("-v", "--verbose")] = False, - overwrite_existing_symlinks: Annotated[bool, Option("--override-existing-symlinks")] = False, - parallel: Annotated[bool, Option("--parallel")] = False, - delete_eval_yamls: Annotated[bool, Option("--delete/--keep")] = False, - use_dir_names: Annotated[Optional[bool], Option("--use-dir-names")] = None, - gpu_ids: Annotated[Optional[List[int]], Option("--gpu-ids")] = None, - config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="YAML config file")] = None, -): - _main( - checkpoints=checkpoints, - train_config=train_config, - model_size=model_size, - rope_theta=rope_theta, - skip_generation=skip_generation, - run_all_yamls=run_all_yamls, - tasks=tasks, - hub_repo=hub_repo, - hub_files=hub_files, - hub_token=hub_token, - wandb_run=wandb_run, - wandb_project=wandb_project, - wandb_entity=wandb_entity, - track_run=track_run, - track_run_project=track_run_project, - pooling_type=pooling_type, - head_class_act=head_class_act, - head_class_norm=head_class_norm, - head_class_dropout=head_class_dropout, - fast_ultrafeedback=fast_ultrafeedback, - seeds=seeds, - verbose=verbose, - overwrite_existing_symlinks=overwrite_existing_symlinks, - parallel=parallel, - delete_eval_yamls=delete_eval_yamls, - use_dir_names=use_dir_names, - gpu_ids=gpu_ids, - config=config, - ) - - +# Register the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) From 9f80f95d99c21caec8b314012b13d89e875b3e66 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 25 Mar 2025 14:19:33 -0500 Subject: [PATCH 4/8] convert to typer cli app with optional config file --- checkpoint_eval_monitor.py | 145 ++++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 58 deletions(-) diff --git a/checkpoint_eval_monitor.py b/checkpoint_eval_monitor.py index 8e7e03a2..b3046e66 100644 --- a/checkpoint_eval_monitor.py +++ b/checkpoint_eval_monitor.py @@ -1,49 +1,41 @@ -import os -import time import json -import sys import logging -import argparse +import os +import sys +import time from pathlib import Path -from typing import Set +from typing import Annotated, List, Optional, Set + +import typer +import yaml from huggingface_hub import HfApi, list_repo_files -from run_evals import _main as eval_main +from typer import Option + +from run_evals import main as eval_main +from run_evals import TaskName logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler(sys.stdout)] + handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("poller") +app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}, pretty_exceptions_show_locals=False) -def parse_args(): - parser = argparse.ArgumentParser( - description="Poll a Hugging Face repo for new .pt checkpoints (with 'rank' in filename); call run_evals." - ) - parser.add_argument("--repo_id", type=str, default="PLACEHOLDER", - help="Hugging Face repo ID to monitor for new checkpoints") - parser.add_argument("--token", type=str, default=None, - help="Optional HF API token for private repos") - parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints", - help="Local directory to store or download checkpoints; " - "this is passed to run_evals._main(..., checkpoints=...)") - parser.add_argument("--poll_interval", type=int, default=60, - help="How many seconds to wait between polls") - - parser.add_argument("--wandb_project", type=str, default=None, - help="Optional W&B project to pass to _main") - parser.add_argument("--wandb_entity", type=str, default=None, - help="Optional W&B entity to pass to _main") - parser.add_argument("--tasks", nargs="+", default=["mnli"], - help="Which tasks to evaluate. Will pass as a list of strings to _main.") - parser.add_argument("--seeds", nargs="+", type=int, default=[42, 314, 1234], - help="Random seeds to pass to _main") - parser.add_argument("--gpu_ids", nargs="+", type=int, default=[3,4,5], - help="Optional list of GPU IDs to use for evaluation") - parser.add_argument("--skip_generation", action="store_true", - help="If set, pass skip_generation=True to _main") - return parser.parse_args() + +# from maxb2: https://github.com/tiangolo/typer/issues/86#issuecomment-996374166 +def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Optional[str] = None): + if config is not None: + typer.echo(f"Loading config file: {config}\n") + try: + with open(config, "r") as f: # Load config file + conf = yaml.safe_load(f) + ctx.default_map = ctx.default_map or {} # Initialize the default map + ctx.default_map.update(conf) # Merge the config dict into default_map + except Exception as ex: + raise typer.BadParameter(str(ex)) + return config def load_processed(file_path: str = "processed_checkpoints.json") -> Set[str]: @@ -82,24 +74,35 @@ def find_new_checkpoints(files_in_repo: list[str], processed: Set[str]) -> Set[s return new_ckpts -def poll_loop(args): +def poll_loop( + repo_id: str, + token: Optional[str], + checkpoint_dir: str, + poll_interval: int, + wandb_project: Optional[str], + wandb_entity: Optional[str], + tasks: List[str], + seeds: List[int], + gpu_ids: List[int], + skip_generation: bool, +): """ - Main polling loop: + Main polling loop: - check the HF repo for new .pt files - pass them to run_evals.programmatic_main - record them in JSON - sleep """ - hf_api = HfApi(token=args.token) + hf_api = HfApi(token=token) processed = load_processed() - logger.info(f"Starting poller for {args.repo_id}") - logger.info(f"Polling every {args.poll_interval} seconds.\n") + logger.info(f"Starting poller for {repo_id}") + logger.info(f"Polling every {poll_interval} seconds.\n") while True: try: - logger.info(f"Checking for new checkpoints in {args.repo_id}...") - repo_files = list_repo_files(args.repo_id, token=args.token) + logger.info(f"Checking for new checkpoints in {repo_id}...") + repo_files = list_repo_files(repo_id, token=token) new_ckpts = find_new_checkpoints(repo_files, processed) if not new_ckpts: @@ -111,16 +114,16 @@ def poll_loop(args): try: eval_main( - checkpoints=args.checkpoint_dir, - hub_repo=args.repo_id, + checkpoints=checkpoint_dir, + hub_repo=repo_id, hub_files=[ckpt], - hub_token=args.token, - wandb_project=args.wandb_project, - wandb_entity=args.wandb_entity, - tasks=args.tasks, - seeds=args.seeds, - skip_generation=args.skip_generation, - gpu_ids=args.gpu_ids, + hub_token=token, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + tasks=tasks, + seeds=seeds, + skip_generation=skip_generation, + gpu_ids=gpu_ids, verbose=True, parallel=True, ) @@ -133,14 +136,40 @@ def poll_loop(args): except Exception as e: logger.error(f"Error in poll loop: {e}", exc_info=True) - logger.info(f"Sleeping {args.poll_interval} seconds...\n") - time.sleep(args.poll_interval) - - -def main(): - args = parse_args() - poll_loop(args) + logger.info(f"Sleeping {poll_interval} seconds...\n") + time.sleep(poll_interval) + + +@app.command() +def main( + repo_id: Annotated[str, Option(help="Hugging Face repo ID to monitor for new checkpoints", show_default=False)], + token: Annotated[Optional[str], Option(help="Optional HF API token for private repos")] = None, + checkpoint_dir: Annotated[Path, Option(help="Local directory to store or download checkpoints")] = "./checkpoints", + poll_interval: Annotated[int, Option(help="How many seconds to wait between polls")] = 60, + wandb_project: Annotated[Optional[str], Option(help="Optional W&B project to pass to eval script")] = None, + wandb_entity: Annotated[Optional[str], Option(help="Optional W&B entity to pass to eval script")] = None, + tasks: Annotated[List[TaskName], Option(help="Which tasks to evaluate")] = [TaskName.mnli], # type: ignore + seeds: Annotated[List[int], Option(help="Random seeds to pass to _main")] = [42, 314, 1234], + gpu_ids: Annotated[Optional[List[int]], Option(help="Optional list of GPU IDs to use for evaluation")] = None, + skip_generation: Annotated[bool, Option(help="If set, pass skip_generation=True to eval script")] = False, + config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, +): # fmt: skip + """ + Poll a Hugging Face repo for new .pt checkpoints (with 'rank' in filename); call run_evals. + """ + poll_loop( + repo_id=repo_id, + token=token, + checkpoint_dir=checkpoint_dir, + poll_interval=poll_interval, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + tasks=tasks, + seeds=seeds, + gpu_ids=gpu_ids, + skip_generation=skip_generation, + ) if __name__ == "__main__": - main() + app() From edefd055896da0b15a3c70e912ae2ab853e83c8a Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 16 May 2025 00:47:30 +0000 Subject: [PATCH 5/8] take in train config in an easier way --- checkpoint_eval_monitor.py | 30 ++++++++++++++++++++++++++++++ run_evals.py | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/checkpoint_eval_monitor.py b/checkpoint_eval_monitor.py index b3046e66..e91f8f72 100644 --- a/checkpoint_eval_monitor.py +++ b/checkpoint_eval_monitor.py @@ -1,3 +1,27 @@ +# Filter out specific FutureWarnings from flash_attn +import warnings +import re + +# Define the warning patterns to filter +warning_patterns = [ + r"torch\.cuda\.amp\.custom_fwd.*is deprecated", + r"torch\.cuda\.amp\.custom_bwd.*is deprecated" +] + +# Create a filter function +def filter_flash_attn_warnings(message, category, filename, lineno, file=None, line=None): + # Check if it's a FutureWarning from flash_attn + if category == FutureWarning and "flash_attn" in filename: + # Check if the message matches any of our patterns + for pattern in warning_patterns: + if re.search(pattern, str(message)): + return None # Suppress the warning + # Return anything else + return True # Show other warnings + +# Apply the filter +warnings.filterwarnings("ignore", category=FutureWarning, module="flash_attn") + import json import logging import os @@ -85,6 +109,7 @@ def poll_loop( seeds: List[int], gpu_ids: List[int], skip_generation: bool, + train_config: Optional[Path], ): """ Main polling loop: @@ -124,6 +149,7 @@ def poll_loop( seeds=seeds, skip_generation=skip_generation, gpu_ids=gpu_ids, + train_config=train_config, verbose=True, parallel=True, ) @@ -146,12 +172,14 @@ def main( token: Annotated[Optional[str], Option(help="Optional HF API token for private repos")] = None, checkpoint_dir: Annotated[Path, Option(help="Local directory to store or download checkpoints")] = "./checkpoints", poll_interval: Annotated[int, Option(help="How many seconds to wait between polls")] = 60, + # wandb_run: Annotated[Optional[str], Option(help="Optional W&B run to pass to eval script")] = None, wandb_project: Annotated[Optional[str], Option(help="Optional W&B project to pass to eval script")] = None, wandb_entity: Annotated[Optional[str], Option(help="Optional W&B entity to pass to eval script")] = None, tasks: Annotated[List[TaskName], Option(help="Which tasks to evaluate")] = [TaskName.mnli], # type: ignore seeds: Annotated[List[int], Option(help="Random seeds to pass to _main")] = [42, 314, 1234], gpu_ids: Annotated[Optional[List[int]], Option(help="Optional list of GPU IDs to use for evaluation")] = None, skip_generation: Annotated[bool, Option(help="If set, pass skip_generation=True to eval script")] = False, + train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip """ @@ -162,12 +190,14 @@ def main( token=token, checkpoint_dir=checkpoint_dir, poll_interval=poll_interval, + # wandb_run=wandb_run, wandb_project=wandb_project, wandb_entity=wandb_entity, tasks=tasks, seeds=seeds, gpu_ids=gpu_ids, skip_generation=skip_generation, + train_config=train_config, ) diff --git a/run_evals.py b/run_evals.py index 421f0772..f0147853 100644 --- a/run_evals.py +++ b/run_evals.py @@ -650,7 +650,7 @@ def move_and_flatten_files(local_dir: Path): def main( checkpoints: Annotated[Path, Option(help="Path to the directory containing FlexBert checkpoints or location to download checkpoints from Hugging Face Hub to", rich_help_panel="Checkpoint & Config Paths", show_default=False)], train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, - model_size: Annotated[ModelSize, Option("--model-size", help="Model to use for default model config", rich_help_panel="Checkpoint & Config Paths")] = ModelSize.BASE, + model_size: Annotated[ModelSize, Option("--model-size", help="Model to use for default model config", rich_help_panel="Checkpoint & Config Paths")] = ModelSize.HUGE, rope_theta: Annotated[Optional[float], Option("--rope-theta", help="Value for `rotary_emb_base` in the model configuration. If not provided, defaults to pretraining value of 10000.0", rich_help_panel="Checkpoint & Config Paths")] = None, skip_generation: Annotated[bool, Option("--skip-generation", help="Skip generation of evaluation configs. If not true, assumes all existing eval yamls have been already ran.", rich_help_panel="Checkpoint & Config Paths")] = False, run_all_yamls: Annotated[bool, Option("--run-all-yamls", help="Run all evaluation yamls in the `checkpoints` directory, even if some have already been run.", rich_help_panel="Checkpoint & Config Paths")] = False, From 56e646a08d14647e0d2c0edc4552629e5eefdaee Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 16 May 2025 00:48:54 +0000 Subject: [PATCH 6/8] ignore latest-rank --- checkpoint_eval_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checkpoint_eval_monitor.py b/checkpoint_eval_monitor.py index e91f8f72..b80cd584 100644 --- a/checkpoint_eval_monitor.py +++ b/checkpoint_eval_monitor.py @@ -93,7 +93,7 @@ def find_new_checkpoints(files_in_repo: list[str], processed: Set[str]) -> Set[s """ new_ckpts = set() for f in files_in_repo: - if f.endswith(".pt") and "rank" in f and f not in processed: + if f.endswith(".pt") and "rank" in f and f not in processed and 'latest' not in f: new_ckpts.add(f) return new_ckpts From 107048376742ec3190718b23613034583171a2c0 Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 16 May 2025 00:50:27 +0000 Subject: [PATCH 7/8] add checkpoint uploader --- hf_checkpoints_uploader.py | 174 +++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 hf_checkpoints_uploader.py diff --git a/hf_checkpoints_uploader.py b/hf_checkpoints_uploader.py new file mode 100644 index 00000000..35c4cabb --- /dev/null +++ b/hf_checkpoints_uploader.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import re +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Set + +import typer +import yaml +from huggingface_hub import HfApi, hf_hub_download, list_repo_files +from huggingface_hub.utils import HfHubHTTPError + +CHECKPOINT_RE = re.compile(r"ep\d+-ba(\d+)-rank0\.pt$") + +app = typer.Typer(pretty_exceptions_show_locals=False, context_settings={"help_option_names": ["-h", "--help"]}) + + +def find_local_checkpoints(base_dir: Path, model_dirs: List[str]) -> Dict[str, Path]: + """Return {repo_path: local_path} for every `ep*-ba*-rank0.pt` under each model dir.""" + ckpts: Dict[str, Path] = {} + for mdir in model_dirs: + for path in (base_dir / mdir).glob("ep*-ba*-rank0.pt"): + if path.name.startswith("latest-"): + continue # Skip alias + if CHECKPOINT_RE.match(path.name): + ckpts[f"{mdir}/{path.name}"] = path + return ckpts + + +def find_remote_checkpoints(api: HfApi, repo_id: str, token: Optional[str]) -> Set[str]: + """Return the set of path strings already present in the HF repo.""" + try: + return set(list_repo_files(repo_id, token=token)) + except HfHubHTTPError as e: + typer.secho(f"❌ Cannot list repo files: {e}", fg=typer.colors.RED, err=True) + raise typer.Exit(1) + + +def upload_file( + api: HfApi, + repo_id: str, + local_path: Path, + path_in_repo: str, + token: Optional[str], + commit_msg: str = "Add checkpoint", +): + print('Uploading ', local_path) + api.upload_file( + path_or_fileobj=str(local_path), + path_in_repo=path_in_repo, + repo_id=repo_id, + token=token, + repo_type="model", + commit_message=commit_msg, + ) + typer.echo(f"✅ Uploaded {path_in_repo}") + + +def catchup_upload( + api: HfApi, + repo_id: str, + base_dir: Path, + model_dirs: List[str], + token: Optional[str], +): + typer.echo("🔍 Running one-off catch-up scan …") + + local_ckpts = find_local_checkpoints(base_dir, model_dirs) + remote_ckpts = find_remote_checkpoints(api, repo_id, token) + + to_upload = [ + (repo_path, local_path) + for repo_path, local_path in local_ckpts.items() + if repo_path not in remote_ckpts + ] + + # Order by batch number (int) to keep history tidy + to_upload.sort(key=lambda x: int(CHECKPOINT_RE.search(x[0]).group(1))) # type: ignore + + for repo_path, local_path in to_upload: + upload_file(api, repo_id, local_path, repo_path, token, commit_msg="Catch-up upload") + + if not to_upload: + typer.echo("✨ Repo already up to date.") + + +def poll_loop( + api: HfApi, + repo_id: str, + base_dir: Path, + model_dirs: List[str], + poll_interval: int, + token: Optional[str], +): + typer.echo(f"🔄 Entering polling loop (every {poll_interval}s) …\n") + + while True: + try: + local_ckpts = find_local_checkpoints(base_dir, model_dirs) + remote_ckpts = find_remote_checkpoints(api, repo_id, token) + + new_items = [ + (rp, lp) for rp, lp in local_ckpts.items() if rp not in remote_ckpts + ] + new_items.sort(key=lambda x: int(CHECKPOINT_RE.search(x[0]).group(1))) # type: ignore + + for repo_path, local_path in new_items: + upload_file(api, repo_id, local_path, repo_path, token, commit_msg="Add checkpoint") + + except Exception as e: + # Log and continue; do not kill the loop + typer.secho(f"⚠️ Error in poll loop: {e}", fg=typer.colors.YELLOW, err=True) + + time.sleep(poll_interval) + + +# --------------------------- CLI ------------------------------------------------- + + +def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Optional[str] = None): + """Merge YAML config into Typer defaults (same helper you already use).""" + if config: + with open(config, "r") as f: + cfg = yaml.safe_load(f) + ctx.default_map = ctx.default_map or {} + ctx.default_map.update(cfg) + return config + + +@app.command() +def main( + repo_id: str = typer.Option(..., help="HF repo to push to, e.g. answerdotai/huge-in-run-checkpoints"), + base_dir: Path = typer.Option( + Path("/mnt/nfs/bert24/checkpoints"), help="Root with model_dir/checkpoints" + ), + model_dirs: List[str] = typer.Option( + ["modernbert-huge-pretrain-v1"], help="One or more sub-dirs to watch" + ), + token: Optional[str] = typer.Option(None, help="HF token (or set HF_TOKEN env var)"), + poll_interval: int = typer.Option(60, help="Seconds between scans after catch-up"), + once: bool = typer.Option( + False, "--once", help="Exit after the catch-up pass (no polling)" + ), + config: Optional[Path] = typer.Option( + None, + "--config", + callback=conf_callback, + is_eager=True, + help="YAML file with default values (CLI overrides)", + ), +): # fmt: skip + """ + Upload all `ep*-ba*-rank0.pt` checkpoints found under *base_dir/model_dir/* to Hugging Face Hub. + + 1. Performs an initial catch-up (only missing files are pushed). + 2. Unless `--once` is given, keeps polling local dirs for fresh checkpoints. + """ + + api = HfApi(token=token) + + catchup_upload(api, repo_id, base_dir, model_dirs, token) + + if once: + typer.echo("🏁 Done (catch-up only).") + return + + poll_loop(api, repo_id, base_dir, model_dirs, poll_interval, token) + + +if __name__ == "__main__": + app() \ No newline at end of file From a41f8e83a28994369574849fc0f6fd089d998f09 Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 16 May 2025 04:44:04 +0000 Subject: [PATCH 8/8] functional --- checkpoint_eval_monitor.py | 17 ++++++++++++++--- generate_eval_config.py | 5 ++++- hf_checkpoints_uploader.py | 4 ++-- processed_checkpoints.json | 1 + run_evals.py | 18 ++++++++++++++++-- 5 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 processed_checkpoints.json diff --git a/checkpoint_eval_monitor.py b/checkpoint_eval_monitor.py index b80cd584..a26ed9b3 100644 --- a/checkpoint_eval_monitor.py +++ b/checkpoint_eval_monitor.py @@ -110,6 +110,9 @@ def poll_loop( gpu_ids: List[int], skip_generation: bool, train_config: Optional[Path], + wandb_run: Optional[str] = None, + track_run: bool = True, + track_run_project: Optional[str] = None, ): """ Main polling loop: @@ -134,7 +137,8 @@ def poll_loop( logger.info("No new checkpoints found.") else: for ckpt in new_ckpts: - logger.info(f"Found new checkpoint: {ckpt}") + eval_batch_count = str(ckpt).split('ba')[1].split('-rank0.pt')[0] + logger.info(f"Found new checkpoint: {ckpt} with eval_batch_count: {eval_batch_count}") logger.info("Calling run_evals.programmatic_main(...) on that checkpoint...") try: @@ -149,9 +153,12 @@ def poll_loop( seeds=seeds, skip_generation=skip_generation, gpu_ids=gpu_ids, + eval_batch_count=eval_batch_count, train_config=train_config, verbose=True, parallel=True, + track_run=track_run, + track_run_project=track_run_project, ) # Mark it processed processed.add(ckpt) @@ -172,13 +179,15 @@ def main( token: Annotated[Optional[str], Option(help="Optional HF API token for private repos")] = None, checkpoint_dir: Annotated[Path, Option(help="Local directory to store or download checkpoints")] = "./checkpoints", poll_interval: Annotated[int, Option(help="How many seconds to wait between polls")] = 60, - # wandb_run: Annotated[Optional[str], Option(help="Optional W&B run to pass to eval script")] = None, + wandb_run: Annotated[Optional[str], Option(help="Optional W&B run to pass to eval script")] = None, wandb_project: Annotated[Optional[str], Option(help="Optional W&B project to pass to eval script")] = None, wandb_entity: Annotated[Optional[str], Option(help="Optional W&B entity to pass to eval script")] = None, tasks: Annotated[List[TaskName], Option(help="Which tasks to evaluate")] = [TaskName.mnli], # type: ignore seeds: Annotated[List[int], Option(help="Random seeds to pass to _main")] = [42, 314, 1234], gpu_ids: Annotated[Optional[List[int]], Option(help="Optional list of GPU IDs to use for evaluation")] = None, skip_generation: Annotated[bool, Option(help="If set, pass skip_generation=True to eval script")] = False, + track_run: Annotated[bool, Option(help="Track the eval run with wandb", rich_help_panel="Weights & Biases")] = True, + track_run_project: Annotated[Optional[str], Option(help="wandb project for tracking the run", rich_help_panel="Weights & Biases")] = None, train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip @@ -190,11 +199,13 @@ def main( token=token, checkpoint_dir=checkpoint_dir, poll_interval=poll_interval, - # wandb_run=wandb_run, + wandb_run=wandb_run, wandb_project=wandb_project, wandb_entity=wandb_entity, tasks=tasks, seeds=seeds, + track_run=track_run, + track_run_project=track_run_project, gpu_ids=gpu_ids, skip_generation=skip_generation, train_config=train_config, diff --git a/generate_eval_config.py b/generate_eval_config.py index 4a2c80d7..a07d8a12 100644 --- a/generate_eval_config.py +++ b/generate_eval_config.py @@ -234,6 +234,7 @@ def main( seeds: Annotated[List[int], Option(help="List of seeds to use for the eval", rich_help_panel="Task Settings")] = [1618, 42, 6033, 3145], gpu_ids: Annotated[List[int], Option(help="List of GPU IDs to use for the eval", rich_help_panel="Task Settings")] = [0], parallel: Annotated[bool, Option("--parallel/--single", help="Run the evals in parallel on multiple GPUs or one GPU. Only use if evaluating a single checkpoint on multiple GPUs.", rich_help_panel="Task Settings")] = False, + eval_batch_count: Annotated[Optional[int], Option(help="Number of batches to evaluate", rich_help_panel="Task Settings")] = None, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip # Read the input YAML file @@ -314,6 +315,8 @@ def main( else: base_run_name = safe_get(input_config, "run_name", ckpt_path.name) new_config["base_run_name"] = base_run_name + if eval_batch_count is not None: + new_config["base_run_name"] = f"{base_run_name}-{eval_batch_count}" new_config["default_seed"] = 19 new_config["precision"] = safe_get(input_config, "precision") @@ -410,7 +413,7 @@ def main( task_config["trainer_kwargs"] = { "save_num_checkpoints_to_keep": 1, "max_duration": "2ep", - "batch_size": 2, + "batch_size": 64, "device_train_microbatch_size": "auto", } diff --git a/hf_checkpoints_uploader.py b/hf_checkpoints_uploader.py index 35c4cabb..a096c994 100644 --- a/hf_checkpoints_uploader.py +++ b/hf_checkpoints_uploader.py @@ -134,10 +134,10 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option def main( repo_id: str = typer.Option(..., help="HF repo to push to, e.g. answerdotai/huge-in-run-checkpoints"), base_dir: Path = typer.Option( - Path("/mnt/nfs/bert24/checkpoints"), help="Root with model_dir/checkpoints" + ..., help="Root with model_dir/checkpoints" ), model_dirs: List[str] = typer.Option( - ["modernbert-huge-pretrain-v1"], help="One or more sub-dirs to watch" + ..., help="One or more sub-dirs to watch" ), token: Optional[str] = typer.Option(None, help="HF token (or set HF_TOKEN env var)"), poll_interval: int = typer.Option(60, help="Seconds between scans after catch-up"), diff --git a/processed_checkpoints.json b/processed_checkpoints.json new file mode 100644 index 00000000..77679bf1 --- /dev/null +++ b/processed_checkpoints.json @@ -0,0 +1 @@ +["modernbert-huge-pretrain-v1/ep0-ba3000-rank0.pt"] \ No newline at end of file diff --git a/run_evals.py b/run_evals.py index f0147853..d9918f9d 100644 --- a/run_evals.py +++ b/run_evals.py @@ -4,17 +4,20 @@ import os import random import re +import shutil import signal import subprocess import tempfile import time import warnings + from collections import deque from enum import Enum from multiprocessing import Process, Queue from pathlib import Path from typing import Annotated, List, Optional + import datasets import psutil import typer @@ -402,6 +405,7 @@ def generate_eval_configs( pooling_type: Optional[str], head_class_act: Optional[str], head_class_norm: Optional[str], + eval_batch_count: Optional[str], head_class_dropout: float, tasks: Optional[List[TaskName]], # type: ignore fast_ultrafeedback: bool, @@ -482,6 +486,9 @@ def generate_eval_configs( if parallel: cmd.append("--parallel") + if eval_batch_count: + cmd.extend(["--eval-batch-count", str(eval_batch_count)]) + if gpu_ids: if isinstance(gpu_ids, int): gpu_ids = [gpu_ids] @@ -675,6 +682,7 @@ def main( delete_eval_yamls: Annotated[bool, Option("--delete/--keep", help="Delete all evaluation YAML files after running the evals. Use `delete_eval_yamls` if passing to `config`", rich_help_panel="Config Options")] = False, use_dir_names: Annotated[Optional[bool], Option("--use-dir-names", help="Use the folder names as the wandb run names. Defaults to true if multiple `checkpoints` are provided with one `train_config`", rich_help_panel="Config Options")] = None, gpu_ids: Annotated[Optional[List[int]], Option(help="List of GPU IDs to use", rich_help_panel="GPU Options")] = None, + eval_batch_count: Annotated[Optional[int], Option("--eval-batch-count", help="Number of batches to evaluate", rich_help_panel="Task Settings")] = None, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip """Run evaluations on model checkpoints.""" @@ -738,6 +746,7 @@ def main( tasks=tasks, fast_ultrafeedback=fast_ultrafeedback, seeds=seeds, + eval_batch_count=eval_batch_count, parallel=parallel, use_dir_names=use_dir_names, model_size=model_size, @@ -782,16 +791,21 @@ def main( else: console.print("[bold green]All jobs completed.") + shutil.rmtree("./checkpoints", ignore_errors=True) + shutil.rmtree("./finetuned-checkpoints", ignore_errors=True) + # Register the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) +# Recursively delete ./checkpoints and ignore any problems. + if __name__ == "__main__": try: app() finally: - # Ensure all subprocesses are terminated when the script exits + # Ensure all subprocesses are terminated when the script exits for process in all_processes: if process.poll() is None: process.terminate() @@ -799,4 +813,4 @@ def main( try: process.wait(timeout=5) except subprocess.TimeoutExpired: - process.kill() + process.kill() \ No newline at end of file