diff --git a/docs/source/blogs/index.md b/docs/source/blogs/index.md index d71fce441b..5df3bdb279 100644 --- a/docs/source/blogs/index.md +++ b/docs/source/blogs/index.md @@ -4,6 +4,9 @@ ```{gallery-grid} :grid-columns: 1 2 2 3 +- header: "{octicon}`terminal` Introducing olive init: An Interactive Wizard for Model Optimization" + content: "Get started with Olive in seconds — the new `olive init` wizard guides you step-by-step from model selection to a ready-to-run optimization command.
{octicon}`arrow-right` [Introducing olive init](olive-init-cli.md)" + - header: "{octicon}`cpu` Exploring Optimal Quantization Settings for Small Language Models" content: "An exploration of how Olive applies different quantization strategies such as GPTQ, mixed precision, and QuaRot to optimize small language models for efficiency and accuracy.
{octicon}`arrow-right` [Exploring Optimal Quantization Settings for Small Language Models](quant-slms.md)" @@ -16,6 +19,7 @@ :maxdepth: 2 :hidden: +olive-init-cli.md quant-slms.md sd-lora.md ``` diff --git a/docs/source/blogs/olive-init-cli.md b/docs/source/blogs/olive-init-cli.md new file mode 100644 index 0000000000..3cf838fd69 --- /dev/null +++ b/docs/source/blogs/olive-init-cli.md @@ -0,0 +1,221 @@ +# Introducing `olive init`: An Interactive Wizard for Model Optimization + +*Author: Xiaoyu Zhang* +*Created: 2026-04-06* + +Getting started with AI model optimization can be overwhelming — choosing the right exporter, quantization algorithm, precision, and target hardware involves navigating a complex decision space. The new **`olive init`** command solves this with an interactive, step-by-step wizard that guides you from model selection to a ready-to-run optimization command. + +--- + +## Why `olive init`? + +Olive offers a powerful set of CLI commands — `optimize`, `quantize`, `capture-onnx-graph`, `finetune`, `diffusion-lora`, and more — each with many options. While this flexibility is great for experts, it can be daunting for newcomers. Common questions include: + +- *Which command should I use?* +- *What exporter is best for my LLM?* +- *Which quantization algorithm should I pick?* +- *Do I need calibration data?* + +**`olive init`** answers all of these by walking you through a guided wizard. It asks the right questions, provides sensible defaults, and generates the exact CLI command or JSON config you need. + +--- + +## Quick Start + +```bash +pip install olive-ai +olive init +``` + +That's it! The wizard launches in your terminal and walks you through every step. + +--- + +## How It Works + +The wizard follows a simple 4-step flow: + +### Step 1: Choose Your Model Type + +``` +? What type of model do you want to optimize? +❯ PyTorch (HuggingFace or local) + ONNX + Diffusers (Stable Diffusion, SDXL, Flux, etc.) +``` + +### Step 2: Specify Your Model + +Depending on the model type, you can provide: + +- **HuggingFace model name** (e.g., `meta-llama/Llama-3.1-8B`) +- **Local directory path** +- **AzureML registry path** +- **PyTorch model with custom script** + +### Step 3: Configure Your Workflow + +This is where the wizard really shines. Based on your model type, it presents relevant operations and guides you through the configuration: + +::::{tab-set} + +:::{tab-item} PyTorch Models + +**Available operations:** + +| Operation | Description | +|-----------|-------------| +| Optimize | Export to ONNX + quantize + graph optimize (all-in-one) | +| Export to ONNX | Convert to ONNX format using Model Builder, Dynamo, or TorchScript | +| Quantize | Apply PyTorch quantization (RTN, GPTQ, AWQ, QuaRot, SpinQuant) | +| Fine-tune | LoRA or QLoRA fine-tuning on custom datasets | + +For the **Optimize** operation, you can choose between: + +- **Auto Mode** — Olive automatically selects the best passes for your target hardware and precision +- **Custom Mode** — Manually pick which operations to include (export, quantize, graph optimization) and configure each one + +::: + +:::{tab-item} ONNX Models + +**Available operations:** + +| Operation | Description | +|-----------|-------------| +| Optimize | Auto-select best passes for target hardware | +| Quantize | Static, dynamic, block-wise RTN, HQQ, or BnB quantization | +| Graph optimization | Apply ONNX graph-level optimizations | +| Convert precision | FP32 → FP16 conversion | +| Tune session params | Auto-tune ONNX Runtime inference parameters | + +::: + +:::{tab-item} Diffusers Models + +**Available operations:** + +| Operation | Description | +|-----------|-------------| +| Export to ONNX | Export diffusion pipeline for ONNX Runtime deployment | +| LoRA Training | Fine-tune with LoRA on custom images (DreamBooth supported) | + +**Supported architectures:** SD 1.x/2.x, SDXL, SD3, Flux, Sana + +::: + +:::: + +### Step 4: Choose Your Output + +``` +? What would you like to do? +❯ Generate CLI command (copy and run later) + Generate configuration file (JSON, for olive run) + Run optimization now +``` + +You can generate the command to review first, save a reusable JSON config, or execute immediately. + +--- + +## Examples + +### Example 1: Optimize a HuggingFace LLM for CPU with INT4 + +``` +$ olive init + +Welcome to Olive Init! This wizard will help you optimize your model. + +? What type of model do you want to optimize? PyTorch (HuggingFace or local) +? How would you like to specify your model? HuggingFace model name +? Model name or path: Qwen/Qwen2.5-0.5B-Instruct +? What do you want to do? Optimize model (export to ONNX + quantize + graph optimize) +? How would you like to configure optimization? Auto Mode (recommended) +? Select target device: CPU +? Select target precision: INT4 (smallest size, best for LLMs) +? Output directory: ./olive-output +? What would you like to do? Generate CLI command (copy and run later) + +Generated command: + + olive optimize -m Qwen/Qwen2.5-0.5B-Instruct --provider CPUExecutionProvider --precision int4 -o ./olive-output +``` + +### Example 2: Quantize a PyTorch Model with GPTQ + +``` +$ olive init + +? What type of model do you want to optimize? PyTorch (HuggingFace or local) +? How would you like to specify your model? HuggingFace model name +? Model name or path: meta-llama/Llama-3.1-8B +? What do you want to do? Quantize only (PyTorch quantization) +? Select quantization algorithm: GPTQ - High quality, requires calibration +? Precision: int4 +? Calibration data source: Use default (wikitext-2) +? Output directory: ./olive-output +? What would you like to do? Run optimization now +``` + +### Example 3: Fine-tune with LoRA + +``` +$ olive init + +? What type of model do you want to optimize? PyTorch (HuggingFace or local) +? How would you like to specify your model? HuggingFace model name +? Model name or path: microsoft/Phi-4-mini-instruct +? What do you want to do? Fine-tune model (LoRA, QLoRA) +? Select fine-tuning method: LoRA (recommended) +? LoRA rank (r): 64 (default) +? LoRA alpha: 16 +? Training dataset: HuggingFace dataset +? Dataset name: tatsu-lab/alpaca +? Train split: train +? How to construct training text? Use chat template +? Max sequence length: 1024 +? Max training samples: 256 +? Torch dtype for training: bfloat16 (recommended) +? Output directory: ./olive-output +? What would you like to do? Generate CLI command (copy and run later) + +Generated command: + + olive finetune -m microsoft/Phi-4-mini-instruct --method lora --lora_r 64 --lora_alpha 16 -d tatsu-lab/alpaca --train_split train --use_chat_template --max_seq_len 1024 --max_samples 256 --torch_dtype bfloat16 -o ./olive-output +``` + +### Example 4: Train a Diffusion LoRA with DreamBooth + +``` +$ olive init + +? What type of model do you want to optimize? Diffusers (Stable Diffusion, SDXL, Flux, etc.) +? Select diffuser model variant: Stable Diffusion XL (SDXL) +? Enter model name or path: stabilityai/stable-diffusion-xl-base-1.0 +? What do you want to do? LoRA Training (fine-tune on custom images) +? LoRA rank (r): 16 (recommended) +? LoRA alpha: 16 +? Training data source: Local image folder +? Path to image folder: ./my-dog-photos +? Enable DreamBooth training? Yes +? Instance prompt: a photo of sks dog +? Enable prior preservation? Yes +? Class prompt: a photo of a dog +? Max training steps: 1000 (recommended) +? Output directory: ./olive-output +? What would you like to do? Generate CLI command (copy and run later) + +Generated command: + + olive diffusion-lora -m stabilityai/stable-diffusion-xl-base-1.0 --model_variant sdxl -r 16 --alpha 16 --lora_dropout 0.0 -d ./my-dog-photos --dreambooth --instance_prompt "a photo of sks dog" --with_prior_preservation --class_prompt "a photo of a dog" --num_class_images 200 --max_train_steps 1000 --learning_rate 1e-4 --train_batch_size 1 --gradient_accumulation_steps 4 --mixed_precision bf16 --lr_scheduler constant --lr_warmup_steps 0 -o ./olive-output +``` + +--- + +## Related Resources + +- [Olive CLI Reference](https://microsoft.github.io/Olive/reference/cli.html) +- [Olive Getting Started Guide](https://microsoft.github.io/Olive/getting-started/getting-started.html) +- [Olive GitHub Repository](https://github.com/microsoft/Olive) diff --git a/olive/cli/init/__init__.py b/olive/cli/init/__init__.py new file mode 100644 index 0000000000..26ae8d3562 --- /dev/null +++ b/olive/cli/init/__init__.py @@ -0,0 +1,30 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from argparse import ArgumentParser + +from olive.cli.base import BaseOliveCLICommand + + +class InitCommand(BaseOliveCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + sub_parser = parser.add_parser( + "init", + help="Interactive wizard to configure and generate Olive optimization commands.", + ) + sub_parser.add_argument( + "-o", + "--output_path", + type=str, + default="./olive-output", + help="Default output directory for the generated command. Default is ./olive-output.", + ) + sub_parser.set_defaults(func=InitCommand) + + def run(self): + from olive.cli.init.wizard import InitWizard + + wizard = InitWizard(default_output_path=self.args.output_path) + wizard.start() diff --git a/olive/cli/init/diffusers_flow.py b/olive/cli/init/diffusers_flow.py new file mode 100644 index 0000000000..04f2c3fd02 --- /dev/null +++ b/olive/cli/init/diffusers_flow.py @@ -0,0 +1,220 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import questionary + +from olive.cli.init.helpers import DiffuserVariant, SourceType, _ask, _ask_select + +# Diffusers operations +OP_EXPORT = "export" +OP_LORA = "lora" + +# Training +TRAIN_STEPS_CUSTOM = "custom" + + +def run_diffusers_flow(model_config): + model_path = model_config.get("model_path", "") + variant = model_config.get("variant", DiffuserVariant.AUTO) + + operation = _ask_select( + "What do you want to do?", + choices=[ + questionary.Choice("Export to ONNX (for deployment with ONNX Runtime)", value=OP_EXPORT), + questionary.Choice("LoRA Training (fine-tune on custom images)", value=OP_LORA), + ], + ) + + if operation == OP_EXPORT: + return _export_flow(model_path, variant) + elif operation == OP_LORA: + return _lora_flow(model_path, variant) + return {} + + +def _export_flow(model_path, variant): + torch_dtype = _ask( + questionary.select( + "Torch dtype:", + choices=[ + questionary.Choice("float16", value="float16"), + questionary.Choice("float32", value="float32"), + ], + ) + ) + + cmd = f"olive capture-onnx-graph -m {model_path} --torch_dtype {torch_dtype}" + if variant != DiffuserVariant.AUTO: + cmd += f" --model_variant {variant}" + + return {"command": cmd} + + +def _lora_flow(model_path, variant): + # LoRA parameters + lora_r = _ask( + questionary.select( + "LoRA rank (r):", + choices=[ + questionary.Choice("16 (recommended)", value="16"), + questionary.Choice("4", value="4"), + questionary.Choice("8", value="8"), + questionary.Choice("32", value="32"), + questionary.Choice("64", value="64"), + ], + ) + ) + + lora_alpha = _ask( + questionary.text( + "LoRA alpha (default = same as rank):", + default=lora_r, + ) + ) + + lora_dropout = _ask(questionary.text("LoRA dropout:", default="0.0")) + + # Data source + data_source = _ask( + questionary.select( + "Training data source:", + choices=[ + questionary.Choice("Local image folder", value=SourceType.LOCAL), + questionary.Choice("HuggingFace dataset", value=SourceType.HF), + ], + ) + ) + + data_args = "" + if data_source == SourceType.LOCAL: + data_dir = _ask( + questionary.text( + "Path to image folder:", + validate=lambda x: True if x.strip() else "Please enter a path", + ) + ) + data_args = f" -d {data_dir}" + else: + data_name = _ask( + questionary.text( + "Dataset name:", + instruction="e.g., linoyts/Tuxemon", + ) + ) + data_split = _ask(questionary.text("Split:", default="train")) + image_column = _ask(questionary.text("Image column name:", default="image")) + caption_column = _ask(questionary.text("Caption column name (optional):", default="")) + + data_args = f" --data_name {data_name} --data_split {data_split} --image_column {image_column}" + if caption_column: + data_args += f" --caption_column {caption_column}" + + # DreamBooth + dreambooth_args = "" + enable_dreambooth = _ask(questionary.confirm("Enable DreamBooth training?", default=False)) + if enable_dreambooth: + dreambooth_args = " --dreambooth" + + instance_prompt = _ask( + questionary.text( + "Instance prompt (e.g., 'a photo of sks dog'):", + validate=lambda x: True if x.strip() else "Instance prompt is required for DreamBooth", + ) + ) + dreambooth_args += f' --instance_prompt "{instance_prompt}"' + + with_prior = _ask(questionary.confirm("Enable prior preservation?", default=True)) + if with_prior: + class_prompt = _ask( + questionary.text( + "Class prompt (e.g., 'a photo of a dog'):", + validate=lambda x: True if x.strip() else "Class prompt is required for prior preservation", + ) + ) + dreambooth_args += f' --with_prior_preservation --class_prompt "{class_prompt}"' + + class_data_dir = _ask(questionary.text("Class data directory (optional):", default="")) + if class_data_dir: + dreambooth_args += f" --class_data_dir {class_data_dir}" + + num_class_images = _ask(questionary.text("Number of class images:", default="200")) + dreambooth_args += f" --num_class_images {num_class_images}" + + # Training parameters + max_train_steps = _ask( + questionary.select( + "Max training steps:", + choices=[ + questionary.Choice("1000 (recommended)", value="1000"), + questionary.Choice("500 (quick)", value="500"), + questionary.Choice("2000 (thorough)", value="2000"), + questionary.Choice("Custom", value=TRAIN_STEPS_CUSTOM), + ], + ) + ) + if max_train_steps == TRAIN_STEPS_CUSTOM: + max_train_steps = _ask(questionary.text("Enter max training steps:")) + + learning_rate = _ask(questionary.text("Learning rate:", default="1e-4")) + train_batch_size = _ask(questionary.text("Train batch size:", default="1")) + gradient_accumulation = _ask(questionary.text("Gradient accumulation steps:", default="4")) + + mixed_precision = _ask( + questionary.select( + "Mixed precision:", + choices=[ + questionary.Choice("bf16 (recommended)", value="bf16"), + questionary.Choice("fp16", value="fp16"), + questionary.Choice("no", value="no"), + ], + ) + ) + + lr_scheduler = _ask( + questionary.select( + "Learning rate scheduler:", + choices=[ + questionary.Choice("constant", value="constant"), + questionary.Choice("linear", value="linear"), + questionary.Choice("cosine", value="cosine"), + questionary.Choice("cosine_with_restarts", value="cosine_with_restarts"), + questionary.Choice("polynomial", value="polynomial"), + questionary.Choice("constant_with_warmup", value="constant_with_warmup"), + ], + ) + ) + + warmup_steps = _ask(questionary.text("Warmup steps:", default="0")) + seed = _ask(questionary.text("Random seed (optional, press Enter to skip):", default="")) + + # Flux-specific + flux_args = "" + if variant == DiffuserVariant.FLUX: + guidance_scale = _ask(questionary.text("Guidance scale (Flux-specific):", default="3.5")) + flux_args = f" --guidance_scale {guidance_scale}" + + # Merge LoRA + merge_lora = _ask(questionary.confirm("Merge LoRA into base model?", default=False)) + + # Build command + cmd = f"olive diffusion-lora -m {model_path}" + if variant != DiffuserVariant.AUTO: + cmd += f" --model_variant {variant}" + cmd += f" -r {lora_r} --alpha {lora_alpha} --lora_dropout {lora_dropout}" + cmd += data_args + cmd += dreambooth_args + cmd += f" --max_train_steps {max_train_steps}" + cmd += f" --learning_rate {learning_rate}" + cmd += f" --train_batch_size {train_batch_size}" + cmd += f" --gradient_accumulation_steps {gradient_accumulation}" + cmd += f" --mixed_precision {mixed_precision}" + cmd += f" --lr_scheduler {lr_scheduler}" + cmd += f" --lr_warmup_steps {warmup_steps}" + if seed: + cmd += f" --seed {seed}" + cmd += flux_args + if merge_lora: + cmd += " --merge_lora" + + return {"command": cmd} diff --git a/olive/cli/init/helpers.py b/olive/cli/init/helpers.py new file mode 100644 index 0000000000..86cd9e3b3d --- /dev/null +++ b/olive/cli/init/helpers.py @@ -0,0 +1,126 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import sys + +from olive.common.utils import StrEnumBase + + +class SourceType(StrEnumBase): + """Source types (shared across flows).""" + + HF = "hf" + LOCAL = "local" + AZUREML = "azureml" + SCRIPT = "script" + DEFAULT = "default" + + +class DiffuserVariant(StrEnumBase): + """Diffuser variants (only those used in routing).""" + + AUTO = "auto" + FLUX = "flux" + + +class GoBackError(Exception): + """Raised when user wants to go back to the previous wizard step.""" + + +_BACK = "__back__" + + +def _device_choices(): + """Return device choices aligned with olive optimize --provider.""" + import questionary + + return [ + questionary.Choice("CPU", value="CPUExecutionProvider"), + questionary.Choice("GPU (NVIDIA CUDA)", value="CUDAExecutionProvider"), + questionary.Choice("GPU (NvTensorRTRTX)", value="NvTensorRTRTXExecutionProvider"), + questionary.Choice("NPU (Qualcomm QNN)", value="QNNExecutionProvider"), + questionary.Choice("NPU (Intel OpenVINO)", value="OpenVINOExecutionProvider"), + questionary.Choice("NPU (AMD Vitis AI)", value="VitisAIExecutionProvider"), + questionary.Choice("WebGPU", value="WebGpuExecutionProvider"), + ] + + +def _precision_choices(): + """Return precision choices aligned with olive optimize --precision.""" + import questionary + + return [ + questionary.Choice("INT4 (smallest size, best for LLMs)", value="int4"), + questionary.Choice("INT8 (balanced)", value="int8"), + questionary.Choice("FP16 (half precision)", value="fp16"), + questionary.Choice("FP32 (full precision)", value="fp32"), + ] + + +def _ask(question): + """Ask a questionary question and handle Ctrl+C (returns None).""" + result = question.ask() + if result is None: + sys.exit(0) + return result + + +def _ask_select(message, choices, allow_back=True): + """Ask a select question with optional Back choice.""" + import questionary + + all_choices = list(choices) + if allow_back: + all_choices.append(questionary.Choice("\u2190 Back", value=_BACK)) + result = _ask(questionary.select(message, choices=all_choices)) + if result == _BACK: + raise GoBackError + return result + + +def prompt_calibration_source(): + """Prompt for calibration data source. Returns dict or None (for default).""" + import questionary + + source = _ask( + questionary.select( + "Calibration data source:", + choices=[ + questionary.Choice("Use default (wikitext-2)", value=SourceType.DEFAULT), + questionary.Choice("HuggingFace dataset", value=SourceType.HF), + questionary.Choice("Local file", value=SourceType.LOCAL), + ], + ) + ) + + if source == SourceType.DEFAULT: + return None + elif source == SourceType.HF: + data_name = _ask(questionary.text("Dataset name:", default="Salesforce/wikitext")) + subset = _ask(questionary.text("Subset (optional):", default="wikitext-2-raw-v1")) + split = _ask(questionary.text("Split:", default="train")) + num_samples = _ask(questionary.text("Number of samples:", default="128")) + return { + "source": SourceType.HF, + "data_name": data_name, + "subset": subset, + "split": split, + "num_samples": num_samples, + } + else: + data_files = _ask(questionary.text("Data file path:")) + return {"source": SourceType.LOCAL, "data_files": data_files} + + +def build_calibration_args(calibration): + """Build CLI args string from calibration config dict.""" + if calibration["source"] == SourceType.HF: + result = f" -d {calibration['data_name']}" + if calibration.get("subset"): + result += f" --subset {calibration['subset']}" + result += f" --split {calibration['split']} --max_samples {calibration['num_samples']}" + return result + elif calibration["source"] == SourceType.LOCAL: + return f" --data_files {calibration['data_files']}" + return "" diff --git a/olive/cli/init/onnx_flow.py b/olive/cli/init/onnx_flow.py new file mode 100644 index 0000000000..283a040713 --- /dev/null +++ b/olive/cli/init/onnx_flow.py @@ -0,0 +1,170 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import questionary + +from olive.cli.init.helpers import ( + _ask, + _ask_select, + _device_choices, + _precision_choices, + build_calibration_args, + prompt_calibration_source, +) +from olive.common.utils import StrEnumBase + + +class OnnxOperation(StrEnumBase): + """ONNX operations.""" + + OPTIMIZE = "optimize" + QUANTIZE = "quantize" + GRAPH_OPT = "graph_opt" + CONVERT_PRECISION = "convert_precision" + TUNE_SESSION = "tune_session" + + +class QuantizationType(StrEnumBase): + """Quantization types.""" + + STATIC = "static" + DYNAMIC = "dynamic" + BLOCKWISE_RTN = "blockwise_rtn" + HQQ = "hqq" + BNB = "bnb" + + +def run_onnx_flow(model_config): + model_path = model_config.get("model_path", "") + + operation = _ask_select( + "What do you want to do?", + choices=[ + questionary.Choice( + "Optimize model (auto-select best passes for target hardware)", value=OnnxOperation.OPTIMIZE + ), + questionary.Choice("Quantize", value=OnnxOperation.QUANTIZE), + questionary.Choice("Graph optimization", value=OnnxOperation.GRAPH_OPT), + questionary.Choice("Convert precision (FP32 \u2192 FP16)", value=OnnxOperation.CONVERT_PRECISION), + questionary.Choice("Tune session parameters", value=OnnxOperation.TUNE_SESSION), + ], + ) + + if operation == OnnxOperation.OPTIMIZE: + return _optimize_flow(model_path) + elif operation == OnnxOperation.QUANTIZE: + return _quantize_flow(model_path) + elif operation == OnnxOperation.GRAPH_OPT: + return _graph_opt_flow(model_path) + elif operation == OnnxOperation.CONVERT_PRECISION: + return _convert_precision_flow(model_path) + elif operation == OnnxOperation.TUNE_SESSION: + return _tune_session_flow(model_path) + return {} + + +def _optimize_flow(model_path): + provider = _ask(questionary.select("Select target device:", choices=_device_choices())) + precision = _ask(questionary.select("Select target precision:", choices=_precision_choices())) + + cmd = f"olive optimize -m {model_path} --provider {provider} --precision {precision}" + return {"command": cmd} + + +def _quantize_flow(model_path): + quant_type = _ask( + questionary.select( + "Select quantization type:", + choices=[ + questionary.Choice( + "Static Quantization (INT8) - requires calibration data", value=QuantizationType.STATIC + ), + questionary.Choice( + "Dynamic Quantization (INT8) - no calibration needed", value=QuantizationType.DYNAMIC + ), + questionary.Choice( + "Block-wise RTN (INT4) - no calibration needed", value=QuantizationType.BLOCKWISE_RTN + ), + questionary.Choice("HQQ Quantization (INT4) - no calibration needed", value=QuantizationType.HQQ), + questionary.Choice("BnB Quantization (FP4/NF4) - no calibration needed", value=QuantizationType.BNB), + ], + ) + ) + + # Map to olive quantize CLI args + quant_map = { + QuantizationType.STATIC: {"implementation": "ort", "precision": "int8"}, + QuantizationType.DYNAMIC: {"implementation": "ort", "precision": "int8"}, + QuantizationType.BLOCKWISE_RTN: {"implementation": "ort", "precision": "int4"}, + QuantizationType.HQQ: {"implementation": "ort", "precision": "int4"}, + QuantizationType.BNB: {"implementation": "bnb", "precision": "nf4"}, + } + + params = quant_map[quant_type] + cmd = ( + f"olive quantize -m {model_path} --precision {params['precision']} --implementation {params['implementation']}" + ) + + if quant_type == QuantizationType.DYNAMIC: + cmd += " --algorithm rtn" + + # Calibration for static quantization + if quant_type == QuantizationType.STATIC: + calib = prompt_calibration_source() + if calib: + cmd += build_calibration_args(calib) + + return {"command": cmd} + + +def _graph_opt_flow(model_path): + cmd = f"olive optimize -m {model_path} --precision fp32" + return {"command": cmd} + + +def _convert_precision_flow(model_path): + cmd = f"olive run-pass --pass-name OnnxFloatToFloat16 -m {model_path}" + return {"command": cmd} + + +def _tune_session_flow(model_path): + device = _ask( + questionary.select( + "Select target device:", + choices=[ + questionary.Choice("CPU", value="cpu"), + questionary.Choice("GPU", value="gpu"), + ], + ) + ) + + providers = _ask( + questionary.checkbox( + "Select execution providers:", + choices=[ + questionary.Choice("CPUExecutionProvider", value="CPUExecutionProvider", checked=(device == "cpu")), + questionary.Choice("CUDAExecutionProvider", value="CUDAExecutionProvider", checked=(device == "gpu")), + questionary.Choice("TensorrtExecutionProvider", value="TensorrtExecutionProvider"), + ], + instruction="(Space to toggle, Enter to confirm)", + ) + ) + + cmd = f"olive tune-session-params -m {model_path} --device {device}" + if providers: + cmd += " --providers_list " + " ".join(providers) + + cpu_cores = _ask(questionary.text("CPU cores for thread tuning (optional, press Enter to skip):", default="")) + if cpu_cores: + cmd += f" --cpu_cores {cpu_cores}" + + io_bind = _ask(questionary.confirm("Enable IO binding?", default=False)) + if io_bind: + cmd += " --io_bind" + + enable_cuda_graph = _ask(questionary.confirm("Enable CUDA graph?", default=False)) + if enable_cuda_graph: + cmd += " --enable_cuda_graph" + + return {"command": cmd} diff --git a/olive/cli/init/pytorch_flow.py b/olive/cli/init/pytorch_flow.py new file mode 100644 index 0000000000..df0dbca4e7 --- /dev/null +++ b/olive/cli/init/pytorch_flow.py @@ -0,0 +1,566 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import questionary + +from olive.cli.init.helpers import ( + SourceType, + _ask, + _ask_select, + _device_choices, + _precision_choices, + build_calibration_args, + prompt_calibration_source, +) + +# PyTorch operations +OP_OPTIMIZE = "optimize" +OP_EXPORT = "export" +OP_QUANTIZE = "quantize" +OP_FINETUNE = "finetune" +OP_GRAPH_OPT = "graph_opt" + +# Optimize modes +MODE_AUTO = "auto" +MODE_CUSTOM = "custom" + +# Exporters +EXPORTER_MODEL_BUILDER = "model_builder" +EXPORTER_DYNAMO = "dynamo" +EXPORTER_TORCHSCRIPT = "torchscript" + +# Text construction modes +TEXT_FIELD = "text_field" +TEXT_TEMPLATE = "template" +TEXT_CHAT_TEMPLATE = "chat_template" + +# Precision (used in routing) +PRECISION_INT4 = "int4" + +# Algorithms that require calibration data +CALIBRATION_ALGORITHMS = {"gptq", "awq", "quarot", "spinquant"} + +# Algorithms that need a non-default --implementation value +# (rtn and gptq use the default "olive" implementation, so they are omitted) +ALGORITHM_TO_IMPLEMENTATION = { + "awq": "awq", + "quarot": "quarot", + "spinquant": "spinquant", +} + + +def _build_model_args(model_config): + """Build model-related CLI args from model config.""" + parts = [] + model_path = model_config.get("model_path") + if model_path: + parts.append(f"-m {model_path}") + if model_config.get("model_script"): + parts.append(f"--model_script {model_config['model_script']}") + if model_config.get("script_dir"): + parts.append(f"--script_dir {model_config['script_dir']}") + return " ".join(parts) + + +def run_pytorch_flow(model_config): + operation = _ask_select( + "What do you want to do?", + choices=[ + questionary.Choice("Optimize model (export to ONNX + quantize + graph optimize)", value=OP_OPTIMIZE), + questionary.Choice("Export to ONNX only", value=OP_EXPORT), + questionary.Choice("Quantize only (PyTorch quantization)", value=OP_QUANTIZE), + questionary.Choice("Fine-tune model (LoRA, QLoRA)", value=OP_FINETUNE), + ], + ) + + if operation == OP_OPTIMIZE: + return _optimize_flow(model_config) + elif operation == OP_EXPORT: + return _export_flow(model_config) + elif operation == OP_QUANTIZE: + return _quantize_flow(model_config) + elif operation == OP_FINETUNE: + return _finetune_flow(model_config) + return {} + + +def _optimize_flow(model_config): + mode = _ask( + questionary.select( + "How would you like to configure optimization?", + choices=[ + questionary.Choice( + "Auto Mode (recommended) - Automatically select best passes for your target", value=MODE_AUTO + ), + questionary.Choice("Custom Mode - Manually pick operations and parameters", value=MODE_CUSTOM), + ], + ) + ) + + if mode == MODE_AUTO: + return _optimize_auto_mode(model_config) + else: + return _optimize_custom_mode(model_config) + + +def _optimize_auto_mode(model_config): + model_args = _build_model_args(model_config) + + provider = _ask(questionary.select("Select target device:", choices=_device_choices())) + precision = _ask(questionary.select("Select target precision:", choices=_precision_choices())) + + cmd = f"olive optimize {model_args} --provider {provider} --precision {precision}" + return {"command": cmd} + + +def _optimize_custom_mode(model_config): + model_args = _build_model_args(model_config) + + operations = _ask( + questionary.checkbox( + "Select operations to perform:", + choices=[ + questionary.Choice("Export to ONNX", value=OP_EXPORT, checked=True), + questionary.Choice("Quantize", value=OP_QUANTIZE, checked=True), + questionary.Choice("Graph Optimization", value=OP_GRAPH_OPT), + ], + instruction="(Space to toggle, Enter to confirm)", + ) + ) + + if not operations: + print("No operations selected.") + return {} + + export_config = None + quant_config = None + + # Export options + if OP_EXPORT in operations: + export_config = _prompt_export_options() + + # Quantize options + if OP_QUANTIZE in operations: + quant_config = _prompt_quantize_options() + + has_export = OP_EXPORT in operations + has_quantize = OP_QUANTIZE in operations + has_graph_opt = OP_GRAPH_OPT in operations + + if has_export and has_quantize: + # Combined export + quantize (±graph_opt) → use olive optimize with --exporter + provider = _ask(questionary.select("Select target device:", choices=_device_choices())) + precision = quant_config["precision"] if quant_config else "fp32" + cmd = f"olive optimize {model_args} --provider {provider} --precision {precision}" + # Pass exporter choice to olive optimize + if export_config: + exporter_map = { + EXPORTER_MODEL_BUILDER: "model_builder", + EXPORTER_DYNAMO: "dynamo_exporter", + EXPORTER_TORCHSCRIPT: "torchscript_exporter", + } + exporter_arg = exporter_map.get(export_config.get("exporter"), "model_builder") + cmd += f" --exporter {exporter_arg}" + # Warn that olive optimize auto-selects the quantization algorithm + algorithm = quant_config.get("algorithm") if quant_config else None + if algorithm: + print( + f"\nNote: 'olive optimize' automatically selects the quantization algorithm based on" + f" provider and precision. Your selection '{algorithm}' is used as a reference but the" + f" actual algorithm may differ. To use '{algorithm}' exactly, select 'Quantize only'" + f" instead.\n" + ) + elif has_export and has_graph_opt: + # Export + graph_opt (no quantize) → olive optimize with fp32 + provider = _ask(questionary.select("Select target device:", choices=_device_choices())) + cmd = f"olive optimize {model_args} --provider {provider} --precision fp32" + if export_config: + exporter_map = { + EXPORTER_MODEL_BUILDER: "model_builder", + EXPORTER_DYNAMO: "dynamo_exporter", + EXPORTER_TORCHSCRIPT: "torchscript_exporter", + } + exporter_arg = exporter_map.get(export_config.get("exporter"), "model_builder") + cmd += f" --exporter {exporter_arg}" + elif has_quantize and has_graph_opt: + # Quantize + graph_opt (no export, ONNX input assumed) → olive optimize + provider = _ask(questionary.select("Select target device:", choices=_device_choices())) + precision = quant_config["precision"] if quant_config else "fp32" + cmd = f"olive optimize {model_args} --provider {provider} --precision {precision}" + elif has_export and export_config: + # Export only → olive capture-onnx-graph with specific exporter options + cmd = _build_export_command(model_args, export_config) + elif has_quantize and quant_config: + # Quantize only → olive quantize with specific algorithm/precision/calibration + cmd = _build_quantize_command(model_args, quant_config) + elif has_graph_opt: + # Graph opt only + provider = _ask(questionary.select("Select target device:", choices=_device_choices())) + cmd = f"olive optimize {model_args} --provider {provider} --precision fp32" + else: + return {} + + return {"command": cmd} + + +def _build_export_command(model_args, export_config): + """Build olive capture-onnx-graph command from export config.""" + exporter = export_config.get("exporter", EXPORTER_DYNAMO) + + if exporter == EXPORTER_MODEL_BUILDER: + precision = export_config.get("precision", "fp16") + cmd = f"olive capture-onnx-graph {model_args} --use_model_builder --precision {precision}" + if precision == PRECISION_INT4 and "int4_block_size" in export_config: + cmd += f" --int4_block_size {export_config['int4_block_size']}" + elif exporter == EXPORTER_DYNAMO: + torch_dtype = export_config.get("torch_dtype", "float32") + cmd = f"olive capture-onnx-graph {model_args} --torch_dtype {torch_dtype}" + else: + cmd = f"olive capture-onnx-graph {model_args}" + + return cmd + + +def _build_quantize_command(model_args, quant_config): + """Build olive quantize command from quantize config.""" + algorithm = quant_config.get("algorithm", "rtn") + precision = quant_config.get("precision", "int4") + + cmd = f"olive quantize {model_args} --algorithm {algorithm} --precision {precision}" + + # Add --implementation for algorithms that need a non-default one + impl = ALGORITHM_TO_IMPLEMENTATION.get(algorithm) + if impl: + cmd += f" --implementation {impl}" + + calibration = quant_config.get("calibration") + if calibration: + cmd += build_calibration_args(calibration) + + return cmd + + +def _export_flow(model_config): + model_args = _build_model_args(model_config) + + exporter = _ask( + questionary.select( + "Select exporter:", + choices=[ + questionary.Choice("Model Builder (recommended for LLMs)", value=EXPORTER_MODEL_BUILDER), + questionary.Choice("Dynamo Exporter (general purpose)", value=EXPORTER_DYNAMO), + questionary.Choice("TorchScript Exporter (legacy)", value=EXPORTER_TORCHSCRIPT), + ], + ) + ) + + if exporter == EXPORTER_MODEL_BUILDER: + precision = _ask( + questionary.select( + "Export precision:", + choices=[ + questionary.Choice("fp16", value="fp16"), + questionary.Choice("fp32", value="fp32"), + questionary.Choice("bf16", value="bf16"), + questionary.Choice("int4", value=PRECISION_INT4), + ], + ) + ) + + cmd = f"olive capture-onnx-graph {model_args} --use_model_builder --precision {precision}" + + if precision == PRECISION_INT4: + block_size = _ask( + questionary.select( + "INT4 block size:", + choices=[ + questionary.Choice("32 (recommended)", value="32"), + questionary.Choice("16", value="16"), + questionary.Choice("64", value="64"), + questionary.Choice("128", value="128"), + questionary.Choice("256", value="256"), + ], + ) + ) + cmd += f" --int4_block_size {block_size}" + + accuracy_level = _ask( + questionary.select( + "INT4 accuracy level:", + choices=[ + questionary.Choice("4 (int8, recommended)", value="4"), + questionary.Choice("1 (fp32)", value="1"), + questionary.Choice("2 (fp16)", value="2"), + questionary.Choice("3 (bf16)", value="3"), + ], + ) + ) + cmd += f" --int4_accuracy_level {accuracy_level}" + + return {"command": cmd} + + elif exporter == EXPORTER_DYNAMO: + torch_dtype = _ask( + questionary.select( + "Torch dtype:", + choices=[ + questionary.Choice("fp32", value="float32"), + questionary.Choice("fp16", value="float16"), + ], + ) + ) + + cmd = f"olive capture-onnx-graph {model_args} --torch_dtype {torch_dtype}" + return {"command": cmd} + + else: + # TorchScript + cmd = f"olive capture-onnx-graph {model_args}" + return {"command": cmd} + + +def _quantize_flow(model_config): + model_args = _build_model_args(model_config) + + algorithm = _ask( + questionary.select( + "Select quantization algorithm:", + choices=[ + questionary.Choice("RTN - Fast, no calibration needed", value="rtn"), + questionary.Choice("GPTQ - High quality, requires calibration", value="gptq"), + questionary.Choice("AWQ - Activation-aware, good for LLMs", value="awq"), + questionary.Choice("QuaRot - For QNN/VitisAI deployment", value="quarot"), + questionary.Choice("SpinQuant - Spin quantization", value="spinquant"), + ], + ) + ) + + precision = _ask( + questionary.select( + "Precision:", + choices=[ + questionary.Choice("int4", value="int4"), + questionary.Choice("uint4", value="uint4"), + questionary.Choice("int8", value="int8"), + ], + ) + ) + + cmd = f"olive quantize {model_args} --algorithm {algorithm} --precision {precision}" + + # Add --implementation for algorithms that need a non-default one + impl = ALGORITHM_TO_IMPLEMENTATION.get(algorithm) + if impl: + cmd += f" --implementation {impl}" + + # Calibration data for algorithms that need it + if algorithm in CALIBRATION_ALGORITHMS: + calib = prompt_calibration_source() + if calib: + cmd += build_calibration_args(calib) + + return {"command": cmd} + + +def _finetune_flow(model_config): + model_args = _build_model_args(model_config) + + method = _ask( + questionary.select( + "Select fine-tuning method:", + choices=[ + questionary.Choice("LoRA (recommended)", value="lora"), + questionary.Choice("QLoRA (quantized, saves GPU memory)", value="qlora"), + ], + ) + ) + + lora_r = _ask( + questionary.select( + "LoRA rank (r):", + choices=[ + questionary.Choice("64 (default)", value="64"), + questionary.Choice("4", value="4"), + questionary.Choice("8", value="8"), + questionary.Choice("16", value="16"), + questionary.Choice("32", value="32"), + ], + ) + ) + + lora_alpha = _ask(questionary.text("LoRA alpha:", default="16")) + + # Dataset + data_source = _ask( + questionary.select( + "Training dataset:", + choices=[ + questionary.Choice("HuggingFace dataset", value=SourceType.HF), + questionary.Choice("Local file", value=SourceType.LOCAL), + ], + ) + ) + + cmd = f"olive finetune {model_args} --method {method} --lora_r {lora_r} --lora_alpha {lora_alpha}" + + if data_source == SourceType.HF: + data_name = _ask( + questionary.text( + "Dataset name:", + default="tatsu-lab/alpaca", + ) + ) + train_split = _ask(questionary.text("Train split:", default="train")) + eval_split = _ask(questionary.text("Eval split (optional, press Enter to skip):", default="")) + + cmd += f" -d {data_name} --train_split {train_split}" + if eval_split: + cmd += f" --eval_split {eval_split}" + else: + data_files = _ask( + questionary.text( + "Path to data file(s):", + validate=lambda x: True if x.strip() else "Please enter a file path", + ) + ) + cmd += f" -d nouse --data_files {data_files}" + + # Text construction + text_mode = _ask( + questionary.select( + "How to construct training text?", + choices=[ + questionary.Choice("Single text field (specify column name)", value=TEXT_FIELD), + questionary.Choice( + "Text template (e.g., '### Question: {prompt} \\n### Answer: {response}')", value=TEXT_TEMPLATE + ), + questionary.Choice("Use chat template", value=TEXT_CHAT_TEMPLATE), + ], + ) + ) + + if text_mode == TEXT_FIELD: + text_field = _ask(questionary.text("Text field name:", default="text")) + cmd += f" --text_field {text_field}" + elif text_mode == TEXT_TEMPLATE: + template = _ask(questionary.text("Text template:")) + cmd += f' --text_template "{template}"' + else: + cmd += " --use_chat_template" + + max_seq_len = _ask(questionary.text("Max sequence length:", default="1024")) + cmd += f" --max_seq_len {max_seq_len}" + + max_samples = _ask(questionary.text("Max training samples:", default="256")) + cmd += f" --max_samples {max_samples}" + + # Torch dtype + torch_dtype = _ask( + questionary.select( + "Torch dtype for training:", + choices=[ + questionary.Choice("bfloat16 (recommended)", value="bfloat16"), + questionary.Choice("float16", value="float16"), + questionary.Choice("float32", value="float32"), + ], + ) + ) + cmd += f" --torch_dtype {torch_dtype}" + + return {"command": cmd} + + +def _prompt_export_options(): + """Prompt export options for custom mode.""" + exporter = _ask( + questionary.select( + "Select exporter:", + choices=[ + questionary.Choice("Model Builder (recommended for LLMs)", value=EXPORTER_MODEL_BUILDER), + questionary.Choice("Dynamo Exporter (general purpose)", value=EXPORTER_DYNAMO), + questionary.Choice("TorchScript Exporter (legacy)", value=EXPORTER_TORCHSCRIPT), + ], + ) + ) + + config = {"exporter": exporter} + + if exporter == EXPORTER_MODEL_BUILDER: + precision = _ask( + questionary.select( + "Export precision:", + choices=[ + questionary.Choice("fp16", value="fp16"), + questionary.Choice("fp32", value="fp32"), + questionary.Choice("bf16", value="bf16"), + questionary.Choice("int4", value=PRECISION_INT4), + ], + ) + ) + config["precision"] = precision + + if precision == PRECISION_INT4: + block_size = _ask( + questionary.select( + "INT4 block size:", + choices=[ + questionary.Choice("32 (recommended)", value="32"), + questionary.Choice("16", value="16"), + questionary.Choice("64", value="64"), + questionary.Choice("128", value="128"), + questionary.Choice("256", value="256"), + ], + ) + ) + config["int4_block_size"] = block_size + + elif exporter == EXPORTER_DYNAMO: + torch_dtype = _ask( + questionary.select( + "Torch dtype:", + choices=[ + questionary.Choice("fp32", value="float32"), + questionary.Choice("fp16", value="float16"), + ], + ) + ) + config["torch_dtype"] = torch_dtype + + return config + + +def _prompt_quantize_options(): + """Prompt quantization options for custom mode.""" + algorithm = _ask( + questionary.select( + "Select quantization algorithm:", + choices=[ + questionary.Choice("RTN - Round-to-Nearest, fast, no calibration needed", value="rtn"), + questionary.Choice("GPTQ - High quality, requires calibration data", value="gptq"), + questionary.Choice("AWQ - Activation-aware, good for LLMs", value="awq"), + questionary.Choice("QuaRot - Rotation-based, for QNN/VitisAI deployment", value="quarot"), + questionary.Choice("SpinQuant - Spin quantization", value="spinquant"), + ], + ) + ) + + precision = _ask( + questionary.select( + "Quantization precision:", + choices=[ + questionary.Choice("int4", value="int4"), + questionary.Choice("uint4", value="uint4"), + questionary.Choice("int8", value="int8"), + ], + ) + ) + + config = {"algorithm": algorithm, "precision": precision} + + if algorithm in CALIBRATION_ALGORITHMS: + calib = prompt_calibration_source() + if calib: + config["calibration"] = calib + + return config diff --git a/olive/cli/init/wizard.py b/olive/cli/init/wizard.py new file mode 100644 index 0000000000..74e31efb49 --- /dev/null +++ b/olive/cli/init/wizard.py @@ -0,0 +1,231 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import subprocess +import sys +from pathlib import Path + +import questionary + +from olive.cli.init.helpers import ( + DiffuserVariant, + GoBackError, + SourceType, + _ask, + _ask_select, +) +from olive.common.utils import StrEnumBase + + +class ModelType(StrEnumBase): + """Model types.""" + + PYTORCH = "pytorch" + ONNX = "onnx" + DIFFUSERS = "diffusers" + + +class OutputAction(StrEnumBase): + """Output actions.""" + + COMMAND = "command" + CONFIG = "config" + RUN = "run" + + +class InitWizard: + def __init__(self, default_output_path: str = "./olive-output"): + self.default_output_path = default_output_path + + def start(self): + print("\nWelcome to Olive Init! This wizard will help you optimize your model.\n") + + try: + step = 0 + model_type = None + model_config = None + result = None + + while step < 4: + try: + if step == 0: + model_type = self._prompt_model_type() + elif step == 1: + model_config = self._prompt_model_source(model_type) + elif step == 2: + result = self._run_model_flow(model_type, model_config) + elif step == 3: + self._prompt_output(result) + step += 1 + except GoBackError: + if step > 0: + step -= 1 + + except KeyboardInterrupt: + sys.exit(0) + + def _prompt_model_type(self): + return _ask_select( + "What type of model do you want to optimize?", + choices=[ + questionary.Choice("PyTorch (HuggingFace or local)", value=ModelType.PYTORCH), + questionary.Choice("ONNX", value=ModelType.ONNX), + questionary.Choice("Diffusers (Stable Diffusion, SDXL, Flux, etc.)", value=ModelType.DIFFUSERS), + ], + allow_back=False, + ) + + def _prompt_model_source(self, model_type): + if model_type == ModelType.PYTORCH: + return self._prompt_pytorch_source() + elif model_type == ModelType.ONNX: + return self._prompt_onnx_source() + elif model_type == ModelType.DIFFUSERS: + return self._prompt_diffusers_source() + return {} + + def _prompt_pytorch_source(self): + source_type = _ask_select( + "How would you like to specify your model?", + choices=[ + questionary.Choice("HuggingFace model name (e.g., meta-llama/Llama-3.1-8B)", value=SourceType.HF), + questionary.Choice("Local directory path", value=SourceType.LOCAL), + questionary.Choice("AzureML registry path", value=SourceType.AZUREML), + questionary.Choice("PyTorch model with custom script", value=SourceType.SCRIPT), + ], + ) + + config = {"source_type": source_type} + + if source_type == SourceType.SCRIPT: + config["model_script"] = _ask( + questionary.path( + "Path to model script (.py):", + ) + ) + script_dir = _ask( + questionary.text( + "Script directory (optional, press Enter to skip):", + default="", + ) + ) + if script_dir: + config["script_dir"] = script_dir + model_path = _ask( + questionary.text( + "Model name or path (optional, press Enter to skip):", + default="", + ) + ) + if model_path: + config["model_path"] = model_path + else: + if source_type == SourceType.HF: + placeholder = "e.g., meta-llama/Llama-3.1-8B" + elif source_type == SourceType.AZUREML: + placeholder = "e.g., azureml://registries//models//versions/" + else: + placeholder = "e.g., ./my-model/" + config["model_path"] = _ask( + questionary.text( + "Model name or path:", + validate=lambda x: True if x.strip() else "Please enter a model name or path", + instruction=placeholder, + ) + ) + + return config + + def _prompt_onnx_source(self): + model_path = _ask( + questionary.text( + "Enter ONNX model path (file or directory):", + validate=lambda x: True if x.strip() else "Please enter a model path", + ) + ) + return {"source_type": SourceType.LOCAL, "model_path": model_path} + + def _prompt_diffusers_source(self): + variant = _ask_select( + "Select diffuser model variant:", + choices=[ + questionary.Choice("Auto-detect", value=DiffuserVariant.AUTO), + questionary.Choice("Stable Diffusion (SD 1.x/2.x)", value="sd"), + questionary.Choice("Stable Diffusion XL (SDXL)", value="sdxl"), + questionary.Choice("Stable Diffusion 3 (SD3)", value="sd3"), + questionary.Choice("Flux", value=DiffuserVariant.FLUX), + questionary.Choice("Sana", value="sana"), + ], + ) + + model_path = _ask( + questionary.text( + "Enter model name or path:", + validate=lambda x: True if x.strip() else "Please enter a model name or path", + instruction="e.g., stabilityai/stable-diffusion-xl-base-1.0", + ) + ) + + return {"source_type": SourceType.HF, "model_path": model_path, "variant": variant} + + def _run_model_flow(self, model_type, model_config): + if model_type == ModelType.PYTORCH: + from olive.cli.init.pytorch_flow import run_pytorch_flow + + return run_pytorch_flow(model_config) + elif model_type == ModelType.ONNX: + from olive.cli.init.onnx_flow import run_onnx_flow + + return run_onnx_flow(model_config) + elif model_type == ModelType.DIFFUSERS: + from olive.cli.init.diffusers_flow import run_diffusers_flow + + return run_diffusers_flow(model_config) + return {} + + def _prompt_output(self, result): + command_str = result.get("command") + + if not command_str: + print("No command generated.") + raise GoBackError + + output_dir = _ask( + questionary.text( + "Output directory:", + default=self.default_output_path, + ) + ) + + # Append output dir to command if not already present + if " -o " not in command_str and " --output_path " not in command_str: + command_str += f" -o {output_dir}" + + action = _ask_select( + "What would you like to do?", + choices=[ + questionary.Choice("Generate CLI command (copy and run later)", value=OutputAction.COMMAND), + questionary.Choice("Generate configuration file (JSON, for olive run)", value=OutputAction.CONFIG), + questionary.Choice("Run optimization now", value=OutputAction.RUN), + ], + ) + + if action == OutputAction.COMMAND: + print(f"\nGenerated command:\n\n {command_str}\n") + run_now = _ask(questionary.confirm("Run this command now?", default=False)) + if run_now: + print(f"\nRunning: {command_str}\n") + subprocess.run(command_str, shell=True, check=False) + + elif action == OutputAction.CONFIG: + config_cmd = command_str + " --save_config_file --dry_run" + print("\nGenerating configuration file...\n") + subprocess.run(config_cmd, shell=True, check=False) + config_path = Path(output_dir) / "config.json" + if config_path.exists(): + print(f"\nYou can run it later with:\n olive run --config {config_path}\n") + + elif action == OutputAction.RUN: + print(f"\nRunning: {command_str}\n") + subprocess.run(command_str, shell=True, check=False) diff --git a/olive/cli/launcher.py b/olive/cli/launcher.py index d9088bc89b..fed339f87d 100644 --- a/olive/cli/launcher.py +++ b/olive/cli/launcher.py @@ -16,6 +16,7 @@ from olive.cli.finetune import FineTuneCommand from olive.cli.generate_adapter import GenerateAdapterCommand from olive.cli.generate_cost_model import GenerateCostModelCommand +from olive.cli.init import InitCommand from olive.cli.optimize import OptimizeCommand from olive.cli.quantize import QuantizeCommand from olive.cli.run import WorkflowRunCommand @@ -37,6 +38,7 @@ def get_cli_parser(called_as_console_script: bool = True) -> ArgumentParser: # Register commands # TODO(jambayk): Consider adding a common tempdir option to all commands # NOTE: The order of the commands is to organize the documentation better. + InitCommand.register_subcommand(commands_parser) WorkflowRunCommand.register_subcommand(commands_parser) RunPassCommand.register_subcommand(commands_parser) AutoOptCommand.register_subcommand(commands_parser) diff --git a/olive/cli/optimize.py b/olive/cli/optimize.py index d80392ecf9..6d94c1407c 100644 --- a/olive/cli/optimize.py +++ b/olive/cli/optimize.py @@ -191,7 +191,7 @@ def register_subcommand(parser: ArgumentParser): def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Optional[list] = None): super().__init__(parser, args, unknown_args) - self.need_wikitest_data_config = False + self.need_wikitext_data_config = False self.is_hf_model = False # will be set in _get_run_config # Pass enabled flags @@ -202,7 +202,6 @@ def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Option self.enable_onnx_conversion = False self.enable_optimum_openvino_conversion = False self.enable_dynamic_to_fixed_shape = False - self.enable_vitis_ai_preprocess = False self.enable_onnx_io_datatype_converter = False self.enable_openvino_io_update = False self.enable_onnx_peephole_optimizer = False @@ -304,7 +303,7 @@ def _update_system_config(self, config: dict[str, Any]): config["target"] = "qnn_system" def _add_data_config(self, config: dict[str, Any]): - config["data_configs"] = WIKITEXT2_DATA_CONFIG_TEMPLATE if self.need_wikitest_data_config else [] + config["data_configs"] = WIKITEXT2_DATA_CONFIG_TEMPLATE if self.need_wikitext_data_config else [] def _build_passes_config(self) -> dict[str, Any]: passes_config = OrderedDict() @@ -553,7 +552,7 @@ def _get_openvino_io_update_pass_config(self) -> dict[str, Any]: def _enable_onnx_peephole_optimizer_pass(self) -> bool: """Return true if condition to add OnnxPeepholeOptimizer pass is met.""" - return self.args.exporter != "model_builder" + return not self.is_hf_model or self.args.exporter != "model_builder" def _get_onnx_peephole_optimizer_pass_config(self) -> dict[str, Any]: """Return pass dictionary for OnnxPeepholeOptimizer pass.""" @@ -668,7 +667,7 @@ def _get_onnx_static_quantization_pass_config(self) -> dict[str, Any]: # Add data_config for text modality if self.args.modality == "text": - self.need_wikitest_data_config = True + self.need_wikitext_data_config = True config["data_config"] = "wikitext2_train" return config diff --git a/requirements.txt b/requirements.txt index 400f0bfe7b..1032c72f97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ optuna pandas pydantic>=2.0 pyyaml +questionary torch torchmetrics>=1.0.0 transformers diff --git a/test/cli/init/__init__.py b/test/cli/init/__init__.py new file mode 100644 index 0000000000..54aa1f92bf --- /dev/null +++ b/test/cli/init/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- diff --git a/test/cli/init/conftest.py b/test/cli/init/conftest.py new file mode 100644 index 0000000000..a411298306 --- /dev/null +++ b/test/cli/init/conftest.py @@ -0,0 +1,24 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +import sys +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture(autouse=True) +def _mock_questionary(monkeypatch): + """Replace ``questionary`` with a MagicMock in every module that imports it.""" + mock_q = MagicMock() + + # Modules that do ``import questionary`` at the top level. + for mod in ( + "olive.cli.init.wizard", + "olive.cli.init.onnx_flow", + "olive.cli.init.pytorch_flow", + "olive.cli.init.diffusers_flow", + ): + monkeypatch.setattr(f"{mod}.questionary", mock_q, raising=False) + monkeypatch.setitem(sys.modules, "questionary", mock_q) diff --git a/test/cli/init/test_diffusers_flow.py b/test/cli/init/test_diffusers_flow.py new file mode 100644 index 0000000000..4a45786ee3 --- /dev/null +++ b/test/cli/init/test_diffusers_flow.py @@ -0,0 +1,216 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +from unittest.mock import patch + + +class TestExportFlow: + @patch("olive.cli.init.diffusers_flow._ask") + def test_export_with_variant(self, mock_ask): + from olive.cli.init.diffusers_flow import _export_flow + + mock_ask.return_value = "float16" + result = _export_flow("stabilityai/sdxl", "sdxl") + cmd = result["command"] + assert "olive capture-onnx-graph -m stabilityai/sdxl" in cmd + assert "--torch_dtype float16" in cmd + assert "--model_variant sdxl" in cmd + + @patch("olive.cli.init.diffusers_flow._ask") + def test_export_auto_variant(self, mock_ask): + from olive.cli.init.diffusers_flow import _export_flow + from olive.cli.init.helpers import DiffuserVariant + + mock_ask.return_value = "float32" + result = _export_flow("my-model", DiffuserVariant.AUTO) + assert "--model_variant" not in result["command"] + + +class TestLoraFlow: + @patch("olive.cli.init.diffusers_flow._ask") + def test_basic_lora_local_data(self, mock_ask): + from olive.cli.init.diffusers_flow import _lora_flow + from olive.cli.init.helpers import DiffuserVariant, SourceType + + mock_ask.side_effect = [ + "16", # lora_r + "16", # lora_alpha + "0.0", # lora_dropout + SourceType.LOCAL, # data_source + "/images", # data_dir + False, # enable_dreambooth + "1000", # max_train_steps + "1e-4", # learning_rate + "1", # train_batch_size + "4", # gradient_accumulation + "bf16", # mixed_precision + "constant", # lr_scheduler + "0", # warmup_steps + "", # seed (skip) + False, # merge_lora + ] + result = _lora_flow("my-model", DiffuserVariant.AUTO) + cmd = result["command"] + assert "olive diffusion-lora -m my-model" in cmd + assert "-r 16 --alpha 16" in cmd + assert "-d /images" in cmd + assert "--max_train_steps 1000" in cmd + assert "--model_variant" not in cmd + + @patch("olive.cli.init.diffusers_flow._ask") + def test_flux_with_dreambooth(self, mock_ask): + from olive.cli.init.diffusers_flow import _lora_flow + from olive.cli.init.helpers import DiffuserVariant, SourceType + + mock_ask.side_effect = [ + "16", # lora_r + "16", # lora_alpha + "0.1", # lora_dropout + SourceType.LOCAL, # data_source + "/images", # data_dir + True, # enable_dreambooth + "a photo of sks dog", # instance_prompt + True, # with_prior + "a photo of a dog", # class_prompt + "", # class_data_dir (skip) + "200", # num_class_images + "500", # max_train_steps + "1e-4", # learning_rate + "1", # train_batch_size + "4", # gradient_accumulation + "bf16", # mixed_precision + "constant", # lr_scheduler + "0", # warmup_steps + "", # seed (skip) + "3.5", # guidance_scale (flux-specific) + True, # merge_lora + ] + result = _lora_flow("my-flux-model", DiffuserVariant.FLUX) + cmd = result["command"] + assert f"--model_variant {DiffuserVariant.FLUX}" in cmd + assert "--dreambooth" in cmd + assert '--instance_prompt "a photo of sks dog"' in cmd + assert "--with_prior_preservation" in cmd + assert "--guidance_scale 3.5" in cmd + assert "--merge_lora" in cmd + + @patch("olive.cli.init.diffusers_flow._ask") + def test_hf_data_source_with_caption(self, mock_ask): + from olive.cli.init.diffusers_flow import _lora_flow + from olive.cli.init.helpers import DiffuserVariant, SourceType + + mock_ask.side_effect = [ + "16", # lora_r + "16", # lora_alpha + "0.0", # lora_dropout + SourceType.HF, # data_source + "linoyts/Tuxemon", # data_name + "train", # data_split + "image", # image_column + "caption", # caption_column + False, # enable_dreambooth + "1000", # max_train_steps + "1e-4", # learning_rate + "1", # train_batch_size + "4", # gradient_accumulation + "bf16", # mixed_precision + "constant", # lr_scheduler + "0", # warmup_steps + "42", # seed (provided) + False, # merge_lora + ] + result = _lora_flow("my-model", DiffuserVariant.AUTO) + cmd = result["command"] + assert "--data_name linoyts/Tuxemon" in cmd + assert "--data_split train" in cmd + assert "--image_column image" in cmd + assert "--caption_column caption" in cmd + assert "--seed 42" in cmd + + @patch("olive.cli.init.diffusers_flow._ask") + def test_custom_max_train_steps(self, mock_ask): + from olive.cli.init.diffusers_flow import TRAIN_STEPS_CUSTOM, _lora_flow + from olive.cli.init.helpers import DiffuserVariant, SourceType + + mock_ask.side_effect = [ + "16", # lora_r + "16", # lora_alpha + "0.0", # lora_dropout + SourceType.LOCAL, # data_source + "/images", # data_dir + False, # enable_dreambooth + TRAIN_STEPS_CUSTOM, # max_train_steps + "3000", # custom value + "1e-4", # learning_rate + "1", # train_batch_size + "4", # gradient_accumulation + "bf16", # mixed_precision + "constant", # lr_scheduler + "0", # warmup_steps + "", # seed (skip) + False, # merge_lora + ] + result = _lora_flow("my-model", DiffuserVariant.AUTO) + assert "--max_train_steps 3000" in result["command"] + + @patch("olive.cli.init.diffusers_flow._ask") + def test_dreambooth_with_class_data_dir(self, mock_ask): + from olive.cli.init.diffusers_flow import _lora_flow + from olive.cli.init.helpers import DiffuserVariant, SourceType + + mock_ask.side_effect = [ + "16", # lora_r + "16", # lora_alpha + "0.0", # lora_dropout + SourceType.LOCAL, # data_source + "/images", # data_dir + True, # enable_dreambooth + "a photo of sks dog", # instance_prompt + True, # with_prior + "a photo of a dog", # class_prompt + "/class_images", # class_data_dir (provided) + "200", # num_class_images + "1000", # max_train_steps + "1e-4", # learning_rate + "1", # train_batch_size + "4", # gradient_accumulation + "bf16", # mixed_precision + "constant", # lr_scheduler + "0", # warmup_steps + "", # seed (skip) + False, # merge_lora + ] + result = _lora_flow("my-model", DiffuserVariant.AUTO) + assert "--class_data_dir /class_images" in result["command"] + + +class TestRunDiffusersFlowRouting: + @patch("olive.cli.init.diffusers_flow._export_flow") + @patch("olive.cli.init.diffusers_flow._ask_select") + def test_routes_to_export(self, mock_select, mock_flow): + from olive.cli.init.diffusers_flow import OP_EXPORT, run_diffusers_flow + + mock_select.return_value = OP_EXPORT + mock_flow.return_value = {"command": "test"} + run_diffusers_flow({"model_path": "m", "variant": "sdxl"}) + mock_flow.assert_called_once_with("m", "sdxl") + + @patch("olive.cli.init.diffusers_flow._lora_flow") + @patch("olive.cli.init.diffusers_flow._ask_select") + def test_routes_to_lora(self, mock_select, mock_flow): + from olive.cli.init.diffusers_flow import OP_LORA, run_diffusers_flow + from olive.cli.init.helpers import DiffuserVariant + + mock_select.return_value = OP_LORA + mock_flow.return_value = {"command": "test"} + run_diffusers_flow({"model_path": "m", "variant": DiffuserVariant.FLUX}) + mock_flow.assert_called_once_with("m", DiffuserVariant.FLUX) + + @patch("olive.cli.init.diffusers_flow._ask_select", return_value="unknown") + def test_unknown_operation_returns_empty(self, mock_select): + from olive.cli.init.diffusers_flow import run_diffusers_flow + from olive.cli.init.helpers import DiffuserVariant + + result = run_diffusers_flow({"model_path": "m", "variant": DiffuserVariant.AUTO}) + assert not result diff --git a/test/cli/init/test_init_command.py b/test/cli/init/test_init_command.py new file mode 100644 index 0000000000..257f8c6dad --- /dev/null +++ b/test/cli/init/test_init_command.py @@ -0,0 +1,37 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +from unittest.mock import patch + + +class TestInitCommand: + def test_register_subcommand(self): + from argparse import ArgumentParser + + from olive.cli.init import InitCommand + + parser = ArgumentParser() + sub_parsers = parser.add_subparsers() + InitCommand.register_subcommand(sub_parsers) + + args = parser.parse_args(["init", "-o", "/tmp/out"]) + assert args.output_path == "/tmp/out" + assert args.func is InitCommand + + @patch("olive.cli.init.wizard.InitWizard") + def test_run(self, mock_wizard_cls): + from argparse import ArgumentParser + + from olive.cli.init import InitCommand + + parser = ArgumentParser() + sub_parsers = parser.add_subparsers() + InitCommand.register_subcommand(sub_parsers) + + args = parser.parse_args(["init", "-o", "./my-output"]) + cmd = InitCommand(parser, args, []) + cmd.run() + + mock_wizard_cls.assert_called_once_with(default_output_path="./my-output") + mock_wizard_cls.return_value.start.assert_called_once() diff --git a/test/cli/init/test_onnx_flow.py b/test/cli/init/test_onnx_flow.py new file mode 100644 index 0000000000..c35ce93ea0 --- /dev/null +++ b/test/cli/init/test_onnx_flow.py @@ -0,0 +1,180 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +from unittest.mock import patch + + +class TestQuantizeFlow: + @patch("olive.cli.init.onnx_flow._ask") + def test_static_quantization_default_calib(self, mock_ask): + from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow + + with patch("olive.cli.init.onnx_flow.prompt_calibration_source", return_value=None): + mock_ask.return_value = QuantizationType.STATIC + result = _quantize_flow("/model.onnx") + cmd = result["command"] + assert "--implementation ort" in cmd + assert "--precision int8" in cmd + + @patch("olive.cli.init.onnx_flow._ask") + def test_dynamic_quantization(self, mock_ask): + from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow + + mock_ask.return_value = QuantizationType.DYNAMIC + result = _quantize_flow("/model.onnx") + cmd = result["command"] + assert "--algorithm rtn" in cmd + assert "--implementation ort" in cmd + + @patch("olive.cli.init.onnx_flow._ask") + def test_bnb_quantization(self, mock_ask): + from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow + + mock_ask.return_value = QuantizationType.BNB + result = _quantize_flow("/model.onnx") + cmd = result["command"] + assert "--implementation bnb" in cmd + assert "--precision nf4" in cmd + + @patch("olive.cli.init.onnx_flow.build_calibration_args", return_value=" -d data --split train --max_samples 128") + @patch("olive.cli.init.onnx_flow.prompt_calibration_source") + @patch("olive.cli.init.onnx_flow._ask") + def test_static_with_calibration_data(self, mock_ask, mock_calib, mock_build): + from olive.cli.init.helpers import SourceType + from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow + + mock_ask.return_value = QuantizationType.STATIC + mock_calib.return_value = { + "source": SourceType.HF, + "data_name": "data", + "subset": "", + "split": "train", + "num_samples": "128", + } + result = _quantize_flow("/model.onnx") + cmd = result["command"] + assert "--implementation ort" in cmd + assert "-d data" in cmd + + +class TestOptimizeFlow: + @patch("olive.cli.init.onnx_flow._ask") + def test_generates_command(self, mock_ask): + from olive.cli.init.onnx_flow import _optimize_flow + + mock_ask.side_effect = ["CPUExecutionProvider", "fp32"] + result = _optimize_flow("/model.onnx") + assert result["command"] == "olive optimize -m /model.onnx --provider CPUExecutionProvider --precision fp32" + + +class TestTuneSessionFlow: + @patch("olive.cli.init.onnx_flow._ask") + def test_cpu_with_options(self, mock_ask): + from olive.cli.init.onnx_flow import _tune_session_flow + + mock_ask.side_effect = [ + "cpu", # device + ["CPUExecutionProvider"], # providers + "4", # cpu_cores + False, # io_bind + False, # enable_cuda_graph + ] + result = _tune_session_flow("/model.onnx") + cmd = result["command"] + assert "--device cpu" in cmd + assert "--providers_list CPUExecutionProvider" in cmd + assert "--cpu_cores 4" in cmd + assert "--io_bind" not in cmd + + @patch("olive.cli.init.onnx_flow._ask") + def test_gpu_with_io_bind_and_cuda_graph(self, mock_ask): + from olive.cli.init.onnx_flow import _tune_session_flow + + mock_ask.side_effect = [ + "gpu", # device + ["CUDAExecutionProvider"], # providers + "", # cpu_cores (skip) + True, # io_bind + True, # enable_cuda_graph + ] + result = _tune_session_flow("/model.onnx") + cmd = result["command"] + assert "--device gpu" in cmd + assert "--io_bind" in cmd + assert "--enable_cuda_graph" in cmd + + +class TestConvertPrecisionFlow: + def test_generates_command(self): + from olive.cli.init.onnx_flow import _convert_precision_flow + + result = _convert_precision_flow("/model.onnx") + assert result["command"] == "olive run-pass --pass-name OnnxFloatToFloat16 -m /model.onnx" + + +class TestGraphOptFlow: + def test_generates_command(self): + from olive.cli.init.onnx_flow import _graph_opt_flow + + result = _graph_opt_flow("/model.onnx") + assert result["command"] == "olive optimize -m /model.onnx --precision fp32" + + +class TestRunOnnxFlowRouting: + @patch("olive.cli.init.onnx_flow._optimize_flow") + @patch("olive.cli.init.onnx_flow._ask_select") + def test_routes_to_optimize(self, mock_select, mock_flow): + from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow + + mock_select.return_value = OnnxOperation.OPTIMIZE + mock_flow.return_value = {"command": "test"} + run_onnx_flow({"model_path": "/m.onnx"}) + mock_flow.assert_called_once_with("/m.onnx") + + @patch("olive.cli.init.onnx_flow._quantize_flow") + @patch("olive.cli.init.onnx_flow._ask_select") + def test_routes_to_quantize(self, mock_select, mock_flow): + from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow + + mock_select.return_value = OnnxOperation.QUANTIZE + mock_flow.return_value = {"command": "test"} + run_onnx_flow({"model_path": "/m.onnx"}) + mock_flow.assert_called_once() + + @patch("olive.cli.init.onnx_flow._graph_opt_flow") + @patch("olive.cli.init.onnx_flow._ask_select") + def test_routes_to_graph_opt(self, mock_select, mock_flow): + from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow + + mock_select.return_value = OnnxOperation.GRAPH_OPT + mock_flow.return_value = {"command": "test"} + run_onnx_flow({"model_path": "/m.onnx"}) + mock_flow.assert_called_once_with("/m.onnx") + + @patch("olive.cli.init.onnx_flow._convert_precision_flow") + @patch("olive.cli.init.onnx_flow._ask_select") + def test_routes_to_convert_precision(self, mock_select, mock_flow): + from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow + + mock_select.return_value = OnnxOperation.CONVERT_PRECISION + mock_flow.return_value = {"command": "test"} + run_onnx_flow({"model_path": "/m.onnx"}) + mock_flow.assert_called_once() + + @patch("olive.cli.init.onnx_flow._tune_session_flow") + @patch("olive.cli.init.onnx_flow._ask_select") + def test_routes_to_tune_session(self, mock_select, mock_flow): + from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow + + mock_select.return_value = OnnxOperation.TUNE_SESSION + mock_flow.return_value = {"command": "test"} + run_onnx_flow({"model_path": "/m.onnx"}) + mock_flow.assert_called_once() + + @patch("olive.cli.init.onnx_flow._ask_select", return_value="unknown") + def test_unknown_operation_returns_empty(self, mock_select): + from olive.cli.init.onnx_flow import run_onnx_flow + + result = run_onnx_flow({"model_path": "/m.onnx"}) + assert not result diff --git a/test/cli/init/test_pytorch_flow.py b/test/cli/init/test_pytorch_flow.py new file mode 100644 index 0000000000..18f9e90051 --- /dev/null +++ b/test/cli/init/test_pytorch_flow.py @@ -0,0 +1,502 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +from unittest.mock import patch + + +class TestBuildModelArgs: + def test_model_path_only(self): + from olive.cli.init.pytorch_flow import _build_model_args + + result = _build_model_args({"model_path": "meta-llama/Llama-3.1-8B"}) + assert result == "-m meta-llama/Llama-3.1-8B" + + def test_with_model_script(self): + from olive.cli.init.pytorch_flow import _build_model_args + + config = {"model_path": "my_model", "model_script": "train.py", "script_dir": "./src"} + result = _build_model_args(config) + assert "-m my_model" in result + assert "--model_script train.py" in result + assert "--script_dir ./src" in result + + def test_script_only_no_model_path(self): + from olive.cli.init.pytorch_flow import _build_model_args + + config = {"model_script": "train.py"} + result = _build_model_args(config) + assert not result.startswith("-m ") + assert "--model_script train.py" in result + + def test_empty_config(self): + from olive.cli.init.pytorch_flow import _build_model_args + + result = _build_model_args({}) + assert result == "" + + +class TestBuildExportCommand: + def test_model_builder(self): + from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, _build_export_command + + cmd = _build_export_command("-m model", {"exporter": EXPORTER_MODEL_BUILDER, "precision": "fp16"}) + assert cmd == "olive capture-onnx-graph -m model --use_model_builder --precision fp16" + + def test_model_builder_int4_with_block_size(self): + from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, PRECISION_INT4, _build_export_command + + config = {"exporter": EXPORTER_MODEL_BUILDER, "precision": PRECISION_INT4, "int4_block_size": "32"} + cmd = _build_export_command("-m model", config) + assert "--use_model_builder --precision int4" in cmd + assert "--int4_block_size 32" in cmd + + def test_dynamo(self): + from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, _build_export_command + + cmd = _build_export_command("-m model", {"exporter": EXPORTER_DYNAMO, "torch_dtype": "float16"}) + assert cmd == "olive capture-onnx-graph -m model --torch_dtype float16" + + def test_torchscript(self): + from olive.cli.init.pytorch_flow import EXPORTER_TORCHSCRIPT, _build_export_command + + cmd = _build_export_command("-m model", {"exporter": EXPORTER_TORCHSCRIPT}) + assert cmd == "olive capture-onnx-graph -m model" + + +class TestBuildQuantizeCommand: + def test_rtn_no_implementation(self): + from olive.cli.init.pytorch_flow import _build_quantize_command + + cmd = _build_quantize_command("-m model", {"algorithm": "rtn", "precision": "int4"}) + assert cmd == "olive quantize -m model --algorithm rtn --precision int4" + assert "--implementation" not in cmd + + def test_awq_with_implementation(self): + from olive.cli.init.pytorch_flow import _build_quantize_command + + cmd = _build_quantize_command("-m model", {"algorithm": "awq", "precision": "int4"}) + assert "--algorithm awq" in cmd + assert "--implementation awq" in cmd + + def test_quarot_with_implementation(self): + from olive.cli.init.pytorch_flow import _build_quantize_command + + cmd = _build_quantize_command("-m model", {"algorithm": "quarot", "precision": "int4"}) + assert "--implementation quarot" in cmd + + def test_spinquant_with_implementation(self): + from olive.cli.init.pytorch_flow import _build_quantize_command + + cmd = _build_quantize_command("-m model", {"algorithm": "spinquant", "precision": "int4"}) + assert "--implementation spinquant" in cmd + + def test_with_calibration(self): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import _build_quantize_command + + config = { + "algorithm": "gptq", + "precision": "int4", + "calibration": { + "source": SourceType.HF, + "data_name": "Salesforce/wikitext", + "subset": "wikitext-2-raw-v1", + "split": "train", + "num_samples": "128", + }, + } + cmd = _build_quantize_command("-m model", config) + assert "--algorithm gptq" in cmd + assert "--implementation" not in cmd # gptq uses default olive + assert "-d Salesforce/wikitext" in cmd + assert "--split train" in cmd + + +class TestOptimizeAutoMode: + @patch("olive.cli.init.pytorch_flow._ask") + def test_generates_command(self, mock_ask): + from olive.cli.init.pytorch_flow import _optimize_auto_mode + + mock_ask.side_effect = ["CUDAExecutionProvider", "int4"] + result = _optimize_auto_mode({"model_path": "my-model"}) + assert result["command"] == "olive optimize -m my-model --provider CUDAExecutionProvider --precision int4" + + +class TestQuantizeFlow: + @patch("olive.cli.init.pytorch_flow.prompt_calibration_source") + @patch("olive.cli.init.pytorch_flow._ask") + def test_rtn_no_calibration(self, mock_ask, mock_calib): + from olive.cli.init.pytorch_flow import _quantize_flow + + mock_ask.side_effect = ["rtn", "int4"] + result = _quantize_flow({"model_path": "my-model"}) + assert "--algorithm rtn" in result["command"] + assert "--implementation" not in result["command"] + mock_calib.assert_not_called() + + @patch("olive.cli.init.pytorch_flow.prompt_calibration_source", return_value=None) + @patch("olive.cli.init.pytorch_flow._ask") + def test_awq_with_default_calibration(self, mock_ask, mock_calib): + from olive.cli.init.pytorch_flow import _quantize_flow + + mock_ask.side_effect = ["awq", "int4"] + result = _quantize_flow({"model_path": "my-model"}) + assert "--algorithm awq" in result["command"] + assert "--implementation awq" in result["command"] + mock_calib.assert_called_once() + + @patch( + "olive.cli.init.pytorch_flow.build_calibration_args", return_value=" -d data --split train --max_samples 128" + ) + @patch("olive.cli.init.pytorch_flow.prompt_calibration_source") + @patch("olive.cli.init.pytorch_flow._ask") + def test_gptq_with_calibration(self, mock_ask, mock_calib, mock_build): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import _quantize_flow + + mock_calib.return_value = { + "source": SourceType.HF, + "data_name": "data", + "subset": "", + "split": "train", + "num_samples": "128", + } + mock_ask.side_effect = ["gptq", "int4"] + result = _quantize_flow({"model_path": "my-model"}) + assert "--algorithm gptq" in result["command"] + assert "-d data" in result["command"] + + +class TestExportFlow: + @patch("olive.cli.init.pytorch_flow._ask") + def test_dynamo_exporter(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, _export_flow + + mock_ask.side_effect = [EXPORTER_DYNAMO, "float16"] + result = _export_flow({"model_path": "my-model"}) + assert result["command"] == "olive capture-onnx-graph -m my-model --torch_dtype float16" + + @patch("olive.cli.init.pytorch_flow._ask") + def test_model_builder_fp16(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, _export_flow + + mock_ask.side_effect = [EXPORTER_MODEL_BUILDER, "fp16"] + result = _export_flow({"model_path": "my-model"}) + assert "--use_model_builder --precision fp16" in result["command"] + + @patch("olive.cli.init.pytorch_flow._ask") + def test_model_builder_int4(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, PRECISION_INT4, _export_flow + + mock_ask.side_effect = [EXPORTER_MODEL_BUILDER, PRECISION_INT4, "32", "4"] + result = _export_flow({"model_path": "my-model"}) + assert "--precision int4" in result["command"] + assert "--int4_block_size 32" in result["command"] + assert "--int4_accuracy_level 4" in result["command"] + + @patch("olive.cli.init.pytorch_flow._ask") + def test_torchscript(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_TORCHSCRIPT, _export_flow + + mock_ask.side_effect = [EXPORTER_TORCHSCRIPT] + result = _export_flow({"model_path": "my-model"}) + assert result["command"] == "olive capture-onnx-graph -m my-model" + + @patch("olive.cli.init.pytorch_flow._ask") + def test_with_model_script(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, _export_flow + + mock_ask.side_effect = [EXPORTER_DYNAMO, "float32"] + result = _export_flow({"model_script": "script.py", "script_dir": "./src"}) + assert "--model_script script.py" in result["command"] + assert "--script_dir ./src" in result["command"] + + +class TestFinetuneFlow: + @patch("olive.cli.init.pytorch_flow._ask") + def test_lora_hf_dataset(self, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import TEXT_FIELD, _finetune_flow + + mock_ask.side_effect = [ + "lora", # method + "64", # lora_r + "16", # lora_alpha + SourceType.HF, # data_source + "tatsu-lab/alpaca", # data_name + "train", # train_split + "", # eval_split (skip) + TEXT_FIELD, # text_mode + "text", # text_field + "1024", # max_seq_len + "256", # max_samples + "bfloat16", # torch_dtype + ] + result = _finetune_flow({"model_path": "my-model"}) + cmd = result["command"] + assert "olive finetune -m my-model" in cmd + assert "--method lora" in cmd + assert "-d tatsu-lab/alpaca" in cmd + assert "--train_split train" in cmd + assert "--text_field text" in cmd + + @patch("olive.cli.init.pytorch_flow._ask") + def test_qlora_local_data_template(self, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import TEXT_TEMPLATE, _finetune_flow + + mock_ask.side_effect = [ + "qlora", # method + "16", # lora_r + "16", # lora_alpha + SourceType.LOCAL, # data_source + "/data/train.json", # data_files + TEXT_TEMPLATE, # text_mode + "Q: {q} A: {a}", # template + "512", # max_seq_len + "100", # max_samples + "float16", # torch_dtype + ] + result = _finetune_flow({"model_path": "my-model"}) + cmd = result["command"] + assert "--method qlora" in cmd + assert "--data_files /data/train.json" in cmd + assert '--text_template "Q: {q} A: {a}"' in cmd + + @patch("olive.cli.init.pytorch_flow._ask") + def test_hf_with_eval_split(self, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import TEXT_FIELD, _finetune_flow + + mock_ask.side_effect = [ + "lora", # method + "64", # lora_r + "16", # lora_alpha + SourceType.HF, # data_source + "tatsu-lab/alpaca", # data_name + "train", # train_split + "test", # eval_split (provided) + TEXT_FIELD, # text_mode + "text", # text_field + "1024", # max_seq_len + "256", # max_samples + "bfloat16", # torch_dtype + ] + result = _finetune_flow({"model_path": "my-model"}) + assert "--eval_split test" in result["command"] + + @patch("olive.cli.init.pytorch_flow._ask") + def test_chat_template(self, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import TEXT_CHAT_TEMPLATE, _finetune_flow + + mock_ask.side_effect = [ + "lora", # method + "64", # lora_r + "16", # lora_alpha + SourceType.HF, # data_source + "dataset", # data_name + "train", # train_split + "", # eval_split (skip) + TEXT_CHAT_TEMPLATE, # text_mode + "1024", # max_seq_len + "256", # max_samples + "bfloat16", # torch_dtype + ] + result = _finetune_flow({"model_path": "my-model"}) + assert "--use_chat_template" in result["command"] + + +class TestOptimizeCustomMode: + @patch("olive.cli.init.pytorch_flow._ask") + def test_export_and_quantize(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, OP_EXPORT, OP_QUANTIZE, _optimize_custom_mode + + mock_ask.side_effect = [ + [OP_EXPORT, OP_QUANTIZE], # operations checkbox + EXPORTER_DYNAMO, # exporter + "float32", # torch_dtype + "rtn", # algorithm + "int4", # precision + "CUDAExecutionProvider", # provider + ] + result = _optimize_custom_mode({"model_path": "my-model"}) + cmd = result["command"] + assert "olive optimize" in cmd + assert "--provider CUDAExecutionProvider" in cmd + assert "--exporter dynamo_exporter" in cmd + + @patch("olive.cli.init.pytorch_flow._ask") + def test_export_only(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, OP_EXPORT, _optimize_custom_mode + + mock_ask.side_effect = [ + [OP_EXPORT], # operations checkbox + EXPORTER_DYNAMO, # exporter + "float16", # torch_dtype + ] + result = _optimize_custom_mode({"model_path": "my-model"}) + assert "olive capture-onnx-graph" in result["command"] + assert "--torch_dtype float16" in result["command"] + + @patch( + "olive.cli.init.pytorch_flow.build_calibration_args", return_value=" -d data --split train --max_samples 128" + ) + @patch("olive.cli.init.pytorch_flow.prompt_calibration_source") + @patch("olive.cli.init.pytorch_flow._ask") + def test_quantize_only(self, mock_ask, mock_calib, mock_build): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import OP_QUANTIZE, _optimize_custom_mode + + mock_calib.return_value = { + "source": SourceType.HF, + "data_name": "data", + "subset": "", + "split": "train", + "num_samples": "128", + } + mock_ask.side_effect = [ + [OP_QUANTIZE], # operations checkbox + "gptq", # algorithm + "int4", # precision + ] + result = _optimize_custom_mode({"model_path": "my-model"}) + assert "olive quantize" in result["command"] + assert "--algorithm gptq" in result["command"] + + @patch("olive.cli.init.pytorch_flow._ask") + def test_graph_opt_only(self, mock_ask): + from olive.cli.init.pytorch_flow import OP_GRAPH_OPT, _optimize_custom_mode + + mock_ask.side_effect = [ + [OP_GRAPH_OPT], # operations checkbox + "CPUExecutionProvider", # provider + ] + result = _optimize_custom_mode({"model_path": "my-model"}) + assert "olive optimize" in result["command"] + assert "--precision fp32" in result["command"] + + @patch("olive.cli.init.pytorch_flow._ask") + def test_export_and_graph_opt(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, OP_EXPORT, OP_GRAPH_OPT, _optimize_custom_mode + + mock_ask.side_effect = [ + [OP_EXPORT, OP_GRAPH_OPT], # operations checkbox + EXPORTER_MODEL_BUILDER, # exporter + "fp16", # precision + "CPUExecutionProvider", # provider + ] + result = _optimize_custom_mode({"model_path": "my-model"}) + cmd = result["command"] + assert "olive optimize" in cmd + assert "--precision fp32" in cmd + assert "--exporter model_builder" in cmd + + @patch("olive.cli.init.pytorch_flow._ask") + def test_quantize_and_graph_opt(self, mock_ask): + from olive.cli.init.pytorch_flow import OP_GRAPH_OPT, OP_QUANTIZE, _optimize_custom_mode + + mock_ask.side_effect = [ + [OP_QUANTIZE, OP_GRAPH_OPT], # operations checkbox + "rtn", # algorithm + "int4", # precision + "CUDAExecutionProvider", # provider + ] + result = _optimize_custom_mode({"model_path": "my-model"}) + cmd = result["command"] + assert "olive optimize" in cmd + assert "--precision int4" in cmd + + @patch("olive.cli.init.pytorch_flow._ask") + def test_no_operations_selected(self, mock_ask): + from olive.cli.init.pytorch_flow import _optimize_custom_mode + + mock_ask.return_value = [] # empty checkbox + result = _optimize_custom_mode({"model_path": "my-model"}) + assert not result + + +class TestOptimizeFlow: + @patch("olive.cli.init.pytorch_flow._optimize_auto_mode") + @patch("olive.cli.init.pytorch_flow._ask") + def test_routes_to_auto(self, mock_ask, mock_auto): + from olive.cli.init.pytorch_flow import MODE_AUTO, _optimize_flow + + mock_ask.return_value = MODE_AUTO + mock_auto.return_value = {"command": "test"} + _optimize_flow({"model_path": "m"}) + mock_auto.assert_called_once() + + @patch("olive.cli.init.pytorch_flow._optimize_custom_mode") + @patch("olive.cli.init.pytorch_flow._ask") + def test_routes_to_custom(self, mock_ask, mock_custom): + from olive.cli.init.pytorch_flow import MODE_CUSTOM, _optimize_flow + + mock_ask.return_value = MODE_CUSTOM + mock_custom.return_value = {"command": "test"} + _optimize_flow({"model_path": "m"}) + mock_custom.assert_called_once() + + +class TestPromptExportOptionsInt4: + @patch("olive.cli.init.pytorch_flow._ask") + def test_model_builder_int4_block_size(self, mock_ask): + from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, PRECISION_INT4, _prompt_export_options + + mock_ask.side_effect = [ + EXPORTER_MODEL_BUILDER, # exporter + PRECISION_INT4, # precision + "64", # block_size + ] + config = _prompt_export_options() + assert config == {"exporter": EXPORTER_MODEL_BUILDER, "precision": PRECISION_INT4, "int4_block_size": "64"} + + +class TestRunPytorchFlowRouting: + @patch("olive.cli.init.pytorch_flow._optimize_flow") + @patch("olive.cli.init.pytorch_flow._ask_select") + def test_routes_to_optimize(self, mock_select, mock_flow): + from olive.cli.init.pytorch_flow import OP_OPTIMIZE, run_pytorch_flow + + mock_select.return_value = OP_OPTIMIZE + mock_flow.return_value = {"command": "test"} + run_pytorch_flow({"model_path": "m"}) + mock_flow.assert_called_once_with({"model_path": "m"}) + + @patch("olive.cli.init.pytorch_flow._export_flow") + @patch("olive.cli.init.pytorch_flow._ask_select") + def test_routes_to_export(self, mock_select, mock_flow): + from olive.cli.init.pytorch_flow import OP_EXPORT, run_pytorch_flow + + mock_select.return_value = OP_EXPORT + mock_flow.return_value = {"command": "test"} + run_pytorch_flow({"model_path": "m"}) + mock_flow.assert_called_once() + + @patch("olive.cli.init.pytorch_flow._quantize_flow") + @patch("olive.cli.init.pytorch_flow._ask_select") + def test_routes_to_quantize(self, mock_select, mock_flow): + from olive.cli.init.pytorch_flow import OP_QUANTIZE, run_pytorch_flow + + mock_select.return_value = OP_QUANTIZE + mock_flow.return_value = {"command": "test"} + run_pytorch_flow({"model_path": "m"}) + mock_flow.assert_called_once() + + @patch("olive.cli.init.pytorch_flow._finetune_flow") + @patch("olive.cli.init.pytorch_flow._ask_select") + def test_routes_to_finetune(self, mock_select, mock_flow): + from olive.cli.init.pytorch_flow import OP_FINETUNE, run_pytorch_flow + + mock_select.return_value = OP_FINETUNE + mock_flow.return_value = {"command": "test"} + run_pytorch_flow({"model_path": "m"}) + mock_flow.assert_called_once() + + @patch("olive.cli.init.pytorch_flow._ask_select", return_value="unknown") + def test_unknown_operation_returns_empty(self, mock_select): + from olive.cli.init.pytorch_flow import run_pytorch_flow + + result = run_pytorch_flow({"model_path": "m"}) + assert not result diff --git a/test/cli/init/test_wizard.py b/test/cli/init/test_wizard.py new file mode 100644 index 0000000000..6afe329f9b --- /dev/null +++ b/test/cli/init/test_wizard.py @@ -0,0 +1,384 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- +from unittest.mock import MagicMock, patch + +import pytest + + +class TestBuildCalibrationArgs: + def test_hf_source_with_subset(self): + from olive.cli.init.helpers import SourceType, build_calibration_args + + calib = { + "source": SourceType.HF, + "data_name": "Salesforce/wikitext", + "subset": "wikitext-2-raw-v1", + "split": "train", + "num_samples": "128", + } + result = build_calibration_args(calib) + assert result == " -d Salesforce/wikitext --subset wikitext-2-raw-v1 --split train --max_samples 128" + + def test_hf_source_without_subset(self): + from olive.cli.init.helpers import SourceType, build_calibration_args + + calib = { + "source": SourceType.HF, + "data_name": "Salesforce/wikitext", + "subset": "", + "split": "train", + "num_samples": "64", + } + result = build_calibration_args(calib) + assert "--subset" not in result + assert result == " -d Salesforce/wikitext --split train --max_samples 64" + + def test_local_source(self): + from olive.cli.init.helpers import SourceType, build_calibration_args + + calib = {"source": SourceType.LOCAL, "data_files": "/data/calib.json"} + result = build_calibration_args(calib) + assert result == " --data_files /data/calib.json" + + def test_unknown_source(self): + from olive.cli.init.helpers import build_calibration_args + + result = build_calibration_args({"source": "unknown"}) + assert result == "" + + +class TestPromptCalibrationSource: + @patch("olive.cli.init.helpers._ask") + def test_default_returns_none(self, mock_ask): + from olive.cli.init.helpers import SourceType, prompt_calibration_source + + mock_ask.return_value = SourceType.DEFAULT + result = prompt_calibration_source() + assert result is None + + @patch("olive.cli.init.helpers._ask") + def test_hf_source(self, mock_ask): + from olive.cli.init.helpers import SourceType, prompt_calibration_source + + mock_ask.side_effect = [SourceType.HF, "my_dataset", "my_subset", "validation", "64"] + result = prompt_calibration_source() + assert result == { + "source": SourceType.HF, + "data_name": "my_dataset", + "subset": "my_subset", + "split": "validation", + "num_samples": "64", + } + + @patch("olive.cli.init.helpers._ask") + def test_local_source(self, mock_ask): + from olive.cli.init.helpers import SourceType, prompt_calibration_source + + mock_ask.side_effect = [SourceType.LOCAL, "/data/calib.json"] + result = prompt_calibration_source() + assert result == {"source": SourceType.LOCAL, "data_files": "/data/calib.json"} + + +class TestAskHelpers: + @patch("olive.cli.init.helpers.sys.exit") + def test_ask_exits_on_none(self, mock_exit): + from olive.cli.init.helpers import _ask + + question = MagicMock() + question.ask.return_value = None + _ask(question) + mock_exit.assert_called_once_with(0) + + def test_ask_returns_value(self): + from olive.cli.init.helpers import _ask + + question = MagicMock() + question.ask.return_value = "hello" + assert _ask(question) == "hello" + + @patch("olive.cli.init.helpers._ask") + def test_ask_select_raises_go_back(self, mock_ask): + from olive.cli.init.helpers import GoBackError, _ask_select + + mock_ask.return_value = "__back__" + with pytest.raises(GoBackError): + _ask_select("Pick one:", choices=["a", "b"]) + + @patch("olive.cli.init.helpers._ask") + def test_ask_select_returns_value(self, mock_ask): + from olive.cli.init.helpers import _ask_select + + mock_ask.return_value = "a" + result = _ask_select("Pick one:", choices=["a", "b"]) + assert result == "a" + + @patch("olive.cli.init.helpers._ask") + def test_ask_select_no_back(self, mock_ask): + from olive.cli.init.helpers import _ask_select + + mock_ask.return_value = "a" + result = _ask_select("Pick:", choices=["a"], allow_back=False) + assert result == "a" + + +class TestInitWizard: + """Test InitWizard end-to-end. + + The wizard dispatches to onnx_flow which imports _ask/_ask_select at module + level, so we must patch both wizard and onnx_flow references. + """ + + @patch("olive.cli.init.onnx_flow._ask_select") + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_full_flow_generate_command(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select): + from olive.cli.init.onnx_flow import OnnxOperation + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION + mock_select.side_effect = [ModelType.ONNX, OutputAction.COMMAND] + mock_ask.side_effect = ["/model.onnx", "./output", False] + + InitWizard().start() + mock_subprocess.assert_not_called() + + @patch("olive.cli.init.onnx_flow._ask_select") + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_full_flow_run_now(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select): + from olive.cli.init.onnx_flow import OnnxOperation + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION + mock_select.side_effect = [ModelType.ONNX, OutputAction.RUN] + mock_ask.side_effect = ["/model.onnx", "./output"] + + InitWizard().start() + mock_subprocess.assert_called_once() + + @patch("olive.cli.init.onnx_flow._ask_select") + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_full_flow_generate_config(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select): + from olive.cli.init.onnx_flow import OnnxOperation + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION + mock_select.side_effect = [ModelType.ONNX, OutputAction.CONFIG] + mock_ask.side_effect = ["/model.onnx", "./output"] + + InitWizard().start() + mock_subprocess.assert_called_once() + cmd = mock_subprocess.call_args[0][0] + assert "--save_config_file" in cmd + assert "--dry_run" in cmd + + @patch("olive.cli.init.onnx_flow._ask_select") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_go_back(self, mock_select, mock_ask, mock_onnx_select): + from olive.cli.init.helpers import GoBackError + from olive.cli.init.onnx_flow import OnnxOperation + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION + select_values = [ModelType.ONNX, GoBackError, ModelType.ONNX, OutputAction.COMMAND] + + def select_with_goback(*args, **kwargs): + val = select_values.pop(0) + if val is GoBackError: + raise GoBackError + return val + + mock_select.side_effect = select_with_goback + mock_ask.side_effect = ["/model.onnx", "./output", False] + + InitWizard().start() + + @patch("olive.cli.init.wizard.sys.exit") + @patch("olive.cli.init.wizard._ask_select") + def test_keyboard_interrupt(self, mock_select, mock_exit): + from olive.cli.init.wizard import InitWizard + + mock_select.side_effect = KeyboardInterrupt + InitWizard().start() + mock_exit.assert_called_once_with(0) + + @patch("olive.cli.init.onnx_flow._ask_select") + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_command_then_run_now(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select): + from olive.cli.init.onnx_flow import OnnxOperation + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION + mock_select.side_effect = [ModelType.ONNX, OutputAction.COMMAND] + mock_ask.side_effect = ["/model.onnx", "./output", True] + + InitWizard().start() + mock_subprocess.assert_called_once() + + @patch("olive.cli.init.wizard.Path") + @patch("olive.cli.init.onnx_flow._ask_select") + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_config_with_existing_file(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select, mock_path): + from olive.cli.init.onnx_flow import OnnxOperation + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION + mock_select.side_effect = [ModelType.ONNX, OutputAction.CONFIG] + mock_ask.side_effect = ["/model.onnx", "./output"] + mock_path.return_value.__truediv__ = lambda self, x: MagicMock(exists=lambda: True) + + InitWizard().start() + mock_subprocess.assert_called_once() + + def test_no_command_raises_go_back(self): + """When _run_model_flow returns {} the wizard should go back, not silently finish.""" + from olive.cli.init.helpers import GoBackError + from olive.cli.init.wizard import InitWizard + + wizard = InitWizard() + result = {"not_command": True} # no "command" key + with pytest.raises(GoBackError): + wizard._prompt_output(result) # pylint: disable=protected-access + + @patch("olive.cli.init.pytorch_flow._ask_select") + @patch("olive.cli.init.pytorch_flow._optimize_flow", return_value={"command": "olive optimize -m m"}) + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_pytorch_flow_dispatch(self, mock_select, mock_ask, mock_subprocess, mock_opt, mock_pt_select): + from olive.cli.init.helpers import SourceType + from olive.cli.init.pytorch_flow import OP_OPTIMIZE + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_pt_select.return_value = OP_OPTIMIZE + mock_select.side_effect = [ModelType.PYTORCH, SourceType.HF, OutputAction.RUN] + mock_ask.side_effect = ["meta-llama/Llama-3.1-8B", "./output"] + + InitWizard().start() + mock_subprocess.assert_called_once() + + @patch("olive.cli.init.diffusers_flow._ask_select") + @patch("olive.cli.init.diffusers_flow._ask") + @patch("olive.cli.init.wizard.subprocess.run") + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_diffusers_flow_dispatch(self, mock_select, mock_ask, mock_subprocess, mock_diff_ask, mock_diff_select): + from olive.cli.init.diffusers_flow import OP_EXPORT + from olive.cli.init.helpers import DiffuserVariant + from olive.cli.init.wizard import InitWizard, ModelType, OutputAction + + mock_diff_select.return_value = OP_EXPORT + mock_select.side_effect = [ModelType.DIFFUSERS, DiffuserVariant.AUTO, OutputAction.RUN] + mock_ask.side_effect = ["my-model", "./output"] + mock_diff_ask.return_value = "float16" + + InitWizard().start() + mock_subprocess.assert_called_once() + + +class TestPromptPytorchSource: + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_hf_source(self, mock_select, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.wizard import InitWizard + + mock_select.return_value = SourceType.HF + mock_ask.return_value = "meta-llama/Llama-3.1-8B" + result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access + assert result == {"source_type": SourceType.HF, "model_path": "meta-llama/Llama-3.1-8B"} + + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_local_source(self, mock_select, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.wizard import InitWizard + + mock_select.return_value = SourceType.LOCAL + mock_ask.return_value = "./my-model/" + result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access + assert result == {"source_type": SourceType.LOCAL, "model_path": "./my-model/"} + + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_azureml_source(self, mock_select, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.wizard import InitWizard + + mock_select.return_value = SourceType.AZUREML + mock_ask.return_value = "azureml://registries/r/models/m/versions/1" + result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access + assert result == {"source_type": SourceType.AZUREML, "model_path": "azureml://registries/r/models/m/versions/1"} + + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_script_source_full(self, mock_select, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.wizard import InitWizard + + mock_select.return_value = SourceType.SCRIPT + mock_ask.side_effect = ["train.py", "./src", "my-model"] + result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access + assert result == { + "source_type": SourceType.SCRIPT, + "model_script": "train.py", + "script_dir": "./src", + "model_path": "my-model", + } + + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_script_source_minimal(self, mock_select, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.wizard import InitWizard + + mock_select.return_value = SourceType.SCRIPT + mock_ask.side_effect = ["train.py", "", ""] # no script_dir, no model_path + result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access + assert result == {"source_type": SourceType.SCRIPT, "model_script": "train.py"} + assert "script_dir" not in result + assert "model_path" not in result + + +class TestPromptDiffusersSource: + @patch("olive.cli.init.wizard._ask") + @patch("olive.cli.init.wizard._ask_select") + def test_diffusers_source(self, mock_select, mock_ask): + from olive.cli.init.helpers import SourceType + from olive.cli.init.wizard import InitWizard + + mock_select.return_value = "sdxl" + mock_ask.return_value = "stabilityai/sdxl-base-1.0" + result = InitWizard()._prompt_diffusers_source() # pylint: disable=protected-access + assert result == { + "source_type": SourceType.HF, + "model_path": "stabilityai/sdxl-base-1.0", + "variant": "sdxl", + } + + +class TestPromptModelSource: + def test_unknown_model_type_returns_empty(self): + from olive.cli.init.wizard import InitWizard + + result = InitWizard()._prompt_model_source("unknown") # pylint: disable=protected-access + assert not result + + +class TestRunModelFlow: + def test_unknown_model_type_returns_empty(self): + from olive.cli.init.wizard import InitWizard + + result = InitWizard()._run_model_flow("unknown", {}) # pylint: disable=protected-access + assert not result