diff --git a/docs/source/blogs/index.md b/docs/source/blogs/index.md
index d71fce441b..5df3bdb279 100644
--- a/docs/source/blogs/index.md
+++ b/docs/source/blogs/index.md
@@ -4,6 +4,9 @@
```{gallery-grid}
:grid-columns: 1 2 2 3
+- header: "{octicon}`terminal` Introducing olive init: An Interactive Wizard for Model Optimization"
+ content: "Get started with Olive in seconds — the new `olive init` wizard guides you step-by-step from model selection to a ready-to-run optimization command.
{octicon}`arrow-right` [Introducing olive init](olive-init-cli.md)"
+
- header: "{octicon}`cpu` Exploring Optimal Quantization Settings for Small Language Models"
content: "An exploration of how Olive applies different quantization strategies such as GPTQ, mixed precision, and QuaRot to optimize small language models for efficiency and accuracy.
{octicon}`arrow-right` [Exploring Optimal Quantization Settings for Small Language Models](quant-slms.md)"
@@ -16,6 +19,7 @@
:maxdepth: 2
:hidden:
+olive-init-cli.md
quant-slms.md
sd-lora.md
```
diff --git a/docs/source/blogs/olive-init-cli.md b/docs/source/blogs/olive-init-cli.md
new file mode 100644
index 0000000000..3cf838fd69
--- /dev/null
+++ b/docs/source/blogs/olive-init-cli.md
@@ -0,0 +1,221 @@
+# Introducing `olive init`: An Interactive Wizard for Model Optimization
+
+*Author: Xiaoyu Zhang*
+*Created: 2026-04-06*
+
+Getting started with AI model optimization can be overwhelming — choosing the right exporter, quantization algorithm, precision, and target hardware involves navigating a complex decision space. The new **`olive init`** command solves this with an interactive, step-by-step wizard that guides you from model selection to a ready-to-run optimization command.
+
+---
+
+## Why `olive init`?
+
+Olive offers a powerful set of CLI commands — `optimize`, `quantize`, `capture-onnx-graph`, `finetune`, `diffusion-lora`, and more — each with many options. While this flexibility is great for experts, it can be daunting for newcomers. Common questions include:
+
+- *Which command should I use?*
+- *What exporter is best for my LLM?*
+- *Which quantization algorithm should I pick?*
+- *Do I need calibration data?*
+
+**`olive init`** answers all of these by walking you through a guided wizard. It asks the right questions, provides sensible defaults, and generates the exact CLI command or JSON config you need.
+
+---
+
+## Quick Start
+
+```bash
+pip install olive-ai
+olive init
+```
+
+That's it! The wizard launches in your terminal and walks you through every step.
+
+---
+
+## How It Works
+
+The wizard follows a simple 4-step flow:
+
+### Step 1: Choose Your Model Type
+
+```
+? What type of model do you want to optimize?
+❯ PyTorch (HuggingFace or local)
+ ONNX
+ Diffusers (Stable Diffusion, SDXL, Flux, etc.)
+```
+
+### Step 2: Specify Your Model
+
+Depending on the model type, you can provide:
+
+- **HuggingFace model name** (e.g., `meta-llama/Llama-3.1-8B`)
+- **Local directory path**
+- **AzureML registry path**
+- **PyTorch model with custom script**
+
+### Step 3: Configure Your Workflow
+
+This is where the wizard really shines. Based on your model type, it presents relevant operations and guides you through the configuration:
+
+::::{tab-set}
+
+:::{tab-item} PyTorch Models
+
+**Available operations:**
+
+| Operation | Description |
+|-----------|-------------|
+| Optimize | Export to ONNX + quantize + graph optimize (all-in-one) |
+| Export to ONNX | Convert to ONNX format using Model Builder, Dynamo, or TorchScript |
+| Quantize | Apply PyTorch quantization (RTN, GPTQ, AWQ, QuaRot, SpinQuant) |
+| Fine-tune | LoRA or QLoRA fine-tuning on custom datasets |
+
+For the **Optimize** operation, you can choose between:
+
+- **Auto Mode** — Olive automatically selects the best passes for your target hardware and precision
+- **Custom Mode** — Manually pick which operations to include (export, quantize, graph optimization) and configure each one
+
+:::
+
+:::{tab-item} ONNX Models
+
+**Available operations:**
+
+| Operation | Description |
+|-----------|-------------|
+| Optimize | Auto-select best passes for target hardware |
+| Quantize | Static, dynamic, block-wise RTN, HQQ, or BnB quantization |
+| Graph optimization | Apply ONNX graph-level optimizations |
+| Convert precision | FP32 → FP16 conversion |
+| Tune session params | Auto-tune ONNX Runtime inference parameters |
+
+:::
+
+:::{tab-item} Diffusers Models
+
+**Available operations:**
+
+| Operation | Description |
+|-----------|-------------|
+| Export to ONNX | Export diffusion pipeline for ONNX Runtime deployment |
+| LoRA Training | Fine-tune with LoRA on custom images (DreamBooth supported) |
+
+**Supported architectures:** SD 1.x/2.x, SDXL, SD3, Flux, Sana
+
+:::
+
+::::
+
+### Step 4: Choose Your Output
+
+```
+? What would you like to do?
+❯ Generate CLI command (copy and run later)
+ Generate configuration file (JSON, for olive run)
+ Run optimization now
+```
+
+You can generate the command to review first, save a reusable JSON config, or execute immediately.
+
+---
+
+## Examples
+
+### Example 1: Optimize a HuggingFace LLM for CPU with INT4
+
+```
+$ olive init
+
+Welcome to Olive Init! This wizard will help you optimize your model.
+
+? What type of model do you want to optimize? PyTorch (HuggingFace or local)
+? How would you like to specify your model? HuggingFace model name
+? Model name or path: Qwen/Qwen2.5-0.5B-Instruct
+? What do you want to do? Optimize model (export to ONNX + quantize + graph optimize)
+? How would you like to configure optimization? Auto Mode (recommended)
+? Select target device: CPU
+? Select target precision: INT4 (smallest size, best for LLMs)
+? Output directory: ./olive-output
+? What would you like to do? Generate CLI command (copy and run later)
+
+Generated command:
+
+ olive optimize -m Qwen/Qwen2.5-0.5B-Instruct --provider CPUExecutionProvider --precision int4 -o ./olive-output
+```
+
+### Example 2: Quantize a PyTorch Model with GPTQ
+
+```
+$ olive init
+
+? What type of model do you want to optimize? PyTorch (HuggingFace or local)
+? How would you like to specify your model? HuggingFace model name
+? Model name or path: meta-llama/Llama-3.1-8B
+? What do you want to do? Quantize only (PyTorch quantization)
+? Select quantization algorithm: GPTQ - High quality, requires calibration
+? Precision: int4
+? Calibration data source: Use default (wikitext-2)
+? Output directory: ./olive-output
+? What would you like to do? Run optimization now
+```
+
+### Example 3: Fine-tune with LoRA
+
+```
+$ olive init
+
+? What type of model do you want to optimize? PyTorch (HuggingFace or local)
+? How would you like to specify your model? HuggingFace model name
+? Model name or path: microsoft/Phi-4-mini-instruct
+? What do you want to do? Fine-tune model (LoRA, QLoRA)
+? Select fine-tuning method: LoRA (recommended)
+? LoRA rank (r): 64 (default)
+? LoRA alpha: 16
+? Training dataset: HuggingFace dataset
+? Dataset name: tatsu-lab/alpaca
+? Train split: train
+? How to construct training text? Use chat template
+? Max sequence length: 1024
+? Max training samples: 256
+? Torch dtype for training: bfloat16 (recommended)
+? Output directory: ./olive-output
+? What would you like to do? Generate CLI command (copy and run later)
+
+Generated command:
+
+ olive finetune -m microsoft/Phi-4-mini-instruct --method lora --lora_r 64 --lora_alpha 16 -d tatsu-lab/alpaca --train_split train --use_chat_template --max_seq_len 1024 --max_samples 256 --torch_dtype bfloat16 -o ./olive-output
+```
+
+### Example 4: Train a Diffusion LoRA with DreamBooth
+
+```
+$ olive init
+
+? What type of model do you want to optimize? Diffusers (Stable Diffusion, SDXL, Flux, etc.)
+? Select diffuser model variant: Stable Diffusion XL (SDXL)
+? Enter model name or path: stabilityai/stable-diffusion-xl-base-1.0
+? What do you want to do? LoRA Training (fine-tune on custom images)
+? LoRA rank (r): 16 (recommended)
+? LoRA alpha: 16
+? Training data source: Local image folder
+? Path to image folder: ./my-dog-photos
+? Enable DreamBooth training? Yes
+? Instance prompt: a photo of sks dog
+? Enable prior preservation? Yes
+? Class prompt: a photo of a dog
+? Max training steps: 1000 (recommended)
+? Output directory: ./olive-output
+? What would you like to do? Generate CLI command (copy and run later)
+
+Generated command:
+
+ olive diffusion-lora -m stabilityai/stable-diffusion-xl-base-1.0 --model_variant sdxl -r 16 --alpha 16 --lora_dropout 0.0 -d ./my-dog-photos --dreambooth --instance_prompt "a photo of sks dog" --with_prior_preservation --class_prompt "a photo of a dog" --num_class_images 200 --max_train_steps 1000 --learning_rate 1e-4 --train_batch_size 1 --gradient_accumulation_steps 4 --mixed_precision bf16 --lr_scheduler constant --lr_warmup_steps 0 -o ./olive-output
+```
+
+---
+
+## Related Resources
+
+- [Olive CLI Reference](https://microsoft.github.io/Olive/reference/cli.html)
+- [Olive Getting Started Guide](https://microsoft.github.io/Olive/getting-started/getting-started.html)
+- [Olive GitHub Repository](https://github.com/microsoft/Olive)
diff --git a/olive/cli/init/__init__.py b/olive/cli/init/__init__.py
new file mode 100644
index 0000000000..26ae8d3562
--- /dev/null
+++ b/olive/cli/init/__init__.py
@@ -0,0 +1,30 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+from argparse import ArgumentParser
+
+from olive.cli.base import BaseOliveCLICommand
+
+
+class InitCommand(BaseOliveCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ sub_parser = parser.add_parser(
+ "init",
+ help="Interactive wizard to configure and generate Olive optimization commands.",
+ )
+ sub_parser.add_argument(
+ "-o",
+ "--output_path",
+ type=str,
+ default="./olive-output",
+ help="Default output directory for the generated command. Default is ./olive-output.",
+ )
+ sub_parser.set_defaults(func=InitCommand)
+
+ def run(self):
+ from olive.cli.init.wizard import InitWizard
+
+ wizard = InitWizard(default_output_path=self.args.output_path)
+ wizard.start()
diff --git a/olive/cli/init/diffusers_flow.py b/olive/cli/init/diffusers_flow.py
new file mode 100644
index 0000000000..04f2c3fd02
--- /dev/null
+++ b/olive/cli/init/diffusers_flow.py
@@ -0,0 +1,220 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import questionary
+
+from olive.cli.init.helpers import DiffuserVariant, SourceType, _ask, _ask_select
+
+# Diffusers operations
+OP_EXPORT = "export"
+OP_LORA = "lora"
+
+# Training
+TRAIN_STEPS_CUSTOM = "custom"
+
+
+def run_diffusers_flow(model_config):
+ model_path = model_config.get("model_path", "")
+ variant = model_config.get("variant", DiffuserVariant.AUTO)
+
+ operation = _ask_select(
+ "What do you want to do?",
+ choices=[
+ questionary.Choice("Export to ONNX (for deployment with ONNX Runtime)", value=OP_EXPORT),
+ questionary.Choice("LoRA Training (fine-tune on custom images)", value=OP_LORA),
+ ],
+ )
+
+ if operation == OP_EXPORT:
+ return _export_flow(model_path, variant)
+ elif operation == OP_LORA:
+ return _lora_flow(model_path, variant)
+ return {}
+
+
+def _export_flow(model_path, variant):
+ torch_dtype = _ask(
+ questionary.select(
+ "Torch dtype:",
+ choices=[
+ questionary.Choice("float16", value="float16"),
+ questionary.Choice("float32", value="float32"),
+ ],
+ )
+ )
+
+ cmd = f"olive capture-onnx-graph -m {model_path} --torch_dtype {torch_dtype}"
+ if variant != DiffuserVariant.AUTO:
+ cmd += f" --model_variant {variant}"
+
+ return {"command": cmd}
+
+
+def _lora_flow(model_path, variant):
+ # LoRA parameters
+ lora_r = _ask(
+ questionary.select(
+ "LoRA rank (r):",
+ choices=[
+ questionary.Choice("16 (recommended)", value="16"),
+ questionary.Choice("4", value="4"),
+ questionary.Choice("8", value="8"),
+ questionary.Choice("32", value="32"),
+ questionary.Choice("64", value="64"),
+ ],
+ )
+ )
+
+ lora_alpha = _ask(
+ questionary.text(
+ "LoRA alpha (default = same as rank):",
+ default=lora_r,
+ )
+ )
+
+ lora_dropout = _ask(questionary.text("LoRA dropout:", default="0.0"))
+
+ # Data source
+ data_source = _ask(
+ questionary.select(
+ "Training data source:",
+ choices=[
+ questionary.Choice("Local image folder", value=SourceType.LOCAL),
+ questionary.Choice("HuggingFace dataset", value=SourceType.HF),
+ ],
+ )
+ )
+
+ data_args = ""
+ if data_source == SourceType.LOCAL:
+ data_dir = _ask(
+ questionary.text(
+ "Path to image folder:",
+ validate=lambda x: True if x.strip() else "Please enter a path",
+ )
+ )
+ data_args = f" -d {data_dir}"
+ else:
+ data_name = _ask(
+ questionary.text(
+ "Dataset name:",
+ instruction="e.g., linoyts/Tuxemon",
+ )
+ )
+ data_split = _ask(questionary.text("Split:", default="train"))
+ image_column = _ask(questionary.text("Image column name:", default="image"))
+ caption_column = _ask(questionary.text("Caption column name (optional):", default=""))
+
+ data_args = f" --data_name {data_name} --data_split {data_split} --image_column {image_column}"
+ if caption_column:
+ data_args += f" --caption_column {caption_column}"
+
+ # DreamBooth
+ dreambooth_args = ""
+ enable_dreambooth = _ask(questionary.confirm("Enable DreamBooth training?", default=False))
+ if enable_dreambooth:
+ dreambooth_args = " --dreambooth"
+
+ instance_prompt = _ask(
+ questionary.text(
+ "Instance prompt (e.g., 'a photo of sks dog'):",
+ validate=lambda x: True if x.strip() else "Instance prompt is required for DreamBooth",
+ )
+ )
+ dreambooth_args += f' --instance_prompt "{instance_prompt}"'
+
+ with_prior = _ask(questionary.confirm("Enable prior preservation?", default=True))
+ if with_prior:
+ class_prompt = _ask(
+ questionary.text(
+ "Class prompt (e.g., 'a photo of a dog'):",
+ validate=lambda x: True if x.strip() else "Class prompt is required for prior preservation",
+ )
+ )
+ dreambooth_args += f' --with_prior_preservation --class_prompt "{class_prompt}"'
+
+ class_data_dir = _ask(questionary.text("Class data directory (optional):", default=""))
+ if class_data_dir:
+ dreambooth_args += f" --class_data_dir {class_data_dir}"
+
+ num_class_images = _ask(questionary.text("Number of class images:", default="200"))
+ dreambooth_args += f" --num_class_images {num_class_images}"
+
+ # Training parameters
+ max_train_steps = _ask(
+ questionary.select(
+ "Max training steps:",
+ choices=[
+ questionary.Choice("1000 (recommended)", value="1000"),
+ questionary.Choice("500 (quick)", value="500"),
+ questionary.Choice("2000 (thorough)", value="2000"),
+ questionary.Choice("Custom", value=TRAIN_STEPS_CUSTOM),
+ ],
+ )
+ )
+ if max_train_steps == TRAIN_STEPS_CUSTOM:
+ max_train_steps = _ask(questionary.text("Enter max training steps:"))
+
+ learning_rate = _ask(questionary.text("Learning rate:", default="1e-4"))
+ train_batch_size = _ask(questionary.text("Train batch size:", default="1"))
+ gradient_accumulation = _ask(questionary.text("Gradient accumulation steps:", default="4"))
+
+ mixed_precision = _ask(
+ questionary.select(
+ "Mixed precision:",
+ choices=[
+ questionary.Choice("bf16 (recommended)", value="bf16"),
+ questionary.Choice("fp16", value="fp16"),
+ questionary.Choice("no", value="no"),
+ ],
+ )
+ )
+
+ lr_scheduler = _ask(
+ questionary.select(
+ "Learning rate scheduler:",
+ choices=[
+ questionary.Choice("constant", value="constant"),
+ questionary.Choice("linear", value="linear"),
+ questionary.Choice("cosine", value="cosine"),
+ questionary.Choice("cosine_with_restarts", value="cosine_with_restarts"),
+ questionary.Choice("polynomial", value="polynomial"),
+ questionary.Choice("constant_with_warmup", value="constant_with_warmup"),
+ ],
+ )
+ )
+
+ warmup_steps = _ask(questionary.text("Warmup steps:", default="0"))
+ seed = _ask(questionary.text("Random seed (optional, press Enter to skip):", default=""))
+
+ # Flux-specific
+ flux_args = ""
+ if variant == DiffuserVariant.FLUX:
+ guidance_scale = _ask(questionary.text("Guidance scale (Flux-specific):", default="3.5"))
+ flux_args = f" --guidance_scale {guidance_scale}"
+
+ # Merge LoRA
+ merge_lora = _ask(questionary.confirm("Merge LoRA into base model?", default=False))
+
+ # Build command
+ cmd = f"olive diffusion-lora -m {model_path}"
+ if variant != DiffuserVariant.AUTO:
+ cmd += f" --model_variant {variant}"
+ cmd += f" -r {lora_r} --alpha {lora_alpha} --lora_dropout {lora_dropout}"
+ cmd += data_args
+ cmd += dreambooth_args
+ cmd += f" --max_train_steps {max_train_steps}"
+ cmd += f" --learning_rate {learning_rate}"
+ cmd += f" --train_batch_size {train_batch_size}"
+ cmd += f" --gradient_accumulation_steps {gradient_accumulation}"
+ cmd += f" --mixed_precision {mixed_precision}"
+ cmd += f" --lr_scheduler {lr_scheduler}"
+ cmd += f" --lr_warmup_steps {warmup_steps}"
+ if seed:
+ cmd += f" --seed {seed}"
+ cmd += flux_args
+ if merge_lora:
+ cmd += " --merge_lora"
+
+ return {"command": cmd}
diff --git a/olive/cli/init/helpers.py b/olive/cli/init/helpers.py
new file mode 100644
index 0000000000..86cd9e3b3d
--- /dev/null
+++ b/olive/cli/init/helpers.py
@@ -0,0 +1,126 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import sys
+
+from olive.common.utils import StrEnumBase
+
+
+class SourceType(StrEnumBase):
+ """Source types (shared across flows)."""
+
+ HF = "hf"
+ LOCAL = "local"
+ AZUREML = "azureml"
+ SCRIPT = "script"
+ DEFAULT = "default"
+
+
+class DiffuserVariant(StrEnumBase):
+ """Diffuser variants (only those used in routing)."""
+
+ AUTO = "auto"
+ FLUX = "flux"
+
+
+class GoBackError(Exception):
+ """Raised when user wants to go back to the previous wizard step."""
+
+
+_BACK = "__back__"
+
+
+def _device_choices():
+ """Return device choices aligned with olive optimize --provider."""
+ import questionary
+
+ return [
+ questionary.Choice("CPU", value="CPUExecutionProvider"),
+ questionary.Choice("GPU (NVIDIA CUDA)", value="CUDAExecutionProvider"),
+ questionary.Choice("GPU (NvTensorRTRTX)", value="NvTensorRTRTXExecutionProvider"),
+ questionary.Choice("NPU (Qualcomm QNN)", value="QNNExecutionProvider"),
+ questionary.Choice("NPU (Intel OpenVINO)", value="OpenVINOExecutionProvider"),
+ questionary.Choice("NPU (AMD Vitis AI)", value="VitisAIExecutionProvider"),
+ questionary.Choice("WebGPU", value="WebGpuExecutionProvider"),
+ ]
+
+
+def _precision_choices():
+ """Return precision choices aligned with olive optimize --precision."""
+ import questionary
+
+ return [
+ questionary.Choice("INT4 (smallest size, best for LLMs)", value="int4"),
+ questionary.Choice("INT8 (balanced)", value="int8"),
+ questionary.Choice("FP16 (half precision)", value="fp16"),
+ questionary.Choice("FP32 (full precision)", value="fp32"),
+ ]
+
+
+def _ask(question):
+ """Ask a questionary question and handle Ctrl+C (returns None)."""
+ result = question.ask()
+ if result is None:
+ sys.exit(0)
+ return result
+
+
+def _ask_select(message, choices, allow_back=True):
+ """Ask a select question with optional Back choice."""
+ import questionary
+
+ all_choices = list(choices)
+ if allow_back:
+ all_choices.append(questionary.Choice("\u2190 Back", value=_BACK))
+ result = _ask(questionary.select(message, choices=all_choices))
+ if result == _BACK:
+ raise GoBackError
+ return result
+
+
+def prompt_calibration_source():
+ """Prompt for calibration data source. Returns dict or None (for default)."""
+ import questionary
+
+ source = _ask(
+ questionary.select(
+ "Calibration data source:",
+ choices=[
+ questionary.Choice("Use default (wikitext-2)", value=SourceType.DEFAULT),
+ questionary.Choice("HuggingFace dataset", value=SourceType.HF),
+ questionary.Choice("Local file", value=SourceType.LOCAL),
+ ],
+ )
+ )
+
+ if source == SourceType.DEFAULT:
+ return None
+ elif source == SourceType.HF:
+ data_name = _ask(questionary.text("Dataset name:", default="Salesforce/wikitext"))
+ subset = _ask(questionary.text("Subset (optional):", default="wikitext-2-raw-v1"))
+ split = _ask(questionary.text("Split:", default="train"))
+ num_samples = _ask(questionary.text("Number of samples:", default="128"))
+ return {
+ "source": SourceType.HF,
+ "data_name": data_name,
+ "subset": subset,
+ "split": split,
+ "num_samples": num_samples,
+ }
+ else:
+ data_files = _ask(questionary.text("Data file path:"))
+ return {"source": SourceType.LOCAL, "data_files": data_files}
+
+
+def build_calibration_args(calibration):
+ """Build CLI args string from calibration config dict."""
+ if calibration["source"] == SourceType.HF:
+ result = f" -d {calibration['data_name']}"
+ if calibration.get("subset"):
+ result += f" --subset {calibration['subset']}"
+ result += f" --split {calibration['split']} --max_samples {calibration['num_samples']}"
+ return result
+ elif calibration["source"] == SourceType.LOCAL:
+ return f" --data_files {calibration['data_files']}"
+ return ""
diff --git a/olive/cli/init/onnx_flow.py b/olive/cli/init/onnx_flow.py
new file mode 100644
index 0000000000..283a040713
--- /dev/null
+++ b/olive/cli/init/onnx_flow.py
@@ -0,0 +1,170 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import questionary
+
+from olive.cli.init.helpers import (
+ _ask,
+ _ask_select,
+ _device_choices,
+ _precision_choices,
+ build_calibration_args,
+ prompt_calibration_source,
+)
+from olive.common.utils import StrEnumBase
+
+
+class OnnxOperation(StrEnumBase):
+ """ONNX operations."""
+
+ OPTIMIZE = "optimize"
+ QUANTIZE = "quantize"
+ GRAPH_OPT = "graph_opt"
+ CONVERT_PRECISION = "convert_precision"
+ TUNE_SESSION = "tune_session"
+
+
+class QuantizationType(StrEnumBase):
+ """Quantization types."""
+
+ STATIC = "static"
+ DYNAMIC = "dynamic"
+ BLOCKWISE_RTN = "blockwise_rtn"
+ HQQ = "hqq"
+ BNB = "bnb"
+
+
+def run_onnx_flow(model_config):
+ model_path = model_config.get("model_path", "")
+
+ operation = _ask_select(
+ "What do you want to do?",
+ choices=[
+ questionary.Choice(
+ "Optimize model (auto-select best passes for target hardware)", value=OnnxOperation.OPTIMIZE
+ ),
+ questionary.Choice("Quantize", value=OnnxOperation.QUANTIZE),
+ questionary.Choice("Graph optimization", value=OnnxOperation.GRAPH_OPT),
+ questionary.Choice("Convert precision (FP32 \u2192 FP16)", value=OnnxOperation.CONVERT_PRECISION),
+ questionary.Choice("Tune session parameters", value=OnnxOperation.TUNE_SESSION),
+ ],
+ )
+
+ if operation == OnnxOperation.OPTIMIZE:
+ return _optimize_flow(model_path)
+ elif operation == OnnxOperation.QUANTIZE:
+ return _quantize_flow(model_path)
+ elif operation == OnnxOperation.GRAPH_OPT:
+ return _graph_opt_flow(model_path)
+ elif operation == OnnxOperation.CONVERT_PRECISION:
+ return _convert_precision_flow(model_path)
+ elif operation == OnnxOperation.TUNE_SESSION:
+ return _tune_session_flow(model_path)
+ return {}
+
+
+def _optimize_flow(model_path):
+ provider = _ask(questionary.select("Select target device:", choices=_device_choices()))
+ precision = _ask(questionary.select("Select target precision:", choices=_precision_choices()))
+
+ cmd = f"olive optimize -m {model_path} --provider {provider} --precision {precision}"
+ return {"command": cmd}
+
+
+def _quantize_flow(model_path):
+ quant_type = _ask(
+ questionary.select(
+ "Select quantization type:",
+ choices=[
+ questionary.Choice(
+ "Static Quantization (INT8) - requires calibration data", value=QuantizationType.STATIC
+ ),
+ questionary.Choice(
+ "Dynamic Quantization (INT8) - no calibration needed", value=QuantizationType.DYNAMIC
+ ),
+ questionary.Choice(
+ "Block-wise RTN (INT4) - no calibration needed", value=QuantizationType.BLOCKWISE_RTN
+ ),
+ questionary.Choice("HQQ Quantization (INT4) - no calibration needed", value=QuantizationType.HQQ),
+ questionary.Choice("BnB Quantization (FP4/NF4) - no calibration needed", value=QuantizationType.BNB),
+ ],
+ )
+ )
+
+ # Map to olive quantize CLI args
+ quant_map = {
+ QuantizationType.STATIC: {"implementation": "ort", "precision": "int8"},
+ QuantizationType.DYNAMIC: {"implementation": "ort", "precision": "int8"},
+ QuantizationType.BLOCKWISE_RTN: {"implementation": "ort", "precision": "int4"},
+ QuantizationType.HQQ: {"implementation": "ort", "precision": "int4"},
+ QuantizationType.BNB: {"implementation": "bnb", "precision": "nf4"},
+ }
+
+ params = quant_map[quant_type]
+ cmd = (
+ f"olive quantize -m {model_path} --precision {params['precision']} --implementation {params['implementation']}"
+ )
+
+ if quant_type == QuantizationType.DYNAMIC:
+ cmd += " --algorithm rtn"
+
+ # Calibration for static quantization
+ if quant_type == QuantizationType.STATIC:
+ calib = prompt_calibration_source()
+ if calib:
+ cmd += build_calibration_args(calib)
+
+ return {"command": cmd}
+
+
+def _graph_opt_flow(model_path):
+ cmd = f"olive optimize -m {model_path} --precision fp32"
+ return {"command": cmd}
+
+
+def _convert_precision_flow(model_path):
+ cmd = f"olive run-pass --pass-name OnnxFloatToFloat16 -m {model_path}"
+ return {"command": cmd}
+
+
+def _tune_session_flow(model_path):
+ device = _ask(
+ questionary.select(
+ "Select target device:",
+ choices=[
+ questionary.Choice("CPU", value="cpu"),
+ questionary.Choice("GPU", value="gpu"),
+ ],
+ )
+ )
+
+ providers = _ask(
+ questionary.checkbox(
+ "Select execution providers:",
+ choices=[
+ questionary.Choice("CPUExecutionProvider", value="CPUExecutionProvider", checked=(device == "cpu")),
+ questionary.Choice("CUDAExecutionProvider", value="CUDAExecutionProvider", checked=(device == "gpu")),
+ questionary.Choice("TensorrtExecutionProvider", value="TensorrtExecutionProvider"),
+ ],
+ instruction="(Space to toggle, Enter to confirm)",
+ )
+ )
+
+ cmd = f"olive tune-session-params -m {model_path} --device {device}"
+ if providers:
+ cmd += " --providers_list " + " ".join(providers)
+
+ cpu_cores = _ask(questionary.text("CPU cores for thread tuning (optional, press Enter to skip):", default=""))
+ if cpu_cores:
+ cmd += f" --cpu_cores {cpu_cores}"
+
+ io_bind = _ask(questionary.confirm("Enable IO binding?", default=False))
+ if io_bind:
+ cmd += " --io_bind"
+
+ enable_cuda_graph = _ask(questionary.confirm("Enable CUDA graph?", default=False))
+ if enable_cuda_graph:
+ cmd += " --enable_cuda_graph"
+
+ return {"command": cmd}
diff --git a/olive/cli/init/pytorch_flow.py b/olive/cli/init/pytorch_flow.py
new file mode 100644
index 0000000000..df0dbca4e7
--- /dev/null
+++ b/olive/cli/init/pytorch_flow.py
@@ -0,0 +1,566 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import questionary
+
+from olive.cli.init.helpers import (
+ SourceType,
+ _ask,
+ _ask_select,
+ _device_choices,
+ _precision_choices,
+ build_calibration_args,
+ prompt_calibration_source,
+)
+
+# PyTorch operations
+OP_OPTIMIZE = "optimize"
+OP_EXPORT = "export"
+OP_QUANTIZE = "quantize"
+OP_FINETUNE = "finetune"
+OP_GRAPH_OPT = "graph_opt"
+
+# Optimize modes
+MODE_AUTO = "auto"
+MODE_CUSTOM = "custom"
+
+# Exporters
+EXPORTER_MODEL_BUILDER = "model_builder"
+EXPORTER_DYNAMO = "dynamo"
+EXPORTER_TORCHSCRIPT = "torchscript"
+
+# Text construction modes
+TEXT_FIELD = "text_field"
+TEXT_TEMPLATE = "template"
+TEXT_CHAT_TEMPLATE = "chat_template"
+
+# Precision (used in routing)
+PRECISION_INT4 = "int4"
+
+# Algorithms that require calibration data
+CALIBRATION_ALGORITHMS = {"gptq", "awq", "quarot", "spinquant"}
+
+# Algorithms that need a non-default --implementation value
+# (rtn and gptq use the default "olive" implementation, so they are omitted)
+ALGORITHM_TO_IMPLEMENTATION = {
+ "awq": "awq",
+ "quarot": "quarot",
+ "spinquant": "spinquant",
+}
+
+
+def _build_model_args(model_config):
+ """Build model-related CLI args from model config."""
+ parts = []
+ model_path = model_config.get("model_path")
+ if model_path:
+ parts.append(f"-m {model_path}")
+ if model_config.get("model_script"):
+ parts.append(f"--model_script {model_config['model_script']}")
+ if model_config.get("script_dir"):
+ parts.append(f"--script_dir {model_config['script_dir']}")
+ return " ".join(parts)
+
+
+def run_pytorch_flow(model_config):
+ operation = _ask_select(
+ "What do you want to do?",
+ choices=[
+ questionary.Choice("Optimize model (export to ONNX + quantize + graph optimize)", value=OP_OPTIMIZE),
+ questionary.Choice("Export to ONNX only", value=OP_EXPORT),
+ questionary.Choice("Quantize only (PyTorch quantization)", value=OP_QUANTIZE),
+ questionary.Choice("Fine-tune model (LoRA, QLoRA)", value=OP_FINETUNE),
+ ],
+ )
+
+ if operation == OP_OPTIMIZE:
+ return _optimize_flow(model_config)
+ elif operation == OP_EXPORT:
+ return _export_flow(model_config)
+ elif operation == OP_QUANTIZE:
+ return _quantize_flow(model_config)
+ elif operation == OP_FINETUNE:
+ return _finetune_flow(model_config)
+ return {}
+
+
+def _optimize_flow(model_config):
+ mode = _ask(
+ questionary.select(
+ "How would you like to configure optimization?",
+ choices=[
+ questionary.Choice(
+ "Auto Mode (recommended) - Automatically select best passes for your target", value=MODE_AUTO
+ ),
+ questionary.Choice("Custom Mode - Manually pick operations and parameters", value=MODE_CUSTOM),
+ ],
+ )
+ )
+
+ if mode == MODE_AUTO:
+ return _optimize_auto_mode(model_config)
+ else:
+ return _optimize_custom_mode(model_config)
+
+
+def _optimize_auto_mode(model_config):
+ model_args = _build_model_args(model_config)
+
+ provider = _ask(questionary.select("Select target device:", choices=_device_choices()))
+ precision = _ask(questionary.select("Select target precision:", choices=_precision_choices()))
+
+ cmd = f"olive optimize {model_args} --provider {provider} --precision {precision}"
+ return {"command": cmd}
+
+
+def _optimize_custom_mode(model_config):
+ model_args = _build_model_args(model_config)
+
+ operations = _ask(
+ questionary.checkbox(
+ "Select operations to perform:",
+ choices=[
+ questionary.Choice("Export to ONNX", value=OP_EXPORT, checked=True),
+ questionary.Choice("Quantize", value=OP_QUANTIZE, checked=True),
+ questionary.Choice("Graph Optimization", value=OP_GRAPH_OPT),
+ ],
+ instruction="(Space to toggle, Enter to confirm)",
+ )
+ )
+
+ if not operations:
+ print("No operations selected.")
+ return {}
+
+ export_config = None
+ quant_config = None
+
+ # Export options
+ if OP_EXPORT in operations:
+ export_config = _prompt_export_options()
+
+ # Quantize options
+ if OP_QUANTIZE in operations:
+ quant_config = _prompt_quantize_options()
+
+ has_export = OP_EXPORT in operations
+ has_quantize = OP_QUANTIZE in operations
+ has_graph_opt = OP_GRAPH_OPT in operations
+
+ if has_export and has_quantize:
+ # Combined export + quantize (±graph_opt) → use olive optimize with --exporter
+ provider = _ask(questionary.select("Select target device:", choices=_device_choices()))
+ precision = quant_config["precision"] if quant_config else "fp32"
+ cmd = f"olive optimize {model_args} --provider {provider} --precision {precision}"
+ # Pass exporter choice to olive optimize
+ if export_config:
+ exporter_map = {
+ EXPORTER_MODEL_BUILDER: "model_builder",
+ EXPORTER_DYNAMO: "dynamo_exporter",
+ EXPORTER_TORCHSCRIPT: "torchscript_exporter",
+ }
+ exporter_arg = exporter_map.get(export_config.get("exporter"), "model_builder")
+ cmd += f" --exporter {exporter_arg}"
+ # Warn that olive optimize auto-selects the quantization algorithm
+ algorithm = quant_config.get("algorithm") if quant_config else None
+ if algorithm:
+ print(
+ f"\nNote: 'olive optimize' automatically selects the quantization algorithm based on"
+ f" provider and precision. Your selection '{algorithm}' is used as a reference but the"
+ f" actual algorithm may differ. To use '{algorithm}' exactly, select 'Quantize only'"
+ f" instead.\n"
+ )
+ elif has_export and has_graph_opt:
+ # Export + graph_opt (no quantize) → olive optimize with fp32
+ provider = _ask(questionary.select("Select target device:", choices=_device_choices()))
+ cmd = f"olive optimize {model_args} --provider {provider} --precision fp32"
+ if export_config:
+ exporter_map = {
+ EXPORTER_MODEL_BUILDER: "model_builder",
+ EXPORTER_DYNAMO: "dynamo_exporter",
+ EXPORTER_TORCHSCRIPT: "torchscript_exporter",
+ }
+ exporter_arg = exporter_map.get(export_config.get("exporter"), "model_builder")
+ cmd += f" --exporter {exporter_arg}"
+ elif has_quantize and has_graph_opt:
+ # Quantize + graph_opt (no export, ONNX input assumed) → olive optimize
+ provider = _ask(questionary.select("Select target device:", choices=_device_choices()))
+ precision = quant_config["precision"] if quant_config else "fp32"
+ cmd = f"olive optimize {model_args} --provider {provider} --precision {precision}"
+ elif has_export and export_config:
+ # Export only → olive capture-onnx-graph with specific exporter options
+ cmd = _build_export_command(model_args, export_config)
+ elif has_quantize and quant_config:
+ # Quantize only → olive quantize with specific algorithm/precision/calibration
+ cmd = _build_quantize_command(model_args, quant_config)
+ elif has_graph_opt:
+ # Graph opt only
+ provider = _ask(questionary.select("Select target device:", choices=_device_choices()))
+ cmd = f"olive optimize {model_args} --provider {provider} --precision fp32"
+ else:
+ return {}
+
+ return {"command": cmd}
+
+
+def _build_export_command(model_args, export_config):
+ """Build olive capture-onnx-graph command from export config."""
+ exporter = export_config.get("exporter", EXPORTER_DYNAMO)
+
+ if exporter == EXPORTER_MODEL_BUILDER:
+ precision = export_config.get("precision", "fp16")
+ cmd = f"olive capture-onnx-graph {model_args} --use_model_builder --precision {precision}"
+ if precision == PRECISION_INT4 and "int4_block_size" in export_config:
+ cmd += f" --int4_block_size {export_config['int4_block_size']}"
+ elif exporter == EXPORTER_DYNAMO:
+ torch_dtype = export_config.get("torch_dtype", "float32")
+ cmd = f"olive capture-onnx-graph {model_args} --torch_dtype {torch_dtype}"
+ else:
+ cmd = f"olive capture-onnx-graph {model_args}"
+
+ return cmd
+
+
+def _build_quantize_command(model_args, quant_config):
+ """Build olive quantize command from quantize config."""
+ algorithm = quant_config.get("algorithm", "rtn")
+ precision = quant_config.get("precision", "int4")
+
+ cmd = f"olive quantize {model_args} --algorithm {algorithm} --precision {precision}"
+
+ # Add --implementation for algorithms that need a non-default one
+ impl = ALGORITHM_TO_IMPLEMENTATION.get(algorithm)
+ if impl:
+ cmd += f" --implementation {impl}"
+
+ calibration = quant_config.get("calibration")
+ if calibration:
+ cmd += build_calibration_args(calibration)
+
+ return cmd
+
+
+def _export_flow(model_config):
+ model_args = _build_model_args(model_config)
+
+ exporter = _ask(
+ questionary.select(
+ "Select exporter:",
+ choices=[
+ questionary.Choice("Model Builder (recommended for LLMs)", value=EXPORTER_MODEL_BUILDER),
+ questionary.Choice("Dynamo Exporter (general purpose)", value=EXPORTER_DYNAMO),
+ questionary.Choice("TorchScript Exporter (legacy)", value=EXPORTER_TORCHSCRIPT),
+ ],
+ )
+ )
+
+ if exporter == EXPORTER_MODEL_BUILDER:
+ precision = _ask(
+ questionary.select(
+ "Export precision:",
+ choices=[
+ questionary.Choice("fp16", value="fp16"),
+ questionary.Choice("fp32", value="fp32"),
+ questionary.Choice("bf16", value="bf16"),
+ questionary.Choice("int4", value=PRECISION_INT4),
+ ],
+ )
+ )
+
+ cmd = f"olive capture-onnx-graph {model_args} --use_model_builder --precision {precision}"
+
+ if precision == PRECISION_INT4:
+ block_size = _ask(
+ questionary.select(
+ "INT4 block size:",
+ choices=[
+ questionary.Choice("32 (recommended)", value="32"),
+ questionary.Choice("16", value="16"),
+ questionary.Choice("64", value="64"),
+ questionary.Choice("128", value="128"),
+ questionary.Choice("256", value="256"),
+ ],
+ )
+ )
+ cmd += f" --int4_block_size {block_size}"
+
+ accuracy_level = _ask(
+ questionary.select(
+ "INT4 accuracy level:",
+ choices=[
+ questionary.Choice("4 (int8, recommended)", value="4"),
+ questionary.Choice("1 (fp32)", value="1"),
+ questionary.Choice("2 (fp16)", value="2"),
+ questionary.Choice("3 (bf16)", value="3"),
+ ],
+ )
+ )
+ cmd += f" --int4_accuracy_level {accuracy_level}"
+
+ return {"command": cmd}
+
+ elif exporter == EXPORTER_DYNAMO:
+ torch_dtype = _ask(
+ questionary.select(
+ "Torch dtype:",
+ choices=[
+ questionary.Choice("fp32", value="float32"),
+ questionary.Choice("fp16", value="float16"),
+ ],
+ )
+ )
+
+ cmd = f"olive capture-onnx-graph {model_args} --torch_dtype {torch_dtype}"
+ return {"command": cmd}
+
+ else:
+ # TorchScript
+ cmd = f"olive capture-onnx-graph {model_args}"
+ return {"command": cmd}
+
+
+def _quantize_flow(model_config):
+ model_args = _build_model_args(model_config)
+
+ algorithm = _ask(
+ questionary.select(
+ "Select quantization algorithm:",
+ choices=[
+ questionary.Choice("RTN - Fast, no calibration needed", value="rtn"),
+ questionary.Choice("GPTQ - High quality, requires calibration", value="gptq"),
+ questionary.Choice("AWQ - Activation-aware, good for LLMs", value="awq"),
+ questionary.Choice("QuaRot - For QNN/VitisAI deployment", value="quarot"),
+ questionary.Choice("SpinQuant - Spin quantization", value="spinquant"),
+ ],
+ )
+ )
+
+ precision = _ask(
+ questionary.select(
+ "Precision:",
+ choices=[
+ questionary.Choice("int4", value="int4"),
+ questionary.Choice("uint4", value="uint4"),
+ questionary.Choice("int8", value="int8"),
+ ],
+ )
+ )
+
+ cmd = f"olive quantize {model_args} --algorithm {algorithm} --precision {precision}"
+
+ # Add --implementation for algorithms that need a non-default one
+ impl = ALGORITHM_TO_IMPLEMENTATION.get(algorithm)
+ if impl:
+ cmd += f" --implementation {impl}"
+
+ # Calibration data for algorithms that need it
+ if algorithm in CALIBRATION_ALGORITHMS:
+ calib = prompt_calibration_source()
+ if calib:
+ cmd += build_calibration_args(calib)
+
+ return {"command": cmd}
+
+
+def _finetune_flow(model_config):
+ model_args = _build_model_args(model_config)
+
+ method = _ask(
+ questionary.select(
+ "Select fine-tuning method:",
+ choices=[
+ questionary.Choice("LoRA (recommended)", value="lora"),
+ questionary.Choice("QLoRA (quantized, saves GPU memory)", value="qlora"),
+ ],
+ )
+ )
+
+ lora_r = _ask(
+ questionary.select(
+ "LoRA rank (r):",
+ choices=[
+ questionary.Choice("64 (default)", value="64"),
+ questionary.Choice("4", value="4"),
+ questionary.Choice("8", value="8"),
+ questionary.Choice("16", value="16"),
+ questionary.Choice("32", value="32"),
+ ],
+ )
+ )
+
+ lora_alpha = _ask(questionary.text("LoRA alpha:", default="16"))
+
+ # Dataset
+ data_source = _ask(
+ questionary.select(
+ "Training dataset:",
+ choices=[
+ questionary.Choice("HuggingFace dataset", value=SourceType.HF),
+ questionary.Choice("Local file", value=SourceType.LOCAL),
+ ],
+ )
+ )
+
+ cmd = f"olive finetune {model_args} --method {method} --lora_r {lora_r} --lora_alpha {lora_alpha}"
+
+ if data_source == SourceType.HF:
+ data_name = _ask(
+ questionary.text(
+ "Dataset name:",
+ default="tatsu-lab/alpaca",
+ )
+ )
+ train_split = _ask(questionary.text("Train split:", default="train"))
+ eval_split = _ask(questionary.text("Eval split (optional, press Enter to skip):", default=""))
+
+ cmd += f" -d {data_name} --train_split {train_split}"
+ if eval_split:
+ cmd += f" --eval_split {eval_split}"
+ else:
+ data_files = _ask(
+ questionary.text(
+ "Path to data file(s):",
+ validate=lambda x: True if x.strip() else "Please enter a file path",
+ )
+ )
+ cmd += f" -d nouse --data_files {data_files}"
+
+ # Text construction
+ text_mode = _ask(
+ questionary.select(
+ "How to construct training text?",
+ choices=[
+ questionary.Choice("Single text field (specify column name)", value=TEXT_FIELD),
+ questionary.Choice(
+ "Text template (e.g., '### Question: {prompt} \\n### Answer: {response}')", value=TEXT_TEMPLATE
+ ),
+ questionary.Choice("Use chat template", value=TEXT_CHAT_TEMPLATE),
+ ],
+ )
+ )
+
+ if text_mode == TEXT_FIELD:
+ text_field = _ask(questionary.text("Text field name:", default="text"))
+ cmd += f" --text_field {text_field}"
+ elif text_mode == TEXT_TEMPLATE:
+ template = _ask(questionary.text("Text template:"))
+ cmd += f' --text_template "{template}"'
+ else:
+ cmd += " --use_chat_template"
+
+ max_seq_len = _ask(questionary.text("Max sequence length:", default="1024"))
+ cmd += f" --max_seq_len {max_seq_len}"
+
+ max_samples = _ask(questionary.text("Max training samples:", default="256"))
+ cmd += f" --max_samples {max_samples}"
+
+ # Torch dtype
+ torch_dtype = _ask(
+ questionary.select(
+ "Torch dtype for training:",
+ choices=[
+ questionary.Choice("bfloat16 (recommended)", value="bfloat16"),
+ questionary.Choice("float16", value="float16"),
+ questionary.Choice("float32", value="float32"),
+ ],
+ )
+ )
+ cmd += f" --torch_dtype {torch_dtype}"
+
+ return {"command": cmd}
+
+
+def _prompt_export_options():
+ """Prompt export options for custom mode."""
+ exporter = _ask(
+ questionary.select(
+ "Select exporter:",
+ choices=[
+ questionary.Choice("Model Builder (recommended for LLMs)", value=EXPORTER_MODEL_BUILDER),
+ questionary.Choice("Dynamo Exporter (general purpose)", value=EXPORTER_DYNAMO),
+ questionary.Choice("TorchScript Exporter (legacy)", value=EXPORTER_TORCHSCRIPT),
+ ],
+ )
+ )
+
+ config = {"exporter": exporter}
+
+ if exporter == EXPORTER_MODEL_BUILDER:
+ precision = _ask(
+ questionary.select(
+ "Export precision:",
+ choices=[
+ questionary.Choice("fp16", value="fp16"),
+ questionary.Choice("fp32", value="fp32"),
+ questionary.Choice("bf16", value="bf16"),
+ questionary.Choice("int4", value=PRECISION_INT4),
+ ],
+ )
+ )
+ config["precision"] = precision
+
+ if precision == PRECISION_INT4:
+ block_size = _ask(
+ questionary.select(
+ "INT4 block size:",
+ choices=[
+ questionary.Choice("32 (recommended)", value="32"),
+ questionary.Choice("16", value="16"),
+ questionary.Choice("64", value="64"),
+ questionary.Choice("128", value="128"),
+ questionary.Choice("256", value="256"),
+ ],
+ )
+ )
+ config["int4_block_size"] = block_size
+
+ elif exporter == EXPORTER_DYNAMO:
+ torch_dtype = _ask(
+ questionary.select(
+ "Torch dtype:",
+ choices=[
+ questionary.Choice("fp32", value="float32"),
+ questionary.Choice("fp16", value="float16"),
+ ],
+ )
+ )
+ config["torch_dtype"] = torch_dtype
+
+ return config
+
+
+def _prompt_quantize_options():
+ """Prompt quantization options for custom mode."""
+ algorithm = _ask(
+ questionary.select(
+ "Select quantization algorithm:",
+ choices=[
+ questionary.Choice("RTN - Round-to-Nearest, fast, no calibration needed", value="rtn"),
+ questionary.Choice("GPTQ - High quality, requires calibration data", value="gptq"),
+ questionary.Choice("AWQ - Activation-aware, good for LLMs", value="awq"),
+ questionary.Choice("QuaRot - Rotation-based, for QNN/VitisAI deployment", value="quarot"),
+ questionary.Choice("SpinQuant - Spin quantization", value="spinquant"),
+ ],
+ )
+ )
+
+ precision = _ask(
+ questionary.select(
+ "Quantization precision:",
+ choices=[
+ questionary.Choice("int4", value="int4"),
+ questionary.Choice("uint4", value="uint4"),
+ questionary.Choice("int8", value="int8"),
+ ],
+ )
+ )
+
+ config = {"algorithm": algorithm, "precision": precision}
+
+ if algorithm in CALIBRATION_ALGORITHMS:
+ calib = prompt_calibration_source()
+ if calib:
+ config["calibration"] = calib
+
+ return config
diff --git a/olive/cli/init/wizard.py b/olive/cli/init/wizard.py
new file mode 100644
index 0000000000..74e31efb49
--- /dev/null
+++ b/olive/cli/init/wizard.py
@@ -0,0 +1,231 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+import subprocess
+import sys
+from pathlib import Path
+
+import questionary
+
+from olive.cli.init.helpers import (
+ DiffuserVariant,
+ GoBackError,
+ SourceType,
+ _ask,
+ _ask_select,
+)
+from olive.common.utils import StrEnumBase
+
+
+class ModelType(StrEnumBase):
+ """Model types."""
+
+ PYTORCH = "pytorch"
+ ONNX = "onnx"
+ DIFFUSERS = "diffusers"
+
+
+class OutputAction(StrEnumBase):
+ """Output actions."""
+
+ COMMAND = "command"
+ CONFIG = "config"
+ RUN = "run"
+
+
+class InitWizard:
+ def __init__(self, default_output_path: str = "./olive-output"):
+ self.default_output_path = default_output_path
+
+ def start(self):
+ print("\nWelcome to Olive Init! This wizard will help you optimize your model.\n")
+
+ try:
+ step = 0
+ model_type = None
+ model_config = None
+ result = None
+
+ while step < 4:
+ try:
+ if step == 0:
+ model_type = self._prompt_model_type()
+ elif step == 1:
+ model_config = self._prompt_model_source(model_type)
+ elif step == 2:
+ result = self._run_model_flow(model_type, model_config)
+ elif step == 3:
+ self._prompt_output(result)
+ step += 1
+ except GoBackError:
+ if step > 0:
+ step -= 1
+
+ except KeyboardInterrupt:
+ sys.exit(0)
+
+ def _prompt_model_type(self):
+ return _ask_select(
+ "What type of model do you want to optimize?",
+ choices=[
+ questionary.Choice("PyTorch (HuggingFace or local)", value=ModelType.PYTORCH),
+ questionary.Choice("ONNX", value=ModelType.ONNX),
+ questionary.Choice("Diffusers (Stable Diffusion, SDXL, Flux, etc.)", value=ModelType.DIFFUSERS),
+ ],
+ allow_back=False,
+ )
+
+ def _prompt_model_source(self, model_type):
+ if model_type == ModelType.PYTORCH:
+ return self._prompt_pytorch_source()
+ elif model_type == ModelType.ONNX:
+ return self._prompt_onnx_source()
+ elif model_type == ModelType.DIFFUSERS:
+ return self._prompt_diffusers_source()
+ return {}
+
+ def _prompt_pytorch_source(self):
+ source_type = _ask_select(
+ "How would you like to specify your model?",
+ choices=[
+ questionary.Choice("HuggingFace model name (e.g., meta-llama/Llama-3.1-8B)", value=SourceType.HF),
+ questionary.Choice("Local directory path", value=SourceType.LOCAL),
+ questionary.Choice("AzureML registry path", value=SourceType.AZUREML),
+ questionary.Choice("PyTorch model with custom script", value=SourceType.SCRIPT),
+ ],
+ )
+
+ config = {"source_type": source_type}
+
+ if source_type == SourceType.SCRIPT:
+ config["model_script"] = _ask(
+ questionary.path(
+ "Path to model script (.py):",
+ )
+ )
+ script_dir = _ask(
+ questionary.text(
+ "Script directory (optional, press Enter to skip):",
+ default="",
+ )
+ )
+ if script_dir:
+ config["script_dir"] = script_dir
+ model_path = _ask(
+ questionary.text(
+ "Model name or path (optional, press Enter to skip):",
+ default="",
+ )
+ )
+ if model_path:
+ config["model_path"] = model_path
+ else:
+ if source_type == SourceType.HF:
+ placeholder = "e.g., meta-llama/Llama-3.1-8B"
+ elif source_type == SourceType.AZUREML:
+ placeholder = "e.g., azureml://registries//models//versions/"
+ else:
+ placeholder = "e.g., ./my-model/"
+ config["model_path"] = _ask(
+ questionary.text(
+ "Model name or path:",
+ validate=lambda x: True if x.strip() else "Please enter a model name or path",
+ instruction=placeholder,
+ )
+ )
+
+ return config
+
+ def _prompt_onnx_source(self):
+ model_path = _ask(
+ questionary.text(
+ "Enter ONNX model path (file or directory):",
+ validate=lambda x: True if x.strip() else "Please enter a model path",
+ )
+ )
+ return {"source_type": SourceType.LOCAL, "model_path": model_path}
+
+ def _prompt_diffusers_source(self):
+ variant = _ask_select(
+ "Select diffuser model variant:",
+ choices=[
+ questionary.Choice("Auto-detect", value=DiffuserVariant.AUTO),
+ questionary.Choice("Stable Diffusion (SD 1.x/2.x)", value="sd"),
+ questionary.Choice("Stable Diffusion XL (SDXL)", value="sdxl"),
+ questionary.Choice("Stable Diffusion 3 (SD3)", value="sd3"),
+ questionary.Choice("Flux", value=DiffuserVariant.FLUX),
+ questionary.Choice("Sana", value="sana"),
+ ],
+ )
+
+ model_path = _ask(
+ questionary.text(
+ "Enter model name or path:",
+ validate=lambda x: True if x.strip() else "Please enter a model name or path",
+ instruction="e.g., stabilityai/stable-diffusion-xl-base-1.0",
+ )
+ )
+
+ return {"source_type": SourceType.HF, "model_path": model_path, "variant": variant}
+
+ def _run_model_flow(self, model_type, model_config):
+ if model_type == ModelType.PYTORCH:
+ from olive.cli.init.pytorch_flow import run_pytorch_flow
+
+ return run_pytorch_flow(model_config)
+ elif model_type == ModelType.ONNX:
+ from olive.cli.init.onnx_flow import run_onnx_flow
+
+ return run_onnx_flow(model_config)
+ elif model_type == ModelType.DIFFUSERS:
+ from olive.cli.init.diffusers_flow import run_diffusers_flow
+
+ return run_diffusers_flow(model_config)
+ return {}
+
+ def _prompt_output(self, result):
+ command_str = result.get("command")
+
+ if not command_str:
+ print("No command generated.")
+ raise GoBackError
+
+ output_dir = _ask(
+ questionary.text(
+ "Output directory:",
+ default=self.default_output_path,
+ )
+ )
+
+ # Append output dir to command if not already present
+ if " -o " not in command_str and " --output_path " not in command_str:
+ command_str += f" -o {output_dir}"
+
+ action = _ask_select(
+ "What would you like to do?",
+ choices=[
+ questionary.Choice("Generate CLI command (copy and run later)", value=OutputAction.COMMAND),
+ questionary.Choice("Generate configuration file (JSON, for olive run)", value=OutputAction.CONFIG),
+ questionary.Choice("Run optimization now", value=OutputAction.RUN),
+ ],
+ )
+
+ if action == OutputAction.COMMAND:
+ print(f"\nGenerated command:\n\n {command_str}\n")
+ run_now = _ask(questionary.confirm("Run this command now?", default=False))
+ if run_now:
+ print(f"\nRunning: {command_str}\n")
+ subprocess.run(command_str, shell=True, check=False)
+
+ elif action == OutputAction.CONFIG:
+ config_cmd = command_str + " --save_config_file --dry_run"
+ print("\nGenerating configuration file...\n")
+ subprocess.run(config_cmd, shell=True, check=False)
+ config_path = Path(output_dir) / "config.json"
+ if config_path.exists():
+ print(f"\nYou can run it later with:\n olive run --config {config_path}\n")
+
+ elif action == OutputAction.RUN:
+ print(f"\nRunning: {command_str}\n")
+ subprocess.run(command_str, shell=True, check=False)
diff --git a/olive/cli/launcher.py b/olive/cli/launcher.py
index d9088bc89b..fed339f87d 100644
--- a/olive/cli/launcher.py
+++ b/olive/cli/launcher.py
@@ -16,6 +16,7 @@
from olive.cli.finetune import FineTuneCommand
from olive.cli.generate_adapter import GenerateAdapterCommand
from olive.cli.generate_cost_model import GenerateCostModelCommand
+from olive.cli.init import InitCommand
from olive.cli.optimize import OptimizeCommand
from olive.cli.quantize import QuantizeCommand
from olive.cli.run import WorkflowRunCommand
@@ -37,6 +38,7 @@ def get_cli_parser(called_as_console_script: bool = True) -> ArgumentParser:
# Register commands
# TODO(jambayk): Consider adding a common tempdir option to all commands
# NOTE: The order of the commands is to organize the documentation better.
+ InitCommand.register_subcommand(commands_parser)
WorkflowRunCommand.register_subcommand(commands_parser)
RunPassCommand.register_subcommand(commands_parser)
AutoOptCommand.register_subcommand(commands_parser)
diff --git a/olive/cli/optimize.py b/olive/cli/optimize.py
index d80392ecf9..6d94c1407c 100644
--- a/olive/cli/optimize.py
+++ b/olive/cli/optimize.py
@@ -191,7 +191,7 @@ def register_subcommand(parser: ArgumentParser):
def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Optional[list] = None):
super().__init__(parser, args, unknown_args)
- self.need_wikitest_data_config = False
+ self.need_wikitext_data_config = False
self.is_hf_model = False # will be set in _get_run_config
# Pass enabled flags
@@ -202,7 +202,6 @@ def __init__(self, parser: ArgumentParser, args: Namespace, unknown_args: Option
self.enable_onnx_conversion = False
self.enable_optimum_openvino_conversion = False
self.enable_dynamic_to_fixed_shape = False
- self.enable_vitis_ai_preprocess = False
self.enable_onnx_io_datatype_converter = False
self.enable_openvino_io_update = False
self.enable_onnx_peephole_optimizer = False
@@ -304,7 +303,7 @@ def _update_system_config(self, config: dict[str, Any]):
config["target"] = "qnn_system"
def _add_data_config(self, config: dict[str, Any]):
- config["data_configs"] = WIKITEXT2_DATA_CONFIG_TEMPLATE if self.need_wikitest_data_config else []
+ config["data_configs"] = WIKITEXT2_DATA_CONFIG_TEMPLATE if self.need_wikitext_data_config else []
def _build_passes_config(self) -> dict[str, Any]:
passes_config = OrderedDict()
@@ -553,7 +552,7 @@ def _get_openvino_io_update_pass_config(self) -> dict[str, Any]:
def _enable_onnx_peephole_optimizer_pass(self) -> bool:
"""Return true if condition to add OnnxPeepholeOptimizer pass is met."""
- return self.args.exporter != "model_builder"
+ return not self.is_hf_model or self.args.exporter != "model_builder"
def _get_onnx_peephole_optimizer_pass_config(self) -> dict[str, Any]:
"""Return pass dictionary for OnnxPeepholeOptimizer pass."""
@@ -668,7 +667,7 @@ def _get_onnx_static_quantization_pass_config(self) -> dict[str, Any]:
# Add data_config for text modality
if self.args.modality == "text":
- self.need_wikitest_data_config = True
+ self.need_wikitext_data_config = True
config["data_config"] = "wikitext2_train"
return config
diff --git a/requirements.txt b/requirements.txt
index 400f0bfe7b..1032c72f97 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,6 +8,7 @@ optuna
pandas
pydantic>=2.0
pyyaml
+questionary
torch
torchmetrics>=1.0.0
transformers
diff --git a/test/cli/init/__init__.py b/test/cli/init/__init__.py
new file mode 100644
index 0000000000..54aa1f92bf
--- /dev/null
+++ b/test/cli/init/__init__.py
@@ -0,0 +1,4 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
diff --git a/test/cli/init/conftest.py b/test/cli/init/conftest.py
new file mode 100644
index 0000000000..a411298306
--- /dev/null
+++ b/test/cli/init/conftest.py
@@ -0,0 +1,24 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
+import sys
+from unittest.mock import MagicMock
+
+import pytest
+
+
+@pytest.fixture(autouse=True)
+def _mock_questionary(monkeypatch):
+ """Replace ``questionary`` with a MagicMock in every module that imports it."""
+ mock_q = MagicMock()
+
+ # Modules that do ``import questionary`` at the top level.
+ for mod in (
+ "olive.cli.init.wizard",
+ "olive.cli.init.onnx_flow",
+ "olive.cli.init.pytorch_flow",
+ "olive.cli.init.diffusers_flow",
+ ):
+ monkeypatch.setattr(f"{mod}.questionary", mock_q, raising=False)
+ monkeypatch.setitem(sys.modules, "questionary", mock_q)
diff --git a/test/cli/init/test_diffusers_flow.py b/test/cli/init/test_diffusers_flow.py
new file mode 100644
index 0000000000..4a45786ee3
--- /dev/null
+++ b/test/cli/init/test_diffusers_flow.py
@@ -0,0 +1,216 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
+from unittest.mock import patch
+
+
+class TestExportFlow:
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_export_with_variant(self, mock_ask):
+ from olive.cli.init.diffusers_flow import _export_flow
+
+ mock_ask.return_value = "float16"
+ result = _export_flow("stabilityai/sdxl", "sdxl")
+ cmd = result["command"]
+ assert "olive capture-onnx-graph -m stabilityai/sdxl" in cmd
+ assert "--torch_dtype float16" in cmd
+ assert "--model_variant sdxl" in cmd
+
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_export_auto_variant(self, mock_ask):
+ from olive.cli.init.diffusers_flow import _export_flow
+ from olive.cli.init.helpers import DiffuserVariant
+
+ mock_ask.return_value = "float32"
+ result = _export_flow("my-model", DiffuserVariant.AUTO)
+ assert "--model_variant" not in result["command"]
+
+
+class TestLoraFlow:
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_basic_lora_local_data(self, mock_ask):
+ from olive.cli.init.diffusers_flow import _lora_flow
+ from olive.cli.init.helpers import DiffuserVariant, SourceType
+
+ mock_ask.side_effect = [
+ "16", # lora_r
+ "16", # lora_alpha
+ "0.0", # lora_dropout
+ SourceType.LOCAL, # data_source
+ "/images", # data_dir
+ False, # enable_dreambooth
+ "1000", # max_train_steps
+ "1e-4", # learning_rate
+ "1", # train_batch_size
+ "4", # gradient_accumulation
+ "bf16", # mixed_precision
+ "constant", # lr_scheduler
+ "0", # warmup_steps
+ "", # seed (skip)
+ False, # merge_lora
+ ]
+ result = _lora_flow("my-model", DiffuserVariant.AUTO)
+ cmd = result["command"]
+ assert "olive diffusion-lora -m my-model" in cmd
+ assert "-r 16 --alpha 16" in cmd
+ assert "-d /images" in cmd
+ assert "--max_train_steps 1000" in cmd
+ assert "--model_variant" not in cmd
+
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_flux_with_dreambooth(self, mock_ask):
+ from olive.cli.init.diffusers_flow import _lora_flow
+ from olive.cli.init.helpers import DiffuserVariant, SourceType
+
+ mock_ask.side_effect = [
+ "16", # lora_r
+ "16", # lora_alpha
+ "0.1", # lora_dropout
+ SourceType.LOCAL, # data_source
+ "/images", # data_dir
+ True, # enable_dreambooth
+ "a photo of sks dog", # instance_prompt
+ True, # with_prior
+ "a photo of a dog", # class_prompt
+ "", # class_data_dir (skip)
+ "200", # num_class_images
+ "500", # max_train_steps
+ "1e-4", # learning_rate
+ "1", # train_batch_size
+ "4", # gradient_accumulation
+ "bf16", # mixed_precision
+ "constant", # lr_scheduler
+ "0", # warmup_steps
+ "", # seed (skip)
+ "3.5", # guidance_scale (flux-specific)
+ True, # merge_lora
+ ]
+ result = _lora_flow("my-flux-model", DiffuserVariant.FLUX)
+ cmd = result["command"]
+ assert f"--model_variant {DiffuserVariant.FLUX}" in cmd
+ assert "--dreambooth" in cmd
+ assert '--instance_prompt "a photo of sks dog"' in cmd
+ assert "--with_prior_preservation" in cmd
+ assert "--guidance_scale 3.5" in cmd
+ assert "--merge_lora" in cmd
+
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_hf_data_source_with_caption(self, mock_ask):
+ from olive.cli.init.diffusers_flow import _lora_flow
+ from olive.cli.init.helpers import DiffuserVariant, SourceType
+
+ mock_ask.side_effect = [
+ "16", # lora_r
+ "16", # lora_alpha
+ "0.0", # lora_dropout
+ SourceType.HF, # data_source
+ "linoyts/Tuxemon", # data_name
+ "train", # data_split
+ "image", # image_column
+ "caption", # caption_column
+ False, # enable_dreambooth
+ "1000", # max_train_steps
+ "1e-4", # learning_rate
+ "1", # train_batch_size
+ "4", # gradient_accumulation
+ "bf16", # mixed_precision
+ "constant", # lr_scheduler
+ "0", # warmup_steps
+ "42", # seed (provided)
+ False, # merge_lora
+ ]
+ result = _lora_flow("my-model", DiffuserVariant.AUTO)
+ cmd = result["command"]
+ assert "--data_name linoyts/Tuxemon" in cmd
+ assert "--data_split train" in cmd
+ assert "--image_column image" in cmd
+ assert "--caption_column caption" in cmd
+ assert "--seed 42" in cmd
+
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_custom_max_train_steps(self, mock_ask):
+ from olive.cli.init.diffusers_flow import TRAIN_STEPS_CUSTOM, _lora_flow
+ from olive.cli.init.helpers import DiffuserVariant, SourceType
+
+ mock_ask.side_effect = [
+ "16", # lora_r
+ "16", # lora_alpha
+ "0.0", # lora_dropout
+ SourceType.LOCAL, # data_source
+ "/images", # data_dir
+ False, # enable_dreambooth
+ TRAIN_STEPS_CUSTOM, # max_train_steps
+ "3000", # custom value
+ "1e-4", # learning_rate
+ "1", # train_batch_size
+ "4", # gradient_accumulation
+ "bf16", # mixed_precision
+ "constant", # lr_scheduler
+ "0", # warmup_steps
+ "", # seed (skip)
+ False, # merge_lora
+ ]
+ result = _lora_flow("my-model", DiffuserVariant.AUTO)
+ assert "--max_train_steps 3000" in result["command"]
+
+ @patch("olive.cli.init.diffusers_flow._ask")
+ def test_dreambooth_with_class_data_dir(self, mock_ask):
+ from olive.cli.init.diffusers_flow import _lora_flow
+ from olive.cli.init.helpers import DiffuserVariant, SourceType
+
+ mock_ask.side_effect = [
+ "16", # lora_r
+ "16", # lora_alpha
+ "0.0", # lora_dropout
+ SourceType.LOCAL, # data_source
+ "/images", # data_dir
+ True, # enable_dreambooth
+ "a photo of sks dog", # instance_prompt
+ True, # with_prior
+ "a photo of a dog", # class_prompt
+ "/class_images", # class_data_dir (provided)
+ "200", # num_class_images
+ "1000", # max_train_steps
+ "1e-4", # learning_rate
+ "1", # train_batch_size
+ "4", # gradient_accumulation
+ "bf16", # mixed_precision
+ "constant", # lr_scheduler
+ "0", # warmup_steps
+ "", # seed (skip)
+ False, # merge_lora
+ ]
+ result = _lora_flow("my-model", DiffuserVariant.AUTO)
+ assert "--class_data_dir /class_images" in result["command"]
+
+
+class TestRunDiffusersFlowRouting:
+ @patch("olive.cli.init.diffusers_flow._export_flow")
+ @patch("olive.cli.init.diffusers_flow._ask_select")
+ def test_routes_to_export(self, mock_select, mock_flow):
+ from olive.cli.init.diffusers_flow import OP_EXPORT, run_diffusers_flow
+
+ mock_select.return_value = OP_EXPORT
+ mock_flow.return_value = {"command": "test"}
+ run_diffusers_flow({"model_path": "m", "variant": "sdxl"})
+ mock_flow.assert_called_once_with("m", "sdxl")
+
+ @patch("olive.cli.init.diffusers_flow._lora_flow")
+ @patch("olive.cli.init.diffusers_flow._ask_select")
+ def test_routes_to_lora(self, mock_select, mock_flow):
+ from olive.cli.init.diffusers_flow import OP_LORA, run_diffusers_flow
+ from olive.cli.init.helpers import DiffuserVariant
+
+ mock_select.return_value = OP_LORA
+ mock_flow.return_value = {"command": "test"}
+ run_diffusers_flow({"model_path": "m", "variant": DiffuserVariant.FLUX})
+ mock_flow.assert_called_once_with("m", DiffuserVariant.FLUX)
+
+ @patch("olive.cli.init.diffusers_flow._ask_select", return_value="unknown")
+ def test_unknown_operation_returns_empty(self, mock_select):
+ from olive.cli.init.diffusers_flow import run_diffusers_flow
+ from olive.cli.init.helpers import DiffuserVariant
+
+ result = run_diffusers_flow({"model_path": "m", "variant": DiffuserVariant.AUTO})
+ assert not result
diff --git a/test/cli/init/test_init_command.py b/test/cli/init/test_init_command.py
new file mode 100644
index 0000000000..257f8c6dad
--- /dev/null
+++ b/test/cli/init/test_init_command.py
@@ -0,0 +1,37 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
+from unittest.mock import patch
+
+
+class TestInitCommand:
+ def test_register_subcommand(self):
+ from argparse import ArgumentParser
+
+ from olive.cli.init import InitCommand
+
+ parser = ArgumentParser()
+ sub_parsers = parser.add_subparsers()
+ InitCommand.register_subcommand(sub_parsers)
+
+ args = parser.parse_args(["init", "-o", "/tmp/out"])
+ assert args.output_path == "/tmp/out"
+ assert args.func is InitCommand
+
+ @patch("olive.cli.init.wizard.InitWizard")
+ def test_run(self, mock_wizard_cls):
+ from argparse import ArgumentParser
+
+ from olive.cli.init import InitCommand
+
+ parser = ArgumentParser()
+ sub_parsers = parser.add_subparsers()
+ InitCommand.register_subcommand(sub_parsers)
+
+ args = parser.parse_args(["init", "-o", "./my-output"])
+ cmd = InitCommand(parser, args, [])
+ cmd.run()
+
+ mock_wizard_cls.assert_called_once_with(default_output_path="./my-output")
+ mock_wizard_cls.return_value.start.assert_called_once()
diff --git a/test/cli/init/test_onnx_flow.py b/test/cli/init/test_onnx_flow.py
new file mode 100644
index 0000000000..c35ce93ea0
--- /dev/null
+++ b/test/cli/init/test_onnx_flow.py
@@ -0,0 +1,180 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
+from unittest.mock import patch
+
+
+class TestQuantizeFlow:
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_static_quantization_default_calib(self, mock_ask):
+ from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow
+
+ with patch("olive.cli.init.onnx_flow.prompt_calibration_source", return_value=None):
+ mock_ask.return_value = QuantizationType.STATIC
+ result = _quantize_flow("/model.onnx")
+ cmd = result["command"]
+ assert "--implementation ort" in cmd
+ assert "--precision int8" in cmd
+
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_dynamic_quantization(self, mock_ask):
+ from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow
+
+ mock_ask.return_value = QuantizationType.DYNAMIC
+ result = _quantize_flow("/model.onnx")
+ cmd = result["command"]
+ assert "--algorithm rtn" in cmd
+ assert "--implementation ort" in cmd
+
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_bnb_quantization(self, mock_ask):
+ from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow
+
+ mock_ask.return_value = QuantizationType.BNB
+ result = _quantize_flow("/model.onnx")
+ cmd = result["command"]
+ assert "--implementation bnb" in cmd
+ assert "--precision nf4" in cmd
+
+ @patch("olive.cli.init.onnx_flow.build_calibration_args", return_value=" -d data --split train --max_samples 128")
+ @patch("olive.cli.init.onnx_flow.prompt_calibration_source")
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_static_with_calibration_data(self, mock_ask, mock_calib, mock_build):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.onnx_flow import QuantizationType, _quantize_flow
+
+ mock_ask.return_value = QuantizationType.STATIC
+ mock_calib.return_value = {
+ "source": SourceType.HF,
+ "data_name": "data",
+ "subset": "",
+ "split": "train",
+ "num_samples": "128",
+ }
+ result = _quantize_flow("/model.onnx")
+ cmd = result["command"]
+ assert "--implementation ort" in cmd
+ assert "-d data" in cmd
+
+
+class TestOptimizeFlow:
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_generates_command(self, mock_ask):
+ from olive.cli.init.onnx_flow import _optimize_flow
+
+ mock_ask.side_effect = ["CPUExecutionProvider", "fp32"]
+ result = _optimize_flow("/model.onnx")
+ assert result["command"] == "olive optimize -m /model.onnx --provider CPUExecutionProvider --precision fp32"
+
+
+class TestTuneSessionFlow:
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_cpu_with_options(self, mock_ask):
+ from olive.cli.init.onnx_flow import _tune_session_flow
+
+ mock_ask.side_effect = [
+ "cpu", # device
+ ["CPUExecutionProvider"], # providers
+ "4", # cpu_cores
+ False, # io_bind
+ False, # enable_cuda_graph
+ ]
+ result = _tune_session_flow("/model.onnx")
+ cmd = result["command"]
+ assert "--device cpu" in cmd
+ assert "--providers_list CPUExecutionProvider" in cmd
+ assert "--cpu_cores 4" in cmd
+ assert "--io_bind" not in cmd
+
+ @patch("olive.cli.init.onnx_flow._ask")
+ def test_gpu_with_io_bind_and_cuda_graph(self, mock_ask):
+ from olive.cli.init.onnx_flow import _tune_session_flow
+
+ mock_ask.side_effect = [
+ "gpu", # device
+ ["CUDAExecutionProvider"], # providers
+ "", # cpu_cores (skip)
+ True, # io_bind
+ True, # enable_cuda_graph
+ ]
+ result = _tune_session_flow("/model.onnx")
+ cmd = result["command"]
+ assert "--device gpu" in cmd
+ assert "--io_bind" in cmd
+ assert "--enable_cuda_graph" in cmd
+
+
+class TestConvertPrecisionFlow:
+ def test_generates_command(self):
+ from olive.cli.init.onnx_flow import _convert_precision_flow
+
+ result = _convert_precision_flow("/model.onnx")
+ assert result["command"] == "olive run-pass --pass-name OnnxFloatToFloat16 -m /model.onnx"
+
+
+class TestGraphOptFlow:
+ def test_generates_command(self):
+ from olive.cli.init.onnx_flow import _graph_opt_flow
+
+ result = _graph_opt_flow("/model.onnx")
+ assert result["command"] == "olive optimize -m /model.onnx --precision fp32"
+
+
+class TestRunOnnxFlowRouting:
+ @patch("olive.cli.init.onnx_flow._optimize_flow")
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ def test_routes_to_optimize(self, mock_select, mock_flow):
+ from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow
+
+ mock_select.return_value = OnnxOperation.OPTIMIZE
+ mock_flow.return_value = {"command": "test"}
+ run_onnx_flow({"model_path": "/m.onnx"})
+ mock_flow.assert_called_once_with("/m.onnx")
+
+ @patch("olive.cli.init.onnx_flow._quantize_flow")
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ def test_routes_to_quantize(self, mock_select, mock_flow):
+ from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow
+
+ mock_select.return_value = OnnxOperation.QUANTIZE
+ mock_flow.return_value = {"command": "test"}
+ run_onnx_flow({"model_path": "/m.onnx"})
+ mock_flow.assert_called_once()
+
+ @patch("olive.cli.init.onnx_flow._graph_opt_flow")
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ def test_routes_to_graph_opt(self, mock_select, mock_flow):
+ from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow
+
+ mock_select.return_value = OnnxOperation.GRAPH_OPT
+ mock_flow.return_value = {"command": "test"}
+ run_onnx_flow({"model_path": "/m.onnx"})
+ mock_flow.assert_called_once_with("/m.onnx")
+
+ @patch("olive.cli.init.onnx_flow._convert_precision_flow")
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ def test_routes_to_convert_precision(self, mock_select, mock_flow):
+ from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow
+
+ mock_select.return_value = OnnxOperation.CONVERT_PRECISION
+ mock_flow.return_value = {"command": "test"}
+ run_onnx_flow({"model_path": "/m.onnx"})
+ mock_flow.assert_called_once()
+
+ @patch("olive.cli.init.onnx_flow._tune_session_flow")
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ def test_routes_to_tune_session(self, mock_select, mock_flow):
+ from olive.cli.init.onnx_flow import OnnxOperation, run_onnx_flow
+
+ mock_select.return_value = OnnxOperation.TUNE_SESSION
+ mock_flow.return_value = {"command": "test"}
+ run_onnx_flow({"model_path": "/m.onnx"})
+ mock_flow.assert_called_once()
+
+ @patch("olive.cli.init.onnx_flow._ask_select", return_value="unknown")
+ def test_unknown_operation_returns_empty(self, mock_select):
+ from olive.cli.init.onnx_flow import run_onnx_flow
+
+ result = run_onnx_flow({"model_path": "/m.onnx"})
+ assert not result
diff --git a/test/cli/init/test_pytorch_flow.py b/test/cli/init/test_pytorch_flow.py
new file mode 100644
index 0000000000..18f9e90051
--- /dev/null
+++ b/test/cli/init/test_pytorch_flow.py
@@ -0,0 +1,502 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
+from unittest.mock import patch
+
+
+class TestBuildModelArgs:
+ def test_model_path_only(self):
+ from olive.cli.init.pytorch_flow import _build_model_args
+
+ result = _build_model_args({"model_path": "meta-llama/Llama-3.1-8B"})
+ assert result == "-m meta-llama/Llama-3.1-8B"
+
+ def test_with_model_script(self):
+ from olive.cli.init.pytorch_flow import _build_model_args
+
+ config = {"model_path": "my_model", "model_script": "train.py", "script_dir": "./src"}
+ result = _build_model_args(config)
+ assert "-m my_model" in result
+ assert "--model_script train.py" in result
+ assert "--script_dir ./src" in result
+
+ def test_script_only_no_model_path(self):
+ from olive.cli.init.pytorch_flow import _build_model_args
+
+ config = {"model_script": "train.py"}
+ result = _build_model_args(config)
+ assert not result.startswith("-m ")
+ assert "--model_script train.py" in result
+
+ def test_empty_config(self):
+ from olive.cli.init.pytorch_flow import _build_model_args
+
+ result = _build_model_args({})
+ assert result == ""
+
+
+class TestBuildExportCommand:
+ def test_model_builder(self):
+ from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, _build_export_command
+
+ cmd = _build_export_command("-m model", {"exporter": EXPORTER_MODEL_BUILDER, "precision": "fp16"})
+ assert cmd == "olive capture-onnx-graph -m model --use_model_builder --precision fp16"
+
+ def test_model_builder_int4_with_block_size(self):
+ from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, PRECISION_INT4, _build_export_command
+
+ config = {"exporter": EXPORTER_MODEL_BUILDER, "precision": PRECISION_INT4, "int4_block_size": "32"}
+ cmd = _build_export_command("-m model", config)
+ assert "--use_model_builder --precision int4" in cmd
+ assert "--int4_block_size 32" in cmd
+
+ def test_dynamo(self):
+ from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, _build_export_command
+
+ cmd = _build_export_command("-m model", {"exporter": EXPORTER_DYNAMO, "torch_dtype": "float16"})
+ assert cmd == "olive capture-onnx-graph -m model --torch_dtype float16"
+
+ def test_torchscript(self):
+ from olive.cli.init.pytorch_flow import EXPORTER_TORCHSCRIPT, _build_export_command
+
+ cmd = _build_export_command("-m model", {"exporter": EXPORTER_TORCHSCRIPT})
+ assert cmd == "olive capture-onnx-graph -m model"
+
+
+class TestBuildQuantizeCommand:
+ def test_rtn_no_implementation(self):
+ from olive.cli.init.pytorch_flow import _build_quantize_command
+
+ cmd = _build_quantize_command("-m model", {"algorithm": "rtn", "precision": "int4"})
+ assert cmd == "olive quantize -m model --algorithm rtn --precision int4"
+ assert "--implementation" not in cmd
+
+ def test_awq_with_implementation(self):
+ from olive.cli.init.pytorch_flow import _build_quantize_command
+
+ cmd = _build_quantize_command("-m model", {"algorithm": "awq", "precision": "int4"})
+ assert "--algorithm awq" in cmd
+ assert "--implementation awq" in cmd
+
+ def test_quarot_with_implementation(self):
+ from olive.cli.init.pytorch_flow import _build_quantize_command
+
+ cmd = _build_quantize_command("-m model", {"algorithm": "quarot", "precision": "int4"})
+ assert "--implementation quarot" in cmd
+
+ def test_spinquant_with_implementation(self):
+ from olive.cli.init.pytorch_flow import _build_quantize_command
+
+ cmd = _build_quantize_command("-m model", {"algorithm": "spinquant", "precision": "int4"})
+ assert "--implementation spinquant" in cmd
+
+ def test_with_calibration(self):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import _build_quantize_command
+
+ config = {
+ "algorithm": "gptq",
+ "precision": "int4",
+ "calibration": {
+ "source": SourceType.HF,
+ "data_name": "Salesforce/wikitext",
+ "subset": "wikitext-2-raw-v1",
+ "split": "train",
+ "num_samples": "128",
+ },
+ }
+ cmd = _build_quantize_command("-m model", config)
+ assert "--algorithm gptq" in cmd
+ assert "--implementation" not in cmd # gptq uses default olive
+ assert "-d Salesforce/wikitext" in cmd
+ assert "--split train" in cmd
+
+
+class TestOptimizeAutoMode:
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_generates_command(self, mock_ask):
+ from olive.cli.init.pytorch_flow import _optimize_auto_mode
+
+ mock_ask.side_effect = ["CUDAExecutionProvider", "int4"]
+ result = _optimize_auto_mode({"model_path": "my-model"})
+ assert result["command"] == "olive optimize -m my-model --provider CUDAExecutionProvider --precision int4"
+
+
+class TestQuantizeFlow:
+ @patch("olive.cli.init.pytorch_flow.prompt_calibration_source")
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_rtn_no_calibration(self, mock_ask, mock_calib):
+ from olive.cli.init.pytorch_flow import _quantize_flow
+
+ mock_ask.side_effect = ["rtn", "int4"]
+ result = _quantize_flow({"model_path": "my-model"})
+ assert "--algorithm rtn" in result["command"]
+ assert "--implementation" not in result["command"]
+ mock_calib.assert_not_called()
+
+ @patch("olive.cli.init.pytorch_flow.prompt_calibration_source", return_value=None)
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_awq_with_default_calibration(self, mock_ask, mock_calib):
+ from olive.cli.init.pytorch_flow import _quantize_flow
+
+ mock_ask.side_effect = ["awq", "int4"]
+ result = _quantize_flow({"model_path": "my-model"})
+ assert "--algorithm awq" in result["command"]
+ assert "--implementation awq" in result["command"]
+ mock_calib.assert_called_once()
+
+ @patch(
+ "olive.cli.init.pytorch_flow.build_calibration_args", return_value=" -d data --split train --max_samples 128"
+ )
+ @patch("olive.cli.init.pytorch_flow.prompt_calibration_source")
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_gptq_with_calibration(self, mock_ask, mock_calib, mock_build):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import _quantize_flow
+
+ mock_calib.return_value = {
+ "source": SourceType.HF,
+ "data_name": "data",
+ "subset": "",
+ "split": "train",
+ "num_samples": "128",
+ }
+ mock_ask.side_effect = ["gptq", "int4"]
+ result = _quantize_flow({"model_path": "my-model"})
+ assert "--algorithm gptq" in result["command"]
+ assert "-d data" in result["command"]
+
+
+class TestExportFlow:
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_dynamo_exporter(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, _export_flow
+
+ mock_ask.side_effect = [EXPORTER_DYNAMO, "float16"]
+ result = _export_flow({"model_path": "my-model"})
+ assert result["command"] == "olive capture-onnx-graph -m my-model --torch_dtype float16"
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_model_builder_fp16(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, _export_flow
+
+ mock_ask.side_effect = [EXPORTER_MODEL_BUILDER, "fp16"]
+ result = _export_flow({"model_path": "my-model"})
+ assert "--use_model_builder --precision fp16" in result["command"]
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_model_builder_int4(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, PRECISION_INT4, _export_flow
+
+ mock_ask.side_effect = [EXPORTER_MODEL_BUILDER, PRECISION_INT4, "32", "4"]
+ result = _export_flow({"model_path": "my-model"})
+ assert "--precision int4" in result["command"]
+ assert "--int4_block_size 32" in result["command"]
+ assert "--int4_accuracy_level 4" in result["command"]
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_torchscript(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_TORCHSCRIPT, _export_flow
+
+ mock_ask.side_effect = [EXPORTER_TORCHSCRIPT]
+ result = _export_flow({"model_path": "my-model"})
+ assert result["command"] == "olive capture-onnx-graph -m my-model"
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_with_model_script(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, _export_flow
+
+ mock_ask.side_effect = [EXPORTER_DYNAMO, "float32"]
+ result = _export_flow({"model_script": "script.py", "script_dir": "./src"})
+ assert "--model_script script.py" in result["command"]
+ assert "--script_dir ./src" in result["command"]
+
+
+class TestFinetuneFlow:
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_lora_hf_dataset(self, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import TEXT_FIELD, _finetune_flow
+
+ mock_ask.side_effect = [
+ "lora", # method
+ "64", # lora_r
+ "16", # lora_alpha
+ SourceType.HF, # data_source
+ "tatsu-lab/alpaca", # data_name
+ "train", # train_split
+ "", # eval_split (skip)
+ TEXT_FIELD, # text_mode
+ "text", # text_field
+ "1024", # max_seq_len
+ "256", # max_samples
+ "bfloat16", # torch_dtype
+ ]
+ result = _finetune_flow({"model_path": "my-model"})
+ cmd = result["command"]
+ assert "olive finetune -m my-model" in cmd
+ assert "--method lora" in cmd
+ assert "-d tatsu-lab/alpaca" in cmd
+ assert "--train_split train" in cmd
+ assert "--text_field text" in cmd
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_qlora_local_data_template(self, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import TEXT_TEMPLATE, _finetune_flow
+
+ mock_ask.side_effect = [
+ "qlora", # method
+ "16", # lora_r
+ "16", # lora_alpha
+ SourceType.LOCAL, # data_source
+ "/data/train.json", # data_files
+ TEXT_TEMPLATE, # text_mode
+ "Q: {q} A: {a}", # template
+ "512", # max_seq_len
+ "100", # max_samples
+ "float16", # torch_dtype
+ ]
+ result = _finetune_flow({"model_path": "my-model"})
+ cmd = result["command"]
+ assert "--method qlora" in cmd
+ assert "--data_files /data/train.json" in cmd
+ assert '--text_template "Q: {q} A: {a}"' in cmd
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_hf_with_eval_split(self, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import TEXT_FIELD, _finetune_flow
+
+ mock_ask.side_effect = [
+ "lora", # method
+ "64", # lora_r
+ "16", # lora_alpha
+ SourceType.HF, # data_source
+ "tatsu-lab/alpaca", # data_name
+ "train", # train_split
+ "test", # eval_split (provided)
+ TEXT_FIELD, # text_mode
+ "text", # text_field
+ "1024", # max_seq_len
+ "256", # max_samples
+ "bfloat16", # torch_dtype
+ ]
+ result = _finetune_flow({"model_path": "my-model"})
+ assert "--eval_split test" in result["command"]
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_chat_template(self, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import TEXT_CHAT_TEMPLATE, _finetune_flow
+
+ mock_ask.side_effect = [
+ "lora", # method
+ "64", # lora_r
+ "16", # lora_alpha
+ SourceType.HF, # data_source
+ "dataset", # data_name
+ "train", # train_split
+ "", # eval_split (skip)
+ TEXT_CHAT_TEMPLATE, # text_mode
+ "1024", # max_seq_len
+ "256", # max_samples
+ "bfloat16", # torch_dtype
+ ]
+ result = _finetune_flow({"model_path": "my-model"})
+ assert "--use_chat_template" in result["command"]
+
+
+class TestOptimizeCustomMode:
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_export_and_quantize(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, OP_EXPORT, OP_QUANTIZE, _optimize_custom_mode
+
+ mock_ask.side_effect = [
+ [OP_EXPORT, OP_QUANTIZE], # operations checkbox
+ EXPORTER_DYNAMO, # exporter
+ "float32", # torch_dtype
+ "rtn", # algorithm
+ "int4", # precision
+ "CUDAExecutionProvider", # provider
+ ]
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ cmd = result["command"]
+ assert "olive optimize" in cmd
+ assert "--provider CUDAExecutionProvider" in cmd
+ assert "--exporter dynamo_exporter" in cmd
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_export_only(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_DYNAMO, OP_EXPORT, _optimize_custom_mode
+
+ mock_ask.side_effect = [
+ [OP_EXPORT], # operations checkbox
+ EXPORTER_DYNAMO, # exporter
+ "float16", # torch_dtype
+ ]
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ assert "olive capture-onnx-graph" in result["command"]
+ assert "--torch_dtype float16" in result["command"]
+
+ @patch(
+ "olive.cli.init.pytorch_flow.build_calibration_args", return_value=" -d data --split train --max_samples 128"
+ )
+ @patch("olive.cli.init.pytorch_flow.prompt_calibration_source")
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_quantize_only(self, mock_ask, mock_calib, mock_build):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import OP_QUANTIZE, _optimize_custom_mode
+
+ mock_calib.return_value = {
+ "source": SourceType.HF,
+ "data_name": "data",
+ "subset": "",
+ "split": "train",
+ "num_samples": "128",
+ }
+ mock_ask.side_effect = [
+ [OP_QUANTIZE], # operations checkbox
+ "gptq", # algorithm
+ "int4", # precision
+ ]
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ assert "olive quantize" in result["command"]
+ assert "--algorithm gptq" in result["command"]
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_graph_opt_only(self, mock_ask):
+ from olive.cli.init.pytorch_flow import OP_GRAPH_OPT, _optimize_custom_mode
+
+ mock_ask.side_effect = [
+ [OP_GRAPH_OPT], # operations checkbox
+ "CPUExecutionProvider", # provider
+ ]
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ assert "olive optimize" in result["command"]
+ assert "--precision fp32" in result["command"]
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_export_and_graph_opt(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, OP_EXPORT, OP_GRAPH_OPT, _optimize_custom_mode
+
+ mock_ask.side_effect = [
+ [OP_EXPORT, OP_GRAPH_OPT], # operations checkbox
+ EXPORTER_MODEL_BUILDER, # exporter
+ "fp16", # precision
+ "CPUExecutionProvider", # provider
+ ]
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ cmd = result["command"]
+ assert "olive optimize" in cmd
+ assert "--precision fp32" in cmd
+ assert "--exporter model_builder" in cmd
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_quantize_and_graph_opt(self, mock_ask):
+ from olive.cli.init.pytorch_flow import OP_GRAPH_OPT, OP_QUANTIZE, _optimize_custom_mode
+
+ mock_ask.side_effect = [
+ [OP_QUANTIZE, OP_GRAPH_OPT], # operations checkbox
+ "rtn", # algorithm
+ "int4", # precision
+ "CUDAExecutionProvider", # provider
+ ]
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ cmd = result["command"]
+ assert "olive optimize" in cmd
+ assert "--precision int4" in cmd
+
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_no_operations_selected(self, mock_ask):
+ from olive.cli.init.pytorch_flow import _optimize_custom_mode
+
+ mock_ask.return_value = [] # empty checkbox
+ result = _optimize_custom_mode({"model_path": "my-model"})
+ assert not result
+
+
+class TestOptimizeFlow:
+ @patch("olive.cli.init.pytorch_flow._optimize_auto_mode")
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_routes_to_auto(self, mock_ask, mock_auto):
+ from olive.cli.init.pytorch_flow import MODE_AUTO, _optimize_flow
+
+ mock_ask.return_value = MODE_AUTO
+ mock_auto.return_value = {"command": "test"}
+ _optimize_flow({"model_path": "m"})
+ mock_auto.assert_called_once()
+
+ @patch("olive.cli.init.pytorch_flow._optimize_custom_mode")
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_routes_to_custom(self, mock_ask, mock_custom):
+ from olive.cli.init.pytorch_flow import MODE_CUSTOM, _optimize_flow
+
+ mock_ask.return_value = MODE_CUSTOM
+ mock_custom.return_value = {"command": "test"}
+ _optimize_flow({"model_path": "m"})
+ mock_custom.assert_called_once()
+
+
+class TestPromptExportOptionsInt4:
+ @patch("olive.cli.init.pytorch_flow._ask")
+ def test_model_builder_int4_block_size(self, mock_ask):
+ from olive.cli.init.pytorch_flow import EXPORTER_MODEL_BUILDER, PRECISION_INT4, _prompt_export_options
+
+ mock_ask.side_effect = [
+ EXPORTER_MODEL_BUILDER, # exporter
+ PRECISION_INT4, # precision
+ "64", # block_size
+ ]
+ config = _prompt_export_options()
+ assert config == {"exporter": EXPORTER_MODEL_BUILDER, "precision": PRECISION_INT4, "int4_block_size": "64"}
+
+
+class TestRunPytorchFlowRouting:
+ @patch("olive.cli.init.pytorch_flow._optimize_flow")
+ @patch("olive.cli.init.pytorch_flow._ask_select")
+ def test_routes_to_optimize(self, mock_select, mock_flow):
+ from olive.cli.init.pytorch_flow import OP_OPTIMIZE, run_pytorch_flow
+
+ mock_select.return_value = OP_OPTIMIZE
+ mock_flow.return_value = {"command": "test"}
+ run_pytorch_flow({"model_path": "m"})
+ mock_flow.assert_called_once_with({"model_path": "m"})
+
+ @patch("olive.cli.init.pytorch_flow._export_flow")
+ @patch("olive.cli.init.pytorch_flow._ask_select")
+ def test_routes_to_export(self, mock_select, mock_flow):
+ from olive.cli.init.pytorch_flow import OP_EXPORT, run_pytorch_flow
+
+ mock_select.return_value = OP_EXPORT
+ mock_flow.return_value = {"command": "test"}
+ run_pytorch_flow({"model_path": "m"})
+ mock_flow.assert_called_once()
+
+ @patch("olive.cli.init.pytorch_flow._quantize_flow")
+ @patch("olive.cli.init.pytorch_flow._ask_select")
+ def test_routes_to_quantize(self, mock_select, mock_flow):
+ from olive.cli.init.pytorch_flow import OP_QUANTIZE, run_pytorch_flow
+
+ mock_select.return_value = OP_QUANTIZE
+ mock_flow.return_value = {"command": "test"}
+ run_pytorch_flow({"model_path": "m"})
+ mock_flow.assert_called_once()
+
+ @patch("olive.cli.init.pytorch_flow._finetune_flow")
+ @patch("olive.cli.init.pytorch_flow._ask_select")
+ def test_routes_to_finetune(self, mock_select, mock_flow):
+ from olive.cli.init.pytorch_flow import OP_FINETUNE, run_pytorch_flow
+
+ mock_select.return_value = OP_FINETUNE
+ mock_flow.return_value = {"command": "test"}
+ run_pytorch_flow({"model_path": "m"})
+ mock_flow.assert_called_once()
+
+ @patch("olive.cli.init.pytorch_flow._ask_select", return_value="unknown")
+ def test_unknown_operation_returns_empty(self, mock_select):
+ from olive.cli.init.pytorch_flow import run_pytorch_flow
+
+ result = run_pytorch_flow({"model_path": "m"})
+ assert not result
diff --git a/test/cli/init/test_wizard.py b/test/cli/init/test_wizard.py
new file mode 100644
index 0000000000..6afe329f9b
--- /dev/null
+++ b/test/cli/init/test_wizard.py
@@ -0,0 +1,384 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# -------------------------------------------------------------------------
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+
+class TestBuildCalibrationArgs:
+ def test_hf_source_with_subset(self):
+ from olive.cli.init.helpers import SourceType, build_calibration_args
+
+ calib = {
+ "source": SourceType.HF,
+ "data_name": "Salesforce/wikitext",
+ "subset": "wikitext-2-raw-v1",
+ "split": "train",
+ "num_samples": "128",
+ }
+ result = build_calibration_args(calib)
+ assert result == " -d Salesforce/wikitext --subset wikitext-2-raw-v1 --split train --max_samples 128"
+
+ def test_hf_source_without_subset(self):
+ from olive.cli.init.helpers import SourceType, build_calibration_args
+
+ calib = {
+ "source": SourceType.HF,
+ "data_name": "Salesforce/wikitext",
+ "subset": "",
+ "split": "train",
+ "num_samples": "64",
+ }
+ result = build_calibration_args(calib)
+ assert "--subset" not in result
+ assert result == " -d Salesforce/wikitext --split train --max_samples 64"
+
+ def test_local_source(self):
+ from olive.cli.init.helpers import SourceType, build_calibration_args
+
+ calib = {"source": SourceType.LOCAL, "data_files": "/data/calib.json"}
+ result = build_calibration_args(calib)
+ assert result == " --data_files /data/calib.json"
+
+ def test_unknown_source(self):
+ from olive.cli.init.helpers import build_calibration_args
+
+ result = build_calibration_args({"source": "unknown"})
+ assert result == ""
+
+
+class TestPromptCalibrationSource:
+ @patch("olive.cli.init.helpers._ask")
+ def test_default_returns_none(self, mock_ask):
+ from olive.cli.init.helpers import SourceType, prompt_calibration_source
+
+ mock_ask.return_value = SourceType.DEFAULT
+ result = prompt_calibration_source()
+ assert result is None
+
+ @patch("olive.cli.init.helpers._ask")
+ def test_hf_source(self, mock_ask):
+ from olive.cli.init.helpers import SourceType, prompt_calibration_source
+
+ mock_ask.side_effect = [SourceType.HF, "my_dataset", "my_subset", "validation", "64"]
+ result = prompt_calibration_source()
+ assert result == {
+ "source": SourceType.HF,
+ "data_name": "my_dataset",
+ "subset": "my_subset",
+ "split": "validation",
+ "num_samples": "64",
+ }
+
+ @patch("olive.cli.init.helpers._ask")
+ def test_local_source(self, mock_ask):
+ from olive.cli.init.helpers import SourceType, prompt_calibration_source
+
+ mock_ask.side_effect = [SourceType.LOCAL, "/data/calib.json"]
+ result = prompt_calibration_source()
+ assert result == {"source": SourceType.LOCAL, "data_files": "/data/calib.json"}
+
+
+class TestAskHelpers:
+ @patch("olive.cli.init.helpers.sys.exit")
+ def test_ask_exits_on_none(self, mock_exit):
+ from olive.cli.init.helpers import _ask
+
+ question = MagicMock()
+ question.ask.return_value = None
+ _ask(question)
+ mock_exit.assert_called_once_with(0)
+
+ def test_ask_returns_value(self):
+ from olive.cli.init.helpers import _ask
+
+ question = MagicMock()
+ question.ask.return_value = "hello"
+ assert _ask(question) == "hello"
+
+ @patch("olive.cli.init.helpers._ask")
+ def test_ask_select_raises_go_back(self, mock_ask):
+ from olive.cli.init.helpers import GoBackError, _ask_select
+
+ mock_ask.return_value = "__back__"
+ with pytest.raises(GoBackError):
+ _ask_select("Pick one:", choices=["a", "b"])
+
+ @patch("olive.cli.init.helpers._ask")
+ def test_ask_select_returns_value(self, mock_ask):
+ from olive.cli.init.helpers import _ask_select
+
+ mock_ask.return_value = "a"
+ result = _ask_select("Pick one:", choices=["a", "b"])
+ assert result == "a"
+
+ @patch("olive.cli.init.helpers._ask")
+ def test_ask_select_no_back(self, mock_ask):
+ from olive.cli.init.helpers import _ask_select
+
+ mock_ask.return_value = "a"
+ result = _ask_select("Pick:", choices=["a"], allow_back=False)
+ assert result == "a"
+
+
+class TestInitWizard:
+ """Test InitWizard end-to-end.
+
+ The wizard dispatches to onnx_flow which imports _ask/_ask_select at module
+ level, so we must patch both wizard and onnx_flow references.
+ """
+
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_full_flow_generate_command(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select):
+ from olive.cli.init.onnx_flow import OnnxOperation
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION
+ mock_select.side_effect = [ModelType.ONNX, OutputAction.COMMAND]
+ mock_ask.side_effect = ["/model.onnx", "./output", False]
+
+ InitWizard().start()
+ mock_subprocess.assert_not_called()
+
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_full_flow_run_now(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select):
+ from olive.cli.init.onnx_flow import OnnxOperation
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION
+ mock_select.side_effect = [ModelType.ONNX, OutputAction.RUN]
+ mock_ask.side_effect = ["/model.onnx", "./output"]
+
+ InitWizard().start()
+ mock_subprocess.assert_called_once()
+
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_full_flow_generate_config(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select):
+ from olive.cli.init.onnx_flow import OnnxOperation
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION
+ mock_select.side_effect = [ModelType.ONNX, OutputAction.CONFIG]
+ mock_ask.side_effect = ["/model.onnx", "./output"]
+
+ InitWizard().start()
+ mock_subprocess.assert_called_once()
+ cmd = mock_subprocess.call_args[0][0]
+ assert "--save_config_file" in cmd
+ assert "--dry_run" in cmd
+
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_go_back(self, mock_select, mock_ask, mock_onnx_select):
+ from olive.cli.init.helpers import GoBackError
+ from olive.cli.init.onnx_flow import OnnxOperation
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION
+ select_values = [ModelType.ONNX, GoBackError, ModelType.ONNX, OutputAction.COMMAND]
+
+ def select_with_goback(*args, **kwargs):
+ val = select_values.pop(0)
+ if val is GoBackError:
+ raise GoBackError
+ return val
+
+ mock_select.side_effect = select_with_goback
+ mock_ask.side_effect = ["/model.onnx", "./output", False]
+
+ InitWizard().start()
+
+ @patch("olive.cli.init.wizard.sys.exit")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_keyboard_interrupt(self, mock_select, mock_exit):
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.side_effect = KeyboardInterrupt
+ InitWizard().start()
+ mock_exit.assert_called_once_with(0)
+
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_command_then_run_now(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select):
+ from olive.cli.init.onnx_flow import OnnxOperation
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION
+ mock_select.side_effect = [ModelType.ONNX, OutputAction.COMMAND]
+ mock_ask.side_effect = ["/model.onnx", "./output", True]
+
+ InitWizard().start()
+ mock_subprocess.assert_called_once()
+
+ @patch("olive.cli.init.wizard.Path")
+ @patch("olive.cli.init.onnx_flow._ask_select")
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_config_with_existing_file(self, mock_select, mock_ask, mock_subprocess, mock_onnx_select, mock_path):
+ from olive.cli.init.onnx_flow import OnnxOperation
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_onnx_select.return_value = OnnxOperation.CONVERT_PRECISION
+ mock_select.side_effect = [ModelType.ONNX, OutputAction.CONFIG]
+ mock_ask.side_effect = ["/model.onnx", "./output"]
+ mock_path.return_value.__truediv__ = lambda self, x: MagicMock(exists=lambda: True)
+
+ InitWizard().start()
+ mock_subprocess.assert_called_once()
+
+ def test_no_command_raises_go_back(self):
+ """When _run_model_flow returns {} the wizard should go back, not silently finish."""
+ from olive.cli.init.helpers import GoBackError
+ from olive.cli.init.wizard import InitWizard
+
+ wizard = InitWizard()
+ result = {"not_command": True} # no "command" key
+ with pytest.raises(GoBackError):
+ wizard._prompt_output(result) # pylint: disable=protected-access
+
+ @patch("olive.cli.init.pytorch_flow._ask_select")
+ @patch("olive.cli.init.pytorch_flow._optimize_flow", return_value={"command": "olive optimize -m m"})
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_pytorch_flow_dispatch(self, mock_select, mock_ask, mock_subprocess, mock_opt, mock_pt_select):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.pytorch_flow import OP_OPTIMIZE
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_pt_select.return_value = OP_OPTIMIZE
+ mock_select.side_effect = [ModelType.PYTORCH, SourceType.HF, OutputAction.RUN]
+ mock_ask.side_effect = ["meta-llama/Llama-3.1-8B", "./output"]
+
+ InitWizard().start()
+ mock_subprocess.assert_called_once()
+
+ @patch("olive.cli.init.diffusers_flow._ask_select")
+ @patch("olive.cli.init.diffusers_flow._ask")
+ @patch("olive.cli.init.wizard.subprocess.run")
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_diffusers_flow_dispatch(self, mock_select, mock_ask, mock_subprocess, mock_diff_ask, mock_diff_select):
+ from olive.cli.init.diffusers_flow import OP_EXPORT
+ from olive.cli.init.helpers import DiffuserVariant
+ from olive.cli.init.wizard import InitWizard, ModelType, OutputAction
+
+ mock_diff_select.return_value = OP_EXPORT
+ mock_select.side_effect = [ModelType.DIFFUSERS, DiffuserVariant.AUTO, OutputAction.RUN]
+ mock_ask.side_effect = ["my-model", "./output"]
+ mock_diff_ask.return_value = "float16"
+
+ InitWizard().start()
+ mock_subprocess.assert_called_once()
+
+
+class TestPromptPytorchSource:
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_hf_source(self, mock_select, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.return_value = SourceType.HF
+ mock_ask.return_value = "meta-llama/Llama-3.1-8B"
+ result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access
+ assert result == {"source_type": SourceType.HF, "model_path": "meta-llama/Llama-3.1-8B"}
+
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_local_source(self, mock_select, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.return_value = SourceType.LOCAL
+ mock_ask.return_value = "./my-model/"
+ result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access
+ assert result == {"source_type": SourceType.LOCAL, "model_path": "./my-model/"}
+
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_azureml_source(self, mock_select, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.return_value = SourceType.AZUREML
+ mock_ask.return_value = "azureml://registries/r/models/m/versions/1"
+ result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access
+ assert result == {"source_type": SourceType.AZUREML, "model_path": "azureml://registries/r/models/m/versions/1"}
+
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_script_source_full(self, mock_select, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.return_value = SourceType.SCRIPT
+ mock_ask.side_effect = ["train.py", "./src", "my-model"]
+ result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access
+ assert result == {
+ "source_type": SourceType.SCRIPT,
+ "model_script": "train.py",
+ "script_dir": "./src",
+ "model_path": "my-model",
+ }
+
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_script_source_minimal(self, mock_select, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.return_value = SourceType.SCRIPT
+ mock_ask.side_effect = ["train.py", "", ""] # no script_dir, no model_path
+ result = InitWizard()._prompt_pytorch_source() # pylint: disable=protected-access
+ assert result == {"source_type": SourceType.SCRIPT, "model_script": "train.py"}
+ assert "script_dir" not in result
+ assert "model_path" not in result
+
+
+class TestPromptDiffusersSource:
+ @patch("olive.cli.init.wizard._ask")
+ @patch("olive.cli.init.wizard._ask_select")
+ def test_diffusers_source(self, mock_select, mock_ask):
+ from olive.cli.init.helpers import SourceType
+ from olive.cli.init.wizard import InitWizard
+
+ mock_select.return_value = "sdxl"
+ mock_ask.return_value = "stabilityai/sdxl-base-1.0"
+ result = InitWizard()._prompt_diffusers_source() # pylint: disable=protected-access
+ assert result == {
+ "source_type": SourceType.HF,
+ "model_path": "stabilityai/sdxl-base-1.0",
+ "variant": "sdxl",
+ }
+
+
+class TestPromptModelSource:
+ def test_unknown_model_type_returns_empty(self):
+ from olive.cli.init.wizard import InitWizard
+
+ result = InitWizard()._prompt_model_source("unknown") # pylint: disable=protected-access
+ assert not result
+
+
+class TestRunModelFlow:
+ def test_unknown_model_type_returns_empty(self):
+ from olive.cli.init.wizard import InitWizard
+
+ result = InitWizard()._run_model_flow("unknown", {}) # pylint: disable=protected-access
+ assert not result