diff --git a/.gitignore b/.gitignore index 2e7237d7..3a3959a4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,22 @@ __pycache__/ # C extensions *.so +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# IDE files +.vscode/ +.idea/ +*.swp +*.swo +*~ + # Distribution / packaging .Python build/ @@ -36,7 +52,7 @@ MANIFEST pip-log.txt pip-delete-this-directory.txt -# Unit test / coverage reports +# Testing and coverage htmlcov/ .tox/ .nox/ @@ -49,6 +65,23 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ + +# MADEngine specific +credential.json +data.json +*.log +*.csv +*.html +library_trace.csv +library_perf.csv +perf.csv +perf.html + +# Temporary and build files +temp/ +tmp/ +*.tmp +.pytest_cache/ cover/ # Translations diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..76c8fd63 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,36 @@ +# Pre-commit hooks configuration +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-json + - id: check-toml + - id: check-added-large-files + - id: check-merge-conflict + - id: debug-statements + + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.3.0 + hooks: + - id: mypy + additional_dependencies: [types-requests, types-PyYAML] + exclude: ^(tests/|scripts/) diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..d1e8a2d8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,68 @@ +# Changelog + +All notable changes to MADEngine will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Comprehensive development tooling and configuration +- Pre-commit hooks for code quality +- Makefile for common development tasks +- Developer guide with coding standards +- Type checking with mypy +- Code formatting with black and isort +- Enhanced .gitignore for better file exclusions +- CI/CD configuration templates +- **Major Documentation Refactor**: Complete integration of distributed execution and CLI guides into README.md +- Professional open-source project structure with badges and table of contents +- Comprehensive MAD package integration documentation +- Enhanced model discovery and tag system documentation +- Modern deployment scenarios and configuration examples + +### Changed +- Improved package initialization and imports +- Replaced print statements with proper logging in main CLI +- Enhanced error handling and logging throughout codebase +- Cleaned up setup.py for better maintainability +- Updated development dependencies in pyproject.toml +- **Complete README.md overhaul**: Merged all documentation into a single, comprehensive source +- Restructured documentation to emphasize MAD package integration +- Enhanced CLI usage examples and distributed execution workflows +- Improved developer contribution guidelines and legacy compatibility notes + +### Fixed +- Removed Python cache files from repository +- Fixed import organization and structure +- Improved docstring formatting and consistency + +### Removed +- Unnecessary debug print statements +- Python cache files and build artifacts +- **Legacy documentation files**: `docs/distributed-execution-solution.md` and `docs/madengine-cli-guide.md` +- Redundant documentation scattered across multiple files + +## [Previous Versions] + +For changes in previous versions, please refer to the git history. + +--- + +## Guidelines for Changelog Updates + +### Categories +- **Added** for new features +- **Changed** for changes in existing functionality +- **Deprecated** for soon-to-be removed features +- **Removed** for now removed features +- **Fixed** for any bug fixes +- **Security** for vulnerability fixes + +### Format +- Keep entries brief but descriptive +- Include ticket/issue numbers when applicable +- Group related changes together +- Use present tense ("Add feature" not "Added feature") +- Target audience: users and developers of the project diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md new file mode 100644 index 00000000..5d55a520 --- /dev/null +++ b/DEVELOPER_GUIDE.md @@ -0,0 +1,282 @@ +# MADEngine Developer Guide + +This guide covers development setup, coding standards, and contribution guidelines for MADEngine. + +## Quick Development Setup + +```bash +# Clone the repository +git clone +cd madengine + +# Development setup +pip install -e ".[dev]" +pre-commit install +``` + +## Modern Python Packaging + +This project follows modern Python packaging standards: + +- **`pyproject.toml`** - Single configuration file for everything +- **No requirements.txt** - Dependencies defined in pyproject.toml +- **Hatchling** - Modern build backend +- **Built-in tool configuration** - Black, pytest, mypy, etc. all configured in pyproject.toml + +### Installation Commands + +```bash +# Production install +pip install . + +# Development install (includes dev tools) +pip install -e ".[dev]" + +# Build package +python -m build # requires: pip install build +``` + +## Development Workflow + +### 1. Code Formatting and Linting + +We use several tools to maintain code quality: + +- **Black**: Code formatting +- **isort**: Import sorting +- **flake8**: Linting +- **mypy**: Type checking + +```bash +# Format code +make format + +# Check formatting +make format-check + +# Run linting +make lint + +```bash +# Format code +black src/ tests/ +isort src/ tests/ + +# Run linting +flake8 src/ tests/ + +# Type checking +mypy src/madengine + +# Run all tools at once +pre-commit run --all-files +``` + +### 2. Testing + +```bash +# Run tests +pytest + +# Run tests with coverage +pytest --cov=madengine --cov-report=html + +# Run specific test file +pytest tests/test_specific.py + +# Run tests with specific marker +pytest -m "not slow" +``` + +### 3. Pre-commit Hooks + +Pre-commit hooks automatically run before each commit: + +```bash +# Install hooks (already done in setup) +pre-commit install + +# Run hooks manually +pre-commit run --all-files +``` + +## Coding Standards + +### Python Code Style + +- Follow PEP 8 style guide +- Use Black for automatic formatting (line length: 88) +- Sort imports with isort +- Maximum cyclomatic complexity: 10 +- Use type hints where possible + +### Documentation + +- All public functions and classes must have docstrings +- Follow Google-style docstrings +- **Primary documentation is in README.md** - Keep it comprehensive and up-to-date +- Document any new configuration options in the README +- For major features, include examples in the appropriate README sections +- Update CLI documentation when adding new commands +- Include deployment scenarios for distributed features + +### Error Handling + +- Use proper logging instead of print statements +- Handle exceptions gracefully +- Provide meaningful error messages +- Use appropriate log levels (DEBUG, INFO, WARNING, ERROR) + +### Testing + +- Write tests for new functionality +- Maintain test coverage above 80% +- Use meaningful test names +- Follow AAA pattern (Arrange, Act, Assert) + +## Code Organization + +``` +src/madengine/ +├── __init__.py # Package initialization +├── mad.py # Main CLI entry point +├── core/ # Core functionality +├── db/ # Database operations +├── tools/ # CLI tools +├── utils/ # Utility functions +└── scripts/ # Shell scripts and tools +``` + +## Adding New Features + +### Documentation Guidelines + +MADEngine uses a centralized documentation approach: + +- **README.md** is the primary documentation source containing: + - Installation and quick start guides + - Complete CLI reference + - Distributed execution workflows + - Configuration options and examples + - Deployment scenarios + - Contributing guidelines + +- **Additional documentation** should be minimal and specific: + - `DEVELOPER_GUIDE.md` - Development setup and coding standards + - `docs/how-to-*.md` - Specific technical guides + - `CHANGELOG.md` - Release notes and changes + +When adding features: +1. Update the relevant README.md sections +2. Add CLI examples if applicable +3. Include configuration options +4. Document any new MAD package integration patterns +5. Add deployment scenarios for distributed features + +1. **Create a feature branch** + ```bash + git checkout -b feature/your-feature-name + ``` + +2. **Implement your feature** + - Write the code following our standards + - Add comprehensive tests + - Update documentation + +3. **Test your changes** + ```bash + pytest --cov=madengine + pre-commit run --all-files + black src/ tests/ + flake8 src/ tests/ + ``` + +4. **Submit a pull request** + - Ensure all CI checks pass + - Write a clear description + - Request appropriate reviewers + +## Environment Variables + +MADEngine uses several environment variables for configuration: + +- `MODEL_DIR`: Location of models directory +- `LOG_LEVEL`: Logging level (DEBUG, INFO, WARNING, ERROR) +- `MAD_VERBOSE_CONFIG`: Enable verbose configuration logging +- `MAD_AWS_S3`: AWS S3 credentials (JSON) +- `NAS_NODES`: NAS configuration (JSON) +- `PUBLIC_GITHUB_ROCM_KEY`: GitHub token (JSON) + +## Common Tasks + +### Adding a New CLI Command + +1. Create a new module in `src/madengine/tools/` +2. Add the command handler in `mad.py` +3. Update the argument parser +4. Add tests in `tests/` +5. Update documentation + +### Adding Dependencies + +1. Add to `pyproject.toml` under `dependencies` or `optional-dependencies` +2. Update setup.py if needed for legacy compatibility +3. Run `pip install -e ".[dev]"` to install +4. Update documentation if the dependency affects usage + +### Debugging + +- Use the logging module instead of print statements +- Set `LOG_LEVEL=DEBUG` for verbose output +- Use `MAD_VERBOSE_CONFIG=true` for configuration debugging + +## Release Process + +1. Update version in `pyproject.toml` +2. Update CHANGELOG.md with new features, changes, and fixes +3. Ensure README.md reflects all current functionality +4. Create a release tag: `git tag -a v1.0.0 -m "Release 1.0.0"` +5. Push tag: `git push origin v1.0.0` +6. Build and publish: `python -m build` + +### Documentation Updates for Releases + +- Verify README.md covers all new features +- Update CLI examples if commands have changed +- Ensure configuration examples are current +- Add any new deployment scenarios +- Update MAD package integration examples if applicable + +## Troubleshooting + +### Common Issues + +1. **Import errors**: Check if package is installed in development mode +2. **Test failures**: Ensure all dependencies are installed +3. **Pre-commit failures**: Run `black src/ tests/` and `isort src/ tests/` to fix formatting issues +4. **Type checking errors**: Add type hints or use `# type: ignore` comments + +### Getting Help + +- **Start with README.md** - Comprehensive documentation covering most use cases +- Check existing issues in the repository +- Review specific guides in `docs/` directory for advanced topics +- Contact the development team +- For CLI questions, refer to the CLI reference section in README.md +- For distributed execution, see the distributed workflows section in README.md + +## Performance Considerations + +- Profile code for performance bottlenecks +- Use appropriate data structures +- Minimize I/O operations +- Cache expensive computations when possible +- Consider memory usage for large datasets + +## Security Guidelines + +- Never commit credentials or secrets +- Use environment variables for sensitive configuration +- Validate all user inputs +- Follow secure coding practices +- Keep dependencies updated diff --git a/README.md b/README.md index 28907fcb..9b2650ea 100644 --- a/README.md +++ b/README.md @@ -1,426 +1,1910 @@ # madengine -Set of interfaces to run various AI models from public MAD. -# What is madengine? +[![Python](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://python.org) +[![CI](https://img.shields.io/badge/CI-GitHub%20Actions-green.svg)](https://github.com/ROCm/madengine/actions) +[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -An AI Models automation and dashboarding command-line tool to run LLMs and Deep Learning models locally or remotelly with CI. +> **Enterprise-grade AI model automation and distributed benchmarking platform** -The madengine library is to support AI automation having following features: -- AI Models run reliably on supported platforms and drive software quality -- Simple, minimalistic, out-of-the-box solution that enable confidence on hardware and software stack -- Real-time, audience-relevant AI Models performance metrics tracking, presented in clear, intuitive manner -- Best-practices for handling internal projects and external open-source projects +madengine is a sophisticated CLI tool designed for running Large Language Models (LLMs) and Deep Learning models across local and distributed environments. Built with modern Python practices, it provides both traditional single-node execution and advanced distributed orchestration capabilities as part of the [MAD (Model Automation and Dashboarding)](https://github.com/ROCm/MAD) ecosystem. -# Installation +## Table of Contents -madengine is meant to be used in conjunction with [MAD](https://github.com/ROCm/MAD). Below are the steps to set it up and run it using the command line interface (CLI). +- [🚀 Quick Start](#-quick-start) +- [✨ Features](#-features) +- [🏗️ Architecture](#️-architecture) +- [📦 Installation](#-installation) +- [💻 Command Line Interface](#-command-line-interface) +- [🔍 Model Discovery](#-model-discovery) +- [🌐 Distributed Execution](#-distributed-execution) +- [⚙️ Configuration](#️-configuration) +- [🎯 Advanced Usage](#-advanced-usage) +- [🚀 Deployment Scenarios](#-deployment-scenarios) +- [📝 Best Practices](#-best-practices) +- [🔧 Troubleshooting](#-troubleshooting) +- [📚 API Reference](#-api-reference) +- [🤝 Contributing](#-contributing) +- [📄 License](#-license) -## Clone MAD +## 🚀 Quick Start + +> **Important**: madengine must be executed from within a MAD package directory for proper model discovery. + +### Prerequisites +- Python 3.8+ with pip +- Docker with GPU support (ROCm for AMD, CUDA for NVIDIA) +- Git for repository management +- [MAD package](https://github.com/ROCm/MAD) cloned locally + +### Install madengine + +```bash +# Basic installation +pip install git+https://github.com/ROCm/madengine.git + +# With distributed runner support +pip install "madengine[runners] @ git+https://github.com/ROCm/madengine.git" + +# Development installation +git clone https://github.com/ROCm/madengine.git +cd madengine && pip install -e ".[dev]" ``` -git clone git@github.com:ROCm/MAD.git -cd MAD + +### Run Your First Model + +```bash +# Clone MAD package and navigate to it +git clone https://github.com/ROCm/MAD.git && cd MAD + +# Single-node workflow (build + run) +madengine-cli run --tags dummy --registry localhost:5000 --timeout 3600 + +# Distributed workflow (build phase) +madengine-cli build --tags dummy --registry docker.io \ + --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' + +# Distributed workflow (run phase) +madengine-cli run --manifest-file build_manifest.json --timeout 1800 ``` -## Install madengine +### Test Model Discovery -### Install from source +```bash +# List all available models +madengine discover +# Discover specific models +madengine discover --tags dummy +madengine discover --tags dummy2:dummy_2 ``` -# Create virtual environment if necessary -python3 -m venv venv -# Active the virtual environment venv -source venv/bin/activate +That's it! You're now ready to run AI models with madengine. Continue reading for advanced features and distributed execution. + +## ✨ Features + +### Core Capabilities +- 🎯 **Dual CLI Interface** - Traditional `madengine` + modern `madengine-cli` with Typer+Rich +- � **Distributed Execution** - SSH, Ansible, Kubernetes, and SLURM runners for scalable deployments +- 🐳 **Containerized Models** - Full Docker integration with GPU support (ROCm, CUDA, Intel) +- � **Intelligent Discovery** - Static, directory-specific, and dynamic Python-based model discovery +- �️ **Split Architecture** - Separate build/run phases optimized for different infrastructure types + +### Enterprise Features +- 📊 **Rich Terminal UI** - Progress bars, panels, syntax highlighting with comprehensive formatting +- 🔄 **Workflow Intelligence** - Automatic detection of build-only vs. full workflow operations +- 🏷️ **Hierarchical Tagging** - Advanced model selection with parameterization (`model:param=value`) +- 🔐 **Credential Management** - Centralized authentication with environment variable overrides +- 📈 **Performance Analytics** - Detailed metrics, reporting, and execution summaries + +### Technical Excellence +- ⚡ **Modern Python** - Built with `pyproject.toml`, Hatchling, type hints, 95%+ test coverage +- 🎯 **GPU Architecture Support** - AMD ROCm, NVIDIA CUDA, Intel GPU architectures +- 📦 **Batch Processing** - Advanced batch manifest support with selective building +- 🔧 **Production Ready** - Comprehensive error handling, logging, monitoring, retry mechanisms -# Clone madengine -git clone git@github.com:ROCm/madengine.git +## 🏗️ Architecture + +### MAD Ecosystem Integration + +madengine operates within the **MAD (Model Automation and Dashboarding)** ecosystem, providing: + +- **Model Hub**: Centralized repository of AI models with standardized interfaces +- **Configuration Management**: Docker definitions, scripts, and environment configurations +- **Data Providers**: Unified data source management with credential handling +- **Build Tools**: Comprehensive toolchain for model preparation and execution + +**Required MAD Structure:** +``` +MAD/ +├── models.json # Root model definitions +├── data.json # Data provider configurations +├── credential.json # Authentication credentials +├── scripts/ # Model-specific directories +│ ├── dummy/ # Example model +│ │ ├── models.json # Static model configs +│ │ ├── get_models_json.py # Dynamic discovery +│ │ └── run.sh # Execution script +│ └── common/ +│ └── tools.json # Build tools configuration +└── pyproject.toml # madengine configuration +``` -# Change current working directory to madengine +### Split Architecture Benefits + +![Architecture Overview](docs/img/architecture_overview.png) + +**Traditional Monolithic Workflow:** +``` +Model Discovery → Docker Build → Container Run → Performance Collection +``` + +**Modern Split Architecture:** +``` +BUILD PHASE (CPU-optimized): RUN PHASE (GPU-optimized): +Model Discovery Load Manifest +Docker Build ───→ Pull Images +Push to Registry Container Run +Export Manifest Performance Collection +``` + +**Key Advantages:** +- 🎯 **Resource Efficiency** - Build on CPU nodes, run on GPU nodes +- ⚡ **Parallel Execution** - Multiple nodes execute different models simultaneously +- 🔄 **Reproducibility** - Consistent Docker images ensure identical results +- 📈 **Scalability** - Easy horizontal scaling by adding execution nodes +- 💰 **Cost Optimization** - Use appropriate instance types for each phase + +## 📦 Installation + +### Prerequisites +- **Python 3.8+** with pip +- **Git** for repository management +- **Docker** with GPU support (ROCm for AMD, CUDA for NVIDIA) +- **MAD package** - Required for model discovery and execution + +### Quick Installation + +```bash +# Install from GitHub +pip install git+https://github.com/ROCm/madengine.git + +# Install with distributed runner support +pip install "madengine[runners] @ git+https://github.com/ROCm/madengine.git" + +# Install specific runner types +pip install "madengine[ssh,ansible] @ git+https://github.com/ROCm/madengine.git" +``` + +### Development Installation + +```bash +# Clone and setup for development +git clone https://github.com/ROCm/madengine.git cd madengine -# Install madengine from source: -pip install . +# Create virtual environment (recommended) +python3 -m venv venv && source venv/bin/activate + +# Install in development mode with all dependencies +pip install -e ".[dev]" +# Setup pre-commit hooks (optional) +pre-commit install ``` -### Install from repo +### Optional Dependencies + +| Extra | Dependencies | Use Case | +|-------|-------------|----------| +| `ssh` | `paramiko>=2.7.0, scp>=0.14.0` | SSH runner for direct node connections | +| `ansible` | `ansible>=4.0.0, ansible-runner>=2.0.0` | Ansible runner for orchestrated deployment | +| `kubernetes` | `kubernetes>=20.0.0, PyYAML>=6.0` | Kubernetes runner for cloud-native execution | +| `runners` | All runner dependencies | Complete distributed execution support | +| `dev` | Testing and development tools | Contributors and developers | +| `all` | All optional dependencies | Complete installation | -You can also install the madengine library directly from the Github repository. +### MAD Package Setup + +```bash +# Clone MAD package (required for model execution) +git clone https://github.com/ROCm/MAD.git +cd MAD +# Install madengine within MAD directory +pip install git+https://github.com/ROCm/madengine.git + +# Verify installation +madengine-cli --version +madengine discover # Test model discovery +``` + +### Docker GPU Setup + +```bash +# AMD ROCm support +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \ + rocm/pytorch:latest rocm-smi + +# NVIDIA CUDA support +docker run --rm --gpus all nvidia/cuda:latest nvidia-smi + +# Verify GPU access +madengine-cli run --tags dummy --additional-context '{"gpu_vendor": "AMD"}' ``` -pip install git+https://github.com/ROCm/madengine.git@main + +### Verification + +```bash +# Check installation +madengine-cli --version +madengine --version + +# Test basic functionality +cd /path/to/MAD +madengine discover --tags dummy +madengine-cli run --tags dummy --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' ``` -## Clone +## 💻 Command Line Interface + +madengine provides dual CLI interfaces optimized for different use cases: + +### Interface Comparison -# Run madengine CLI +| Interface | Use Case | Framework | Features | +|-----------|----------|-----------|----------| +| `madengine` | Local development, simple workflows | Argparse | Traditional interface, backward compatible | +| `madengine-cli` | Production, distributed workflows | Typer+Rich | Modern UI, distributed runners, advanced error handling | -How to run madengine CLI on your local machine. +### Modern CLI (`madengine-cli`) - Recommended -```shell -(venv) test-node:~/MAD$ madengine --help -usage: madengine [-h] [-v] {run,discover,report,database} ... +#### Build Command +Create Docker images and manifests for distributed execution: -A Model automation and dashboarding command-line tool to run LLMs and Deep Learning models locally. +```bash +# Basic build +madengine-cli build --tags dummy --registry localhost:5000 + +# Production build with context +madengine-cli build --tags production_models \ + --registry docker.io \ + --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' \ + --clean-docker-cache \ + --summary-output build_report.json + +# Batch build mode +madengine-cli build --batch-manifest batch.json \ + --registry docker.io \ + --additional-context '{"gpu_vendor": "NVIDIA", "guest_os": "UBUNTU"}' +``` + +#### Run Command +Intelligent execution with automatic workflow detection: -optional arguments: - -h, --help show this help message and exit - -v, --version show program's version number and exit +```bash +# Complete workflow (no manifest exists) +madengine-cli run --tags dummy --registry localhost:5000 --timeout 3600 -Commands: - Available commands for running models, generating reports, and toolings. +# Execution-only (manifest exists) +madengine-cli run --manifest-file build_manifest.json --timeout 1800 - {run,discover,report,database} - run Run models on container - discover Discover the models - report Generate report of models - database CRUD for database +# Advanced execution with monitoring +madengine-cli run --tags models --live-output --verbose --keep-alive ``` -## Run models locally +#### Distributed Runner Commands +Execute across multiple infrastructure types: + +```bash +# SSH Runner - Direct connections +madengine-cli runner ssh \ + --inventory inventory.yml \ + --manifest-file build_manifest.json \ + --report-output ssh_results.json + +# Ansible Runner - Orchestrated deployment +madengine-cli runner ansible \ + --inventory cluster.yml \ + --playbook deployment.yml \ + --report-output ansible_results.json + +# Kubernetes Runner - Cloud-native execution +madengine-cli runner k8s \ + --inventory k8s_inventory.yml \ + --manifests-dir k8s-setup \ + --report-output k8s_results.json + +# SLURM Runner - HPC cluster execution +madengine-cli runner slurm \ + --inventory slurm_inventory.yml \ + --job-scripts-dir slurm-setup \ + --timeout 7200 +``` -Command to run LLMs and Deep Learning Models on container. +#### Generate Commands +Create deployment configurations: +```bash +# Generate Ansible playbook +madengine-cli generate ansible \ + --manifest-file build_manifest.json \ + --output cluster-deployment.yml + +# Generate Kubernetes manifests +madengine-cli generate k8s \ + --manifest-file build_manifest.json \ + --namespace madengine-prod + +# Generate SLURM job scripts +madengine-cli generate slurm \ + --manifest-file build_manifest.json \ + --environment prod \ + --output-dir slurm-setup ``` -# An example CLI command to run a model -madengine run --tags pyt_huggingface_bert --live-output --additional-context "{'guest_os': 'UBUNTU'}" + +### Traditional CLI (`madengine`) + +Simplified interface for local development: + +```bash +# Run models locally +madengine run --tags pyt_huggingface_bert --live-output \ + --additional-context '{"guest_os": "UBUNTU"}' + +# Model discovery +madengine discover --tags dummy + +# Generate reports +madengine report to-html --csv-file-path perf.csv + +# Database operations +madengine database create-table ``` -```shell -(venv) test-node:~/MAD$ madengine run --help -usage: madengine run [-h] [--tags TAGS [TAGS ...]] [--timeout TIMEOUT] [--live-output] [--clean-docker-cache] [--additional-context-file ADDITIONAL_CONTEXT_FILE] - [--additional-context ADDITIONAL_CONTEXT] [--data-config-file-name DATA_CONFIG_FILE_NAME] [--tools-json-file-name TOOLS_JSON_FILE_NAME] - [--generate-sys-env-details GENERATE_SYS_ENV_DETAILS] [--force-mirror-local FORCE_MIRROR_LOCAL] [--keep-alive] [--keep-model-dir] - [--skip-model-run] [--disable-skip-gpu-arch] [-o OUTPUT] +### Key Command Options + +| Option | Description | Example | +|--------|-------------|---------| +| `--tags, -t` | Model tags to process | `--tags dummy resnet` | +| `--registry, -r` | Docker registry URL | `--registry docker.io` | +| `--additional-context, -c` | Runtime context JSON | `--additional-context '{"gpu_vendor": "AMD"}'` | +| `--timeout` | Execution timeout (seconds) | `--timeout 3600` | +| `--live-output, -l` | Real-time output streaming | `--live-output` | +| `--verbose, -v` | Detailed logging | `--verbose` | +| `--manifest-file, -m` | Build manifest file | `--manifest-file build_manifest.json` | +| `--batch-manifest` | Batch build configuration | `--batch-manifest batch.json` | +## 🔍 Model Discovery + +madengine provides flexible model discovery through the MAD package ecosystem with support for static, directory-specific, and dynamic configurations. + +### Discovery Methods -Run LLMs and Deep Learning models on container +#### 1. Root Models (`models.json`) +Central model definitions at MAD package root: -optional arguments: - -h, --help show this help message and exit - --tags TAGS [TAGS ...] - tags to run (can be multiple). - --timeout TIMEOUT time out for model run in seconds; Overrides per-model timeout if specified or default timeout of 7200 (2 hrs). Timeout of 0 will never - timeout. - --live-output prints output in real-time directly on STDOUT - --clean-docker-cache rebuild docker image without using cache - --additional-context-file ADDITIONAL_CONTEXT_FILE - additonal context, as json file, to filter behavior of workloads. Overrides detected contexts. - --additional-context ADDITIONAL_CONTEXT - additional context, as string representation of python dict, to filter behavior of workloads. Overrides detected contexts and additional- - context-file. - --data-config-file-name DATA_CONFIG_FILE_NAME - custom data configuration file. - --tools-json-file-name TOOLS_JSON_FILE_NAME - custom tools json configuration file. - --generate-sys-env-details GENERATE_SYS_ENV_DETAILS - generate system config env details by default - --force-mirror-local FORCE_MIRROR_LOCAL - Path to force all relevant dataproviders to mirror data locally on. - --keep-alive keep Docker container alive after run; will keep model directory after run - --keep-model-dir keep model directory after run - --skip-model-run skips running the model; will not keep model directory after run unless specified through keep-alive or keep-model-dir - --disable-skip-gpu-arch - disables skipping model based on gpu architecture - -o OUTPUT, --output OUTPUT - output file +```bash +# Discover and run root models +madengine discover --tags dummy +madengine-cli run --tags dummy pyt_huggingface_bert ``` -For each model in models.json, the script -- builds docker images associated with each model. The images are named 'ci-$(model_name)', and are not removed after the script completes. -- starts the docker container, with name, 'container_$(model_name)'. The container should automatically be stopped and removed whenever the script exits. -- clones the git 'url', and runs the 'script' -- compiles the final perf.csv and perf.html +#### 2. Directory-Specific (`scripts/{model_dir}/models.json`) +Organized model definitions in subdirectories: -### Tag functionality for running model +```bash +# Directory-specific models +madengine discover --tags dummy2:dummy_2 +madengine-cli run --tags dummy2:dummy_2 +``` -With the tag functionality, the user can select a subset of the models, that have the corresponding tags matching user specified tags, to be run. User specified tags can be specified with the `--tags` argument. If multiple tags are specified, all models that match any tag is selected. -Each model name in models.json is automatically a tag that can be used to run that model. Tags are also supported in comma-separated form as a Jenkins parameter. +#### 3. Dynamic Discovery (`scripts/{model_dir}/get_models_json.py`) +Python scripts generating model configurations with parameters: +```bash +# Dynamic models with parameterization +madengine discover --tags dummy3:dummy_3:batch_size=512 +madengine-cli run --tags dummy3:dummy_3:batch_size=512:in=32:out=16 +``` -#### Search models with tags +### Tag System -Use cases of running models with static and dynamic search. Tags option supports searching models in models.json, scripts/model_dir/models.json, and scripts/model_dir/get_models_json.py. A user can add new models not only to the models.json file of DLM but also to the model folder in Flexible. To do this, the user needs to follow these steps: +| Tag Format | Description | Example | +|------------|-------------|---------| +| `model` | Simple model tag | `dummy` | +| `dir:model` | Directory-specific model | `dummy2:dummy_2` | +| `dir:model:param=value` | Parameterized model | `dummy3:dummy_3:batch_size=512` | +| `dir:model:p1=v1:p2=v2` | Multiple parameters | `dummy3:dummy_3:batch_size=512:in=32` | -Update models.json: Add the new model's configuration details to the models.json file. This includes specifying the model's name, version, and any other relevant metadata. -Place Model Files: Copy the model files into the appropriate directory within the model folder in Flexible. Ensure that the folder structure and file naming conventions match the expected format. +### Required MAD Structure ``` -# 1. run models in ~/MAD/models.json -(venv) test-node:~/MAD$ madengine run --tags dummy --live-output +MAD/ +├── models.json # Root model definitions +├── data.json # Data provider configurations +├── credential.json # Authentication credentials +├── scripts/ +│ ├── model_name/ # Model-specific directory +│ │ ├── models.json # Static configurations +│ │ ├── get_models_json.py # Dynamic discovery script +│ │ ├── run.sh # Model execution script +│ │ └── Dockerfile # Container definition +│ └── common/ +│ └── tools.json # Build tools configuration +└── pyproject.toml # madengine configuration +``` -# 2. run model in ~/MAD/scripts/dummy2/models.json -(venv) test-node:~/MAD$ madengine run --tags dummy2:dummy_2 --live-output +### Discovery Commands -# 3. run model in ~/MAD/scripts/dummy3/get_models_json.py -(venv) test-node:~/MAD$ madengine run --tags dummy3:dummy_3 --live-output +```bash +# List all available models +madengine discover + +# Discover specific models +madengine discover --tags dummy +madengine discover --tags dummy2:dummy_2 +madengine discover --tags dummy3:dummy_3:batch_size=256 + +# Validate model configurations +madengine discover --tags production_models --verbose +``` -# 4. run model with configurations -(venv) test-node:~/MAD$ madengine run --tags dummy2:dummy_2:batch_size=512:in=32:out=16 --live-output +### Batch Processing + +Define multiple models for selective building: + +**batch.json:** +```json +[ + { + "model_name": "dummy", + "build_new": true, + "registry": "docker.io", + "registry_image": "my-org/dummy:latest" + }, + { + "model_name": "resnet", + "build_new": false, + "registry_image": "existing-registry/resnet:v1.0" + } +] +``` -# 5. run model with configurations -(venv) test-node:~/MAD$ madengine run --tags dummy3:dummy_3:batch_size=512:in=32:out=16 --live-output +**Usage:** +```bash +# Build only models with build_new=true +madengine-cli build --batch-manifest batch.json \ + --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' ``` -The configs of batch_size512:in32:out16 will be pass to environment variables and build arguments of docker. +## 🌐 Distributed Execution -### Custom timeouts -The default timeout for model run is 2 hrs. This can be overridden if the model in models.json contains a `'timeout' : TIMEOUT` entry. Both the default timeout and/or timeout specified in models.json can be overridden using `--timeout TIMEOUT` command line argument. Having `TIMEOUT` set to 0 means that the model run will never timeout. +madengine supports sophisticated distributed execution with unified orchestration across multiple infrastructure types for optimal resource utilization and scalability. -### Live output functionality -By default, `madengine` is silent. The output is piped into log files. By specifying `--live-output`, the output is printed in real-time to STDOUT. +![Distributed Workflow](docs/img/distributed_workflow.png) -### Contexts -Contexts are run-time parameters that change how the model is executed. Some contexts are auto-detected. Detected contexts may be over-ridden. Contexts are also used to filter Dockerfile used in model. +### Architecture Overview -For more details, see [How to provide contexts](docs/how-to-provide-contexts.md) +``` +┌─────────────────────────────────────────────────────────────────┐ +│ madengine CLI │ +│ (madengine-cli runner) │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ Runner Factory │ +│ (RunnerFactory.create_runner) │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌───────────────┼───────────────┼───────────────┐ + ▼ ▼ ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ SSH Runner │ │ Ansible Runner │ │ Kubernetes │ │ SLURM Runner │ + │ │ │ │ │ Runner │ │ │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +### Runner Types -### Credentials -Credentials to clone model git urls are provided in a centralized `credential.json` file. Models that require special credentials for cloning have a special `cred` field in the model definition in `models.json`. This field denotes the specific credential in `credential.json` to use. Public models repositories can skip the `cred` field. +#### 🔗 SSH Runner +Direct SSH connections for simple distributed execution: -There are several types of credentials supported. +**Use Cases:** Individual workstations, small clusters, development +**Features:** Direct SSH with paramiko, SCP file transfer, parallel execution + +```bash +madengine-cli runner ssh \ + --inventory inventory.yml \ + --manifest-file build_manifest.json \ + --report-output ssh_results.json +``` + +#### 📋 Ansible Runner +Orchestrated deployment using Ansible playbooks: + +**Use Cases:** Large clusters, complex deployment, configuration management +**Features:** Playbook generation, inventory management, rich error reporting + +```bash +madengine-cli runner ansible \ + --inventory cluster.yml \ + --playbook deployment.yml \ + --report-output ansible_results.json +``` + +#### ☸️ Kubernetes Runner +Cloud-native execution in Kubernetes clusters: + +**Use Cases:** Cloud deployments, container orchestration, auto-scaling +**Features:** Dynamic Job creation, ConfigMap management, namespace isolation + +```bash +madengine-cli runner k8s \ + --inventory k8s_inventory.yml \ + --manifests-dir k8s-setup \ + --report-output k8s_results.json +``` + +#### 🖥️ SLURM Runner +HPC cluster execution with job scheduling: + +**Use Cases:** Academic institutions, supercomputers, resource-constrained environments +**Features:** Job arrays, resource management, module system integration + +```bash +# Two-step workflow +madengine-cli generate slurm --manifest-file build_manifest.json --output-dir slurm-setup +madengine-cli runner slurm --inventory slurm_inventory.yml --job-scripts-dir slurm-setup +``` + +### Environment Setup Process + +All runners automatically perform these steps on each node/pod: + +1. **Clone MAD Repository** - Downloads latest MAD package from GitHub +2. **Setup Virtual Environment** - Creates isolated Python environment +3. **Install Dependencies** - Installs madengine and all required packages +4. **Copy Configuration** - Transfers credentials, data configs, build manifests +5. **Verify Installation** - Validates madengine-cli functionality +6. **Execute from MAD Directory** - Runs with proper MODEL_DIR context + +### Inventory Configuration Examples + +#### SSH/Ansible Inventory +```yaml +nodes: + - hostname: "gpu-node-1" + address: "192.168.1.101" + username: "madengine" + ssh_key_path: "~/.ssh/id_rsa" + gpu_count: 4 + gpu_vendor: "AMD" + environment: + ROCR_VISIBLE_DEVICES: "0,1,2,3" +``` + +#### Kubernetes Inventory +```yaml +pods: + - name: "madengine-pod-1" + node_selector: + gpu-type: "amd" + resources: + requests: + amd.com/gpu: "2" + gpu_vendor: "AMD" +``` + +#### SLURM Inventory +```yaml +slurm_cluster: + login_node: + hostname: "hpc-login01" + address: "hpc-login01.example.com" + username: "madengine" + partitions: + - name: "gpu" + max_time: "24:00:00" + gpu_types: ["MI250X", "A100"] + gpu_vendor: "AMD" +``` + +### Use Case Examples + +#### Single GPU Development +```bash +madengine-cli runner ssh \ + --inventory dev_inventory.yml \ + --manifest-file build_manifest.json \ + --timeout 1800 +``` + +#### Multi-Node Production +```bash +madengine-cli runner ansible \ + --inventory production_cluster.yml \ + --manifest-file build_manifest.json \ + --parallelism 4 \ + --report-output production_results.json +``` + +#### Cloud Kubernetes Deployment +```bash +madengine-cli generate k8s --manifest-file build_manifest.json --namespace prod +madengine-cli runner k8s --inventory k8s_prod.yml --manifests-dir k8s-manifests +``` + +#### HPC SLURM Cluster +```bash +madengine-cli generate slurm --manifest-file research_models.json --environment hpc +madengine-cli runner slurm --inventory hpc_cluster.yml --job-scripts-dir slurm-setup --timeout 28800 +``` +## ⚙️ Configuration + +### Context System + +Runtime parameters controlling model execution behavior: + +```json +{ + "gpu_vendor": "AMD", + "guest_os": "UBUNTU", + "timeout_multiplier": 2.0, + "tools": [{"name": "rocprof"}] +} +``` + +**Required Build Context:** +- `gpu_vendor`: AMD, NVIDIA, INTEL (case-insensitive) +- `guest_os`: UBUNTU, CENTOS, ROCKY (case-insensitive) + +**Context Usage:** +```bash +# JSON string +--additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' + +# From file +--additional-context-file context.json +``` + +### Credential Management + +Centralized authentication in `credential.json`: + +```json +{ + "dockerhub": { + "username": "dockerhub_username", + "password": "dockerhub_token", + "repository": "my-org" + }, + "AMD_GITHUB": { + "username": "github_username", + "password": "github_token" + }, + "MAD_AWS_S3": { + "username": "aws_access_key", + "password": "aws_secret_key" + } +} +``` + +### Registry Configuration + +**Automatic Registry Detection:** +- `docker.io` or empty → uses `dockerhub` credentials +- `localhost:5000` → uses `localhost:5000` credentials +- Custom URLs → uses URL as credential key + +**Registry Override with Environment Variables:** +```bash +export MAD_DOCKERHUB_USER=my_username +export MAD_DOCKERHUB_PASSWORD=my_token +export MAD_DOCKERHUB_REPO=my_org +``` + +### Data Provider Configuration + +Configure data sources in `data.json`: + +```json +{ + "data_sources": { + "model_data": { + "nas": {"path": "/home/datum"}, + "minio": {"path": "s3://datasets/datum"}, + "aws": {"path": "s3://datasets/datum"} + } + }, + "mirrorlocal": "/tmp/local_mirror" +} +``` + +### Environment Variables + +| Variable | Description | Example | +|----------|-------------|---------| +| `MAD_VERBOSE_CONFIG` | Enable verbose configuration logging | `"true"` | +| `MAD_SETUP_MODEL_DIR` | Auto-setup MODEL_DIR during import | `"true"` | +| `MODEL_DIR` | Model directory path | `/path/to/models` | +| `MAD_DOCKERHUB_*` | Docker Hub credentials override | See above | + +**Configuration Priority:** +1. Environment variables (highest) +2. Command-line arguments +3. Configuration files +4. Built-in defaults (lowest) +## 🎯 Advanced Usage + +### Custom Timeouts + +```bash +# Model-specific timeout in models.json +{"timeout": 3600} + +# Command-line timeout override +madengine-cli run --tags models --timeout 7200 + +# No timeout (run indefinitely) +madengine-cli run --tags models --timeout 0 +``` + +### Performance Profiling + +```bash +# GPU profiling with ROCm +madengine-cli run --tags models \ + --additional-context '{"tools": [{"name":"rocprof"}]}' + +# Memory and performance monitoring +madengine-cli run --tags models --live-output --verbose \ + --summary-output detailed_metrics.json + +# Multiple profiling tools +madengine-cli run --tags models \ + --additional-context '{"tools": [{"name":"rocprof"}, {"name":"trace"}]}' +``` + +### Local Data Mirroring + +```bash +# Force local mirroring for all workloads +madengine-cli run --tags models --force-mirror-local /tmp/mirror + +# Configure per-model in data.json +{ + "mirrorlocal": "/path/to/local/mirror" +} +``` + +### Development and Debugging + +```bash +# Keep containers alive for debugging +madengine-cli run --tags models --keep-alive --keep-model-dir + +# Skip model execution (build/setup only) +madengine-cli run --tags models --skip-model-run + +# Detailed logging with stack traces +madengine-cli run --tags models --verbose + +# Clean rebuild without cache +madengine-cli build --tags models --clean-docker-cache +``` + +### Batch Processing Advanced + +**Selective Building:** +```json +[ + { + "model_name": "production_model", + "build_new": true, + "registry": "prod.registry.com", + "registry_image": "prod/model:v2.0" + }, + { + "model_name": "cached_model", + "build_new": false, + "registry_image": "cache/model:v1.5" + } +] +``` + +**Complex Context Override:** +```bash +madengine-cli build --batch-manifest batch.json \ + --additional-context '{ + "gpu_vendor": "AMD", + "guest_os": "UBUNTU", + "docker_env_vars": {"ROCR_VISIBLE_DEVICES": "0,1,2,3"}, + "timeout_multiplier": 2.0 + }' +``` + +### Registry Management + +```bash +# Multi-registry deployment +madengine-cli build --tags models --registry docker.io +scp build_manifest.json remote-cluster:/shared/ + +# Private registry with authentication +madengine-cli build --tags models --registry private.company.com \ + --additional-context '{"registry_auth": {"username": "user", "password": "token"}}' + +# Local registry for development +docker run -d -p 5000:5000 registry:2 +madengine-cli build --tags dev_models --registry localhost:5000 +``` + +### Error Recovery and Monitoring + +```bash +# Retry failed operations +madengine-cli run --tags models --timeout 3600 --verbose + +# Generate comprehensive reports +madengine-cli run --tags models \ + --summary-output execution_summary.json \ + --report-output detailed_report.json + +# Monitor execution progress +madengine-cli run --tags models --live-output --verbose +``` -1. For HTTP/HTTPS git urls, `username` and `password` should be provided in the credential. For Source Code Management(SCM) systems that support Access Tokens, the token can be substituted for the `password` field. The `username` and `password` will be passed as a docker build argument and a container environment variable in the docker build and run steps. Fore example, for `"cred":"AMD_GITHUB"` field in `models.json` and entry `"AMD_GITHUB": { "username": "github_username", "password":"pass" }` in `credential.json` the following docker build arguments and container environment variables will be added: `AMD_GITHUB_USERNAME="github_username"` and `AMD_GITHUB_PASSWORD="pass"`. +## 🚀 Deployment Scenarios + +### Research Lab Environment + +**Setup:** Multiple GPU workstations, shared storage, local registry +**Goal:** Model comparison across different GPU architectures + +```bash +# Central build server +madengine-cli build --tags research_models --registry lab-registry:5000 \ + --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' \ + --summary-output research_build_$(date +%Y%m%d).json + +# Distribute via shared storage +cp build_manifest.json /shared/nfs/madengine/experiments/ + +# Execute on researcher workstations +madengine-cli run --manifest-file /shared/nfs/madengine/experiments/build_manifest.json \ + --live-output --timeout 7200 --verbose +``` + +### Cloud Service Provider + +**Setup:** Kubernetes cluster, CI/CD pipeline, cloud registry +**Goal:** ML benchmarking as a service for customers + +```bash +# CI/CD build pipeline +madengine-cli build --tags customer_models --registry gcr.io/ml-bench \ + --additional-context-file customer_context.json \ + --summary-output build_report_${CUSTOMER_ID}.json + +# Batch build for multiple customer models +madengine-cli build --batch-manifest customer_${CUSTOMER_ID}_models.json \ + --registry gcr.io/ml-bench \ + --additional-context-file customer_context.json + +# Generate and deploy K8s configuration +madengine-cli generate k8s \ + --manifest-file build_manifest.json \ + --namespace customer-bench-${CUSTOMER_ID} + +kubectl apply -f k8s-manifests/ --namespace customer-bench-${CUSTOMER_ID} +``` + +### Enterprise Data Center + +**Setup:** Large-scale on-premise infrastructure with heterogeneous GPU nodes +**Goal:** Centralized benchmarking and resource optimization + +```bash +# Centralized build on dedicated build server +madengine-cli build --tags enterprise_models --registry dc-registry.local \ + --additional-context '{"gpu_vendor": "NVIDIA", "guest_os": "UBUNTU"}' \ + --clean-docker-cache \ + --summary-output enterprise_build_$(date +%Y%m%d).json + +# Distributed execution across data center +madengine-cli runner ansible \ + --inventory datacenter_inventory.yml \ + --manifest-file enterprise_build_$(date +%Y%m%d).json \ + --parallelism 12 \ + --report-output datacenter_execution_$(date +%Y%m%d).json \ + --verbose + +# Generate comprehensive performance reports +madengine report to-html --csv-file-path datacenter_perf_$(date +%Y%m%d).csv +``` + +### Academic HPC Institution + +**Setup:** SLURM-managed supercomputer with shared filesystem +**Goal:** Large-scale research model benchmarking + +```bash +# Generate SLURM configuration for research workload +madengine-cli generate slurm \ + --manifest-file research_models_v2.json \ + --environment hpc \ + --output-dir research-slurm-$(date +%Y%m%d) + +# Submit to HPC job scheduler +madengine-cli runner slurm \ + --inventory supercomputer_cluster.yml \ + --job-scripts-dir research-slurm-$(date +%Y%m%d) \ + --timeout 86400 \ + --verbose + +# Monitor and collect results +squeue -u $USER +ls /shared/results/research-*/job_summary.json +``` + +### Hybrid Cloud-Edge Deployment + +**Setup:** Mixed cloud and edge infrastructure +**Goal:** Distributed model validation across environments + +```bash +# Build for multiple environments +madengine-cli build --tags hybrid_models --registry hybrid-registry.com \ + --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' \ + --summary-output hybrid_build.json + +# Cloud execution (Kubernetes) +madengine-cli runner k8s \ + --inventory cloud_k8s_inventory.yml \ + --manifests-dir cloud-k8s-setup \ + --report-output cloud_results.json + +# Edge execution (SSH) +madengine-cli runner ssh \ + --inventory edge_nodes_inventory.yml \ + --manifest-file hybrid_build.json \ + --report-output edge_results.json + +# Aggregate results +python scripts/aggregate_hybrid_results.py cloud_results.json edge_results.json +``` + +### CI/CD Pipeline Integration + +**Setup:** GitHub Actions with automated model validation +**Goal:** Continuous benchmarking for model releases + +```yaml +# .github/workflows/model-benchmark.yml +name: Model Benchmark +on: + push: + paths: ['models/**', 'scripts/**'] + +jobs: + benchmark: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 -2. For SSH git urls, `username` and `ssh_key_file` should be provided in the credential. The `username` is the SSH username, and `ssh_key_file` is the private ssh key, that has been registed with the SCM system. -Due to legal requirements, the Credentials to access all models is not provided by default in DLM. Please contact the model owner if you wish to access and run the model. + - name: Build Models + run: | + madengine-cli build --tags ci_models \ + --registry ${{ secrets.REGISTRY_URL }} \ + --additional-context '{"gpu_vendor": "NVIDIA", "guest_os": "UBUNTU"}' \ + --summary-output ci_build_${{ github.sha }}.json + + - name: Deploy to Test Cluster + run: | + madengine-cli runner k8s \ + --inventory .github/k8s_test_inventory.yml \ + --manifests-dir ci-k8s-setup \ + --report-output ci_test_results.json +``` + +## 📝 Best Practices + +### 🔧 Infrastructure Management + +**Inventory Organization:** +- Store inventory files in version control with environment separation +- Use descriptive hostnames and consistent naming conventions +- Document node purposes, GPU configurations, and network topology +- Validate inventory files before deployment with dry-run tests + +**Security Hardening:** +- Use SSH keys instead of passwords for all remote connections +- Implement least privilege access with dedicated service accounts +- Restrict network access to essential ports and trusted sources +- Rotate credentials regularly and store them securely + +### ⚡ Performance Optimization + +**Resource Allocation:** +- Match CPU/memory requests to actual model requirements +- Monitor GPU utilization and adjust parallelism accordingly +- Use local or geographically close registries for faster image pulls +- Implement resource quotas to prevent over-subscription + +**Parallelism Tuning:** +```bash +# Start conservative and scale up +madengine-cli runner ansible --parallelism 2 # Initial test +madengine-cli runner ansible --parallelism 4 # Scale based on results +madengine-cli runner ansible --parallelism 8 # Monitor resource usage +``` + +**Network Optimization:** +- Use high-bandwidth connections (10GbE+) for large clusters +- Minimize network latency between build and execution nodes +- Implement registry caching for frequently used images + +### 🔍 Error Handling & Monitoring + +**Comprehensive Logging:** +```bash +# Enable verbose logging for troubleshooting +madengine-cli run --tags models --verbose --live-output + +# Capture execution summaries for analysis +madengine-cli run --tags models --summary-output execution_$(date +%Y%m%d).json +``` + +**Proactive Monitoring:** +- Monitor cluster resource usage and job queue status +- Set up alerts for failed executions and resource exhaustion +- Implement health checks for critical infrastructure components +- Track performance metrics over time for capacity planning + +### 📊 Registry & Build Management + +**Registry Strategy:** +```bash +# Use environment-specific registries +madengine-cli build --registry dev-registry.local # Development +madengine-cli build --registry staging-registry.com # Staging +madengine-cli build --registry prod-registry.com # Production +``` + +**Build Optimization:** +- Use Docker layer caching and multi-stage builds +- Clean up intermediate containers and unused images regularly +- Tag images with semantic versions for reproducibility +- Implement registry garbage collection policies + +### 🔄 Workflow Management + +**Environment Separation:** +```bash +# Separate configurations for each environment +inventory/ +├── dev_inventory.yml +├── staging_inventory.yml +└── prod_inventory.yml + +contexts/ +├── dev_context.json +├── staging_context.json +└── prod_context.json +``` -3. For NAS urls, `HOST`, `PORT`, `USERNAME`, and `PASSWORD` should be provided in the credential. Please check env variables starting with NAS in [Environment Variables] (https://github.com/ROCm/madengine/blob/main/README.md#environment-variables) +**Version Control:** +- Track all configuration files (inventory, contexts, batch manifests) +- Use branching strategies for environment promotion +- Tag releases with corresponding model versions +- Maintain change logs for configuration updates -3. For AWS S3 urls, `USERNAME`, and `PASSWORD` should be provided in the credential with var name as MAD_AWS_S3 as mentioned in [Environment Variables] (https://github.com/ROCm/madengine/blob/main/README.md#environment-variables) +### 🎯 Model Lifecycle Management +**Discovery Organization:** +``` +scripts/ +├── production_models/ # Stable, validated models +├── experimental_models/ # Development and testing +├── archived_models/ # Historical or deprecated +└── common/ # Shared tooling and utilities +``` + +**Testing Strategy:** +- Test new models in development environment first +- Use subset of data for initial validation runs +- Implement automated testing for critical model changes +- Maintain baseline performance metrics for comparison -### Local data provider -The DLM user may wish to run a model locally multiple times, with the input data downloaded once, and reused subsquently. This functionality is only supported on models that support the Data Provider functionality. That is, the model specification in `models.json` have the `data` field, which points to a data specification in `data.json`. +## 🔧 Troubleshooting -To use existing data on a local path, add to the data specification, using a `local` field within `data.json`. By default, this path is mounted read-only. To change this path to read-write, specify the `readwrite` field to `'true'` in the data configuration. +### Common Issues & Solutions + +#### 🔗 SSH Connection Failures + +**Symptoms:** Cannot connect to remote nodes +```bash +# Test basic connectivity +ping +ssh -v -i ~/.ssh/id_rsa user@node # Verbose SSH test + +# Fix common issues +chmod 600 ~/.ssh/id_rsa # Fix key permissions +ssh-add ~/.ssh/id_rsa # Add key to agent +systemctl status sshd # Check SSH service +``` -If no data exists in local path, a local copy of data can be downloaded using by setting the `mirrorlocal` field in data specification in `data.json`. Not all providers support `mirrorlocal`. For the ones that do support this feature, the remote data is mirrored on this host path during the first run. In subsequent runs, the data may be reused through synchronization mechanisms. If the user wishes to skip the remote synchronization, the same location can be set as a `local` data provider in data.json, with higher precedence, or as the only provider for the data, by locally editing `data.json`. +#### 📋 Ansible Execution Errors -Alternatively, the command-line argument, `--force-mirror-local` forces local mirroring on *all* workloads, to the provided FORCEMIRRORLOCAL path. +**Symptoms:** Playbook failures or connectivity issues +```bash +# Test Ansible connectivity +ansible all -i inventory.yml -m ping -## Discover models +# Debug inventory format +ansible-inventory -i inventory.yml --list -Commands for discovering models through models.json, scripts/{model_dir}/models.json, or scripts/{model_dir}/get_models_json.py +# Check Python installation +ansible all -i inventory.yml -m setup +# Run with increased verbosity +madengine-cli runner ansible --verbose ``` -(venv) test-node:~/MAD$ madengine discover --help -usage: madengine discover [-h] [--tags TAGS [TAGS ...]] -Discover the models +#### ☸️ Kubernetes Job Failures + +**Symptoms:** Jobs fail to start or complete +```bash +# Check cluster health +kubectl get nodes +kubectl get pods --all-namespaces + +# Inspect job details +kubectl describe job madengine-job -n madengine +kubectl logs job/madengine-job -n madengine -optional arguments: - -h, --help show this help message and exit - --tags TAGS [TAGS ...] - tags to discover models (can be multiple). +# Check resource availability +kubectl describe quota -n madengine +kubectl top nodes ``` -Use cases about how to discover models: +#### 🐳 Docker Registry Issues + +**Symptoms:** Image pull failures or authentication errors +```bash +# Test registry connectivity +docker pull / + +# Check authentication +docker login + +# Verify image exists +docker images | grep +# Test network access +curl -I https:///v2/ ``` -# 1 discover all models in DLM -(venv) test-node:~/MAD$ madengine discover -# 2. discover specified model using tags in models.json of DLM -(venv) test-node:~/MAD$ madengine discover --tags dummy +#### 🖥️ GPU Resource Problems -# 3. discover specified model using tags in scripts/{model_dir}/models.json with static search i.e. models.json -(venv) test-node:~/MAD$ madengine discover --tags dummy2/dummy_2 +**Symptoms:** GPU not detected or allocated properly +```bash +# Check GPU status +nvidia-smi # NVIDIA GPUs +rocm-smi # AMD GPUs + +# Verify Kubernetes GPU resources +kubectl describe nodes | grep -A5 "Allocated resources" -# 4. discover specified model using tags in scripts/{model_dir}/get_models_json.py with dynamic search i.e. get_models_json.py -(venv) test-node:~/MAD$ madengine discover --tags dummy3/dummy_3 +# Check device plugin status +kubectl get pods -n kube-system | grep gpu +``` -# 5. pass additional args to your model script from CLI -(venv) test-node:~/MAD$ madengine discover --tags dummy3/dummy_3:bs16 +#### 🏗️ MAD Environment Setup Failures -# 6. get multiple models using tags -(venv) test-node:~/MAD$ madengine discover --tags pyt_huggingface_bert pyt_huggingface_gpt2 +**Symptoms:** Repository cloning or installation issues +```bash +# Test GitHub connectivity +ping github.com +curl -I https://github.com + +# Manual setup test +git clone https://github.com/ROCm/MAD.git test_mad +cd test_mad && python3 -m venv test_venv +source test_venv/bin/activate && pip install git+https://github.com/ROCm/madengine.git + +# Check system requirements +python3 --version # Ensure Python 3.8+ +pip --version # Verify pip availability +df -h # Check disk space ``` -Note: You cannot use a backslash '/' or a colon ':' in a model name or a tag for a model in `models.json` or `get_models_json.py` +#### 📊 SLURM Job Problems + +**Symptoms:** Job submission or execution failures +```bash +# Check SLURM cluster status +sinfo # Cluster overview +sinfo -p gpu # GPU partition status +squeue -u $(whoami) # Your job queue -## Generate reports +# Verify SLURM account and permissions +sacctmgr show assoc user=$(whoami) +sacctmgr show qos # Available QoS options -Commands for generating reports. +# Test manual job submission +sbatch --test-only job_script.sh +# Check job logs +cat logs/madengine_*.out +cat logs/madengine_*.err ``` -(venv) test-node:~/MAD$ madengine report --help -usage: madengine report [-h] {update-perf,to-html,to-email} ... -optional arguments: - -h, --help show this help message and exit +### Debugging Strategies + +#### 🔍 Systematic Troubleshooting + +1. **Enable Verbose Logging** + ```bash + madengine-cli run --tags models --verbose --live-output + ``` + +2. **Test Components Individually** + ```bash + # Test model discovery first + madengine discover --tags dummy + + # Test build phase only + madengine-cli build --tags dummy --registry localhost:5000 + + # Test run phase with existing manifest + madengine-cli run --manifest-file build_manifest.json + ``` + +3. **Use Minimal Test Cases** + ```bash + # Start with simple dummy model + madengine-cli run --tags dummy --timeout 300 + + # Test single node before multi-node + madengine-cli runner ssh --inventory single_node.yml + ``` + +4. **Check Resource Utilization** + ```bash + # Monitor during execution + htop # CPU/Memory usage + nvidia-smi -l 1 # GPU utilization + iotop # Disk I/O + nethogs # Network usage + ``` + +### Performance Diagnostics + +#### 🚀 Optimization Analysis + +**Identify Bottlenecks:** +```bash +# Profile container execution +madengine-cli run --tags models --live-output --keep-alive + +# Monitor registry pull times +time docker pull / + +# Check network throughput +iperf3 -c -Report Commands: - Available commands for generating reports. +# Analyze build times +madengine-cli build --tags models --verbose --summary-output build_profile.json +``` + +**Resource Monitoring:** +```bash +# Real-time monitoring during execution +watch -n 1 'kubectl top nodes && kubectl top pods' - {update-perf,to-html,to-email} - update-perf Update perf.csv to database - to-html Convert CSV to HTML report of models - to-email Convert CSV to Email of models +# Generate resource usage reports +madengine-cli runner ansible --report-output detailed_metrics.json ``` -### Report command - Update perf CSV to database +### Emergency Recovery + +#### 🆘 Cluster Recovery Procedures + +**Clean Up Failed Jobs:** +```bash +# Kubernetes cleanup +kubectl delete jobs --all -n madengine +kubectl delete pods --field-selector=status.phase=Failed -n madengine + +# SLURM cleanup +scancel -u $(whoami) # Cancel all your jobs +squeue -u $(whoami) # Verify cancellation + +# Docker cleanup +docker system prune -f # Clean unused containers/images +``` -Update perf.csv to database +**Reset Environment:** +```bash +# Reset MAD environment on remote nodes +madengine-cli runner ssh --inventory inventory.yml \ + --additional-context '{"reset_environment": true}' +# Recreate virtual environments +ssh node1 'rm -rf /path/to/MAD/venv && python3 -m venv /path/to/MAD/venv' ``` -(venv) test-node:~/MAD$ madengine report update-perf --help -usage: madengine report update-perf [-h] [--single_result SINGLE_RESULT] [--exception-result EXCEPTION_RESULT] [--failed-result FAILED_RESULT] - [--multiple-results MULTIPLE_RESULTS] [--perf-csv PERF_CSV] [--model-name MODEL_NAME] [--common-info COMMON_INFO] -Update performance metrics of models perf.csv to database. +### Getting Help + +#### 📞 Support Resources + +**Log Collection for Support:** +```bash +# Collect comprehensive logs +madengine-cli run --tags failing_model --verbose > madengine_debug.log 2>&1 + +# Generate system information +madengine-cli run --tags dummy --sys-env-details --summary-output system_info.json -optional arguments: - -h, --help show this help message and exit - --single_result SINGLE_RESULT - path to the single result json - --exception-result EXCEPTION_RESULT - path to the single result json - --failed-result FAILED_RESULT - path to the single result json - --multiple-results MULTIPLE_RESULTS - path to the results csv - --perf-csv PERF_CSV - --model-name MODEL_NAME - --common-info COMMON_INFO +# Package logs for support +tar -czf madengine_support_$(date +%Y%m%d).tar.gz \ + madengine_debug.log system_info.json build_manifest.json ``` -### Report command - Convert CSV to HTML +**Community Support:** +- GitHub Issues: https://github.com/ROCm/madengine/issues +- ROCm Community: https://rocm.docs.amd.com/en/latest/ +- Documentation: https://github.com/ROCm/madengine/tree/main/docs + +## 📚 API Reference -Convert CSV to HTML report of models +### Core Command Structure +```bash +# Modern CLI (Recommended) +madengine-cli [options] + +# Traditional CLI (Compatibility) +madengine [options] ``` -(venv) test-node:~/MAD$ madengine report to-html --help -usage: madengine report to-html [-h] [--csv-file-path CSV_FILE_PATH] -Convert CSV to HTML report of models. +### Build Command + +**Purpose:** Create Docker images and manifests for distributed execution -optional arguments: - -h, --help show this help message and exit - --csv-file-path CSV_FILE_PATH +```bash +madengine-cli build [OPTIONS] +``` + +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--tags, -t` | Multiple | Model tags to build | `[]` | +| `--registry, -r` | String | Docker registry URL | `None` | +| `--batch-manifest` | File | Batch build configuration file | `None` | +| `--additional-context, -c` | JSON | Runtime context as JSON string | `"{}"` | +| `--additional-context-file, -f` | File | Runtime context from file | `None` | +| `--clean-docker-cache` | Flag | Rebuild without Docker cache | `false` | +| `--manifest-output, -m` | File | Build manifest output path | `build_manifest.json` | +| `--summary-output, -s` | File | Build summary JSON output | `None` | +| `--live-output, -l` | Flag | Real-time output streaming | `false` | +| `--verbose, -v` | Flag | Enable detailed logging | `false` | + +**Examples:** +```bash +# Basic build +madengine-cli build --tags dummy --registry localhost:5000 + +# Production build +madengine-cli build --tags production_models \ + --registry docker.io \ + --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' \ + --clean-docker-cache \ + --summary-output build_report.json ``` -### Report command - Convert CSV to Email +### Run Command + +**Purpose:** Execute models with intelligent workflow detection + +```bash +madengine-cli run [OPTIONS] +``` -Convert CSV to Email report of models +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--tags, -t` | Multiple | Model tags to run | `[]` | +| `--manifest-file, -m` | File | Build manifest file path | `""` | +| `--registry, -r` | String | Docker registry URL | `None` | +| `--timeout` | Integer | Execution timeout in seconds | `-1` | +| `--additional-context, -c` | JSON | Runtime context as JSON string | `"{}"` | +| `--additional-context-file, -f` | File | Runtime context from file | `None` | +| `--keep-alive` | Flag | Keep containers alive after run | `false` | +| `--keep-model-dir` | Flag | Keep model directory after run | `false` | +| `--skip-model-run` | Flag | Skip model execution (setup only) | `false` | +| `--live-output, -l` | Flag | Real-time output streaming | `false` | +| `--verbose, -v` | Flag | Enable detailed logging | `false` | + +**Examples:** +```bash +# Complete workflow +madengine-cli run --tags dummy --registry localhost:5000 --timeout 3600 +# Execution-only +madengine-cli run --manifest-file build_manifest.json --timeout 1800 ``` -(venv) test-node:~/MAD$ madengine report to-email --help -usage: madengine report to-email [-h] [--csv-file-path CSV_FILE_PATH] -Convert CSV to Email of models. +### Runner Commands -optional arguments: - -h, --help show this help message and exit - --csv-file-path CSV_FILE_PATH - Path to the directory containing the CSV files. +**Purpose:** Execute across distributed infrastructure + +```bash +madengine-cli runner [OPTIONS] ``` -## Database +**Runner Types:** `ssh`, `ansible`, `k8s`, `slurm` + +#### Common Runner Options -Commands for database, such as create and update table of DB. +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--inventory, -i` | File | Inventory configuration file | `inventory.yml` | +| `--report-output` | File | Execution report output | `runner_report.json` | +| `--verbose, -v` | Flag | Enable detailed logging | `false` | +#### SSH Runner + +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--manifest-file, -m` | File | Build manifest file | `build_manifest.json` | + +#### Ansible Runner + +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--playbook` | File | Ansible playbook file | `madengine_distributed.yml` | + +#### Kubernetes Runner + +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--manifests-dir, -d` | Directory | Kubernetes manifests directory | `k8s-setup` | +| `--kubeconfig` | File | Kubeconfig file path | Auto-detected | + +#### SLURM Runner + +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--job-scripts-dir, -j` | Directory | SLURM job scripts directory | `slurm-setup` | +| `--timeout, -t` | Integer | Execution timeout in seconds | `3600` | + +### Generate Commands + +**Purpose:** Create deployment configurations + +```bash +madengine-cli generate [OPTIONS] ``` -(venv) test-node:~/MAD$ madengine database --help -usage: madengine database [-h] {create-table,update-table,upload-mongodb} ... -optional arguments: - -h, --help show this help message and exit +**Types:** `ansible`, `k8s`, `slurm` + +| Option | Type | Description | Default | +|--------|------|-------------|---------| +| `--manifest-file, -m` | File | Build manifest input file | `build_manifest.json` | +| `--output, -o` | File/Dir | Output file or directory | Type-specific | +| `--namespace` | String | Kubernetes namespace (k8s only) | `madengine` | +| `--environment` | String | SLURM environment (slurm only) | `default` | -Database Commands: - Available commands for database, such as creating and updating table in DB. +### Traditional CLI Commands - {create-table,update-table,upload-mongodb} - create-table Create table in DB - update-table Update table in DB - upload-mongodb Update table in DB +#### Model Operations +```bash +madengine run --tags [OPTIONS] +madengine discover --tags [OPTIONS] +``` + +#### Reporting +```bash +madengine report to-html --csv-file-path +madengine report to-email --csv-file-path +madengine report update-perf --perf-csv ``` -### Database - Create Table +#### Database Operations +```bash +madengine database create-table +madengine database update-table --csv-file-path +madengine database upload-mongodb --type --file-path ``` -(venv) test-node:~/MAD$ madengine database create-table --help -usage: madengine database create-table [-h] [-v] -Create table in DB. +### Exit Codes + +| Code | Description | +|------|-------------| +| `0` | Success | +| `1` | General failure | +| `2` | Build failure | +| `3` | Execution failure | +| `4` | Invalid arguments | +| `5` | Configuration error | + +### Configuration Files + +#### Batch Manifest Format +```json +[ + { + "model_name": "model1", + "build_new": true, + "registry": "docker.io", + "registry_image": "org/model1:latest" + } +] +``` -optional arguments: - -h, --help show this help message and exit - -v, --verbose verbose output +#### Context Format +```json +{ + "gpu_vendor": "AMD|NVIDIA|INTEL", + "guest_os": "UBUNTU|CENTOS|ROCKY", + "timeout_multiplier": 2.0, + "tools": [{"name": "rocprof"}], + "docker_env_vars": {"VAR": "value"} +} ``` -### Database - Update Table +#### Inventory Format (SSH/Ansible) +```yaml +nodes: + - hostname: "node1" + address: "192.168.1.100" + username: "user" + ssh_key_path: "~/.ssh/id_rsa" + gpu_count: 4 + gpu_vendor: "AMD" ``` -(venv) test-node:~/MAD$ madengine database update-table --help -usage: madengine database update-table [-h] [--csv-file-path CSV_FILE_PATH] [--model-json-path MODEL_JSON_PATH] -Update table in DB. +#### Inventory Format (Kubernetes) +```yaml +pods: + - name: "madengine-pod" + resources: + requests: + amd.com/gpu: "2" + gpu_vendor: "AMD" +``` -optional arguments: - -h, --help show this help message and exit - --csv-file-path CSV_FILE_PATH - Path to the csv file - --model-json-path MODEL_JSON_PATH - Path to the model json file +#### Inventory Format (SLURM) +```yaml +slurm_cluster: + login_node: + hostname: "hpc-login" + address: "login.hpc.edu" + partitions: + - name: "gpu" + gpu_types: ["MI250X"] + gpu_vendor: "AMD" ``` -### Database - Upload MongoDB +## 🤝 Contributing + +We welcome contributions to madengine! This project follows modern Python development practices with comprehensive testing and code quality standards. + +### 🚀 Quick Start for Contributors + +```bash +# Fork and clone the repository +git clone https://github.com/yourusername/madengine.git +cd madengine + +# Create development environment +python3 -m venv venv && source venv/bin/activate +# Install in development mode with all tools +pip install -e ".[dev]" + +# Setup pre-commit hooks (recommended) +pre-commit install + +# Run tests to verify setup +pytest ``` -(venv) test-node:~/MAD$ madengine database upload-mongodb --help -usage: madengine database upload-mongodb [-h] [--type TYPE] [--file-path FILE_PATH] [--name NAME] -Update table in DB. +### 🧪 Development Workflow + +#### Testing +```bash +# Run full test suite +pytest + +# Run with coverage report +pytest --cov=src/madengine --cov-report=html -optional arguments: - -h, --help show this help message and exit - --type TYPE type of document to upload: job or run - --file-path FILE_PATH - total path to directory where perf_entry.csv, *env.csv, and *.log are stored - --name NAME name of model to upload +# Run specific test categories +pytest -m "not slow" # Skip slow tests +pytest tests/test_cli.py # Specific test file +pytest -k "test_build" # Tests matching pattern ``` -## Tools in madengine +#### Code Quality +```bash +# Format code +black src/ tests/ +isort src/ tests/ + +# Lint code +flake8 src/ tests/ + +# Type checking +mypy src/madengine -There are some tools distributed with madengine together. They work with madengine CLI to profile GPU and get trace of ROCm libraries. +# Run all quality checks +pre-commit run --all-files +``` -### Tools - GPU Info Profile +#### Documentation +```bash +# Build documentation locally +cd docs && make html -Profile GPU usage of running LLMs and Deep Learning models. +# Test documentation examples +python docs/test_examples.py +# Update API documentation +sphinx-apidoc -o docs/api src/madengine ``` -(venv) test-node:~/MAD$ madengine run --tags pyt_huggingface_bert --additional-context "{'guest_os': 'UBUNTU','tools': [{'name':'rocprof'}]}" + +### 📋 Contribution Guidelines + +#### Code Standards +- **Python Style:** Follow PEP 8 with Black formatting (88 character line length) +- **Type Hints:** Add type hints for all public functions and class methods +- **Docstrings:** Use Google-style docstrings for all modules, classes, and functions +- **Testing:** Maintain 95%+ test coverage for new code +- **Imports:** Use isort for consistent import ordering + +#### Commit Guidelines +- **Semantic Commits:** Use conventional commit format +- **Scope:** Include relevant scope (cli, runner, docs, etc.) +- **Description:** Clear, concise description of changes + +```bash +# Good commit examples +git commit -m "feat(cli): add SLURM runner support for HPC clusters" +git commit -m "fix(ssh): handle connection timeouts gracefully" +git commit -m "docs: update distributed execution examples" +git commit -m "test: add integration tests for Kubernetes runner" ``` -### Tools - Trace Libraries of ROCm +#### Pull Request Process +1. **Create Feature Branch:** `git checkout -b feature/your-feature-name` +2. **Write Tests:** Add comprehensive tests for new functionality +3. **Update Documentation:** Update relevant documentation and examples +4. **Run Quality Checks:** Ensure all tests pass and code quality checks succeed +5. **Create Pull Request:** Use the provided PR template +6. **Address Reviews:** Respond to review feedback promptly + +### 🎯 Areas for Contribution + +#### High Priority +- **Additional Runners:** Support for new distributed execution platforms +- **Performance Optimization:** Improve execution speed and resource utilization +- **Error Handling:** Enhanced error messages and recovery mechanisms +- **Testing:** Expand test coverage for edge cases and integration scenarios + +#### Medium Priority +- **CLI Enhancements:** New commands and improved user experience +- **Documentation:** Tutorials, guides, and API documentation improvements +- **Monitoring:** Advanced metrics and observability features +- **Configuration:** Simplified configuration management -Trace library usage of running LLMs and Deep Learning models. A demo of running model with tracing rocBlas. +#### Low Priority +- **UI Improvements:** Enhanced terminal output and progress indicators +- **Utilities:** Helper scripts and development tools +- **Examples:** Additional deployment scenarios and use cases +### � Bug Reports + +When reporting bugs, please include: + +```bash +# System information +madengine-cli --version +python --version +docker --version + +# Error reproduction +madengine-cli run --tags failing_model --verbose > debug.log 2>&1 + +# Environment details +madengine-cli run --tags dummy --sys-env-details --summary-output env_info.json ``` -(venv) test-node:~/MAD$ madengine run --tags pyt_huggingface_bert --additional-context "{'guest_os': 'UBUNTU','tools': [{'name':'rocblas_trace'}]}" + +**Bug Report Template:** +- **Description:** Clear description of the issue +- **Steps to Reproduce:** Minimal steps to reproduce the problem +- **Expected Behavior:** What should happen +- **Actual Behavior:** What actually happens +- **Environment:** OS, Python version, Docker version, madengine version +- **Logs:** Relevant log output with `--verbose` enabled + +### 💡 Feature Requests + +For feature requests, please provide: +- **Use Case:** Detailed description of the use case +- **Proposed Solution:** How you envision the feature working +- **Alternatives:** Any alternative solutions you've considered +- **Impact:** Who would benefit from this feature + +### 🏗️ Development Environment + +#### System Requirements +- **Python 3.8+** with pip and venv +- **Docker** with GPU support (for testing containerized execution) +- **Git** for version control +- **Optional:** Kubernetes cluster, SLURM cluster, or SSH-accessible nodes for distributed testing + +#### IDE Configuration +**VS Code (Recommended):** +```json +// .vscode/settings.json +{ + "python.defaultInterpreterPath": "./venv/bin/python", + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "python.formatting.provider": "black", + "python.sortImports.args": ["--profile", "black"] +} ``` -## Environment Variables +**PyCharm:** +- Set interpreter to project venv +- Enable Black as code formatter +- Configure isort with Black profile +- Enable flake8 as linter -Madengine also exposes environment variables to allow for models location setting or data loading at DLM/MAD runtime. +### 🔧 Architecture Understanding -| Field | Description | -|-----------------------------| ----------------------------------------------------------------------------------| -| MODEL_DIR | the location of models dir | -| PUBLIC_GITHUB_ROCM_KEY | username and token of GitHub | -| MAD_AWS_S3 | the username and password of AWS S3 | -| NAS_NODES | the list of credentials of NAS Nodes | +#### Key Components +- **CLI Layer:** Typer+Rich for modern CLI interface (`mad_cli.py`) +- **Orchestrator:** Core workflow orchestration (`orchestrator.py`) +- **Runners:** Distributed execution implementations (`runners/`) +- **Discovery:** Model discovery system (`discover.py`) +- **Container:** Docker integration (`container_runner.py`) -Examples for running models using environment variables. +#### Testing Philosophy +- **Unit Tests:** Fast, isolated tests for individual components +- **Integration Tests:** End-to-end workflow testing +- **Mock-Heavy:** Extensive use of mocks for external dependencies +- **GPU-Aware:** Tests automatically adapt to available hardware + +### 📞 Getting Help + +- **GitHub Issues:** https://github.com/ROCm/madengine/issues +- **Discussions:** https://github.com/ROCm/madengine/discussions +- **ROCm Community:** https://rocm.docs.amd.com/en/latest/ +- **Documentation:** https://github.com/ROCm/madengine/tree/main/docs + +### 🙏 Recognition + +Contributors are recognized in: +- **CHANGELOG.md:** All contributions documented +- **GitHub Contributors:** Automatic recognition +- **Release Notes:** Major contributions highlighted +- **Documentation:** Author attribution where appropriate + +## 📄 License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +--- + +## 📖 Additional Resources + +### SLURM Runner Quick Reference + +For users working with HPC clusters, the SLURM runner provides a two-step workflow: + +#### Step 1: Generate SLURM Configuration ```bash -# Apply AWS S3 -MAD_AWS_S3='{"USERNAME":"username","PASSWORD":"password"}' madengine run --tags dummy_data_aws --live-output +madengine-cli generate slurm \ + --manifest-file build_manifest.json \ + --environment prod \ + --output-dir slurm-setup +``` -# Apply customized NAS -NAS_NODES=[{"HOST":"hostname","PORT":"22","USERNAME":"username","PASSWORD":"password"}] madengine run --tags dummy_data_austin_nas --live-output +#### Step 2: Execute SLURM Workload +```bash +madengine-cli runner slurm \ + --inventory slurm_inventory.yml \ + --job-scripts-dir slurm-setup \ + --timeout 14400 ``` -## Unit Test -Run pytest to validate unit tests of MAD Engine. +**Key Features:** +- Job arrays for parallel model execution +- Automated MAD environment setup on shared filesystems +- Integration with HPC module systems +- Resource management across SLURM partitions + +### Legacy Command Reference + +For compatibility with existing workflows: + +```bash +# Model execution +madengine run --tags pyt_huggingface_bert --live-output + +# Model discovery +madengine discover --tags dummy2:dummy_2 +# Report generation +madengine report to-html --csv-file-path perf.csv + +# Database operations +madengine database create-table ``` -pytest -v -s + +### Migration Guide + +**From Legacy to Modern CLI:** +```bash +# Old approach +madengine run --tags models --live-output + +# New approach +madengine-cli run --tags models --live-output --verbose ``` + +**Key Advantages of Modern CLI:** +- Rich terminal output with progress bars and panels +- Distributed execution across SSH, Ansible, Kubernetes, SLURM +- Advanced error handling with helpful suggestions +- Intelligent workflow detection (build vs. run phases) +- Comprehensive validation and configuration management + +--- + +## 🚀 Project Status + +### Current Implementation Status + +✅ **Production Ready** +- Dual CLI interface (traditional + modern) +- Distributed runners (SSH, Ansible, Kubernetes, SLURM) +- Model discovery (static, directory-specific, dynamic) +- Comprehensive error handling with Rich formatting +- Extensive testing infrastructure (95%+ coverage) +- Complete documentation and API reference + +🔄 **Active Development** +- Performance optimization for large-scale deployments +- Enhanced monitoring and observability features +- Configuration management simplification +- Additional runner implementations + +⚠️ **Known Considerations** +- Maintaining dual CLI implementations for compatibility +- Complex configuration file ecosystem +- Some orchestrator methods could benefit from refactoring + +### Roadmap + +**Short Term (Next Release)** +- CLI consolidation while maintaining backward compatibility +- Performance optimizations for distributed execution +- Enhanced error reporting and debugging tools + +**Medium Term** +- Unified configuration management system +- Advanced metrics and monitoring dashboard +- Additional cloud provider integrations + +**Long Term** +- Machine learning model recommendation system +- Automated performance optimization +- Integration with popular ML frameworks and platforms + +--- + +**Note:** Model names and tags cannot contain backslash '/' or colon ':' characters, as these are reserved for the hierarchical tag system (`directory:model:parameter=value`). diff --git a/docs/img/architecture_overview.png b/docs/img/architecture_overview.png new file mode 100755 index 00000000..7bf972b3 Binary files /dev/null and b/docs/img/architecture_overview.png differ diff --git a/docs/img/distributed_workflow.png b/docs/img/distributed_workflow.png new file mode 100755 index 00000000..a6723b44 Binary files /dev/null and b/docs/img/distributed_workflow.png differ diff --git a/pyproject.toml b/pyproject.toml index 00e9011d..bc7e7a26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,11 @@ dependencies = [ "typing-extensions", "pymongo", "toml", + "typer>=0.9.0", + "rich>=13.0.0", + "click>=8.0.0", + "jinja2>=3.0.0", + "pyyaml>=6.0", ] classifiers = [ "Programming Language :: Python :: 3", @@ -34,12 +39,13 @@ classifiers = [ [project.scripts] madengine = "madengine.mad:main" +madengine-cli = "madengine.mad_cli:cli_main" [project.urls] Homepage = "https://github.com/ROCm/madengine" Issues = "https://github.com/ROCm/madengine/issues" -[project.extras] +[project.optional-dependencies] dev = [ "pytest", "pytest-cov", @@ -47,6 +53,54 @@ dev = [ "pytest-timeout", "pytest-mock", "pytest-asyncio", + "black>=21.0.0", + "flake8", + "mypy>=0.910", + "isort", + "pre-commit", +] +# Optional dependencies for distributed runners +ssh = [ + "paramiko>=2.7.0", + "scp>=0.14.0", +] +ansible = [ + "ansible>=4.0.0", + "ansible-runner>=2.0.0", + "PyYAML>=6.0", +] +kubernetes = [ + "kubernetes>=20.0.0", + "PyYAML>=6.0", +] +# All runner dependencies +runners = [ + "paramiko>=2.7.0", + "scp>=0.14.0", + "ansible>=4.0.0", + "ansible-runner>=2.0.0", + "kubernetes>=20.0.0", + "PyYAML>=6.0", +] +# Complete development environment +all = [ + "paramiko>=2.7.0", + "scp>=0.14.0", + "ansible>=4.0.0", + "ansible-runner>=2.0.0", + "kubernetes>=20.0.0", + "PyYAML>=6.0", + "pytest", + "pytest-cov", + "pytest-xdist", + "pytest-timeout", + "pytest-mock", + "pytest-asyncio", + "black>=21.0.0", + "flake8", + "mypy>=0.910", + "isort", + "pre-commit", ] [tool.hatch.build.targets.wheel] @@ -68,3 +122,87 @@ regex = "v(?P.*)" distance = "{base_version}.post{distance}+{vcs}{rev}" dirty = "{base_version}+d{build_date:%Y%m%d}" distance-dirty = "{base_version}.post{distance}+{vcs}{rev}.d{build_date:%Y%m%d}" + +# Code formatting and linting configuration +[tool.black] +line-length = 88 +target-version = ['py38', 'py39', 'py310', 'py311'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +multi_line_output = 3 +line_length = 88 +known_first_party = ["madengine"] +known_third_party = ["pytest", "pandas", "numpy", "sqlalchemy"] + +[tool.mypy] +python_version = "3.8" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = true +disallow_untyped_decorators = false +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +strict_equality = true + +[[tool.mypy.overrides]] +module = [ + "paramiko.*", + "pymongo.*", + "mysql.connector.*", + "pymysql.*", + "toml.*", + "jsondiff.*", + "git.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_paths = ["src"] +addopts = "-v --tb=short" +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", +] + +[tool.coverage.run] +source = ["src/madengine"] +omit = [ + "*/tests/*", + "*/test_*", + "*/__pycache__/*", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if self.debug:", + "if settings.DEBUG", + "raise AssertionError", + "raise NotImplementedError", + "if 0:", + "if __name__ == .__main__.:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod", +] diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..a45628ee --- /dev/null +++ b/setup.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +""" +Simplified setup.py for madengine + +This setup.py provides compatibility with environments that require traditional +setup.py installations while reading configuration from pyproject.toml. + +For modern installations, prefer: + pip install . + python -m build + pip install -e .[dev] + +For legacy compatibility: + python setup.py install + python setup.py develop + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import sys +from pathlib import Path + +try: + from setuptools import setup, find_packages +except ImportError: + print("setuptools is required for setup.py") + print("Install it using: pip install setuptools") + sys.exit(1) + +def read_readme(readme_file="README.md"): + """Read README.md file for long description.""" + readme_path = Path(__file__).parent / readme_file + if readme_path.exists(): + with open(readme_path, "r", encoding="utf-8") as f: + return f.read() + + # Fallback to README.md if specified file doesn't exist + fallback_path = Path(__file__).parent / "README.md" + if fallback_path.exists() and readme_file != "README.md": + with open(fallback_path, "r", encoding="utf-8") as f: + return f.read() + + return "" + +def get_config_from_pyproject(): + """Read configuration from pyproject.toml.""" + try: + import tomllib + except ImportError: + try: + import tomli as tomllib + except ImportError: + try: + import toml as tomllib_alt + def load(f): + if hasattr(f, 'read'): + content = f.read() + if isinstance(content, bytes): + content = content.decode('utf-8') + return tomllib_alt.loads(content) + else: + return tomllib_alt.load(f) + tomllib.load = load + except ImportError: + print("Warning: No TOML library found. Using fallback configuration.") + return get_fallback_config() + + pyproject_path = Path(__file__).parent / "pyproject.toml" + if not pyproject_path.exists(): + print("Warning: pyproject.toml not found. Using fallback configuration.") + return get_fallback_config() + + try: + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) + + project = data.get("project", {}) + + # Extract configuration + config = { + "name": project.get("name", "madengine"), + "description": project.get("description", "MAD Engine"), + "authors": project.get("authors", []), + "dependencies": project.get("dependencies", []), + "optional_dependencies": project.get("optional-dependencies", {}), + "requires_python": project.get("requires-python", ">=3.8"), + "classifiers": project.get("classifiers", []), + "urls": project.get("urls", {}), + "scripts": project.get("scripts", {}), + "readme": project.get("readme", "README.md"), + } + + return config + + except Exception as e: + print(f"Warning: Could not read pyproject.toml: {e}") + return get_fallback_config() + +def get_fallback_config(): + """Fallback configuration if pyproject.toml cannot be read.""" + return { + "name": "madengine", + "description": "MAD Engine is a set of interfaces to run various AI models from public MAD.", + "authors": [{"name": "Advanced Micro Devices", "email": "mad.support@amd.com"}], + "dependencies": [ + "pandas", "GitPython", "jsondiff", "sqlalchemy", "setuptools-rust", + "paramiko", "mysql-connector-python", "pymysql", "tqdm", "pytest", + "typing-extensions", "pymongo", "toml", + ], + "optional_dependencies": { + "dev": [ + "pytest", "pytest-cov", "pytest-xdist", "pytest-timeout", + "pytest-mock", "pytest-asyncio", + ] + }, + "requires_python": ">=3.8", + "classifiers": [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + "urls": { + "Homepage": "https://github.com/ROCm/madengine", + "Issues": "https://github.com/ROCm/madengine/issues", + }, + "scripts": { + "madengine": "madengine.mad:main" + }, + } + +def get_version(): + """Get version from git tags or fallback to a default.""" + try: + import subprocess + import re + + # Try to get version from git describe first (more accurate) + try: + result = subprocess.run( + ["git", "describe", "--tags", "--dirty", "--always", "--long"], + capture_output=True, text=True, timeout=10, cwd=Path(__file__).parent + ) + if result.returncode == 0: + version_str = result.stdout.strip() + + # Handle case where there are no tags yet + if not version_str or len(version_str.split('-')) < 3: + # Try to get just the commit hash + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + capture_output=True, text=True, timeout=10, cwd=Path(__file__).parent + ) + if result.returncode == 0: + commit = result.stdout.strip() + # Check if dirty + dirty_result = subprocess.run( + ["git", "diff-index", "--quiet", "HEAD", "--"], + capture_output=True, cwd=Path(__file__).parent + ) + is_dirty = dirty_result.returncode != 0 + if is_dirty: + return f"1.0.0.dev0+g{commit}.dirty" + else: + return f"1.0.0.dev0+g{commit}" + + # Clean up the version string to be PEP 440 compliant + if version_str.startswith('v'): + version_str = version_str[1:] + + # Handle patterns like "1.0.0-5-g1234567" or "1.0.0-5-g1234567-dirty" + match = re.match(r'^([^-]+)-(\d+)-g([a-f0-9]+)(-dirty)?$', version_str) + if match: + base_version, distance, commit, dirty = match.groups() + if distance == "0": + # Exact tag match + if dirty: + return f"{base_version}+dirty" + else: + return base_version + else: + # Post-release version + version_str = f"{base_version}.post{distance}+g{commit}" + if dirty: + version_str += ".dirty" + return version_str + + # Handle case where we just have a commit hash (no tags) + if re.match(r'^[a-f0-9]+(-dirty)?$', version_str): + clean_hash = version_str.replace('-dirty', '') + if '-dirty' in version_str: + return f"1.0.0.dev0+g{clean_hash}.dirty" + else: + return f"1.0.0.dev0+g{clean_hash}" + + return version_str + + except (subprocess.SubprocessError, FileNotFoundError): + pass + + # Fallback to short commit hash + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + capture_output=True, text=True, timeout=10, cwd=Path(__file__).parent + ) + if result.returncode == 0: + commit = result.stdout.strip() + return f"1.0.0.dev0+g{commit}" + + except Exception: + pass + + # Final fallback + return "1.0.0.dev0" + +def main(): + """Main setup function.""" + try: + config = get_config_from_pyproject() + + # Extract author information + authors = config.get("authors", []) + if authors: + author_name = authors[0].get("name", "Advanced Micro Devices") + author_email = authors[0].get("email", "mad.support@amd.com") + else: + author_name = "Advanced Micro Devices" + author_email = "mad.support@amd.com" + + # Extract scripts/entry points + scripts = config.get("scripts", {}) + entry_points = {"console_scripts": []} + for script_name, module_path in scripts.items(): + entry_points["console_scripts"].append(f"{script_name}={module_path}") + + # Find all packages + packages = find_packages(where="src") + if not packages: + print("Warning: No packages found in src/ directory") + # Fallback: look for madengine package specifically + import os + src_path = Path(__file__).parent / "src" + if (src_path / "madengine").exists(): + packages = ["madengine"] + [ + f"madengine.{name}" for name in find_packages(where="src/madengine") + ] + + # Setup package data to include scripts + package_data = {"madengine": ["scripts/**/*"]} + + # Check if scripts directory exists and add patterns accordingly + scripts_path = Path(__file__).parent / "src" / "madengine" / "scripts" + if scripts_path.exists(): + # Add more specific patterns to ensure all script files are included + package_data["madengine"].extend([ + "scripts/*", + "scripts/*/*", + "scripts/*/*/*", + "scripts/*/*/*/*", + ]) + + # Get version + version = get_version() + + # Setup configuration + setup_kwargs = { + "name": config["name"], + "version": version, + "author": author_name, + "author_email": author_email, + "description": config["description"], + "long_description": read_readme(config.get("readme", "README.md")), + "long_description_content_type": "text/markdown", + "url": config["urls"].get("Homepage", "https://github.com/ROCm/madengine"), + "project_urls": config["urls"], + "package_dir": {"": "src"}, + "packages": packages, + "install_requires": config["dependencies"], + "extras_require": config["optional_dependencies"], + "python_requires": config["requires_python"], + "entry_points": entry_points if entry_points["console_scripts"] else None, + "classifiers": config["classifiers"], + "include_package_data": True, + "package_data": package_data, + "zip_safe": False, + "platforms": ["any"], + } + + # Remove None values to avoid setuptools warnings + setup_kwargs = {k: v for k, v in setup_kwargs.items() if v is not None} + + # Print some info for debugging + if len(sys.argv) > 1 and any(arg in sys.argv for arg in ["--version", "--help", "--help-commands"]): + print(f"madengine version: {version}") + print(f"Found {len(packages)} packages") + if entry_points and entry_points["console_scripts"]: + print(f"Console scripts: {', '.join(entry_points['console_scripts'])}") + + setup(**setup_kwargs) + + except Exception as e: + print(f"Error during setup: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/src/madengine/__init__.py b/src/madengine/__init__.py index 8db410f6..f667022e 100644 --- a/src/madengine/__init__.py +++ b/src/madengine/__init__.py @@ -1,26 +1,22 @@ """ -Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -""" -r''' -# What is MADEngine? +MADEngine - AI Models automation and dashboarding command-line tool. -An AI Models automation and dashboarding command-line tool to run LLMs and Deep Learning models locally or remotelly with CI. -The MADEngine library is to support AI automation having following features: +An AI Models automation and dashboarding command-line tool to run LLMs and Deep Learning +models locally or remotely with CI. The MADEngine library supports AI automation with: - AI Models run reliably on supported platforms and drive software quality -- Simple, minimalistic, out-of-the-box solution that enable confidence on hardware and software stack +- Simple, minimalistic, out-of-the-box solution that enables confidence on hardware and software stack - Real-time, audience-relevant AI Models performance metrics tracking, presented in clear, intuitive manner - Best-practices for handling internal projects and external open-source projects +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" +from importlib.metadata import version, PackageNotFoundError -.. include:: ../../docs/how-to-build.md -.. include:: ../../docs/how-to-quick-start.md -.. include:: ../../docs/how-to-provide-contexts.md -.. include:: ../../docs/how-to-profile-a-model.md -.. include:: ../../docs/how-to-collect-competitive-library-perf.md -.. include:: ../../docs/how-to-contribute.md - -''' -from importlib.metadata import version +try: + __version__ = version("madengine") +except PackageNotFoundError: + # Package is not installed, use a default version + __version__ = "dev" -__version__ = version("madengine") \ No newline at end of file +__all__ = ["__version__"] diff --git a/src/madengine/core/console.py b/src/madengine/core/console.py index 9340924a..cee93c47 100644 --- a/src/madengine/core/console.py +++ b/src/madengine/core/console.py @@ -8,24 +8,23 @@ # built-in modules import subprocess import typing +import re + # third-party modules import typing_extensions class Console: """Class to run console commands. - + Attributes: shellVerbose (bool): The shell verbose flag. live_output (bool): The live output flag. """ - def __init__( - self, - shellVerbose: bool=True, - live_output: bool=False - ) -> None: + + def __init__(self, shellVerbose: bool = True, live_output: bool = False) -> None: """Constructor of the Console class. - + Args: shellVerbose (bool): The shell verbose flag. live_output (bool): The live output flag. @@ -33,17 +32,84 @@ def __init__( self.shellVerbose = shellVerbose self.live_output = live_output + def _highlight_docker_operations(self, command: str) -> str: + """Highlight docker push/pull/build/run operations for better visibility. + + Args: + command (str): The command to potentially highlight. + + Returns: + str: The highlighted command if it's a docker operation. + """ + # Check if this is a docker operation + docker_push_pattern = r"^docker\s+push\s+" + docker_pull_pattern = r"^docker\s+pull\s+" + docker_build_pattern = r"^docker\s+build\s+" + docker_run_pattern = r"^docker\s+run\s+" + + if re.match(docker_push_pattern, command, re.IGNORECASE): + return f"\n{'='*80}\n🚀 DOCKER PUSH OPERATION: {command}\n{'='*80}" + elif re.match(docker_pull_pattern, command, re.IGNORECASE): + return f"\n{'='*80}\n📥 DOCKER PULL OPERATION: {command}\n{'='*80}" + elif re.match(docker_build_pattern, command, re.IGNORECASE): + return f"\n{'='*80}\n🔨 DOCKER BUILD OPERATION: {command}\n{'='*80}" + elif re.match(docker_run_pattern, command, re.IGNORECASE): + return f"\n{'='*80}\n🏃 DOCKER RUN OPERATION: {command}\n{'='*80}" + + return command + + def _show_docker_completion(self, command: str, success: bool = True) -> None: + """Show completion message for docker operations. + + Args: + command (str): The command that was executed. + success (bool): Whether the operation was successful. + """ + docker_push_pattern = r"^docker\s+push\s+" + docker_pull_pattern = r"^docker\s+pull\s+" + docker_build_pattern = r"^docker\s+build\s+" + docker_run_pattern = r"^docker\s+run\s+" + + if re.match(docker_push_pattern, command, re.IGNORECASE): + if success: + print(f"✅ DOCKER PUSH COMPLETED SUCCESSFULLY") + print(f"{'='*80}\n") + else: + print(f"❌ DOCKER PUSH FAILED") + print(f"{'='*80}\n") + elif re.match(docker_pull_pattern, command, re.IGNORECASE): + if success: + print(f"✅ DOCKER PULL COMPLETED SUCCESSFULLY") + print(f"{'='*80}\n") + else: + print(f"❌ DOCKER PULL FAILED") + print(f"{'='*80}\n") + elif re.match(docker_build_pattern, command, re.IGNORECASE): + if success: + print(f"✅ DOCKER BUILD COMPLETED SUCCESSFULLY") + print(f"{'='*80}\n") + else: + print(f"❌ DOCKER BUILD FAILED") + print(f"{'='*80}\n") + elif re.match(docker_run_pattern, command, re.IGNORECASE): + if success: + print(f"✅ DOCKER RUN COMPLETED SUCCESSFULLY") + print(f"{'='*80}\n") + else: + print(f"❌ DOCKER RUN FAILED") + print(f"{'='*80}\n") + def sh( - self, - command: str, - canFail: bool=False, - timeout: int=60, - secret: bool=False, - prefix: str="", - env: typing.Optional[typing.Dict[str, str]]=None - ) -> str: + self, + command: str, + canFail: bool = False, + timeout: int = 60, + secret: bool = False, + prefix: str = "", + env: typing.Optional[typing.Dict[str, str]] = None, + ) -> str: """Run shell command. - + Args: command (str): The shell command. canFail (bool): The flag to allow failure. @@ -51,7 +117,7 @@ def sh( secret (bool): The flag to hide the command. prefix (str): The prefix of the output. env (typing_extensions.TypedDict): The environment variables. - + Returns: str: The output of the shell command. @@ -60,7 +126,8 @@ def sh( """ # Print the command if shellVerbose is True if self.shellVerbose and not secret: - print("> " + command, flush=True) + highlighted_command = self._highlight_docker_operations(command) + print("> " + highlighted_command, flush=True) # Run the shell command proc = subprocess.Popen( @@ -80,7 +147,12 @@ def sh( outs, errs = proc.communicate(timeout=timeout) else: outs = [] - for stdout_line in iter(lambda: proc.stdout.readline().encode('utf-8', errors='replace').decode('utf-8', errors='replace'), ""): + for stdout_line in iter( + lambda: proc.stdout.readline() + .encode("utf-8", errors="replace") + .decode("utf-8", errors="replace"), + "", + ): print(prefix + stdout_line, end="") outs.append(stdout_line) outs = "".join(outs) @@ -89,8 +161,14 @@ def sh( except subprocess.TimeoutExpired as exc: proc.kill() raise RuntimeError("Console script timeout") from exc - + # Check for failure + success = proc.returncode == 0 + + # Show docker operation completion status + if not secret: + self._show_docker_completion(command, success) + if proc.returncode != 0: if not canFail: if not secret: @@ -102,11 +180,9 @@ def sh( ) else: raise RuntimeError( - "Subprocess '" - + secret - + "' failed with exit code " + "Subprocess '***HIDDEN COMMAND***' failed with exit code " + str(proc.returncode) ) - + # Return the output return outs.strip() diff --git a/src/madengine/core/constants.py b/src/madengine/core/constants.py index c0cbd5c0..2bba883f 100644 --- a/src/madengine/core/constants.py +++ b/src/madengine/core/constants.py @@ -3,89 +3,219 @@ This module provides the constants used in the MAD Engine. +Environment Variables: + - MAD_VERBOSE_CONFIG: Set to "true" to enable verbose configuration logging + - MAD_SETUP_MODEL_DIR: Set to "true" to enable automatic MODEL_DIR setup during import + - MODEL_DIR: Path to model directory to copy to current working directory + - MAD_MINIO: JSON string with MinIO configuration + - MAD_AWS_S3: JSON string with AWS S3 configuration + - NAS_NODES: JSON string with NAS nodes configuration + - PUBLIC_GITHUB_ROCM_KEY: JSON string with GitHub token configuration + +Configuration Loading: + All configuration constants follow a priority order: + 1. Environment variables (as JSON strings) + 2. credential.json file + 3. Built-in defaults + + Invalid JSON in environment variables will fall back to defaults with error logging. + Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ # built-in modules import os import json +import logging + + +# Utility function for optional verbose logging of configuration +def _log_config_info(message: str, force_print: bool = False): + """Log configuration information either to logger or print if specified.""" + if force_print or os.environ.get("MAD_VERBOSE_CONFIG", "").lower() == "true": + print(message) + else: + logging.debug(message) + + # third-party modules from madengine.core.console import Console # Get the model directory, if it is not set, set it to None. MODEL_DIR = os.environ.get("MODEL_DIR") - -# MADEngine update -if MODEL_DIR: - # Copy MODEL_DIR to the current working directory. - cwd_path = os.getcwd() - print(f"Current working directory: {cwd_path}") - console = Console(live_output=True) - # copy the MODEL_DIR to the current working directory - console.sh(f"cp -vLR --preserve=all {MODEL_DIR}/* {cwd_path}") - print(f"Model dir: {MODEL_DIR} copied to current dir: {cwd_path}") - -# MADEngine update + + +def _setup_model_dir(): + """Setup model directory if MODEL_DIR environment variable is set.""" + if MODEL_DIR: + # Copy MODEL_DIR to the current working directory. + cwd_path = os.getcwd() + _log_config_info(f"Current working directory: {cwd_path}") + console = Console(live_output=True) + # copy the MODEL_DIR to the current working directory + console.sh(f"cp -vLR --preserve=all {MODEL_DIR}/* {cwd_path}") + _log_config_info(f"Model dir: {MODEL_DIR} copied to current dir: {cwd_path}") + + +# Only setup model directory if explicitly requested (when not just importing for constants) +if os.environ.get("MAD_SETUP_MODEL_DIR", "").lower() == "true": + _setup_model_dir() + +# MADEngine credentials configuration CRED_FILE = "credential.json" -try: - # read credentials - with open(CRED_FILE) as f: - CREDS = json.load(f) -except FileNotFoundError: - CREDS = {} -if "NAS_NODES" not in os.environ: - if "NAS_NODES" in CREDS: - NAS_NODES = CREDS["NAS_NODES"] +def _load_credentials(): + """Load credentials from file with proper error handling.""" + try: + # read credentials + with open(CRED_FILE) as f: + creds = json.load(f) + _log_config_info(f"Credentials loaded from {CRED_FILE}") + return creds + except FileNotFoundError: + _log_config_info(f"Credentials file {CRED_FILE} not found, using defaults") + return {} + except json.JSONDecodeError as e: + _log_config_info(f"Error parsing {CRED_FILE}: {e}, using defaults") + return {} + except Exception as e: + _log_config_info(f"Unexpected error loading {CRED_FILE}: {e}, using defaults") + return {} + + +CREDS = _load_credentials() + + +def _get_nas_nodes(): + """Initialize NAS_NODES configuration.""" + if "NAS_NODES" not in os.environ: + _log_config_info("NAS_NODES environment variable is not set.") + if "NAS_NODES" in CREDS: + _log_config_info("NAS_NODES loaded from credentials file.") + return CREDS["NAS_NODES"] + else: + _log_config_info("NAS_NODES is using default values.") + return [ + { + "NAME": "DEFAULT", + "HOST": "localhost", + "PORT": 22, + "USERNAME": "username", + "PASSWORD": "password", + } + ] else: - NAS_NODES = [{ - "NAME": "DEFAULT", - "HOST": "localhost", - "PORT": 22, - "USERNAME": "username", - "PASSWORD": "password", - }] -else: - NAS_NODES = json.loads(os.environ["NAS_NODES"]) - -# Check the MAD_AWS_S3 environment variable which is a dict, if it is not set, set its element to default values. -if "MAD_AWS_S3" not in os.environ: - # Check if the MAD_AWS_S3 is in the credentials.json file. - if "MAD_AWS_S3" in CREDS: - MAD_AWS_S3 = CREDS["MAD_AWS_S3"] + _log_config_info("NAS_NODES is loaded from env variables.") + try: + return json.loads(os.environ["NAS_NODES"]) + except json.JSONDecodeError as e: + _log_config_info( + f"Error parsing NAS_NODES environment variable: {e}, using defaults" + ) + return [ + { + "NAME": "DEFAULT", + "HOST": "localhost", + "PORT": 22, + "USERNAME": "username", + "PASSWORD": "password", + } + ] + + +NAS_NODES = _get_nas_nodes() + + +def _get_mad_aws_s3(): + """Initialize MAD_AWS_S3 configuration.""" + if "MAD_AWS_S3" not in os.environ: + _log_config_info("MAD_AWS_S3 environment variable is not set.") + if "MAD_AWS_S3" in CREDS: + _log_config_info("MAD_AWS_S3 loaded from credentials file.") + return CREDS["MAD_AWS_S3"] + else: + _log_config_info("MAD_AWS_S3 is using default values.") + return { + "USERNAME": None, + "PASSWORD": None, + } else: - MAD_AWS_S3 = { - "USERNAME": None, - "PASSWORD": None, - } -else: - MAD_AWS_S3 = json.loads(os.environ["MAD_AWS_S3"]) + _log_config_info("MAD_AWS_S3 is loaded from env variables.") + try: + return json.loads(os.environ["MAD_AWS_S3"]) + except json.JSONDecodeError as e: + _log_config_info( + f"Error parsing MAD_AWS_S3 environment variable: {e}, using defaults" + ) + return { + "USERNAME": None, + "PASSWORD": None, + } + + +MAD_AWS_S3 = _get_mad_aws_s3() + # Check the MAD_MINIO environment variable which is a dict. -if "MAD_MINIO" not in os.environ: - print("MAD_MINIO environment variable is not set.") - if "MAD_MINIO" in CREDS: - MAD_MINIO = CREDS["MAD_MINIO"] +def _get_mad_minio(): + """Initialize MAD_MINIO configuration.""" + if "MAD_MINIO" not in os.environ: + _log_config_info("MAD_MINIO environment variable is not set.") + if "MAD_MINIO" in CREDS: + _log_config_info("MAD_MINIO loaded from credentials file.") + return CREDS["MAD_MINIO"] + else: + _log_config_info("MAD_MINIO is using default values.") + return { + "USERNAME": None, + "PASSWORD": None, + "MINIO_ENDPOINT": "http://localhost:9000", + "AWS_ENDPOINT_URL_S3": "http://localhost:9000", + } else: - print("MAD_MINIO is using default values.") - MAD_MINIO = { - "USERNAME": None, - "PASSWORD": None, - "MINIO_ENDPOINT": "http://localhost:9000", - "AWS_ENDPOINT_URL_S3": "http://localhost:9000", - } -else: - print("MAD_MINIO is loaded from env variables.") - MAD_MINIO = json.loads(os.environ["MAD_MINIO"]) - -# Check the auth GitHub token environment variable which is a dict, if it is not set, set it to None. -if "PUBLIC_GITHUB_ROCM_KEY" not in os.environ: - if "PUBLIC_GITHUB_ROCM_KEY" in CREDS: - PUBLIC_GITHUB_ROCM_KEY = CREDS["PUBLIC_GITHUB_ROCM_KEY"] + _log_config_info("MAD_MINIO is loaded from env variables.") + try: + return json.loads(os.environ["MAD_MINIO"]) + except json.JSONDecodeError as e: + _log_config_info( + f"Error parsing MAD_MINIO environment variable: {e}, using defaults" + ) + return { + "USERNAME": None, + "PASSWORD": None, + "MINIO_ENDPOINT": "http://localhost:9000", + "AWS_ENDPOINT_URL_S3": "http://localhost:9000", + } + + +MAD_MINIO = _get_mad_minio() + + +def _get_public_github_rocm_key(): + """Initialize PUBLIC_GITHUB_ROCM_KEY configuration.""" + if "PUBLIC_GITHUB_ROCM_KEY" not in os.environ: + _log_config_info("PUBLIC_GITHUB_ROCM_KEY environment variable is not set.") + if "PUBLIC_GITHUB_ROCM_KEY" in CREDS: + _log_config_info("PUBLIC_GITHUB_ROCM_KEY loaded from credentials file.") + return CREDS["PUBLIC_GITHUB_ROCM_KEY"] + else: + _log_config_info("PUBLIC_GITHUB_ROCM_KEY is using default values.") + return { + "username": None, + "token": None, + } else: - PUBLIC_GITHUB_ROCM_KEY = { - "username": None, - "token": None, - } -else: - PUBLIC_GITHUB_ROCM_KEY = json.loads(os.environ["PUBLIC_GITHUB_ROCM_KEY"]) + _log_config_info("PUBLIC_GITHUB_ROCM_KEY is loaded from env variables.") + try: + return json.loads(os.environ["PUBLIC_GITHUB_ROCM_KEY"]) + except json.JSONDecodeError as e: + _log_config_info( + f"Error parsing PUBLIC_GITHUB_ROCM_KEY environment variable: {e}, using defaults" + ) + return { + "username": None, + "token": None, + } + + +PUBLIC_GITHUB_ROCM_KEY = _get_public_github_rocm_key() diff --git a/src/madengine/core/context.py b/src/madengine/core/context.py index b1c7c225..d5f06bce 100644 --- a/src/madengine/core/context.py +++ b/src/madengine/core/context.py @@ -18,6 +18,7 @@ import os import re import typing + # third-party modules from madengine.core.console import Console from madengine.utils.gpu_validator import validate_rocm_installation, GPUInstallationError @@ -25,11 +26,11 @@ def update_dict(d: typing.Dict, u: typing.Dict) -> typing.Dict: """Update dictionary. - + Args: d: The dictionary. u: The update dictionary. - + Returns: dict: The updated dictionary. """ @@ -45,11 +46,14 @@ def update_dict(d: typing.Dict, u: typing.Dict) -> typing.Dict: class Context: """Class to determine context. - + Attributes: console: The console. ctx: The context. - + _gpu_context_initialized: Flag to track if GPU context is initialized. + _system_context_initialized: Flag to track if system context is initialized. + _build_only_mode: Flag to indicate if running in build-only mode. + Methods: get_ctx_test: Get context test. get_gpu_vendor: Get GPU vendor. @@ -60,110 +64,275 @@ class Context: get_docker_gpus: Get Docker GPUs. get_gpu_renderD_nodes: Get GPU renderD nodes. set_multi_node_runner: Sets multi-node runner context. + init_system_context: Initialize system-specific context. + init_gpu_context: Initialize GPU-specific context for runtime. + init_build_context: Initialize build-specific context. + init_runtime_context: Initialize runtime-specific context. + ensure_system_context: Ensure system context is initialized. + ensure_runtime_context: Ensure runtime context is initialized. filter: Filter. """ + def __init__( - self, - additional_context: str=None, - additional_context_file: str=None - ) -> None: + self, + additional_context: str = None, + additional_context_file: str = None, + build_only_mode: bool = False, + ) -> None: """Constructor of the Context class. - + Args: additional_context: The additional context. additional_context_file: The additional context file. - + build_only_mode: Whether running in build-only mode (no GPU detection). + Raises: - RuntimeError: If the GPU vendor is not detected. - RuntimeError: If the GPU architecture is not detected. + RuntimeError: If GPU detection fails and not in build-only mode. """ # Initialize the console self.console = Console() + self._gpu_context_initialized = False + self._build_only_mode = build_only_mode + self._system_context_initialized = False - # Initialize the context + # Initialize base context self.ctx = {} - self.ctx["ctx_test"] = self.get_ctx_test() - self.ctx["host_os"] = self.get_host_os() - self.ctx["numa_balancing"] = self.get_numa_balancing() - # Check if NUMA balancing is enabled or disabled. - if self.ctx["numa_balancing"] == "1": - print("Warning: numa balancing is ON ...") - elif self.ctx["numa_balancing"] == "0": - print("Warning: numa balancing is OFF ...") - else: - print("Warning: unknown numa balancing setup ...") - # Keeping gpu_vendor for filterning purposes, if we filter using file names we can get rid of this attribute. - self.ctx["gpu_vendor"] = self.get_gpu_vendor() - - # Validate ROCm installation if AMD GPU is detected - if self.ctx["gpu_vendor"] == "AMD": - try: - validate_rocm_installation(verbose=False, raise_on_error=True) - except GPUInstallationError as e: - print("\n" + "="*70) - print("ERROR: ROCm Installation Validation Failed") - print("="*70) - print(str(e)) - print("="*70) - raise - - # Initialize the docker context + # Initialize docker contexts as empty - will be populated based on mode + self.ctx["docker_build_arg"] = {} self.ctx["docker_env_vars"] = {} - self.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] = self.ctx["gpu_vendor"] - self.ctx["docker_env_vars"]["MAD_SYSTEM_NGPUS"] = self.get_system_ngpus() - self.ctx["docker_env_vars"]["MAD_SYSTEM_GPU_ARCHITECTURE"] = self.get_system_gpu_architecture() - self.ctx["docker_env_vars"]["MAD_SYSTEM_GPU_PRODUCT_NAME"] = self.get_system_gpu_product_name() - self.ctx['docker_env_vars']['MAD_SYSTEM_HIP_VERSION'] = self.get_system_hip_version() - self.ctx["docker_build_arg"] = { - "MAD_SYSTEM_GPU_ARCHITECTURE": self.get_system_gpu_architecture(), - "MAD_SYSTEM_GPU_PRODUCT_NAME": self.get_system_gpu_product_name() - } - self.ctx["docker_gpus"] = self.get_docker_gpus() - self.ctx["gpu_renderDs"] = self.get_gpu_renderD_nodes() - - # Default multi-node configuration - self.ctx['multi_node_args'] = { - 'RUNNER': 'torchrun', - 'MAD_RUNTIME_NGPUS': self.ctx['docker_env_vars']['MAD_SYSTEM_NGPUS'], # Use system's GPU count - 'NNODES': 1, - 'NODE_RANK': 0, - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': 6006, - 'HOST_LIST': '', - 'NCCL_SOCKET_IFNAME': '', - 'GLOO_SOCKET_IFNAME': '' - } - - # Read and update MAD SECRETS env variable + + # Read and update MAD SECRETS env variable (can be used for both build and run) mad_secrets = {} for key in os.environ: if "MAD_SECRETS" in key: mad_secrets[key] = os.environ[key] if mad_secrets: - update_dict(self.ctx['docker_build_arg'], mad_secrets) - update_dict(self.ctx['docker_env_vars'], mad_secrets) - - ## ADD MORE CONTEXTS HERE ## + update_dict(self.ctx["docker_build_arg"], mad_secrets) + update_dict(self.ctx["docker_env_vars"], mad_secrets) - # additional contexts provided in file override detected contexts + # Additional contexts provided in file override detected contexts if additional_context_file: with open(additional_context_file) as f: update_dict(self.ctx, json.load(f)) - # additional contexts provided in command-line override detected contexts and contexts in file + # Additional contexts provided in command-line override detected contexts and contexts in file if additional_context: # Convert the string representation of python dictionary to a dictionary. dict_additional_context = ast.literal_eval(additional_context) - update_dict(self.ctx, dict_additional_context) - # Set multi-node runner after context update - self.ctx['docker_env_vars']['MAD_MULTI_NODE_RUNNER'] = self.set_multi_node_runner() + # Initialize context based on mode + # User-provided contexts will not be overridden by detection + if not build_only_mode: + # For full workflow mode, initialize everything (legacy behavior preserved) + self.init_runtime_context() + else: + # For build-only mode, only initialize what's needed for building + self.init_build_context() + + ## ADD MORE CONTEXTS HERE ## + + def init_build_context(self) -> None: + """Initialize build-specific context. + + This method sets up only the context needed for Docker builds, + avoiding GPU detection that would fail on build-only nodes. + System-specific contexts (host_os, numa_balancing, etc.) should be + provided via --additional-context for build-only nodes if needed. + """ + print("Initializing build-only context...") + + # Initialize only essential system contexts if not provided via additional_context + if "host_os" not in self.ctx: + try: + self.ctx["host_os"] = self.get_host_os() + print(f"Detected host OS: {self.ctx['host_os']}") + except Exception as e: + print(f"Warning: Could not detect host OS on build node: {e}") + print( + "Consider providing host_os via --additional-context if needed for build" + ) + + # Don't detect GPU-specific contexts in build-only mode + # These should be provided via additional_context if needed for build args + if "MAD_SYSTEM_GPU_ARCHITECTURE" not in self.ctx.get("docker_build_arg", {}): + print( + "Info: MAD_SYSTEM_GPU_ARCHITECTURE not provided - should be set via --additional-context for GPU-specific builds" + ) + + # Handle multi-node configuration for build phase + self._setup_build_multi_node_context() + + # Don't initialize NUMA balancing check for build-only nodes + # This is runtime-specific and should be handled on execution nodes + + def init_runtime_context(self) -> None: + """Initialize runtime-specific context. + + This method sets up the full context including system and GPU detection + for nodes that will run containers. + """ + print("Initializing runtime context with system and GPU detection...") + + # Initialize system context first + self.init_system_context() + + # Initialize GPU context + self.init_gpu_context() + + # Setup runtime multi-node runner + self._setup_runtime_multi_node_context() + + def init_system_context(self) -> None: + """Initialize system-specific context. + + This method detects system configuration like OS, NUMA balancing, etc. + Should be called on runtime nodes to get actual execution environment context. + """ + if self._system_context_initialized: + return + + print("Detecting system configuration...") + + try: + # Initialize system contexts if not already provided via additional_context + if "ctx_test" not in self.ctx: + self.ctx["ctx_test"] = self.get_ctx_test() + + if "host_os" not in self.ctx: + self.ctx["host_os"] = self.get_host_os() + print(f"Detected host OS: {self.ctx['host_os']}") + + if "numa_balancing" not in self.ctx: + self.ctx["numa_balancing"] = self.get_numa_balancing() + + # Check if NUMA balancing is enabled or disabled. + if self.ctx["numa_balancing"] == "1": + print("Warning: numa balancing is ON ...") + elif self.ctx["numa_balancing"] == "0": + print("Warning: numa balancing is OFF ...") + else: + print("Warning: unknown numa balancing setup ...") + + self._system_context_initialized = True + + except Exception as e: + print(f"Warning: System context detection failed: {e}") + if not self._build_only_mode: + raise RuntimeError( + f"System context detection failed on runtime node: {e}" + ) + + def init_gpu_context(self) -> None: + """Initialize GPU-specific context for runtime. + + This method detects GPU configuration and sets up environment variables + needed for container execution. Should only be called on GPU nodes. + User-provided GPU contexts will not be overridden. + + Raises: + RuntimeError: If GPU detection fails. + """ + if self._gpu_context_initialized: + return + + print("Detecting GPU configuration...") + + try: + # GPU vendor detection - only if not provided by user + if "gpu_vendor" not in self.ctx: + self.ctx["gpu_vendor"] = self.get_gpu_vendor() + print(f"Detected GPU vendor: {self.ctx['gpu_vendor']}") + else: + print(f"Using provided GPU vendor: {self.ctx['gpu_vendor']}") + + # Initialize docker env vars for runtime - only if not already set + if "MAD_GPU_VENDOR" not in self.ctx["docker_env_vars"]: + self.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] = self.ctx["gpu_vendor"] + + if "MAD_SYSTEM_NGPUS" not in self.ctx["docker_env_vars"]: + self.ctx["docker_env_vars"][ + "MAD_SYSTEM_NGPUS" + ] = self.get_system_ngpus() + + if "MAD_SYSTEM_GPU_ARCHITECTURE" not in self.ctx["docker_env_vars"]: + self.ctx["docker_env_vars"][ + "MAD_SYSTEM_GPU_ARCHITECTURE" + ] = self.get_system_gpu_architecture() + + if "MAD_SYSTEM_HIP_VERSION" not in self.ctx["docker_env_vars"]: + self.ctx["docker_env_vars"][ + "MAD_SYSTEM_HIP_VERSION" + ] = self.get_system_hip_version() + + if "MAD_SYSTEM_GPU_PRODUCT_NAME" not in self.ctx["docker_env_vars"]: + self.ctx["docker_env_vars"][ + "MAD_SYSTEM_GPU_PRODUCT_NAME" + ] = self.get_system_gpu_product_name() + + # Also add to build args (for runtime builds) - only if not already set + if "MAD_SYSTEM_GPU_ARCHITECTURE" not in self.ctx["docker_build_arg"]: + self.ctx["docker_build_arg"]["MAD_SYSTEM_GPU_ARCHITECTURE"] = self.ctx[ + "docker_env_vars" + ]["MAD_SYSTEM_GPU_ARCHITECTURE"] + + # Docker GPU configuration - only if not already set + if "docker_gpus" not in self.ctx: + self.ctx["docker_gpus"] = self.get_docker_gpus() + + if "gpu_renderDs" not in self.ctx: + self.ctx["gpu_renderDs"] = self.get_gpu_renderD_nodes() + + # Default multi-node configuration - only if not already set + if "multi_node_args" not in self.ctx: + self.ctx["multi_node_args"] = { + "RUNNER": "torchrun", + "MAD_RUNTIME_NGPUS": self.ctx["docker_env_vars"][ + "MAD_SYSTEM_NGPUS" + ], # Use system's GPU count + "NNODES": 1, + "NODE_RANK": 0, + "MASTER_ADDR": "localhost", + "MASTER_PORT": 6006, + "HOST_LIST": "", + "NCCL_SOCKET_IFNAME": "", + "GLOO_SOCKET_IFNAME": "", + } + + self._gpu_context_initialized = True + + except Exception as e: + if self._build_only_mode: + print( + f"Warning: GPU detection failed in build-only mode (expected): {e}" + ) + else: + raise RuntimeError(f"GPU detection failed: {e}") + + def ensure_runtime_context(self) -> None: + """Ensure runtime context is initialized. + + This method should be called before any runtime operations + that require system and GPU context. + """ + if not self._system_context_initialized and not self._build_only_mode: + self.init_system_context() + if not self._gpu_context_initialized and not self._build_only_mode: + self.init_gpu_context() + + def ensure_system_context(self) -> None: + """Ensure system context is initialized. + + This method should be called when system context is needed + but may not be initialized (e.g., in build-only mode). + """ + if not self._system_context_initialized: + self.init_system_context() def get_ctx_test(self) -> str: """Get context test. - + Returns: str: The output of the shell command. @@ -177,13 +346,13 @@ def get_ctx_test(self) -> str: def get_gpu_vendor(self) -> str: """Get GPU vendor. - + Returns: str: The output of the shell command. - + Raises: RuntimeError: If the GPU vendor is unable to detect. - + Note: What types of GPU vendors are supported? - NVIDIA @@ -196,10 +365,10 @@ def get_gpu_vendor(self) -> str: def get_host_os(self) -> str: """Get host OS. - + Returns: str: The output of the shell command. - + Raises: RuntimeError: If the host OS is unable to detect. @@ -216,7 +385,7 @@ def get_host_os(self) -> str: def get_numa_balancing(self) -> bool: """Get NUMA balancing. - + Returns: bool: The output of the shell command. @@ -225,9 +394,9 @@ def get_numa_balancing(self) -> bool: Note: NUMA balancing is enabled if the output is '1', and disabled if the output is '0'. - + What is NUMA balancing? - Non-Uniform Memory Access (NUMA) is a computer memory design used in multiprocessing, + Non-Uniform Memory Access (NUMA) is a computer memory design used in multiprocessing, where the memory access time depends on the memory location relative to the processor. """ # Check if NUMA balancing is enabled or disabled. @@ -239,13 +408,13 @@ def get_numa_balancing(self) -> bool: def get_system_ngpus(self) -> int: """Get system number of GPUs. - + Returns: int: The number of GPUs. - + Raises: - RuntimeError: If the GPU vendor is not detected or GPU count cannot be determined. - + RuntimeError: If the GPU vendor is not detected. + Note: What types of GPU vendors are supported? - NVIDIA @@ -274,14 +443,14 @@ def get_system_ngpus(self) -> int: def get_system_gpu_architecture(self) -> str: """Get system GPU architecture. - + Returns: str: The GPU architecture. - + Raises: RuntimeError: If the GPU vendor is not detected. RuntimeError: If the GPU architecture is unable to determine. - + Note: What types of GPU vendors are supported? - NVIDIA @@ -348,7 +517,7 @@ def get_system_hip_version(self): def get_docker_gpus(self) -> typing.Optional[str]: """Get Docker GPUs. - + Returns: str: The range of GPUs. """ @@ -360,7 +529,7 @@ def get_docker_gpus(self) -> typing.Optional[str]: def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]: """Get GPU renderD nodes from KFD properties. - + Returns: list: The list of GPU renderD nodes, or None if not AMD GPU. @@ -539,9 +708,11 @@ def set_multi_node_runner(self) -> str: environment variable settings. """ # NOTE: mpirun is untested - if self.ctx["multi_node_args"]["RUNNER"] == 'mpirun': + if self.ctx["multi_node_args"]["RUNNER"] == "mpirun": if not self.ctx["multi_node_args"]["HOST_LIST"]: - self.ctx["multi_node_args"]["HOST_LIST"] = f"localhost:{self.ctx['multi_node_args']['MAD_RUNTIME_NGPUS']}" + self.ctx["multi_node_args"][ + "HOST_LIST" + ] = f"localhost:{self.ctx['multi_node_args']['MAD_RUNTIME_NGPUS']}" multi_node_runner = ( f"mpirun -np {self.ctx['multi_node_args']['NNODES'] * self.ctx['multi_node_args']['MAD_RUNTIME_NGPUS']} " f"--host {self.ctx['multi_node_args']['HOST_LIST']}" @@ -565,12 +736,161 @@ def set_multi_node_runner(self) -> str: return multi_node_runner + def _setup_build_multi_node_context(self) -> None: + """Setup multi-node context for build phase. + + This method handles multi-node configuration during build phase, + storing the configuration for inclusion in the manifest without requiring + runtime GPU detection. The multi_node_args will be preserved as-is and + MAD_MULTI_NODE_RUNNER will be generated at runtime. + """ + if "multi_node_args" in self.ctx: + print("Setting up multi-node context for build phase...") + + # Store the complete multi_node_args structure (excluding MAD_RUNTIME_NGPUS) + # This will be included in build_manifest.json and used at runtime + build_multi_node_args = {} + for key, value in self.ctx["multi_node_args"].items(): + # Skip MAD_RUNTIME_NGPUS as it's runtime-specific - will be set at runtime + if key != "MAD_RUNTIME_NGPUS": + build_multi_node_args[key] = value + + # Store the multi_node_args for inclusion in the manifest + # This will be accessible in build_manifest.json under context + self.ctx["build_multi_node_args"] = build_multi_node_args + + # Remove any individual MAD_MULTI_NODE_* env vars from docker_env_vars + # Only structured multi_node_args should be stored in the manifest + env_vars_to_remove = [] + for env_var in self.ctx.get("docker_env_vars", {}): + if ( + env_var.startswith("MAD_MULTI_NODE_") + and env_var != "MAD_MULTI_NODE_RUNNER" + ): + env_vars_to_remove.append(env_var) + + for env_var in env_vars_to_remove: + del self.ctx["docker_env_vars"][env_var] + print( + f"Removed {env_var} from docker_env_vars - will be reconstructed at runtime" + ) + + print( + f"Multi-node configuration stored for runtime: {list(build_multi_node_args.keys())}" + ) + print("MAD_RUNTIME_NGPUS will be resolved at runtime phase") + + def _create_build_multi_node_runner_template(self) -> str: + """Create a build-time multi-node runner command template. + + This creates a command template that uses environment variable substitution + for runtime-specific values like MAD_RUNTIME_NGPUS. + + Returns: + str: Command template string with environment variable placeholders + """ + runner = self.ctx["multi_node_args"].get("RUNNER", "torchrun") + + if runner == "mpirun": + # For mpirun, construct command with runtime substitution + host_list = self.ctx["multi_node_args"].get("HOST_LIST", "") + if not host_list: + # Use runtime GPU count substitution + multi_node_runner = ( + "mpirun -np $(($MAD_MULTI_NODE_NNODES * ${MAD_RUNTIME_NGPUS:-1})) " + "--host ${MAD_MULTI_NODE_HOST_LIST:-localhost:${MAD_RUNTIME_NGPUS:-1}}" + ) + else: + multi_node_runner = ( + "mpirun -np $(($MAD_MULTI_NODE_NNODES * ${MAD_RUNTIME_NGPUS:-1})) " + f"--host {host_list}" + ) + else: + # For torchrun, use environment variable substitution + distributed_args = ( + "--nproc_per_node ${MAD_RUNTIME_NGPUS:-1} " + "--nnodes ${MAD_MULTI_NODE_NNODES:-1} " + "--node_rank ${MAD_MULTI_NODE_NODE_RANK:-0} " + "--master_addr ${MAD_MULTI_NODE_MASTER_ADDR:-localhost} " + "--master_port ${MAD_MULTI_NODE_MASTER_PORT:-6006}" + ) + multi_node_runner = f"torchrun {distributed_args}" + + # Add NCCL and GLOO interface environment variables with conditional setting + nccl_var = "${MAD_MULTI_NODE_NCCL_SOCKET_IFNAME:+NCCL_SOCKET_IFNAME=$MAD_MULTI_NODE_NCCL_SOCKET_IFNAME}" + gloo_var = "${MAD_MULTI_NODE_GLOO_SOCKET_IFNAME:+GLOO_SOCKET_IFNAME=$MAD_MULTI_NODE_GLOO_SOCKET_IFNAME}" + + multi_node_runner = f"{nccl_var} {gloo_var} {multi_node_runner}" + + return multi_node_runner + + def _setup_runtime_multi_node_context(self) -> None: + """Setup runtime multi-node context. + + This method handles multi-node configuration during runtime phase, + setting MAD_RUNTIME_NGPUS and creating the final MAD_MULTI_NODE_RUNNER. + """ + # Set MAD_RUNTIME_NGPUS for runtime based on detected GPU count + if "MAD_RUNTIME_NGPUS" not in self.ctx["docker_env_vars"]: + runtime_ngpus = self.ctx["docker_env_vars"].get("MAD_SYSTEM_NGPUS", 1) + self.ctx["docker_env_vars"]["MAD_RUNTIME_NGPUS"] = runtime_ngpus + print(f"Set MAD_RUNTIME_NGPUS to {runtime_ngpus} for runtime") + + # If we have multi_node_args from build phase or runtime, ensure MAD_RUNTIME_NGPUS is set + if "multi_node_args" in self.ctx: + # Add MAD_RUNTIME_NGPUS to multi_node_args if not already present + if "MAD_RUNTIME_NGPUS" not in self.ctx["multi_node_args"]: + self.ctx["multi_node_args"]["MAD_RUNTIME_NGPUS"] = self.ctx[ + "docker_env_vars" + ]["MAD_RUNTIME_NGPUS"] + + # If we have build_multi_node_args from manifest, reconstruct full multi_node_args + elif "build_multi_node_args" in self.ctx: + print("Reconstructing multi_node_args from build manifest...") + self.ctx["multi_node_args"] = self.ctx["build_multi_node_args"].copy() + self.ctx["multi_node_args"]["MAD_RUNTIME_NGPUS"] = self.ctx[ + "docker_env_vars" + ]["MAD_RUNTIME_NGPUS"] + + # Generate MAD_MULTI_NODE_RUNNER if we have multi_node_args + if "multi_node_args" in self.ctx: + print("Creating MAD_MULTI_NODE_RUNNER with runtime values...") + + # Set individual MAD_MULTI_NODE_* environment variables for runtime execution + # These are needed by the bash scripts that use the template runner command + multi_node_mapping = { + "NNODES": "MAD_MULTI_NODE_NNODES", + "NODE_RANK": "MAD_MULTI_NODE_NODE_RANK", + "MASTER_ADDR": "MAD_MULTI_NODE_MASTER_ADDR", + "MASTER_PORT": "MAD_MULTI_NODE_MASTER_PORT", + "NCCL_SOCKET_IFNAME": "MAD_MULTI_NODE_NCCL_SOCKET_IFNAME", + "GLOO_SOCKET_IFNAME": "MAD_MULTI_NODE_GLOO_SOCKET_IFNAME", + "HOST_LIST": "MAD_MULTI_NODE_HOST_LIST", + } + + for multi_node_key, env_var_name in multi_node_mapping.items(): + if multi_node_key in self.ctx["multi_node_args"]: + self.ctx["docker_env_vars"][env_var_name] = str( + self.ctx["multi_node_args"][multi_node_key] + ) + print( + f"Set {env_var_name} to {self.ctx['multi_node_args'][multi_node_key]} for runtime" + ) + + # Generate the MAD_MULTI_NODE_RUNNER command + self.ctx["docker_env_vars"][ + "MAD_MULTI_NODE_RUNNER" + ] = self.set_multi_node_runner() + print( + f"MAD_MULTI_NODE_RUNNER: {self.ctx['docker_env_vars']['MAD_MULTI_NODE_RUNNER']}" + ) + def filter(self, unfiltered: typing.Dict) -> typing.Dict: """Filter the unfiltered dictionary based on the context. - + Args: unfiltered: The unfiltered dictionary. - + Returns: dict: The filtered dictionary. """ diff --git a/src/madengine/core/dataprovider.py b/src/madengine/core/dataprovider.py index 29e675fe..d552b3fd 100644 --- a/src/madengine/core/dataprovider.py +++ b/src/madengine/core/dataprovider.py @@ -118,7 +118,7 @@ def prepare_data(self, model_docker: Docker) -> bool: Args: model_docker: The model docker object - + Returns: bool: The status of preparing the data """ @@ -135,23 +135,19 @@ class CustomDataProvider(DataProvider): provider_type = "custom" - def __init__( - self, - dataname: str, - config: typing.Dict - ) -> None: + def __init__(self, dataname: str, config: typing.Dict) -> None: """Constructor of the CustomDataProvider class.""" super().__init__(dataname, config) def check_source(self, config: typing.Dict) -> bool: """Check if the data source is valid - + Args: config (dict): Configuration of the data provider - + Returns: bool: The status of the data source - + Raises: RuntimeError: Raised when the mirrorlocal path is a non-existent path """ @@ -165,7 +161,7 @@ def check_source(self, config: typing.Dict) -> bool: os.makedirs( self.config["mirrorlocal"] + "/" + self.dataname, exist_ok=True ) - + # get the base directory of the current file. BASE_DIR = os.path.dirname(os.path.realpath(__file__)) print("DEBUG - BASE_DIR::", BASE_DIR) @@ -269,7 +265,7 @@ def check_source(self, config): return True else: print(f"Failed to connect to NAS {self.name} at {self.ip}:{self.port}") - + print("Failed to connect to all available NAS nodes.") return False @@ -333,7 +329,7 @@ def prepare_data(self, model_docker): touch ~/.ssh/known_hosts ssh-keyscan -p {port} {ip} >> ~/.ssh/known_hosts echo '#!/bin/bash' > /tmp/ssh.sh - echo 'sshpass -p {password} rsync --progress -avz -e \\\"ssh -p {port} \\\" \\\"\$@\\\"' >> /tmp/ssh.sh + echo 'sshpass -p {password} rsync --progress -avz -e \\"ssh -p {port} \\" \\"\\$@\\"' >> /tmp/ssh.sh cat /tmp/ssh.sh chmod u+x /tmp/ssh.sh timeout --preserve-status {timeout} /tmp/ssh.sh {username}@{ip}:{datapath}/* {datahome} && rm -f /tmp/ssh.sh @@ -371,7 +367,7 @@ def prepare_data(self, model_docker): touch ~/.ssh/known_hosts ssh-keyscan -p {port} {ip} >> ~/.ssh/known_hosts echo '#!/bin/bash' > /tmp/ssh.sh - echo 'sshpass -p {password} ssh -v \$*' >> /tmp/ssh.sh + echo 'sshpass -p {password} ssh -v \\$*' >> /tmp/ssh.sh chmod u+x /tmp/ssh.sh timeout --preserve-status {timeout} mount -t fuse sshfs#{username}@{ip}:{datapath} {datahome} -o ssh_command=/tmp/ssh.sh,port={port} && rm -f /tmp/ssh.sh """ @@ -507,7 +503,7 @@ def check_source(self, config): except Exception as e: print(f"Failed to connect to Minio endpoint ({self.minio_endpoint}): {e}") return False - + return True def get_mountpath(self): @@ -545,7 +541,7 @@ def prepare_data(self, model_docker): datahome=datahome, dataname=self.dataname, ) - + # Measure time taken to copy data from MinIO to local start = time.time() model_docker.sh(cmd, timeout=3600) # 60 min timeout @@ -553,13 +549,13 @@ def prepare_data(self, model_docker): self.duration = end - start print("Copy data from MinIO to local") print("Data Download Duration: {} seconds".format(self.duration)) - + # Get the size of the data of dataname in the path of datahome and store it in the config cmd = f"du -sh {datahome} | cut -f1" data_size = model_docker.sh(cmd) self.size = data_size print("Data Size: ", self.size) - + return True @@ -721,9 +717,11 @@ def find_dataprovider(self, dataname: str) -> typing.Optional[DataProvider]: self.selected_data_provider = { "dataname": dataname, "data_provider_type": data_provider_type, - "data_provider_config": self.data_provider_config[dataname][data_provider_type], + "data_provider_config": self.data_provider_config[dataname][ + data_provider_type + ], "duration": data_provider.duration, - "size": data_provider.size + "size": data_provider.size, } break diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index 7ed4ff36..57b26473 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -8,6 +8,7 @@ # built-in modules import os import typing + # user-defined modules from madengine.core.console import Console @@ -83,7 +84,7 @@ def __init__( if mounts is not None: for mount in mounts: command += "-v " + mount + ":" + mount + " " - + # add current working directory command += "-v " + cwd + ":/myworkspace/ " @@ -91,12 +92,13 @@ def __init__( if envVars is not None: for evar in envVars.keys(): command += "-e " + evar + "=" + envVars[evar] + " " - + command += "--workdir /myworkspace/ " command += "--name " + container_name + " " command += image + " " - # hack to keep docker open + # Use 'cat' command to keep the container running in interactive mode + # This allows subsequent exec commands while maintaining the container state command += "cat " self.console.sh(command) @@ -105,19 +107,14 @@ def __init__( "docker ps -aqf 'name=" + container_name + "' " ) - def sh( - self, - command: str, - timeout: int=60, - secret: bool=False - ) -> str: + def sh(self, command: str, timeout: int = 60, secret: bool = False) -> str: """Run shell command inside docker. - + Args: command (str): The shell command. timeout (int): The timeout in seconds. secret (bool): The flag to hide the command. - + Returns: str: The output of the shell command. """ diff --git a/src/madengine/core/errors.py b/src/madengine/core/errors.py new file mode 100644 index 00000000..c8a460a9 --- /dev/null +++ b/src/madengine/core/errors.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 +""" +Unified Error Handling System for MADEngine + +This module provides a centralized error handling system with structured +error types and consistent Rich console-based error reporting. +""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Optional, Any, Dict, List +from enum import Enum + +try: + from rich.console import Console + from rich.panel import Panel + from rich.text import Text + from rich.table import Table +except ImportError: + raise ImportError("Rich is required for error handling. Install with: pip install rich") + + +class ErrorCategory(Enum): + """Error category enumeration for classification.""" + + VALIDATION = "validation" + CONNECTION = "connection" + AUTHENTICATION = "authentication" + RUNTIME = "runtime" + BUILD = "build" + DISCOVERY = "discovery" + ORCHESTRATION = "orchestration" + RUNNER = "runner" + CONFIGURATION = "configuration" + TIMEOUT = "timeout" + + +@dataclass +class ErrorContext: + """Context information for errors.""" + + operation: str + phase: Optional[str] = None + component: Optional[str] = None + model_name: Optional[str] = None + node_id: Optional[str] = None + file_path: Optional[str] = None + additional_info: Optional[Dict[str, Any]] = None + + +class MADEngineError(Exception): + """Base exception for all MADEngine errors.""" + + def __init__( + self, + message: str, + category: ErrorCategory, + context: Optional[ErrorContext] = None, + cause: Optional[Exception] = None, + recoverable: bool = False, + suggestions: Optional[List[str]] = None + ): + super().__init__(message) + self.message = message + self.category = category + self.context = context or ErrorContext(operation="unknown") + self.cause = cause + self.recoverable = recoverable + self.suggestions = suggestions or [] + + +class ValidationError(MADEngineError): + """Validation and input errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.VALIDATION, + context, + recoverable=True, + **kwargs + ) + + +class ConnectionError(MADEngineError): + """Connection and network errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.CONNECTION, + context, + recoverable=True, + **kwargs + ) + + +class AuthenticationError(MADEngineError): + """Authentication and credential errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.AUTHENTICATION, + context, + recoverable=True, + **kwargs + ) + + +class RuntimeError(MADEngineError): + """Runtime execution errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.RUNTIME, + context, + recoverable=False, + **kwargs + ) + + +class BuildError(MADEngineError): + """Build and compilation errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.BUILD, + context, + recoverable=False, + **kwargs + ) + + +class DiscoveryError(MADEngineError): + """Model discovery errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.DISCOVERY, + context, + recoverable=True, + **kwargs + ) + + +class OrchestrationError(MADEngineError): + """Distributed orchestration errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.ORCHESTRATION, + context, + recoverable=False, + **kwargs + ) + + +class RunnerError(MADEngineError): + """Distributed runner errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.RUNNER, + context, + recoverable=True, + **kwargs + ) + + +class ConfigurationError(MADEngineError): + """Configuration and setup errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.CONFIGURATION, + context, + recoverable=True, + **kwargs + ) + + +class TimeoutError(MADEngineError): + """Timeout and duration errors.""" + + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): + super().__init__( + message, + ErrorCategory.TIMEOUT, + context, + recoverable=True, + **kwargs + ) + + +class ErrorHandler: + """Unified error handler with Rich console integration.""" + + def __init__(self, console: Optional[Console] = None, verbose: bool = False): + self.console = console or Console() + self.verbose = verbose + self.logger = logging.getLogger(__name__) + + def handle_error( + self, + error: Exception, + context: Optional[ErrorContext] = None, + show_traceback: Optional[bool] = None + ) -> None: + """Handle and display errors with rich formatting.""" + + show_tb = show_traceback if show_traceback is not None else self.verbose + + if isinstance(error, MADEngineError): + self._handle_madengine_error(error, show_tb) + else: + self._handle_generic_error(error, context, show_tb) + + def _handle_madengine_error(self, error: MADEngineError, show_traceback: bool) -> None: + """Handle MADEngine structured errors.""" + + # Determine error emoji and color + category_info = { + ErrorCategory.VALIDATION: ("⚠️", "yellow"), + ErrorCategory.CONNECTION: ("🔌", "blue"), + ErrorCategory.AUTHENTICATION: ("🔒", "red"), + ErrorCategory.RUNTIME: ("💥", "red"), + ErrorCategory.BUILD: ("🔨", "red"), + ErrorCategory.DISCOVERY: ("🔍", "yellow"), + ErrorCategory.ORCHESTRATION: ("⚡", "red"), + ErrorCategory.RUNNER: ("🚀", "red"), + ErrorCategory.CONFIGURATION: ("⚙️", "yellow"), + ErrorCategory.TIMEOUT: ("⏱️", "yellow"), + } + + emoji, color = category_info.get(error.category, ("❌", "red")) + + # Create error panel + title = f"{emoji} {error.category.value.title()} Error" + + # Build error content + content = Text() + content.append(f"{error.message}\n", style=f"bold {color}") + + # Add context information + if error.context: + content.append("\n📋 Context:\n", style="bold cyan") + if error.context.operation: + content.append(f" Operation: {error.context.operation}\n") + if error.context.phase: + content.append(f" Phase: {error.context.phase}\n") + if error.context.component: + content.append(f" Component: {error.context.component}\n") + if error.context.model_name: + content.append(f" Model: {error.context.model_name}\n") + if error.context.node_id: + content.append(f" Node: {error.context.node_id}\n") + if error.context.file_path: + content.append(f" File: {error.context.file_path}\n") + + # Add cause information + if error.cause: + content.append(f"\n🔗 Caused by: {str(error.cause)}\n", style="dim") + + # Add suggestions + if error.suggestions: + content.append("\n💡 Suggestions:\n", style="bold green") + for suggestion in error.suggestions: + content.append(f" • {suggestion}\n", style="green") + + # Add recovery information + if error.recoverable: + content.append("\n♻️ This error may be recoverable", style="bold blue") + + panel = Panel( + content, + title=title, + border_style=color, + expand=False + ) + + self.console.print(panel) + + # Show traceback if requested + if show_traceback and error.cause: + self.console.print("\n📚 [bold]Full Traceback:[/bold]") + self.console.print_exception() + + # Log to file + self.logger.error( + f"{error.category.value}: {error.message}", + extra={ + "context": error.context.__dict__ if error.context else {}, + "recoverable": error.recoverable, + "suggestions": error.suggestions + } + ) + + def _handle_generic_error( + self, + error: Exception, + context: Optional[ErrorContext], + show_traceback: bool + ) -> None: + """Handle generic Python exceptions.""" + + title = f"❌ {type(error).__name__}" + + content = Text() + content.append(f"{str(error)}\n", style="bold red") + + if context: + content.append("\n📋 Context:\n", style="bold cyan") + content.append(f" Operation: {context.operation}\n") + if context.phase: + content.append(f" Phase: {context.phase}\n") + if context.component: + content.append(f" Component: {context.component}\n") + + panel = Panel( + content, + title=title, + border_style="red", + expand=False + ) + + self.console.print(panel) + + if show_traceback: + self.console.print("\n📚 [bold]Full Traceback:[/bold]") + self.console.print_exception() + + # Log to file + self.logger.error(f"{type(error).__name__}: {str(error)}") + + +# Global error handler instance +_global_error_handler: Optional[ErrorHandler] = None + + +def set_error_handler(handler: ErrorHandler) -> None: + """Set the global error handler.""" + global _global_error_handler + _global_error_handler = handler + + +def get_error_handler() -> Optional[ErrorHandler]: + """Get the global error handler.""" + return _global_error_handler + + +def handle_error( + error: Exception, + context: Optional[ErrorContext] = None, + show_traceback: Optional[bool] = None +) -> None: + """Handle error using the global error handler.""" + if _global_error_handler: + _global_error_handler.handle_error(error, context, show_traceback) + else: + # Fallback to basic logging + logging.error(f"Error: {error}") + if show_traceback: + logging.exception("Exception details:") + + +def create_error_context( + operation: str, + phase: Optional[str] = None, + component: Optional[str] = None, + **kwargs +) -> ErrorContext: + """Convenience function to create error context.""" + return ErrorContext( + operation=operation, + phase=phase, + component=component, + **kwargs + ) \ No newline at end of file diff --git a/src/madengine/core/timeout.py b/src/madengine/core/timeout.py index 705a972a..0f72bd84 100644 --- a/src/madengine/core/timeout.py +++ b/src/madengine/core/timeout.py @@ -12,16 +12,14 @@ class Timeout: """Class to handle timeouts. - + Attributes: seconds (int): The timeout in seconds. """ - def __init__( - self, - seconds: int=15 - ) -> None: + + def __init__(self, seconds: int = 15) -> None: """Constructor of the Timeout class. - + Args: seconds (int): The timeout in seconds. """ @@ -29,14 +27,14 @@ def __init__( def handle_timeout(self, signum, frame) -> None: """Handle timeout. - + Args: signum: The signal number. frame: The frame. Returns: None - + Raises: TimeoutError: If the program times out. """ diff --git a/src/madengine/db/base_class.py b/src/madengine/db/base_class.py index e8ca31ac..e71fe72c 100644 --- a/src/madengine/db/base_class.py +++ b/src/madengine/db/base_class.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -""" Module for creating DB tables interfaces +"""Module for creating DB tables interfaces This module provides the base class for our own common functionalities among tables @@ -29,8 +29,6 @@ def obj_as_list_dict(cls, obj): for elem in obj: # extra elem at top of dict elem.__dict__.pop("_sa_instance_state", None) - # print(elem.__dict__) - # print(row.__table__.columns) dict_list.append(elem.__dict__) return dict_list diff --git a/src/madengine/db/database.py b/src/madengine/db/database.py index 1e384854..1ba0310f 100644 --- a/src/madengine/db/database.py +++ b/src/madengine/db/database.py @@ -8,6 +8,7 @@ # built-in modules import os from datetime import datetime, timezone + # third-party modules from sqlalchemy import Column, Integer, String, DateTime, TEXT, MetaData, Table from sqlalchemy.exc import OperationalError @@ -47,32 +48,35 @@ ) # Define the path to the SQL file -SQL_FILE_PATH = os.path.join(os.path.dirname(__file__), 'db_table_def.sql') +SQL_FILE_PATH = os.path.join(os.path.dirname(__file__), "db_table_def.sql") # Update TABLE_SCHEMA and TABLE_NAME variables TABLE_SCHEMA = ENV_VARS["db_name"] TABLE_NAME = None # get table name from SQL file -with open(SQL_FILE_PATH, 'r') as file: +with open(SQL_FILE_PATH, "r") as file: for line in file: - if 'CREATE TABLE' in line: - TABLE_NAME = line.split(' ')[2].split('(')[0] - TABLE_NAME = TABLE_NAME.replace('`', '') + if "CREATE TABLE" in line: + TABLE_NAME = line.split(" ")[2].split("(")[0] + TABLE_NAME = TABLE_NAME.replace("`", "") break if TABLE_NAME is None: raise ValueError("Table name not found in SQL file") + def read_sql_file(file_path: str) -> str: """Read the SQL file and return its content.""" - with open(file_path, 'r') as file: + with open(file_path, "r") as file: return file.read() + def parse_table_definition(sql_content: str) -> Table: """Parse the SQL content and return the table definition.""" metadata = MetaData() table = Table(TABLE_NAME, metadata, autoload_with=ENGINE, autoload_replace=True) return table + # Read and parse the SQL file sql_content = read_sql_file(SQL_FILE_PATH) db_table_definition = parse_table_definition(sql_content) @@ -80,9 +84,11 @@ def parse_table_definition(sql_content: str) -> Table: # Clear any existing mappers clear_mappers() + # Define the DB_TABLE class dynamically class DB_TABLE(BaseMixin, BASE): """Represents db job table""" + __tablename__ = db_table_definition.name __table__ = db_table_definition @@ -146,7 +152,9 @@ def show_db() -> None: result = ENGINE.execute( "SELECT * FROM {} \ WHERE {}.created_date= \ - (SELECT MAX(created_date) FROM {}) ;".format(DB_TABLE.__tablename__) + (SELECT MAX(created_date) FROM {}) ;".format( + DB_TABLE.__tablename__ + ) ) for row in result: print(row) @@ -222,7 +230,9 @@ def get_column_names() -> list: "SELECT `COLUMN_NAME` \ FROM `INFORMATION_SCHEMA`.`COLUMNS` \ WHERE `TABLE_SCHEMA`='{}' \ - AND `TABLE_NAME`='{}'".format(db_name, DB_TABLE.__tablename__) + AND `TABLE_NAME`='{}'".format( + db_name, DB_TABLE.__tablename__ + ) ) ret = [] for row in result: diff --git a/src/madengine/db/database_functions.py b/src/madengine/db/database_functions.py index 97561fc1..9ad4a49d 100644 --- a/src/madengine/db/database_functions.py +++ b/src/madengine/db/database_functions.py @@ -45,9 +45,7 @@ def get_matching_db_entries( """ print( "Looking for entries with {}, {} and {}".format( - recent_entry["model"], - recent_entry["gpu_architecture"], - filters + recent_entry["model"], recent_entry["gpu_architecture"], filters ) ) @@ -57,8 +55,7 @@ def get_matching_db_entries( WHERE model='{}' \ AND gpu_architecture='{}' \ ".format( - recent_entry["model"], - recent_entry["gpu_architecture"] + recent_entry["model"], recent_entry["gpu_architecture"] ) ) matching_entries = matching_entries.mappings().all() @@ -76,8 +73,7 @@ def get_matching_db_entries( print( "Found {} similar entries in database filtered down to {} entries".format( - len(matching_entries), - len(filtered_matching_entries) + len(matching_entries), len(filtered_matching_entries) ) ) return filtered_matching_entries diff --git a/src/madengine/db/logger.py b/src/madengine/db/logger.py index 8f450013..07731eea 100644 --- a/src/madengine/db/logger.py +++ b/src/madengine/db/logger.py @@ -4,6 +4,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import logging import os diff --git a/src/madengine/db/relative_perf.py b/src/madengine/db/relative_perf.py index 93d2569f..11d6b179 100644 --- a/src/madengine/db/relative_perf.py +++ b/src/madengine/db/relative_perf.py @@ -4,6 +4,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import argparse import ast @@ -112,12 +113,12 @@ def relative_perf( def relative_perf_all_configs(data: pd.DataFrame) -> pd.DataFrame: """Get the relative performance of all configurations. - + This function gets the relative performance of all configurations. - + Args: data (pd.DataFrame): The data. - + Returns: pd.DataFrame: The data. """ diff --git a/src/madengine/db/upload_csv_to_db.py b/src/madengine/db/upload_csv_to_db.py index d70d15b5..da63350d 100644 --- a/src/madengine/db/upload_csv_to_db.py +++ b/src/madengine/db/upload_csv_to_db.py @@ -1,10 +1,11 @@ -"""Script to upload csv files to the database, +"""Script to upload csv files to the database, and create or update tables in the database. This script uploads csv files to the database, and creates or updates tables in the database. Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import sys @@ -12,9 +13,11 @@ import pandas as pd import typing from datetime import datetime + # third-party modules from tqdm import tqdm from sqlalchemy.orm import sessionmaker + # MAD Engine modules from database import ENGINE, create_tables, DB_TABLE, LOGGER from utils import dataFrame_to_list, load_perf_csv, replace_nans_with_None @@ -42,21 +45,21 @@ def add_csv_to_db(data: pd.DataFrame) -> bool: data = replace_nans_with_None(data) # Add unique ID column if it doesn't exist - if 'id' not in data.columns: + if "id" not in data.columns: # Get the max ID from the existing table to ensure uniqueness try: max_id_query = s.query(DB_TABLE.id).order_by(DB_TABLE.id.desc()).first() start_id = 1 if max_id_query is None else max_id_query[0] + 1 - except: - LOGGER.warning('Failed to query max ID, starting from 1') + except Exception as e: + LOGGER.warning("Failed to query max ID, starting from 1: %s", str(e)) start_id = 1 # Add sequential unique IDs - data['id'] = range(start_id, start_id + len(data)) + data["id"] = range(start_id, start_id + len(data)) # Explicitly set created_date to current timestamp if not provided - if 'created_date' not in data.columns: - data['created_date'] = datetime.now() + if "created_date" not in data.columns: + data["created_date"] = datetime.now() LOGGER.info("Data:") LOGGER.info(data) @@ -68,26 +71,31 @@ def add_csv_to_db(data: pd.DataFrame) -> bool: for model_perf_info in tqdm(data_as_list): try: # Ensure created_date is set for each record if not present - if 'created_date' not in model_perf_info or model_perf_info['created_date'] is None: - model_perf_info['created_date'] = datetime.now() + if ( + "created_date" not in model_perf_info + or model_perf_info["created_date"] is None + ): + model_perf_info["created_date"] = datetime.now() record = DB_TABLE(**model_perf_info) s.add(record) success_count += 1 except Exception as e: - LOGGER.warning( - 'Failed to add record to table due to %s \n', str(e)) + LOGGER.warning("Failed to add record to table due to %s \n", str(e)) LOGGER.info(model_perf_info) s.rollback() # commit changes and close sesstion try: s.commit() - LOGGER.info('Successfully added %d out of %d records to the database', - success_count, total_records) + LOGGER.info( + "Successfully added %d out of %d records to the database", + success_count, + total_records, + ) success = success_count > 0 except Exception as e: - LOGGER.error('Failed to commit changes: %s', str(e)) + LOGGER.error("Failed to commit changes: %s", str(e)) s.rollback() success = False finally: @@ -99,12 +107,12 @@ def add_csv_to_db(data: pd.DataFrame) -> bool: def main() -> None: """Main script function to upload csv files to the database.""" # parse arg - parser = argparse.ArgumentParser(description='Upload perf.csv to database') + parser = argparse.ArgumentParser(description="Upload perf.csv to database") parser.add_argument("--csv-file-path", type=str) args = parser.parse_args() ret = create_tables() - LOGGER.info('DB creation successful: %s', ret) + LOGGER.info("DB creation successful: %s", ret) if args.csv_file_path is None: LOGGER.info("Only creating tables in the database") @@ -116,5 +124,6 @@ def main() -> None: data = relative_perf_all_configs(data) add_csv_to_db(data) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/src/madengine/db/utils.py b/src/madengine/db/utils.py index 13c6e879..a16acb56 100644 --- a/src/madengine/db/utils.py +++ b/src/madengine/db/utils.py @@ -29,7 +29,7 @@ def get_env_vars() -> dict: - SLURM_CPUS_ON_NODE - LOG_LEVEL - MODEL_DIR - + Returns: dict: Dictionary of DLM specific env_vars """ @@ -76,20 +76,19 @@ def get_env_vars() -> dict: env_vars["ssh_port"] = str(os.environ["TUNA_SSH_PORT"]) else: env_vars["ssh_port"] = "22" - + return env_vars def get_avg_perf( - entry_list: typing.List[dict], - n: int=5 - ) -> typing.Tuple[float, typing.List[float]]: + entry_list: typing.List[dict], n: int = 5 +) -> typing.Tuple[float, typing.List[float]]: """Get average performance from the last n entries - + Args: entry_list (list): List of entries n (int): Number of entries to consider - + Returns: tuple: Tuple of average performance and list of performances """ @@ -109,10 +108,10 @@ def get_avg_perf( def replace_nans_with_None(data: pd.DataFrame) -> pd.DataFrame: """Replace NaNs with None in the dataframe - + Args: data (pd.DataFrame): Dataframe to replace NaNs with None - + Returns: pd.DataFrame: Dataframe with NaNs replaced with None """ @@ -124,15 +123,24 @@ def replace_nans_with_None(data: pd.DataFrame) -> pd.DataFrame: def load_perf_csv(csv: str) -> pd.DataFrame: """Load performance csv file - + Args: csv (str): Path to the performance csv file - + Returns: pd.DataFrame: Dataframe of the performance csv file """ df = pd.read_csv(csv) - df = df.drop(columns=["dataname", "data_provider_type", "data_size", "data_download_duration", "build_number"], errors="ignore") + df = df.drop( + columns=[ + "dataname", + "data_provider_type", + "data_size", + "data_download_duration", + "build_number", + ], + errors="ignore", + ) df.rename(columns=lambda x: x.strip(), inplace=True) df = df.rename(columns=lambda x: x.strip()) df = df.where((pd.notnull(df)), None) @@ -147,10 +155,10 @@ def trim_strings(x): def dataFrame_to_list(df: pd.DataFrame) -> typing.List[dict]: """Convert dataframe to list of dictionaries - + Args: df (pd.DataFrame): Dataframe to convert - + Returns: list: List of dictionaries """ diff --git a/src/madengine/mad.py b/src/madengine/mad.py index 861571b7..87232561 100644 --- a/src/madengine/mad.py +++ b/src/madengine/mad.py @@ -1,13 +1,15 @@ -#!/usr/bin/env python -"""Mad Engine CLI tool. +#!/usr/bin/env python3 +"""MAD Engine CLI tool. This script provides a command-line interface to run models, generate reports, and tools for profiling and tracing. This tool is used to run LLMs and Deep Learning models locally. Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ -# built-in imports + import argparse +import logging + import sys # MAD Engine imports from madengine import __version__ @@ -19,9 +21,15 @@ from madengine.tools.update_perf_csv import UpdatePerfCsv from madengine.tools.csv_to_html import ConvertCsvToHtml from madengine.tools.csv_to_email import ConvertCsvToEmail -from madengine.core.constants import MODEL_DIR # pylint: disable=unused-import +from madengine.core.constants import MODEL_DIR # pylint: disable=unused-import from madengine.utils.gpu_validator import validate_gpu_installation, GPUInstallationError, detect_gpu_vendor, GPUVendor +# Setup logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- # Sub-command functions @@ -29,83 +37,84 @@ # Router of the command-line arguments to the corresponding functions def run_models(args: argparse.Namespace): """Run models on container. - + Args: args: The command-line arguments. """ - print(f"Running models on container") - run_models = RunModels(args=args) - return run_models.run() - + logger.info("Running models on container") + run_models_instance = RunModels(args=args) + return run_models_instance.run() + def discover_models(args: argparse.Namespace): """Discover the models. - + Args: args: The command-line arguments. """ - print(f"Discovering all models in the project") - discover_models = DiscoverModels(args=args) - return discover_models.run() - + logger.info("Discovering all models in the project") + discover_models_instance = DiscoverModels(args=args) + return discover_models_instance.run() + def update_perf_csv(args): """Update performance metrics of models perf.csv to database. - + Args: args: The command-line arguments. """ - print(f"Running update_perf_csv") - update_perf_csv = UpdatePerfCsv(args=args) - return update_perf_csv.run() + logger.info("Running update_perf_csv") + update_perf_csv_instance = UpdatePerfCsv(args=args) + return update_perf_csv_instance.run() def csv_to_html(args): """Convert CSV to HTML report of models. - + Args: args: The command-line arguments. """ - print(f"Running csv_to_html") + logger.info("Running csv_to_html") convert_csv_to_html = ConvertCsvToHtml(args=args) return convert_csv_to_html.run() def csv_to_email(args): """Convert CSV to Email of models. - + Args: args: The command-line arguments. """ - print(f"Convert CSV to Email of models") + logger.info("Convert CSV to Email of models") convert_csv_to_email = ConvertCsvToEmail(args=args) return convert_csv_to_email.run() def create_table(args): """Create table in DB. - + Args: args: The command-line arguments. - """ - print(f"Create table in DB") - create_table = CreateTable(args=args) - return create_table.run() + """ + logger.info("Create table in DB") + create_table_instance = CreateTable(args=args) + return create_table_instance.run() def update_table(args): """Update table in DB. - + Args: args: The command-line arguments. - """ - print(f"Update table in DB") - update_table = UpdateTable(args=args) - return update_table.run() + """ + logger.info("Update table in DB") + update_table_instance = UpdateTable(args=args) + return update_table_instance.run() + def upload_mongodb(args): """Upload to MongoDB. - + Args: args: The command-line arguments. """ @@ -193,83 +202,217 @@ def validate_gpu(args): # Main function # ----------------------------------------------------------------------------- def main(): - """Main function to parse the command-line arguments. - """ - parser = argparse.ArgumentParser(description="A Models automation and dashboarding command-line tool to run LLMs and Deep Learning models locally.") + """Main function to parse the command-line arguments.""" + parser = argparse.ArgumentParser( + description="A Models automation and dashboarding command-line tool to run LLMs and Deep Learning models locally." + ) + + parser.add_argument("-v", "--version", action="version", version=__version__) + + subparsers = parser.add_subparsers( + title="Commands", + description="Available commands for running models, generating reports, and toolings.", + dest="command", + ) - parser.add_argument('-v', '--version', action='version', version=__version__) - - subparsers = parser.add_subparsers(title="Commands", description="Available commands for running models, generating reports, and toolings.", dest="command") - # Run models command - parser_run = subparsers.add_parser('run', description="Run LLMs and Deep Learning models on container", help='Run models on container') - parser_run.add_argument('--tags', nargs='+', default=[], help="tags to run (can be multiple).") + parser_run = subparsers.add_parser( + "run", + description="Run LLMs and Deep Learning models on container", + help="Run models on container", + ) + parser_run.add_argument( + "--tags", nargs="+", default=[], help="tags to run (can be multiple)." + ) # Deprecated Tag - parser_run.add_argument('--ignore-deprecated-flag', action='store_true', help="Force run deprecated models even if marked deprecated.") - - parser_run.add_argument('--timeout', type=int, default=-1, help="time out for model run in seconds; Overrides per-model timeout if specified or default timeout of 7200 (2 hrs).\ - Timeout of 0 will never timeout.") - parser_run.add_argument('--live-output', action='store_true', help="prints output in real-time directly on STDOUT") - parser_run.add_argument('--clean-docker-cache', action='store_true', help="rebuild docker image without using cache") - parser_run.add_argument('--additional-context-file', default=None, help="additonal context, as json file, to filter behavior of workloads. Overrides detected contexts.") - parser_run.add_argument('--additional-context', default='{}', help="additional context, as string representation of python dict, to filter behavior of workloads. " + - " Overrides detected contexts and additional-context-file.") - parser_run.add_argument('--data-config-file-name', default="data.json", help="custom data configuration file.") - parser_run.add_argument('--tools-json-file-name', default="./scripts/common/tools.json", help="custom tools json configuration file.") - parser_run.add_argument('--generate-sys-env-details', default=True, help='generate system config env details by default') - parser_run.add_argument('--force-mirror-local', default=None, help="Path to force all relevant dataproviders to mirror data locally on.") - parser_run.add_argument('--keep-alive', action='store_true', help="keep Docker container alive after run; will keep model directory after run") - parser_run.add_argument('--keep-model-dir', action='store_true', help="keep model directory after run") - parser_run.add_argument('--skip-model-run', action='store_true', help="skips running the model; will not keep model directory after run unless specified through keep-alive or keep-model-dir") - parser_run.add_argument('--disable-skip-gpu-arch', action='store_true', help="disables skipping model based on gpu architecture") - parser_run.add_argument('-o', '--output', default='perf.csv', help='output file') + parser_run.add_argument( + "--ignore-deprecated-flag", + action="store_true", + help="Force run deprecated models even if marked deprecated.", + ) + + parser_run.add_argument( + "--timeout", + type=int, + default=-1, + help="time out for model run in seconds; Overrides per-model timeout if specified or default timeout of 7200 (2 hrs).\ + Timeout of 0 will never timeout.", + ) + parser_run.add_argument( + "--live-output", + action="store_true", + help="prints output in real-time directly on STDOUT", + ) + parser_run.add_argument( + "--clean-docker-cache", + action="store_true", + help="rebuild docker image without using cache", + ) + parser_run.add_argument( + "--additional-context-file", + default=None, + help="additonal context, as json file, to filter behavior of workloads. Overrides detected contexts.", + ) + parser_run.add_argument( + "--additional-context", + default="{}", + help="additional context, as string representation of python dict, to filter behavior of workloads. " + + " Overrides detected contexts and additional-context-file.", + ) + parser_run.add_argument( + "--data-config-file-name", + default="data.json", + help="custom data configuration file.", + ) + parser_run.add_argument( + "--tools-json-file-name", + default="./scripts/common/tools.json", + help="custom tools json configuration file.", + ) + parser_run.add_argument( + "--generate-sys-env-details", + default=True, + help="generate system config env details by default", + ) + parser_run.add_argument( + "--force-mirror-local", + default=None, + help="Path to force all relevant dataproviders to mirror data locally on.", + ) + parser_run.add_argument( + "--keep-alive", + action="store_true", + help="keep Docker container alive after run; will keep model directory after run", + ) + parser_run.add_argument( + "--keep-model-dir", action="store_true", help="keep model directory after run" + ) + parser_run.add_argument( + "--skip-model-run", + action="store_true", + help="skips running the model; will not keep model directory after run unless specified through keep-alive or keep-model-dir", + ) + parser_run.add_argument( + "--disable-skip-gpu-arch", + action="store_true", + help="disables skipping model based on gpu architecture", + ) + parser_run.add_argument("-o", "--output", default="perf.csv", help="output file") parser_run.set_defaults(func=run_models) # Discover models command - parser_discover = subparsers.add_parser('discover', description="Discover all models in the project", help='Discover the models.') - parser_discover.add_argument('--tags', nargs='+', default=[], help="tags to discover models (can be multiple).") + parser_discover = subparsers.add_parser( + "discover", + description="Discover all models in the project", + help="Discover the models.", + ) + parser_discover.add_argument( + "--tags", + nargs="+", + default=[], + help="tags to discover models (can be multiple).", + ) parser_discover.set_defaults(func=discover_models) # Report command - parser_report = subparsers.add_parser('report', description="", help='Generate report of models') - subparsers_report = parser_report.add_subparsers(title="Report Commands", description="Available commands for generating reports.", dest="report_command") + parser_report = subparsers.add_parser( + "report", description="", help="Generate report of models" + ) + subparsers_report = parser_report.add_subparsers( + title="Report Commands", + description="Available commands for generating reports.", + dest="report_command", + ) # Report subcommand update-perf - parser_report_update_perf= subparsers_report.add_parser('update-perf', description="Update performance metrics of models perf.csv to database.", help='Update perf.csv to database') - parser_report_update_perf.add_argument("--single_result", help="path to the single result json") - parser_report_update_perf.add_argument("--exception-result", help="path to the single result json") - parser_report_update_perf.add_argument("--failed-result", help="path to the single result json") - parser_report_update_perf.add_argument("--multiple-results", help="path to the results csv") + parser_report_update_perf = subparsers_report.add_parser( + "update-perf", + description="Update performance metrics of models perf.csv to database.", + help="Update perf.csv to database", + ) + parser_report_update_perf.add_argument( + "--single_result", help="path to the single result json" + ) + parser_report_update_perf.add_argument( + "--exception-result", help="path to the single result json" + ) + parser_report_update_perf.add_argument( + "--failed-result", help="path to the single result json" + ) + parser_report_update_perf.add_argument( + "--multiple-results", help="path to the results csv" + ) parser_report_update_perf.add_argument("--perf-csv", default="perf.csv") parser_report_update_perf.add_argument("--model-name") parser_report_update_perf.add_argument("--common-info") parser_report_update_perf.set_defaults(func=update_perf_csv) # Report subcommand to-html - parser_report_html= subparsers_report.add_parser('to-html', description="Convert CSV to HTML report of models.", help='Convert CSV to HTML report of models') + parser_report_html = subparsers_report.add_parser( + "to-html", + description="Convert CSV to HTML report of models.", + help="Convert CSV to HTML report of models", + ) parser_report_html.add_argument("--csv-file-path", type=str) parser_report_html.set_defaults(func=csv_to_html) # Report subcommand to-email - parser_report_email= subparsers_report.add_parser('to-email', description="Convert CSV to Email of models.", help='Convert CSV to Email of models') - parser_report_email.add_argument("--csv-file-path", type=str, default='.', help="Path to the directory containing the CSV files.") + parser_report_email = subparsers_report.add_parser( + "to-email", + description="Convert CSV to Email of models.", + help="Convert CSV to Email of models", + ) + parser_report_email.add_argument( + "--csv-file-path", + type=str, + default=".", + help="Path to the directory containing the CSV files.", + ) parser_report_email.set_defaults(func=csv_to_email) # Database command - parser_database = subparsers.add_parser('database', help='CRUD for database') - subparsers_database = parser_database.add_subparsers(title="Database Commands", description="Available commands for database, such as creating and updating table in DB.", dest="database_command") + parser_database = subparsers.add_parser("database", help="CRUD for database") + subparsers_database = parser_database.add_subparsers( + title="Database Commands", + description="Available commands for database, such as creating and updating table in DB.", + dest="database_command", + ) # Database subcommand creating tabe - parser_database_create_table = subparsers_database.add_parser('create-table', description="Create table in DB.", help='Create table in DB') - parser_database_create_table.add_argument('-v', '--verbose', action='store_true', help='verbose output') + parser_database_create_table = subparsers_database.add_parser( + "create-table", description="Create table in DB.", help="Create table in DB" + ) + parser_database_create_table.add_argument( + "-v", "--verbose", action="store_true", help="verbose output" + ) parser_database_create_table.set_defaults(func=create_table) # Database subcommand updating table - parser_database_update_table = subparsers_database.add_parser('update-table', description="Update table in DB.", help='Update table in DB') - parser_database_update_table.add_argument('--csv-file-path', type=str, help='Path to the csv file') - parser_database_update_table.add_argument('--model-json-path', type=str, help='Path to the model json file') + parser_database_update_table = subparsers_database.add_parser( + "update-table", description="Update table in DB.", help="Update table in DB" + ) + parser_database_update_table.add_argument( + "--csv-file-path", type=str, help="Path to the csv file" + ) + parser_database_update_table.add_argument( + "--model-json-path", type=str, help="Path to the model json file" + ) parser_database_update_table.set_defaults(func=update_table) # Database subcommand uploading to MongoDB - parser_database_upload_mongodb = subparsers_database.add_parser('upload-mongodb', description="Update table in DB.", help='Update table in DB') - parser_database_upload_mongodb.add_argument('--csv-file-path', type=str, default='perf_entry.csv', help='Path to the csv file') - parser_database_upload_mongodb.add_argument("--database-name", type=str, required=True, help="Name of the MongoDB database") - parser_database_upload_mongodb.add_argument("--collection-name", type=str, required=True, help="Name of the MongoDB collection") + parser_database_upload_mongodb = subparsers_database.add_parser( + "upload-mongodb", description="Update table in DB.", help="Update table in DB" + ) + parser_database_upload_mongodb.add_argument( + "--csv-file-path", + type=str, + default="perf_entry.csv", + help="Path to the csv file", + ) + parser_database_upload_mongodb.add_argument( + "--database-name", type=str, required=True, help="Name of the MongoDB database" + ) + parser_database_upload_mongodb.add_argument( + "--collection-name", + type=str, + required=True, + help="Name of the MongoDB collection", + ) parser_database_upload_mongodb.set_defaults(func=upload_mongodb) # Validate GPU command @@ -278,7 +421,7 @@ def main(): parser_validate.set_defaults(func=validate_gpu) args = parser.parse_args() - + if args.command: result = args.func(args) if args.command == 'validate' and result is not None: diff --git a/src/madengine/mad_cli.py b/src/madengine/mad_cli.py new file mode 100644 index 00000000..0ea5dcc6 --- /dev/null +++ b/src/madengine/mad_cli.py @@ -0,0 +1,2267 @@ +#!/usr/bin/env python3 +""" +Modern CLI for madengine Distributed Orchestrator + +Production-ready command-line interface built with Typer and Rich +for building and running models in distributed scenarios. +""" + +import ast +import json +import logging +import os +import sys +import glob +from pathlib import Path +from typing import Dict, List, Optional, Union + +try: + from typing import Annotated # Python 3.9+ +except ImportError: + from typing_extensions import Annotated # Python 3.8 + +import typer +from rich import print as rprint +from rich.console import Console +from rich.logging import RichHandler +from rich.panel import Panel +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.syntax import Syntax +from rich.table import Table +from rich.traceback import install + +# Install rich traceback handler for better error displays +install(show_locals=True) + +# Initialize Rich console +console = Console() + +# Import madengine components +from madengine.tools.distributed_orchestrator import DistributedOrchestrator +from madengine.tools.discover_models import DiscoverModels +from madengine.runners.orchestrator_generation import ( + generate_ansible_setup, + generate_k8s_setup, + generate_slurm_setup, +) +from madengine.runners.factory import RunnerFactory +from madengine.core.errors import ErrorHandler, set_error_handler + +# Initialize the main Typer app +app = typer.Typer( + name="madengine-cli", + help="🚀 madengine Distributed Orchestrator - Build and run AI models in distributed scenarios", + rich_markup_mode="rich", + add_completion=False, + no_args_is_help=True, +) + +# Sub-applications for organized commands +generate_app = typer.Typer( + name="generate", + help="📋 Generate orchestration files (Slurm, Ansible, Kubernetes)", + rich_markup_mode="rich", +) +app.add_typer(generate_app, name="generate") + +# Runner application for distributed execution +runner_app = typer.Typer( + name="runner", + help="🚀 Distributed runner for orchestrated execution across multiple nodes (SSH, Slurm, Ansible, Kubernetes)", + rich_markup_mode="rich", +) +app.add_typer(runner_app, name="runner") + +# Constants +DEFAULT_MANIFEST_FILE = "build_manifest.json" +DEFAULT_PERF_OUTPUT = "perf.csv" +DEFAULT_DATA_CONFIG = "data.json" +DEFAULT_TOOLS_CONFIG = "./scripts/common/tools.json" +DEFAULT_ANSIBLE_OUTPUT = "madengine_distributed.yml" +DEFAULT_TIMEOUT = -1 +DEFAULT_INVENTORY_FILE = "inventory.yml" +DEFAULT_RUNNER_REPORT = "runner_report.json" + + +# Exit codes +class ExitCode: + SUCCESS = 0 + FAILURE = 1 + BUILD_FAILURE = 2 + RUN_FAILURE = 3 + INVALID_ARGS = 4 + + +# Valid values for validation +VALID_GPU_VENDORS = ["AMD", "NVIDIA", "INTEL"] +VALID_GUEST_OS = ["UBUNTU", "CENTOS", "ROCKY"] + + +def setup_logging(verbose: bool = False) -> None: + """Setup Rich logging configuration and unified error handler.""" + log_level = logging.DEBUG if verbose else logging.INFO + + # Setup rich logging handler + rich_handler = RichHandler( + console=console, + show_time=True, + show_path=verbose, + markup=True, + rich_tracebacks=True, + ) + + logging.basicConfig( + level=log_level, + format="%(message)s", + datefmt="[%X]", + handlers=[rich_handler], + ) + + # Setup unified error handler + error_handler = ErrorHandler(console=console, verbose=verbose) + set_error_handler(error_handler) + + +def create_args_namespace(**kwargs) -> object: + """Create an argparse.Namespace-like object from keyword arguments.""" + + class Args: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + return Args(**kwargs) + + +def process_batch_manifest(batch_manifest_file: str) -> Dict[str, List[str]]: + """Process batch manifest file and extract model tags based on build_new flag. + + Args: + batch_manifest_file: Path to the input batch.json file + + Returns: + Dict containing 'build_tags' and 'all_tags' lists + + Raises: + FileNotFoundError: If the manifest file doesn't exist + ValueError: If the manifest format is invalid + """ + if not os.path.exists(batch_manifest_file): + raise FileNotFoundError(f"Batch manifest file not found: {batch_manifest_file}") + + try: + with open(batch_manifest_file, "r") as f: + manifest_data = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in batch manifest file: {e}") + + if not isinstance(manifest_data, list): + raise ValueError("Batch manifest must be a list of model objects") + + build_tags = [] # Models that need to be built (build_new=true) + all_tags = [] # All models in the manifest + + for i, model in enumerate(manifest_data): + if not isinstance(model, dict): + raise ValueError(f"Model entry {i} must be a dictionary") + + if "model_name" not in model: + raise ValueError(f"Model entry {i} missing required 'model_name' field") + + model_name = model["model_name"] + build_new = model.get("build_new", False) + + all_tags.append(model_name) + if build_new: + build_tags.append(model_name) + + return { + "build_tags": build_tags, + "all_tags": all_tags, + "manifest_data": manifest_data, + } + + +def validate_additional_context( + additional_context: str, + additional_context_file: Optional[str] = None, +) -> Dict[str, str]: + """ + Validate and parse additional context. + + Args: + additional_context: JSON string containing additional context + additional_context_file: Optional file containing additional context + + Returns: + Dict containing parsed additional context + + Raises: + typer.Exit: If validation fails + """ + context = {} + + # Load from file first + if additional_context_file: + try: + with open(additional_context_file, "r") as f: + context = json.load(f) + console.print( + f"✅ Loaded additional context from file: [cyan]{additional_context_file}[/cyan]" + ) + except (FileNotFoundError, json.JSONDecodeError) as e: + console.print(f"❌ Failed to load additional context file: [red]{e}[/red]") + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Parse string context (overrides file) + if additional_context and additional_context != "{}": + try: + string_context = json.loads(additional_context) + context.update(string_context) + console.print("✅ Loaded additional context from command line") + except json.JSONDecodeError as e: + console.print(f"❌ Invalid JSON in additional context: [red]{e}[/red]") + console.print("💡 Please provide valid JSON format") + raise typer.Exit(ExitCode.INVALID_ARGS) + + if not context: + console.print("❌ [red]No additional context provided[/red]") + console.print( + "💡 For build operations, you must provide additional context with gpu_vendor and guest_os" + ) + + # Show example usage + example_panel = Panel( + """[bold cyan]Example usage:[/bold cyan] +madengine-cli build --tags dummy --additional-context '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' + +[bold cyan]Or using a file:[/bold cyan] +madengine-cli build --tags dummy --additional-context-file context.json + +[bold cyan]Required fields:[/bold cyan] +• gpu_vendor: [green]AMD[/green], [green]NVIDIA[/green], [green]INTEL[/green] +• guest_os: [green]UBUNTU[/green], [green]CENTOS[/green], [green]ROCKY[/green]""", + title="Additional Context Help", + border_style="blue", + ) + console.print(example_panel) + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Validate required fields + required_fields = ["gpu_vendor", "guest_os"] + missing_fields = [field for field in required_fields if field not in context] + + if missing_fields: + console.print( + f"❌ Missing required fields: [red]{', '.join(missing_fields)}[/red]" + ) + console.print( + "💡 Both gpu_vendor and guest_os are required for build operations" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Validate gpu_vendor + gpu_vendor = context["gpu_vendor"].upper() + if gpu_vendor not in VALID_GPU_VENDORS: + console.print(f"❌ Invalid gpu_vendor: [red]{context['gpu_vendor']}[/red]") + console.print( + f"💡 Supported values: [green]{', '.join(VALID_GPU_VENDORS)}[/green]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Validate guest_os + guest_os = context["guest_os"].upper() + if guest_os not in VALID_GUEST_OS: + console.print(f"❌ Invalid guest_os: [red]{context['guest_os']}[/red]") + console.print( + f"💡 Supported values: [green]{', '.join(VALID_GUEST_OS)}[/green]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + console.print( + f"✅ Context validated: [green]{gpu_vendor}[/green] + [green]{guest_os}[/green]" + ) + return context + + +def save_summary_with_feedback( + summary: Dict, output_path: Optional[str], summary_type: str +) -> None: + """Save summary to file with user feedback.""" + if output_path: + try: + with open(output_path, "w") as f: + json.dump(summary, f, indent=2) + console.print( + f"💾 {summary_type} summary saved to: [cyan]{output_path}[/cyan]" + ) + except IOError as e: + console.print(f"❌ Failed to save {summary_type} summary: [red]{e}[/red]") + raise typer.Exit(ExitCode.FAILURE) + + +def _process_batch_manifest_entries( + batch_data: Dict, + manifest_output: str, + registry: Optional[str], + guest_os: Optional[str], + gpu_vendor: Optional[str], +) -> None: + """Process batch manifest and add entries for all models to build_manifest.json. + + Args: + batch_data: Processed batch manifest data + manifest_output: Path to the build manifest file + registry: Registry used for the build + guest_os: Guest OS for the build + gpu_vendor: GPU vendor for the build + """ + + # Load the existing build manifest + if os.path.exists(manifest_output): + with open(manifest_output, "r") as f: + build_manifest = json.load(f) + # Remove top-level registry if present + build_manifest.pop("registry", None) + else: + # Create a minimal manifest structure + build_manifest = { + "built_images": {}, + "built_models": {}, + "context": {}, + "credentials_required": [], + } + + # Process each model in the batch manifest + for model_entry in batch_data["manifest_data"]: + model_name = model_entry["model_name"] + build_new = model_entry.get("build_new", False) + model_registry_image = model_entry.get("registry_image", "") + model_registry = model_entry.get("registry", "") + + # If the model was not built (build_new=false), create an entry for it + if not build_new: + # Find the model configuration by discovering models with this tag + try: + # Create a temporary args object to discover the model + temp_args = create_args_namespace( + tags=[model_name], + registry=registry, + additional_context="{}", + additional_context_file=None, + clean_docker_cache=False, + manifest_output=manifest_output, + live_output=False, + output="perf.csv", + ignore_deprecated_flag=False, + data_config_file_name="data.json", + tools_json_file_name="scripts/common/tools.json", + generate_sys_env_details=True, + force_mirror_local=None, + disable_skip_gpu_arch=False, + verbose=False, + _separate_phases=True, + ) + + discover_models = DiscoverModels(args=temp_args) + models = discover_models.run() + + for model_info in models: + if model_info["name"] == model_name: + # Get dockerfile + dockerfile = model_info.get("dockerfile") + dockerfile_specified = ( + f"{dockerfile}.{guest_os.lower()}.{gpu_vendor.lower()}" + ) + dockerfile_matched_list = glob.glob(f"{dockerfile_specified}.*") + + # Check the matched list + if not dockerfile_matched_list: + console.print( + f"Warning: No Dockerfile found for {dockerfile_specified}" + ) + raise FileNotFoundError( + f"No Dockerfile found for {dockerfile_specified}" + ) + else: + dockerfile_matched = dockerfile_matched_list[0].split("/")[-1].replace(".Dockerfile", "") + + # Create a synthetic image name for this model + synthetic_image_name = f"ci-{model_name}_{dockerfile_matched}" + + # Add to built_images (even though it wasn't actually built) + build_manifest["built_images"][synthetic_image_name] = { + "docker_image": synthetic_image_name, + "dockerfile": model_info.get("dockerfile"), + "base_docker": "", # No base since not built + "docker_sha": "", # No SHA since not built + "build_duration": 0, + "build_command": f"# Skipped build for {model_name} (build_new=false)", + "log_file": f"{model_name}_{dockerfile_matched}.build.skipped.log", + "registry_image": ( + model_registry_image + or f"{model_registry or registry or 'dockerhub'}/{synthetic_image_name}" + if model_registry_image or model_registry or registry + else "" + ), + "registry": model_registry or registry or "dockerhub", + } + + # Add to built_models - include all discovered model fields + model_entry = model_info.copy() # Start with all fields from discovered model + + # Ensure minimum required fields have fallback values + model_entry.setdefault("name", model_name) + model_entry.setdefault("dockerfile", f"docker/{model_name}") + model_entry.setdefault("scripts", f"scripts/{model_name}/run.sh") + model_entry.setdefault("n_gpus", "1") + model_entry.setdefault("owner", "") + model_entry.setdefault("training_precision", "") + model_entry.setdefault("tags", []) + model_entry.setdefault("args", "") + model_entry.setdefault("cred", "") + + build_manifest["built_models"][synthetic_image_name] = model_entry + break + + except Exception as e: + console.print(f"Warning: Could not process model {model_name}: {e}") + # Create a minimal entry anyway + synthetic_image_name = f"ci-{model_name}_{dockerfile_matched}" + build_manifest["built_images"][synthetic_image_name] = { + "docker_image": synthetic_image_name, + "dockerfile": f"docker/{model_name}", + "base_docker": "", + "docker_sha": "", + "build_duration": 0, + "build_command": f"# Skipped build for {model_name} (build_new=false)", + "log_file": f"{model_name}_{dockerfile_matched}.build.skipped.log", + "registry_image": model_registry_image or "", + "registry": model_registry or registry or "dockerhub", + } + build_manifest["built_models"][synthetic_image_name] = { + "name": model_name, + "dockerfile": f"docker/{model_name}", + "scripts": f"scripts/{model_name}/run.sh", + "n_gpus": "1", + "owner": "", + "training_precision": "", + "tags": [], + "args": "", + } + + # Save the updated manifest + with open(manifest_output, "w") as f: + json.dump(build_manifest, f, indent=2) + + console.print( + f"✅ Added entries for all models from batch manifest to {manifest_output}" + ) + + +def display_results_table(summary: Dict, title: str, show_gpu_arch: bool = False) -> None: + """Display results in a formatted table with each model as a separate row.""" + table = Table(title=title, show_header=True, header_style="bold magenta") + table.add_column("Index", justify="right", style="dim") + table.add_column("Status", style="bold") + table.add_column("Model", style="cyan") + + # Add GPU Architecture column if multi-arch build was used + if show_gpu_arch: + table.add_column("GPU Architecture", style="yellow") + + successful = summary.get("successful_builds", summary.get("successful_runs", [])) + failed = summary.get("failed_builds", summary.get("failed_runs", [])) + + # Helper function to extract model name from build result + def extract_model_name(item): + if isinstance(item, dict): + # Prioritize direct model name field if available + if "model" in item: + return item["model"] + elif "name" in item: + return item["name"] + # Fallback to extracting from docker_image for backward compatibility + elif "docker_image" in item: + # Extract model name from docker image name + # e.g., "ci-dummy_dummy.ubuntu.amd" -> "dummy" + # e.g., "ci-dummy_dummy.ubuntu.amd_gfx908" -> "dummy" + docker_image = item["docker_image"] + if docker_image.startswith("ci-"): + # Remove ci- prefix and extract model name + parts = docker_image[3:].split("_") + if len(parts) >= 2: + model_name = parts[0] # First part is the model name + else: + model_name = parts[0] if parts else docker_image + else: + model_name = docker_image + return model_name + return str(item)[:20] # Fallback + + # Helper function to extract GPU architecture + def extract_gpu_arch(item): + if isinstance(item, dict) and "gpu_architecture" in item: + return item["gpu_architecture"] + return "N/A" + + # Add successful builds/runs + row_index = 1 + for item in successful: + model_name = extract_model_name(item) + if show_gpu_arch: + gpu_arch = extract_gpu_arch(item) + table.add_row(str(row_index), "✅ Success", model_name, gpu_arch) + else: + table.add_row(str(row_index), "✅ Success", model_name) + row_index += 1 + + # Add failed builds/runs + for item in failed: + if isinstance(item, dict): + model_name = item.get("model", "Unknown") + if show_gpu_arch: + gpu_arch = item.get("architecture", "N/A") + table.add_row(str(row_index), "❌ Failed", model_name, gpu_arch) + else: + table.add_row(str(row_index), "❌ Failed", model_name) + else: + if show_gpu_arch: + table.add_row(str(row_index), "❌ Failed", str(item), "N/A") + else: + table.add_row(str(row_index), "❌ Failed", str(item)) + row_index += 1 + + # Show empty state if no results + if not successful and not failed: + if show_gpu_arch: + table.add_row("1", "ℹ️ No items", "", "") + else: + table.add_row("1", "ℹ️ No items", "") + + console.print(table) + + +@app.command() +def build( + tags: Annotated[ + List[str], + typer.Option("--tags", "-t", help="Model tags to build (can specify multiple)"), + ] = [], + target_archs: Annotated[ + List[str], + typer.Option( + "--target-archs", + "-a", + help="Target GPU architectures to build for (e.g., gfx908,gfx90a,gfx942). If not specified, builds single image with MAD_SYSTEM_GPU_ARCHITECTURE from additional_context or detected GPU architecture." + ), + ] = [], + registry: Annotated[ + Optional[str], + typer.Option("--registry", "-r", help="Docker registry to push images to"), + ] = None, + batch_manifest: Annotated[ + Optional[str], + typer.Option( + "--batch-manifest", help="Input batch.json file for batch build mode" + ), + ] = None, + additional_context: Annotated[ + str, + typer.Option( + "--additional-context", "-c", help="Additional context as JSON string" + ), + ] = "{}", + additional_context_file: Annotated[ + Optional[str], + typer.Option( + "--additional-context-file", + "-f", + help="File containing additional context JSON", + ), + ] = None, + clean_docker_cache: Annotated[ + bool, + typer.Option("--clean-docker-cache", help="Rebuild images without using cache"), + ] = False, + manifest_output: Annotated[ + str, + typer.Option("--manifest-output", "-m", help="Output file for build manifest"), + ] = DEFAULT_MANIFEST_FILE, + summary_output: Annotated[ + Optional[str], + typer.Option( + "--summary-output", "-s", help="Output file for build summary JSON" + ), + ] = None, + live_output: Annotated[ + bool, typer.Option("--live-output", "-l", help="Print output in real-time") + ] = False, + output: Annotated[ + str, typer.Option("--output", "-o", help="Performance output file") + ] = DEFAULT_PERF_OUTPUT, + ignore_deprecated_flag: Annotated[ + bool, typer.Option("--ignore-deprecated", help="Force run deprecated models") + ] = False, + data_config_file_name: Annotated[ + str, typer.Option("--data-config", help="Custom data configuration file") + ] = DEFAULT_DATA_CONFIG, + tools_json_file_name: Annotated[ + str, typer.Option("--tools-config", help="Custom tools JSON configuration") + ] = DEFAULT_TOOLS_CONFIG, + generate_sys_env_details: Annotated[ + bool, + typer.Option("--sys-env-details", help="Generate system config env details"), + ] = True, + force_mirror_local: Annotated[ + Optional[str], + typer.Option("--force-mirror-local", help="Path to force local data mirroring"), + ] = None, + disable_skip_gpu_arch: Annotated[ + bool, + typer.Option( + "--disable-skip-gpu-arch", + help="Disable skipping models based on GPU architecture", + ), + ] = False, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 🔨 Build Docker images for models in distributed scenarios. + + This command builds Docker images for the specified model tags and optionally + pushes them to a registry. Additional context with gpu_vendor and guest_os + is required for build-only operations. + """ + setup_logging(verbose) + + # Validate mutually exclusive options + if batch_manifest and tags: + console.print( + "❌ [bold red]Error: Cannot specify both --batch-manifest and --tags options[/bold red]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Process batch manifest if provided + batch_data = None + effective_tags = tags + batch_build_metadata = None + + # There are 2 scenarios for batch builds and single builds + # - Batch builds: Use the batch manifest to determine which models to build + # - Single builds: Use the tags directly + if batch_manifest: + # Process the batch manifest + if verbose: + console.print(f"[DEBUG] Processing batch manifest: {batch_manifest}") + try: + batch_data = process_batch_manifest(batch_manifest) + if verbose: + console.print(f"[DEBUG] batch_data: {batch_data}") + + effective_tags = batch_data["build_tags"] + # Build a mapping of model_name -> registry_image/registry for build_new models + batch_build_metadata = {} + for model in batch_data["manifest_data"]: + if model.get("build_new", False): + batch_build_metadata[model["model_name"]] = { + "registry_image": model.get("registry_image"), + "registry": model.get("registry"), + } + if verbose: + console.print(f"[DEBUG] batch_build_metadata: {batch_build_metadata}") + + console.print( + Panel( + f"� [bold cyan]Batch Build Mode[/bold cyan]\n" + f"Input manifest: [yellow]{batch_manifest}[/yellow]\n" + f"Total models: [yellow]{len(batch_data['all_tags'])}[/yellow]\n" + f"Models to build: [yellow]{len(batch_data['build_tags'])}[/yellow] ({', '.join(batch_data['build_tags']) if batch_data['build_tags'] else 'none'})\n" + f"Registry: [yellow]{registry or 'Local only'}[/yellow]", + title="Batch Build Configuration", + border_style="blue", + ) + ) + except (FileNotFoundError, ValueError) as e: + console.print( + f"❌ [bold red]Error processing batch manifest: {e}[/bold red]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + else: + console.print( + Panel( + f"�🔨 [bold cyan]Building Models[/bold cyan]\n" + f"Tags: [yellow]{', '.join(tags) if tags else 'All models'}[/yellow]\n" + f"Registry: [yellow]{registry or 'Local only'}[/yellow]", + title="Build Configuration", + border_style="blue", + ) + ) + + try: + # Validate additional context + validate_additional_context(additional_context, additional_context_file) + + # Create arguments object + args = create_args_namespace( + tags=effective_tags, + target_archs=target_archs, + registry=registry, + additional_context=additional_context, + additional_context_file=additional_context_file, + clean_docker_cache=clean_docker_cache, + manifest_output=manifest_output, + live_output=live_output, + output=output, + ignore_deprecated_flag=ignore_deprecated_flag, + data_config_file_name=data_config_file_name, + tools_json_file_name=tools_json_file_name, + generate_sys_env_details=generate_sys_env_details, + force_mirror_local=force_mirror_local, + disable_skip_gpu_arch=disable_skip_gpu_arch, + verbose=verbose, + _separate_phases=True, + batch_build_metadata=batch_build_metadata if batch_build_metadata else None, + ) + + # Initialize orchestrator in build-only mode + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Initializing build orchestrator...", total=None) + orchestrator = DistributedOrchestrator(args, build_only_mode=True) + progress.update(task, description="Building models...") + + # Prepare build phase arguments + build_phase_kwargs = dict( + registry=registry, + clean_cache=clean_docker_cache, + manifest_output=manifest_output, + ) + # Pass batch_build_metadata to build_phase if present + if batch_build_metadata: + build_phase_kwargs["batch_build_metadata"] = batch_build_metadata + + build_summary = orchestrator.build_phase(**build_phase_kwargs) + progress.update(task, description="Build completed!") + + # Handle batch manifest post-processing + if batch_data: + with console.status("Processing batch manifest..."): + additional_context = getattr(args, "additional_context", None) + if isinstance(additional_context, str): + additional_context = json.loads(additional_context) + guest_os = ( + additional_context.get("guest_os") if additional_context else None + ) + gpu_vendor = ( + additional_context.get("gpu_vendor") if additional_context else None + ) + _process_batch_manifest_entries( + batch_data, manifest_output, registry, guest_os, gpu_vendor + ) + + # Display results + # Check if target_archs was used to show GPU architecture column + show_gpu_arch = bool(target_archs) + display_results_table(build_summary, "Build Results", show_gpu_arch) + + # Save summary + save_summary_with_feedback(build_summary, summary_output, "Build") + + # Check results and exit + failed_builds = len(build_summary.get("failed_builds", [])) + if failed_builds == 0: + console.print( + "🎉 [bold green]All builds completed successfully![/bold green]" + ) + raise typer.Exit(ExitCode.SUCCESS) + else: + console.print( + f"💥 [bold red]Build failed for {failed_builds} models[/bold red]" + ) + raise typer.Exit(ExitCode.BUILD_FAILURE) + + except typer.Exit: + raise + except Exception as e: + from madengine.core.errors import handle_error, create_error_context + + context = create_error_context( + operation="build", + phase="build", + component="build_command" + ) + handle_error(e, context=context) + raise typer.Exit(ExitCode.FAILURE) + + +@app.command() +def run( + tags: Annotated[ + List[str], + typer.Option("--tags", "-t", help="Model tags to run (can specify multiple)"), + ] = [], + manifest_file: Annotated[ + str, typer.Option("--manifest-file", "-m", help="Build manifest file path") + ] = "", + registry: Annotated[ + Optional[str], typer.Option("--registry", "-r", help="Docker registry URL") + ] = None, + timeout: Annotated[ + int, + typer.Option( + "--timeout", + help="Timeout for model run in seconds (-1 for default, 0 for no timeout)", + ), + ] = DEFAULT_TIMEOUT, + additional_context: Annotated[ + str, + typer.Option( + "--additional-context", "-c", help="Additional context as JSON string" + ), + ] = "{}", + additional_context_file: Annotated[ + Optional[str], + typer.Option( + "--additional-context-file", + "-f", + help="File containing additional context JSON", + ), + ] = None, + keep_alive: Annotated[ + bool, + typer.Option("--keep-alive", help="Keep Docker containers alive after run"), + ] = False, + keep_model_dir: Annotated[ + bool, typer.Option("--keep-model-dir", help="Keep model directory after run") + ] = False, + skip_model_run: Annotated[ + bool, typer.Option("--skip-model-run", help="Skip running the model") + ] = False, + clean_docker_cache: Annotated[ + bool, + typer.Option( + "--clean-docker-cache", + help="Rebuild images without using cache (for full workflow)", + ), + ] = False, + manifest_output: Annotated[ + str, + typer.Option( + "--manifest-output", help="Output file for build manifest (full workflow)" + ), + ] = DEFAULT_MANIFEST_FILE, + summary_output: Annotated[ + Optional[str], + typer.Option("--summary-output", "-s", help="Output file for summary JSON"), + ] = None, + live_output: Annotated[ + bool, typer.Option("--live-output", "-l", help="Print output in real-time") + ] = False, + output: Annotated[ + str, typer.Option("--output", "-o", help="Performance output file") + ] = DEFAULT_PERF_OUTPUT, + ignore_deprecated_flag: Annotated[ + bool, typer.Option("--ignore-deprecated", help="Force run deprecated models") + ] = False, + data_config_file_name: Annotated[ + str, typer.Option("--data-config", help="Custom data configuration file") + ] = DEFAULT_DATA_CONFIG, + tools_json_file_name: Annotated[ + str, typer.Option("--tools-config", help="Custom tools JSON configuration") + ] = DEFAULT_TOOLS_CONFIG, + generate_sys_env_details: Annotated[ + bool, + typer.Option("--sys-env-details", help="Generate system config env details"), + ] = True, + force_mirror_local: Annotated[ + Optional[str], + typer.Option("--force-mirror-local", help="Path to force local data mirroring"), + ] = None, + disable_skip_gpu_arch: Annotated[ + bool, + typer.Option( + "--disable-skip-gpu-arch", + help="Disable skipping models based on GPU architecture", + ), + ] = False, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 🚀 Run model containers in distributed scenarios. + + If manifest-file is provided and exists, runs execution phase only. + Otherwise runs the complete workflow (build + run). + """ + setup_logging(verbose) + + # Input validation + if timeout < -1: + console.print( + "❌ [red]Timeout must be -1 (default) or a positive integer[/red]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + try: + # Check if we're doing execution-only or full workflow + manifest_exists = manifest_file and os.path.exists(manifest_file) + + if manifest_exists: + console.print( + Panel( + f"🚀 [bold cyan]Running Models (Execution Only)[/bold cyan]\n" + f"Manifest: [yellow]{manifest_file}[/yellow]\n" + f"Registry: [yellow]{registry or 'Auto-detected'}[/yellow]\n" + f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s", + title="Execution Configuration", + border_style="green", + ) + ) + + # Create arguments object for execution only + args = create_args_namespace( + tags=tags, + manifest_file=manifest_file, + registry=registry, + timeout=timeout, + keep_alive=keep_alive, + keep_model_dir=keep_model_dir, + skip_model_run=skip_model_run, + live_output=live_output, + output=output, + ignore_deprecated_flag=ignore_deprecated_flag, + data_config_file_name=data_config_file_name, + tools_json_file_name=tools_json_file_name, + generate_sys_env_details=generate_sys_env_details, + force_mirror_local=force_mirror_local, + disable_skip_gpu_arch=disable_skip_gpu_arch, + verbose=verbose, + _separate_phases=True, + ) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + "Initializing execution orchestrator...", total=None + ) + orchestrator = DistributedOrchestrator(args) + progress.update(task, description="Running models...") + + execution_summary = orchestrator.run_phase( + manifest_file=manifest_file, + registry=registry, + timeout=timeout, + keep_alive=keep_alive, + ) + progress.update(task, description="Execution completed!") + + # Display results + display_results_table(execution_summary, "Execution Results") + save_summary_with_feedback(execution_summary, summary_output, "Execution") + + failed_runs = len(execution_summary.get("failed_runs", [])) + if failed_runs == 0: + console.print( + "🎉 [bold green]All model executions completed successfully![/bold green]" + ) + raise typer.Exit(ExitCode.SUCCESS) + else: + console.print( + f"💥 [bold red]Execution failed for {failed_runs} models[/bold red]" + ) + raise typer.Exit(ExitCode.RUN_FAILURE) + + else: + # Check if MAD_CONTAINER_IMAGE is provided - this enables local image mode + additional_context_dict = {} + try: + if additional_context and additional_context != "{}": + additional_context_dict = json.loads(additional_context) + except json.JSONDecodeError: + try: + # Try parsing as Python dict literal + additional_context_dict = ast.literal_eval(additional_context) + except (ValueError, SyntaxError): + console.print( + f"❌ [red]Invalid additional_context format: {additional_context}[/red]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Load additional context from file if provided + if additional_context_file and os.path.exists(additional_context_file): + try: + with open(additional_context_file, 'r') as f: + file_context = json.load(f) + additional_context_dict.update(file_context) + except json.JSONDecodeError: + console.print( + f"❌ [red]Invalid JSON format in {additional_context_file}[/red]" + ) + raise typer.Exit(ExitCode.INVALID_ARGS) + + # Check for MAD_CONTAINER_IMAGE in additional context + mad_container_image = additional_context_dict.get("MAD_CONTAINER_IMAGE") + + if mad_container_image: + # Local image mode - skip build phase and generate manifest + console.print( + Panel( + f"🏠📦 [bold cyan]Local Image Mode (Skip Build + Run)[/bold cyan]\n" + f"Container Image: [yellow]{mad_container_image}[/yellow]\n" + f"Tags: [yellow]{', '.join(tags) if tags else 'All models'}[/yellow]\n" + f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s\n" + f"[dim]Note: Build phase will be skipped, using local image[/dim]", + title="Local Image Configuration", + border_style="blue", + ) + ) + + # Create arguments object for local image mode + args = create_args_namespace( + tags=tags, + registry=registry, + timeout=timeout, + additional_context=additional_context, + additional_context_file=additional_context_file, + keep_alive=keep_alive, + keep_model_dir=keep_model_dir, + skip_model_run=skip_model_run, + clean_docker_cache=clean_docker_cache, + manifest_output=manifest_output, + live_output=live_output, + output=output, + ignore_deprecated_flag=ignore_deprecated_flag, + data_config_file_name=data_config_file_name, + tools_json_file_name=tools_json_file_name, + generate_sys_env_details=generate_sys_env_details, + force_mirror_local=force_mirror_local, + disable_skip_gpu_arch=disable_skip_gpu_arch, + verbose=verbose, + _separate_phases=True, + ) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + "Initializing local image orchestrator...", total=None + ) + orchestrator = DistributedOrchestrator(args) + + # Generate manifest for local image (skip build phase) + progress.update(task, description="Generating manifest for local image...") + build_summary = orchestrator.generate_local_image_manifest( + container_image=mad_container_image, + manifest_output=manifest_output, + ) + + # Run phase with local image + progress.update(task, description="Running models with local image...") + execution_summary = orchestrator.run_phase( + manifest_file=manifest_output, + registry=registry, + timeout=timeout, + keep_alive=keep_alive, + ) + progress.update(task, description="Local image workflow completed!") + + # Combine summaries for local image mode + workflow_summary = { + "build_phase": build_summary, + "run_phase": execution_summary, + "local_image_mode": True, + "container_image": mad_container_image, + "overall_success": len(execution_summary.get("failed_runs", [])) == 0, + } + + # Display results + display_results_table(execution_summary, "Local Image Execution Results") + save_summary_with_feedback(workflow_summary, summary_output, "Local Image Workflow") + + if workflow_summary["overall_success"]: + console.print( + "🎉 [bold green]Local image workflow finished successfully![/bold green]" + ) + raise typer.Exit(ExitCode.SUCCESS) + else: + failed_runs = len(execution_summary.get("failed_runs", [])) + console.print( + f"💥 [bold red]Local image workflow completed but {failed_runs} model executions failed[/bold red]" + ) + raise typer.Exit(ExitCode.RUN_FAILURE) + + else: + # Full workflow + if manifest_file: + console.print( + f"⚠️ Manifest file [yellow]{manifest_file}[/yellow] not found, running complete workflow" + ) + + console.print( + Panel( + f"🔨🚀 [bold cyan]Complete Workflow (Build + Run)[/bold cyan]\n" + f"Tags: [yellow]{', '.join(tags) if tags else 'All models'}[/yellow]\n" + f"Registry: [yellow]{registry or 'Local only'}[/yellow]\n" + f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s", + title="Workflow Configuration", + border_style="magenta", + ) + ) + + # Create arguments object for full workflow + args = create_args_namespace( + tags=tags, + registry=registry, + timeout=timeout, + additional_context=additional_context, + additional_context_file=additional_context_file, + keep_alive=keep_alive, + keep_model_dir=keep_model_dir, + skip_model_run=skip_model_run, + clean_docker_cache=clean_docker_cache, + manifest_output=manifest_output, + live_output=live_output, + output=output, + ignore_deprecated_flag=ignore_deprecated_flag, + data_config_file_name=data_config_file_name, + tools_json_file_name=tools_json_file_name, + generate_sys_env_details=generate_sys_env_details, + force_mirror_local=force_mirror_local, + disable_skip_gpu_arch=disable_skip_gpu_arch, + verbose=verbose, + _separate_phases=True, + ) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + # Build phase + task = progress.add_task( + "Initializing workflow orchestrator...", total=None + ) + orchestrator = DistributedOrchestrator(args) + + progress.update(task, description="Building models...") + build_summary = orchestrator.build_phase( + registry=registry, + clean_cache=clean_docker_cache, + manifest_output=manifest_output, + ) + + failed_builds = len(build_summary.get("failed_builds", [])) + if failed_builds > 0: + progress.update(task, description="Build failed!") + console.print( + f"💥 [bold red]Build failed for {failed_builds} models, aborting workflow[/bold red]" + ) + display_results_table(build_summary, "Build Results") + raise typer.Exit(ExitCode.BUILD_FAILURE) + + # Run phase + progress.update(task, description="Running models...") + execution_summary = orchestrator.run_phase( + manifest_file=manifest_output, + registry=registry, + timeout=timeout, + keep_alive=keep_alive, + ) + progress.update(task, description="Workflow completed!") + + # Combine summaries + workflow_summary = { + "build_phase": build_summary, + "run_phase": execution_summary, + "overall_success": ( + len(build_summary.get("failed_builds", [])) == 0 + and len(execution_summary.get("failed_runs", [])) == 0 + ), + } + + # Display results + display_results_table(build_summary, "Build Results") + display_results_table(execution_summary, "Execution Results") + save_summary_with_feedback(workflow_summary, summary_output, "Workflow") + + if workflow_summary["overall_success"]: + console.print( + "🎉 [bold green]Complete workflow finished successfully![/bold green]" + ) + raise typer.Exit(ExitCode.SUCCESS) + else: + failed_runs = len(execution_summary.get("failed_runs", [])) + if failed_runs > 0: + console.print( + f"💥 [bold red]Workflow completed but {failed_runs} model executions failed[/bold red]" + ) + raise typer.Exit(ExitCode.RUN_FAILURE) + else: + console.print( + "💥 [bold red]Workflow failed for unknown reasons[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + except typer.Exit: + raise + except Exception as e: + console.print(f"💥 [bold red]Run process failed: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@app.command() +def discover( + tags: Annotated[ + List[str], + typer.Option("--tags", "-t", help="Model tags to discover (can specify multiple)"), + ] = [], + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 🔍 Discover all models in the project. + + This command discovers all available models in the project based on the + specified tags. If no tags are provided, all models will be discovered. + """ + setup_logging(verbose) + + console.print( + Panel( + f"🔍 [bold cyan]Discovering Models[/bold cyan]\n" + f"Tags: [yellow]{tags if tags else 'All models'}[/yellow]", + title="Model Discovery", + border_style="blue", + ) + ) + + try: + # Create args namespace similar to mad.py + args = create_args_namespace(tags=tags) + + # Use DiscoverModels class + # Note: DiscoverModels prints output directly and returns None + discover_models_instance = DiscoverModels(args=args) + result = discover_models_instance.run() + + console.print("✅ [bold green]Model discovery completed successfully[/bold green]") + + except Exception as e: + console.print(f"💥 [bold red]Model discovery failed: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@generate_app.command("ansible") +def generate_ansible( + manifest_file: Annotated[ + str, typer.Option("--manifest-file", "-m", help="Build manifest file") + ] = DEFAULT_MANIFEST_FILE, + environment: Annotated[ + str, typer.Option("--environment", "-e", help="Environment configuration") + ] = "default", + output: Annotated[ + str, typer.Option("--output", "-o", help="Output Ansible playbook file") + ] = DEFAULT_ANSIBLE_OUTPUT, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 📋 Generate Ansible playbook for distributed execution. + + Uses the enhanced build manifest as the primary configuration source + with environment-specific values for customization. + """ + setup_logging(verbose) + + console.print( + Panel( + f"📋 [bold cyan]Generating Ansible Playbook[/bold cyan]\n" + f"Manifest: [yellow]{manifest_file}[/yellow]\n" + f"Environment: [yellow]{environment}[/yellow]\n" + f"Output: [yellow]{output}[/yellow]", + title="Ansible Generation", + border_style="blue", + ) + ) + + try: + # Validate input files + if not os.path.exists(manifest_file): + console.print( + f"❌ [bold red]Manifest file not found: {manifest_file}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Generating Ansible playbook...", total=None) + + # Use the new template system + result = generate_ansible_setup( + manifest_file=manifest_file, + environment=environment, + output_dir=str(Path(output).parent), + ) + + progress.update(task, description="Ansible playbook generated!") + + console.print( + f"✅ [bold green]Ansible setup generated successfully:[/bold green]" + ) + for file_type, file_path in result.items(): + console.print(f" 📄 {file_type}: [cyan]{file_path}[/cyan]") + + except Exception as e: + console.print( + f"💥 [bold red]Failed to generate Ansible playbook: {e}[/bold red]" + ) + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@generate_app.command("k8s") +def generate_k8s( + manifest_file: Annotated[ + str, typer.Option("--manifest-file", "-m", help="Build manifest file") + ] = DEFAULT_MANIFEST_FILE, + environment: Annotated[ + str, typer.Option("--environment", "-e", help="Environment configuration") + ] = "default", + output_dir: Annotated[ + str, typer.Option("--output-dir", "-o", help="Output directory for manifests") + ] = "k8s-setup", + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + ☸️ Generate Kubernetes manifests for distributed execution. + + Uses the enhanced build manifest as the primary configuration source + with environment-specific values for customization. + """ + setup_logging(verbose) + + console.print( + Panel( + f"☸️ [bold cyan]Generating Kubernetes Manifests[/bold cyan]\n" + f"Manifest: [yellow]{manifest_file}[/yellow]\n" + f"Environment: [yellow]{environment}[/yellow]\n" + f"Output Directory: [yellow]{output_dir}[/yellow]", + title="Kubernetes Generation", + border_style="blue", + ) + ) + + try: + # Validate input files + if not os.path.exists(manifest_file): + console.print( + f"❌ [bold red]Manifest file not found: {manifest_file}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Generating Kubernetes manifests...", total=None) + + # Use the new template system + result = generate_k8s_setup( + manifest_file=manifest_file, + environment=environment, + output_dir=output_dir, + ) + + progress.update(task, description="Kubernetes manifests generated!") + + console.print( + f"✅ [bold green]Kubernetes setup generated successfully:[/bold green]" + ) + for file_type, file_paths in result.items(): + console.print(f" 📄 {file_type}:") + if isinstance(file_paths, list): + for file_path in file_paths: + console.print(f" - [cyan]{file_path}[/cyan]") + else: + console.print(f" - [cyan]{file_paths}[/cyan]") + + except Exception as e: + console.print( + f"💥 [bold red]Failed to generate Kubernetes manifests: {e}[/bold red]" + ) + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@generate_app.command("slurm") +def generate_slurm( + manifest_file: Annotated[ + str, + typer.Option( + "--manifest-file", + "-m", + help="📄 Path to build manifest JSON file", + ), + ] = "build_manifest.json", + environment: Annotated[ + str, + typer.Option( + "--environment", + "-e", + help="🌍 Environment configuration (default, dev, prod, test)", + ), + ] = "default", + output_dir: Annotated[ + str, + typer.Option( + "--output-dir", + "-o", + help="📂 Output directory for generated SLURM files", + ), + ] = "slurm-setup", + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 🖥️ Generate SLURM job scripts and configuration for distributed execution. + + Creates job array scripts, individual job scripts, inventory configuration, + and submission helper scripts for SLURM cluster execution. + + Example: + madengine-cli generate slurm --manifest-file build_manifest.json --environment prod --output-dir slurm-setup + """ + setup_logging(verbose) + + console.print( + Panel( + f"🖥️ [bold cyan]Generating SLURM Setup[/bold cyan]\n" + f"📄 Manifest: {manifest_file}\n" + f"🌍 Environment: {environment}\n" + f"📂 Output: {output_dir}", + title="SLURM Generation", + border_style="blue", + ) + ) + + # Validate manifest file exists + if not os.path.exists(manifest_file): + console.print(f"❌ [bold red]Manifest file not found: {manifest_file}[/bold red]") + raise typer.Exit(ExitCode.FAILURE) + + try: + with console.status("[bold green]Generating SLURM configuration..."): + # Generate complete SLURM setup + result = generate_slurm_setup( + manifest_file=manifest_file, + environment=environment, + output_dir=output_dir, + ) + + # Display success message with generated files + console.print(f"✅ [bold green]SLURM setup generated successfully![/bold green]") + console.print(f"📁 [cyan]Setup directory:[/cyan] {output_dir}") + + console.print("\n📋 [cyan]Generated files:[/cyan]") + for file_type, file_path in result.items(): + if file_type == "individual_jobs": + console.print(f" • [yellow]{file_type}:[/yellow] {len(file_path)} job scripts") + for job_script in file_path[:3]: # Show first 3 + console.print(f" - {os.path.basename(job_script)}") + if len(file_path) > 3: + console.print(f" - ... and {len(file_path) - 3} more") + else: + console.print(f" • [yellow]{file_type}:[/yellow] {file_path}") + + console.print( + f"\n💡 [dim]Next step:[/dim] [cyan]madengine-cli runner slurm --inventory {os.path.join(output_dir, 'inventory.yml')} --job-scripts-dir {output_dir}[/cyan]" + ) + + except FileNotFoundError as e: + console.print( + f"💥 [bold red]File not found: {e}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + except Exception as e: + console.print( + f"💥 [bold red]Failed to generate SLURM setup: {e}[/bold red]" + ) + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@generate_app.command("list") +def list_templates( + template_dir: Annotated[ + Optional[str], typer.Option("--template-dir", help="Custom template directory") + ] = None, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 📋 List available templates. + + Shows all available Jinja2 templates organized by type (ansible, k8s, etc.). + """ + setup_logging(verbose) + + console.print( + Panel( + f"📋 [bold cyan]Available Templates[/bold cyan]", + title="Template Listing", + border_style="blue", + ) + ) + + try: + # Create template generator + from madengine.runners.template_generator import TemplateGenerator + + generator = TemplateGenerator(template_dir) + + templates = generator.list_templates() + + if not templates: + console.print("❌ [yellow]No templates found[/yellow]") + raise typer.Exit(ExitCode.SUCCESS) + + # Display templates in a formatted table + table = Table( + title="Available Templates", show_header=True, header_style="bold magenta" + ) + table.add_column("Type", style="cyan") + table.add_column("Templates", style="yellow") + + for template_type, template_files in templates.items(): + files_str = "\n".join(template_files) if template_files else "No templates" + table.add_row(template_type.upper(), files_str) + + console.print(table) + + except Exception as e: + console.print(f"💥 [bold red]Failed to list templates: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@generate_app.command("validate") +def validate_template( + template_path: Annotated[ + str, typer.Argument(help="Path to template file to validate") + ], + template_dir: Annotated[ + Optional[str], typer.Option("--template-dir", help="Custom template directory") + ] = None, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + ✅ Validate template syntax. + + Validates Jinja2 template syntax and checks for common issues. + """ + setup_logging(verbose) + + console.print( + Panel( + f"✅ [bold cyan]Validating Template[/bold cyan]\n" + f"Template: [yellow]{template_path}[/yellow]", + title="Template Validation", + border_style="green", + ) + ) + + try: + # Create template generator + from madengine.runners.template_generator import TemplateGenerator + + generator = TemplateGenerator(template_dir) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Validating template...", total=None) + + is_valid = generator.validate_template(template_path) + + progress.update(task, description="Validation completed!") + + if is_valid: + console.print( + f"✅ [bold green]Template validation successful:[/bold green]" + ) + console.print(f" 📄 Template: [cyan]{template_path}[/cyan]") + console.print(f" 🎯 Syntax: [green]Valid[/green]") + else: + console.print(f"❌ [bold red]Template validation failed:[/bold red]") + console.print(f" 📄 Template: [cyan]{template_path}[/cyan]") + console.print(f" 🎯 Syntax: [red]Invalid[/red]") + raise typer.Exit(ExitCode.FAILURE) + + except Exception as e: + console.print(f"💥 [bold red]Failed to validate template: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(ExitCode.FAILURE) + + +@app.callback(invoke_without_command=True) +def main( + ctx: typer.Context, + version: Annotated[ + bool, typer.Option("--version", help="Show version and exit") + ] = False, +) -> None: + """ + 🚀 madengine Distributed Orchestrator + + Modern CLI for building and running AI models in distributed scenarios. + Built with Typer and Rich for a beautiful, production-ready experience. + """ + if version: + # You might want to get the actual version from your package + console.print( + "🚀 [bold cyan]madengine-cli[/bold cyan] version [green]1.0.0[/green]" + ) + raise typer.Exit() + + # If no command is provided, show help + if ctx.invoked_subcommand is None: + console.print(ctx.get_help()) + ctx.exit() + + +def cli_main() -> None: + """Entry point for the CLI application.""" + try: + app() + except KeyboardInterrupt: + console.print("\n🛑 [yellow]Operation cancelled by user[/yellow]") + sys.exit(ExitCode.FAILURE) + except Exception as e: + console.print(f"💥 [bold red]Unexpected error: {e}[/bold red]") + console.print_exception() + sys.exit(ExitCode.FAILURE) + + +if __name__ == "__main__": + cli_main() + + +# ============================================================================ +# RUNNER COMMANDS +# ============================================================================ + + +@runner_app.command("ssh") +def runner_ssh( + inventory_file: Annotated[ + str, + typer.Option( + "--inventory", + "-i", + help="🗂️ Path to inventory file (YAML or JSON format)", + ), + ] = DEFAULT_INVENTORY_FILE, + manifest_file: Annotated[ + str, + typer.Option( + "--manifest-file", + "-m", + help="📋 Build manifest file (generated by 'madengine-cli build')", + ), + ] = DEFAULT_MANIFEST_FILE, + report_output: Annotated[ + str, + typer.Option( + "--report-output", + help="📊 Output file for execution report", + ), + ] = DEFAULT_RUNNER_REPORT, + verbose: Annotated[ + bool, + typer.Option( + "--verbose", + "-v", + help="🔍 Enable verbose logging", + ), + ] = False, +): + """ + 🔐 Execute models across multiple nodes using SSH. + + Distributes pre-built build manifest (created by 'madengine-cli build') + to remote nodes based on inventory configuration and executes + 'madengine-cli run' remotely through SSH client. + + The build manifest contains all configuration (tags, timeout, registry, etc.) + so only inventory and manifest file paths are needed. + + Example: + madengine-cli runner ssh --inventory nodes.yml --manifest-file build_manifest.json + """ + setup_logging(verbose) + + try: + # Validate input files + if not os.path.exists(inventory_file): + console.print( + f"❌ [bold red]Inventory file not found: {inventory_file}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + if not os.path.exists(manifest_file): + console.print( + f"❌ [bold red]Build manifest file not found: {manifest_file}[/bold red]" + ) + console.print( + "💡 Generate it first using: [cyan]madengine-cli build[/cyan]" + ) + raise typer.Exit(ExitCode.FAILURE) + + # Create SSH runner + console.print("🚀 [bold blue]Starting SSH distributed execution[/bold blue]") + + with console.status("Initializing SSH runner..."): + runner = RunnerFactory.create_runner( + "ssh", inventory_path=inventory_file, console=console, verbose=verbose + ) + + # Execute workload (minimal spec - most info is in the manifest) + console.print(f"� Distributing manifest: [cyan]{manifest_file}[/cyan]") + console.print(f"📋 Using inventory: [cyan]{inventory_file}[/cyan]") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + "Executing SSH distributed workload...", total=None + ) + + # Create minimal workload spec (most info is in the manifest) + from madengine.runners.base import WorkloadSpec + + workload = WorkloadSpec( + model_tags=[], # Not needed - in manifest + manifest_file=manifest_file, # This is the key input + timeout=3600, # Default timeout, actual timeout from manifest + registry=None, # Auto-detected from manifest + additional_context={}, + node_selector={}, + parallelism=1, + ) + + result = runner.run(workload) + + # Display results + _display_runner_results(result, "SSH") + + # Generate report + report_path = runner.generate_report(report_output) + console.print( + f"📊 Execution report saved to: [bold green]{report_path}[/bold green]" + ) + + # Exit with appropriate code + if result.failed_executions == 0: + console.print( + "✅ [bold green]All executions completed successfully[/bold green]" + ) + raise typer.Exit(code=ExitCode.SUCCESS) + else: + console.print( + f"❌ [bold red]{result.failed_executions} execution(s) failed[/bold red]" + ) + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + except ImportError as e: + console.print(f"💥 [bold red]SSH runner not available: {e}[/bold red]") + console.print( + "Install SSH dependencies: [bold cyan]pip install paramiko scp[/bold cyan]" + ) + raise typer.Exit(code=ExitCode.FAILURE) + except Exception as e: + console.print(f"💥 [bold red]SSH execution failed: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + +@runner_app.command("ansible") +def runner_ansible( + inventory_file: Annotated[ + str, + typer.Option( + "--inventory", + "-i", + help="🗂️ Path to inventory file (YAML or JSON format)", + ), + ] = DEFAULT_INVENTORY_FILE, + playbook_file: Annotated[ + str, + typer.Option( + "--playbook", + help="📋 Path to Ansible playbook file (generated by 'madengine-cli generate ansible')", + ), + ] = DEFAULT_ANSIBLE_OUTPUT, + report_output: Annotated[ + str, + typer.Option( + "--report-output", + help="📊 Output file for execution report", + ), + ] = DEFAULT_RUNNER_REPORT, + verbose: Annotated[ + bool, + typer.Option( + "--verbose", + "-v", + help="🔍 Enable verbose logging", + ), + ] = False, +): + """ + ⚡ Execute models across cluster using Ansible. + + Runs pre-generated Ansible playbook (created by 'madengine-cli generate ansible') + with inventory file leveraging ansible-runner to distribute + workload for parallel execution of models on cluster. + + The playbook contains all configuration (tags, timeout, registry, etc.) + so only inventory and playbook paths are needed. + + Example: + madengine-cli runner ansible --inventory cluster.yml --playbook madengine_distributed.yml + """ + setup_logging(verbose) + + try: + # Validate input files + if not os.path.exists(inventory_file): + console.print( + f"❌ [bold red]Inventory file not found: {inventory_file}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + if not os.path.exists(playbook_file): + console.print( + f"❌ [bold red]Playbook file not found: {playbook_file}[/bold red]" + ) + console.print( + "💡 Generate it first using: [cyan]madengine-cli generate ansible[/cyan]" + ) + raise typer.Exit(ExitCode.FAILURE) + + # Create Ansible runner + console.print( + "🚀 [bold blue]Starting Ansible distributed execution[/bold blue]" + ) + + with console.status("Initializing Ansible runner..."): + runner = RunnerFactory.create_runner( + "ansible", + inventory_path=inventory_file, + playbook_path=playbook_file, + console=console, + verbose=verbose, + ) + + # Execute workload (no workload spec needed - everything is in the playbook) + console.print(f"� Executing playbook: [cyan]{playbook_file}[/cyan]") + console.print(f"📋 Using inventory: [cyan]{inventory_file}[/cyan]") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Executing Ansible playbook...", total=None) + + # Create minimal workload spec (most info is in the playbook) + from madengine.runners.base import WorkloadSpec + + workload = WorkloadSpec( + model_tags=[], # Not needed - in playbook + manifest_file="", # Not needed - in playbook + ) + + result = runner.run(workload) + + # Display results + _display_runner_results(result, "Ansible") + + # Generate report + report_path = runner.generate_report(report_output) + console.print( + f"📊 Execution report saved to: [bold green]{report_path}[/bold green]" + ) + + # Exit with appropriate code + if result.failed_executions == 0: + console.print( + "✅ [bold green]All executions completed successfully[/bold green]" + ) + raise typer.Exit(code=ExitCode.SUCCESS) + else: + console.print( + f"❌ [bold red]{result.failed_executions} execution(s) failed[/bold red]" + ) + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + except ImportError as e: + console.print(f"💥 [bold red]Ansible runner not available: {e}[/bold red]") + console.print( + "Install Ansible dependencies: [bold cyan]pip install ansible-runner[/bold cyan]" + ) + raise typer.Exit(code=ExitCode.FAILURE) + except Exception as e: + console.print(f"💥 [bold red]Ansible execution failed: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + +@runner_app.command("k8s") +def runner_k8s( + inventory_file: Annotated[ + str, + typer.Option( + "--inventory", + "-i", + help="🗂️ Path to inventory file (YAML or JSON format)", + ), + ] = DEFAULT_INVENTORY_FILE, + manifests_dir: Annotated[ + str, + typer.Option( + "--manifests-dir", + "-d", + help="📁 Directory containing Kubernetes manifests (generated by 'madengine-cli generate k8s')", + ), + ] = "k8s-setup", + kubeconfig: Annotated[ + Optional[str], + typer.Option( + "--kubeconfig", + help="⚙️ Path to kubeconfig file", + ), + ] = None, + report_output: Annotated[ + str, + typer.Option( + "--report-output", + help="📊 Output file for execution report", + ), + ] = DEFAULT_RUNNER_REPORT, + verbose: Annotated[ + bool, + typer.Option( + "--verbose", + "-v", + help="🔍 Enable verbose logging", + ), + ] = False, +): + """ + ☸️ Execute models across Kubernetes cluster. + + Runs pre-generated Kubernetes manifests (created by 'madengine-cli generate k8s') + with inventory file leveraging kubernetes python client to distribute + workload for parallel execution of models on cluster. + + The manifests contain all configuration (tags, timeout, registry, etc.) + so only inventory and manifests directory paths are needed. + + Example: + madengine-cli runner k8s --inventory cluster.yml --manifests-dir k8s-setup + """ + setup_logging(verbose) + + try: + # Validate input files/directories + if not os.path.exists(inventory_file): + console.print( + f"❌ [bold red]Inventory file not found: {inventory_file}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + if not os.path.exists(manifests_dir): + console.print( + f"❌ [bold red]Manifests directory not found: {manifests_dir}[/bold red]" + ) + console.print( + "💡 Generate it first using: [cyan]madengine-cli generate k8s[/cyan]" + ) + raise typer.Exit(ExitCode.FAILURE) + + # Create Kubernetes runner + console.print( + "🚀 [bold blue]Starting Kubernetes distributed execution[/bold blue]" + ) + + with console.status("Initializing Kubernetes runner..."): + runner = RunnerFactory.create_runner( + "k8s", + inventory_path=inventory_file, + manifests_dir=manifests_dir, + kubeconfig_path=kubeconfig, + console=console, + verbose=verbose, + ) + + # Execute workload (no workload spec needed - everything is in the manifests) + console.print(f"☸️ Applying manifests from: [cyan]{manifests_dir}[/cyan]") + console.print(f"📋 Using inventory: [cyan]{inventory_file}[/cyan]") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Executing Kubernetes manifests...", total=None) + + # Create minimal workload spec (most info is in the manifests) + from madengine.runners.base import WorkloadSpec + + workload = WorkloadSpec( + model_tags=[], # Not needed - in manifests + manifest_file="", # Not needed - in manifests + ) + + result = runner.run(workload) + + # Display results + _display_runner_results(result, "Kubernetes") + + # Generate report + report_path = runner.generate_report(report_output) + console.print( + f"📊 Execution report saved to: [bold green]{report_path}[/bold green]" + ) + + # Exit with appropriate code + if result.failed_executions == 0: + console.print( + "✅ [bold green]All executions completed successfully[/bold green]" + ) + raise typer.Exit(code=ExitCode.SUCCESS) + else: + console.print( + f"❌ [bold red]{result.failed_executions} execution(s) failed[/bold red]" + ) + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + except ImportError as e: + console.print(f"💥 [bold red]Kubernetes runner not available: {e}[/bold red]") + console.print( + "Install Kubernetes dependencies: [bold cyan]pip install kubernetes[/bold cyan]" + ) + raise typer.Exit(code=ExitCode.FAILURE) + except Exception as e: + console.print(f"💥 [bold red]Kubernetes execution failed: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + +@runner_app.command("slurm") +def runner_slurm( + inventory: Annotated[ + str, + typer.Option( + "--inventory", + "-i", + help="📋 Path to SLURM inventory file (generated by 'madengine-cli generate slurm')", + ), + ], + job_scripts_dir: Annotated[ + str, + typer.Option( + "--job-scripts-dir", + "-j", + help="📂 Directory containing generated SLURM job scripts", + ), + ], + timeout: Annotated[ + int, + typer.Option( + "--timeout", + "-t", + help="⏰ Execution timeout in seconds", + ), + ] = 3600, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Enable verbose logging") + ] = False, +) -> None: + """ + 🖥️ Run distributed workload using pre-generated SLURM job scripts. + + Runs pre-generated SLURM job scripts (created by 'madengine-cli generate slurm') + for distributed model execution across SLURM cluster nodes. + + Example: + madengine-cli runner slurm --inventory cluster.yml --job-scripts-dir slurm-setup + """ + setup_logging(verbose) + + console.print( + Panel( + f"🖥️ [bold cyan]SLURM Distributed Execution[/bold cyan]\n" + f"📋 Inventory: {inventory}\n" + f"📂 Job Scripts: {job_scripts_dir}\n" + f"⏰ Timeout: {timeout}s", + title="SLURM Runner", + border_style="blue", + ) + ) + + try: + # Validate input files/directories + if not os.path.exists(inventory): + console.print( + f"❌ [bold red]Inventory file not found: {inventory}[/bold red]" + ) + raise typer.Exit(ExitCode.FAILURE) + + if not os.path.exists(job_scripts_dir): + console.print( + f"❌ [bold red]Job scripts directory not found: {job_scripts_dir}[/bold red]" + ) + console.print( + "💡 Generate it first using: [cyan]madengine-cli generate slurm[/cyan]" + ) + raise typer.Exit(ExitCode.FAILURE) + + # Create SLURM runner + console.print( + "🚀 [bold blue]Starting SLURM distributed execution[/bold blue]" + ) + + with console.status("Initializing SLURM runner..."): + runner = RunnerFactory.create_runner( + "slurm", + inventory_path=inventory, + job_scripts_dir=job_scripts_dir, + console=console, + verbose=verbose, + ) + + # Create minimal workload spec for SLURM runner + from madengine.runners.base import WorkloadSpec + workload = WorkloadSpec( + model_tags=["slurm-execution"], # Will be determined from job scripts + manifest_file="", # Not needed for pre-generated scripts + timeout=timeout, + ) + + # Execute the workload + with console.status("🔄 Executing SLURM workload..."): + result = runner.run(workload) + + # Display results + _display_runner_results(result, "SLURM") + + # Display success/failure message + if result.successful_executions > 0: + console.print( + f"✅ [bold green]SLURM execution completed with {result.successful_executions} successful tasks[/bold green]" + ) + + if result.failed_executions > 0: + console.print( + f"⚠️ [bold yellow]{result.failed_executions} tasks failed[/bold yellow]" + ) + + # Exit with appropriate code + if result.successful_executions == 0: + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + except KeyboardInterrupt: + console.print("\n⚠️ [bold yellow]SLURM execution interrupted by user[/bold yellow]") + raise typer.Exit(code=ExitCode.FAILURE) + except Exception as e: + console.print(f"💥 [bold red]SLURM execution failed: {e}[/bold red]") + if verbose: + console.print_exception() + raise typer.Exit(code=ExitCode.RUN_FAILURE) + + +def _display_runner_results(result, runner_type: str): + """Display runner execution results in a formatted table. + + Args: + result: DistributedResult object + runner_type: Type of runner (SSH, Ansible, Kubernetes) + """ + console.print(f"\n📊 [bold blue]{runner_type} Execution Results[/bold blue]") + + # Summary table + summary_table = Table(title="Execution Summary") + summary_table.add_column("Metric", style="cyan") + summary_table.add_column("Value", style="magenta") + + summary_table.add_row("Total Nodes", str(result.total_nodes)) + summary_table.add_row("Successful Executions", str(result.successful_executions)) + summary_table.add_row("Failed Executions", str(result.failed_executions)) + summary_table.add_row("Total Duration", f"{result.total_duration:.2f}s") + + console.print(summary_table) + + # Detailed results table + if result.node_results: + results_table = Table(title="Detailed Results") + results_table.add_column("Node", style="cyan") + results_table.add_column("Model", style="yellow") + results_table.add_column("Status", style="green") + results_table.add_column("Duration", style="magenta") + results_table.add_column("Error", style="red") + + for exec_result in result.node_results: + status_color = "green" if exec_result.status == "SUCCESS" else "red" + status_text = f"[{status_color}]{exec_result.status}[/{status_color}]" + + results_table.add_row( + exec_result.node_id, + exec_result.model_tag, + status_text, + f"{exec_result.duration:.2f}s", + exec_result.error_message or "", + ) + + console.print(results_table) diff --git a/src/madengine/runners/__init__.py b/src/madengine/runners/__init__.py new file mode 100644 index 00000000..314dc1e5 --- /dev/null +++ b/src/madengine/runners/__init__.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +""" +MADEngine Distributed Runners Package + +This package provides distributed runners for orchestrating workloads +across multiple nodes and clusters using different infrastructure types. +""" + +from .base import ( + BaseDistributedRunner, + NodeConfig, + WorkloadSpec, + ExecutionResult, + DistributedResult, +) +from .factory import RunnerFactory + +# Import runners (optional imports to handle missing dependencies) +try: + from .ssh_runner import SSHDistributedRunner + + __all__ = ["SSHDistributedRunner"] +except ImportError: + __all__ = [] + +try: + from .ansible_runner import AnsibleDistributedRunner + + __all__.append("AnsibleDistributedRunner") +except ImportError: + pass + +try: + from .k8s_runner import KubernetesDistributedRunner + + __all__.append("KubernetesDistributedRunner") +except ImportError: + pass + +# Always export base classes and factory +__all__.extend( + [ + "BaseDistributedRunner", + "NodeConfig", + "WorkloadSpec", + "ExecutionResult", + "DistributedResult", + "RunnerFactory", + ] +) + +__version__ = "1.0.0" diff --git a/src/madengine/runners/ansible_runner.py b/src/madengine/runners/ansible_runner.py new file mode 100644 index 00000000..aaf01550 --- /dev/null +++ b/src/madengine/runners/ansible_runner.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +""" +Ansible Distributed Runner for MADEngine + +This module implements Ansible-based distributed execution using +the ansible-runner library for orchestrated parallel execution. +""" + +import json +import os +import tempfile +import time +import yaml +from typing import List, Optional, Dict, Any, Union +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass + +try: + import ansible_runner +except ImportError: + raise ImportError( + "Ansible runner requires ansible-runner. " + "Install with: pip install ansible-runner" + ) + +from madengine.runners.base import ( + BaseDistributedRunner, + NodeConfig, + WorkloadSpec, + ExecutionResult, + DistributedResult, +) +from madengine.core.errors import ( + RunnerError, + ConfigurationError, + create_error_context +) + + +@dataclass +class AnsibleExecutionError(RunnerError): + """Ansible execution specific errors.""" + + playbook_path: str + + def __init__(self, message: str, playbook_path: str, **kwargs): + self.playbook_path = playbook_path + context = create_error_context( + operation="ansible_execution", + component="AnsibleRunner", + file_path=playbook_path + ) + super().__init__(message, context=context, **kwargs) + + +class AnsibleDistributedRunner(BaseDistributedRunner): + """Distributed runner using Ansible with enhanced error handling.""" + + def __init__(self, inventory_path: str, playbook_path: str = None, **kwargs): + """Initialize Ansible distributed runner. + + Args: + inventory_path: Path to Ansible inventory file + playbook_path: Path to pre-generated Ansible playbook file + **kwargs: Additional arguments passed to base class + """ + super().__init__(inventory_path, **kwargs) + self.playbook_path = playbook_path or "madengine_distributed.yml" + self.playbook_dir = kwargs.get("playbook_dir", "/tmp/madengine_ansible") + self.cleanup_handlers: List[callable] = [] + self.created_files: List[str] = [] + self.executor: Optional[ThreadPoolExecutor] = None + + def _validate_inventory(self) -> bool: + """Validate Ansible inventory file.""" + try: + if not os.path.exists(self.inventory_path): + self.logger.error(f"Inventory file not found: {self.inventory_path}") + return False + + # Try to parse inventory + with open(self.inventory_path, "r") as f: + content = f.read() + + # Basic validation - should contain host information + if not content.strip(): + self.logger.error("Inventory file is empty") + return False + + return True + + except Exception as e: + self.logger.error(f"Invalid inventory file: {e}") + return False + + def _ensure_playbook_directory(self) -> bool: + """Ensure playbook directory exists and is writable.""" + try: + os.makedirs(self.playbook_dir, exist_ok=True) + + # Test write permissions + test_file = os.path.join(self.playbook_dir, ".test_write") + try: + with open(test_file, "w") as f: + f.write("test") + os.remove(test_file) + return True + except Exception as e: + self.logger.error(f"Playbook directory not writable: {e}") + return False + + except Exception as e: + self.logger.error(f"Failed to create playbook directory: {e}") + return False + + def _create_ansible_inventory(self, target_nodes: List[NodeConfig]) -> str: + """Create Ansible inventory file from node configurations. + + Args: + target_nodes: List of target nodes + + Returns: + Path to created inventory file + """ + inventory_data = { + "gpu_nodes": { + "hosts": {}, + "vars": { + "ansible_user": "root", + "ansible_ssh_common_args": "-o StrictHostKeyChecking=no", + }, + } + } + + for node in target_nodes: + host_vars = { + "ansible_host": node.address, + "ansible_port": node.port, + "ansible_user": node.username, + "gpu_count": node.gpu_count, + "gpu_vendor": node.gpu_vendor, + } + + # Add SSH key if provided + if node.ssh_key_path: + host_vars["ansible_ssh_private_key_file"] = node.ssh_key_path + + # Add custom labels as variables + host_vars.update(node.labels) + + inventory_data["gpu_nodes"]["hosts"][node.hostname] = host_vars + + # Write inventory file + inventory_file = os.path.join(self.playbook_dir, "inventory.yml") + with open(inventory_file, "w") as f: + yaml.dump(inventory_data, f, default_flow_style=False) + + return inventory_file + + def setup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Setup Ansible infrastructure for distributed execution. + + Args: + workload: Workload specification + + Returns: + True if setup successful, False otherwise + """ + try: + self.logger.info("Setting up Ansible infrastructure") + + # Validate prerequisites + if not self._validate_inventory(): + return False + + if not self._ensure_playbook_directory(): + return False + + # Validate that the pre-generated playbook exists + if not os.path.exists(self.playbook_path): + self.logger.error( + f"Playbook file not found: {self.playbook_path}. " + f"Generate it first using 'madengine-cli generate ansible'" + ) + return False + + # Create executor + self.executor = ThreadPoolExecutor(max_workers=4) + + self.logger.info("Ansible infrastructure setup completed") + return True + + except Exception as e: + self.logger.error(f"Ansible infrastructure setup failed: {e}") + return False + + def _execute_playbook(self) -> bool: + """Execute the pre-generated Ansible playbook.""" + try: + self.logger.info(f"Executing Ansible playbook: {self.playbook_path}") + + # Use ansible-runner for execution + result = ansible_runner.run( + private_data_dir=self.playbook_dir, + playbook=os.path.basename(self.playbook_path), + inventory=self.inventory_path, + suppress_env_files=True, + quiet=False, + ) + + if result.status == "successful": + self.logger.info("Ansible playbook completed successfully") + return True + else: + self.logger.error( + f"Ansible playbook failed with status: {result.status}" + ) + + # Log detailed error information + if hasattr(result, "stderr") and result.stderr: + self.logger.error(f"Stderr: {result.stderr}") + + return False + + except Exception as e: + self.logger.error(f"Playbook execution failed: {e}") + return False + + def execute_workload(self, workload: WorkloadSpec) -> DistributedResult: + """Execute workload using pre-generated Ansible playbook. + + Args: + workload: Minimal workload specification (most config is in playbook) + + Returns: + Distributed execution result + """ + try: + self.logger.info("Starting Ansible distributed workload execution") + + # Validate that the pre-generated playbook exists + if not os.path.exists(self.playbook_path): + return DistributedResult( + success=False, + node_results=[], + error_message=f"Playbook file not found: {self.playbook_path}. " + f"Generate it first using 'madengine-cli generate ansible'", + ) + + # Execute the pre-generated playbook directly + if not self._execute_playbook(): + return DistributedResult( + success=False, + node_results=[], + error_message="Playbook execution failed", + ) + + # Parse results + results = self._parse_execution_results() + + distributed_result = DistributedResult( + success=any(r.success for r in results), node_results=results + ) + + self.logger.info("Ansible distributed workload execution completed") + return distributed_result + + except Exception as e: + self.logger.error(f"Distributed execution failed: {e}") + return DistributedResult( + success=False, node_results=[], error_message=str(e) + ) + + def _parse_execution_results(self) -> List[ExecutionResult]: + """Parse execution results from Ansible output.""" + results = [] + + try: + # Parse results from ansible-runner output + artifacts_dir = os.path.join(self.playbook_dir, "artifacts") + if not os.path.exists(artifacts_dir): + self.logger.warning("No artifacts directory found") + return results + + # Look for job events or stdout + stdout_file = os.path.join(artifacts_dir, "stdout") + if os.path.exists(stdout_file): + with open(stdout_file, "r") as f: + output = f.read() + + # Create a basic result based on overall success + result = ExecutionResult( + node_id="ansible-execution", + model_tag="playbook", + success=True, # If we got here, basic execution succeeded + output=output, + error_message=None, + execution_time=0, + ) + results.append(result) + else: + # No output found - assume failed + result = ExecutionResult( + node_id="ansible-execution", + model_tag="playbook", + success=False, + error_message="No output artifacts found", + ) + results.append(result) + + return results + + except Exception as e: + self.logger.error(f"Failed to parse execution results: {e}") + return [ + ExecutionResult( + node_id="ansible-execution", + model_tag="playbook", + success=False, + error_message=f"Result parsing failed: {e}", + ) + ] + + def cleanup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Cleanup infrastructure after execution. + + Args: + workload: Workload specification + + Returns: + True if cleanup successful, False otherwise + """ + try: + self.logger.info("Cleaning up Ansible infrastructure") + + # Run custom cleanup handlers + for cleanup_handler in self.cleanup_handlers: + try: + cleanup_handler() + except Exception as e: + self.logger.warning(f"Cleanup handler failed: {e}") + + # Clean up created files + for file_path in self.created_files: + try: + if os.path.exists(file_path): + os.remove(file_path) + except Exception as e: + self.logger.warning(f"Failed to remove {file_path}: {e}") + + self.created_files.clear() + + # Shutdown executor + if self.executor: + self.executor.shutdown(wait=True) + self.executor = None + + # Optionally clean up playbook directory + if os.path.exists(self.playbook_dir): + try: + import shutil + + shutil.rmtree(self.playbook_dir) + except Exception as e: + self.logger.warning(f"Failed to remove playbook directory: {e}") + + self.logger.info("Ansible infrastructure cleanup completed") + return True + + except Exception as e: + self.logger.error(f"Cleanup failed: {e}") + return False + + def add_cleanup_handler(self, handler: callable): + """Add a cleanup handler to be called during cleanup.""" + self.cleanup_handlers.append(handler) + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit with cleanup.""" + self.cleanup_infrastructure(None) diff --git a/src/madengine/runners/base.py b/src/madengine/runners/base.py new file mode 100644 index 00000000..f82fbb53 --- /dev/null +++ b/src/madengine/runners/base.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +""" +Base Distributed Runner for MADEngine + +This module provides the abstract base class for distributed runners +that orchestrate workload execution across multiple nodes and clusters. +""" + +import json +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any + +from madengine.core.console import Console + + +@dataclass +class NodeConfig: + """Configuration for a single node in the distributed system.""" + + hostname: str + address: str + port: int = 22 + username: str = "root" + ssh_key_path: Optional[str] = None + gpu_count: int = 1 + gpu_vendor: str = "AMD" + labels: Dict[str, str] = field(default_factory=dict) + environment: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self): + """Validate node configuration.""" + if not self.hostname or not self.address: + raise ValueError("hostname and address are required") + if self.gpu_vendor not in ["AMD", "NVIDIA", "INTEL"]: + raise ValueError(f"Invalid gpu_vendor: {self.gpu_vendor}") + + +@dataclass +class WorkloadSpec: + """Specification for a distributed workload.""" + + model_tags: List[str] + manifest_file: str + timeout: int = 3600 + registry: Optional[str] = None + additional_context: Dict[str, Any] = field(default_factory=dict) + node_selector: Dict[str, str] = field(default_factory=dict) + parallelism: int = 1 + + def __post_init__(self): + """Validate workload specification.""" + if not self.model_tags: + raise ValueError("model_tags cannot be empty") + if not os.path.exists(self.manifest_file): + raise FileNotFoundError(f"Manifest file not found: {self.manifest_file}") + + +@dataclass +class ExecutionResult: + """Result of a distributed execution.""" + + node_id: str + model_tag: str + status: str # SUCCESS, FAILURE, TIMEOUT, SKIPPED + duration: float + performance_metrics: Dict[str, Any] = field(default_factory=dict) + error_message: Optional[str] = None + stdout: Optional[str] = None + stderr: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "node_id": self.node_id, + "model_tag": self.model_tag, + "status": self.status, + "duration": self.duration, + "performance_metrics": self.performance_metrics, + "error_message": self.error_message, + "stdout": self.stdout, + "stderr": self.stderr, + } + + +@dataclass +class DistributedResult: + """Overall result of a distributed execution.""" + + total_nodes: int + successful_executions: int + failed_executions: int + total_duration: float + node_results: List[ExecutionResult] = field(default_factory=list) + + def add_result(self, result: ExecutionResult): + """Add a node execution result.""" + self.node_results.append(result) + if result.status == "SUCCESS": + self.successful_executions += 1 + else: + self.failed_executions += 1 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "total_nodes": self.total_nodes, + "successful_executions": self.successful_executions, + "failed_executions": self.failed_executions, + "total_duration": self.total_duration, + "node_results": [result.to_dict() for result in self.node_results], + } + + +class BaseDistributedRunner(ABC): + """Abstract base class for distributed runners.""" + + def __init__( + self, + inventory_path: str, + console: Optional[Console] = None, + verbose: bool = False, + ): + """Initialize the distributed runner. + + Args: + inventory_path: Path to inventory configuration file + console: Console instance for output + verbose: Enable verbose logging + """ + self.inventory_path = inventory_path + self.console = console or Console() + self.verbose = verbose + self.logger = logging.getLogger(self.__class__.__name__) + + # Load inventory configuration + self.nodes = self._load_inventory(inventory_path) + + # Initialize result tracking + self.results = DistributedResult( + total_nodes=len(self.nodes), + successful_executions=0, + failed_executions=0, + total_duration=0.0, + ) + + def _load_inventory(self, inventory_path: str) -> List[NodeConfig]: + """Load inventory from configuration file. + + Args: + inventory_path: Path to inventory file + + Returns: + List of NodeConfig objects + """ + if not os.path.exists(inventory_path): + raise FileNotFoundError(f"Inventory file not found: {inventory_path}") + + with open(inventory_path, "r") as f: + if inventory_path.endswith(".json"): + inventory_data = json.load(f) + elif inventory_path.endswith((".yml", ".yaml")): + import yaml + + inventory_data = yaml.safe_load(f) + else: + raise ValueError(f"Unsupported inventory format: {inventory_path}") + + return self._parse_inventory(inventory_data) + + def _parse_inventory(self, inventory_data: Dict[str, Any]) -> List[NodeConfig]: + """Parse inventory data into NodeConfig objects. + + Args: + inventory_data: Raw inventory data + + Returns: + List of NodeConfig objects + """ + nodes = [] + + # Support different inventory formats + if "nodes" in inventory_data: + # Simple format: {"nodes": [{"hostname": "...", ...}]} + for node_data in inventory_data["nodes"]: + nodes.append(NodeConfig(**node_data)) + elif "gpu_nodes" in inventory_data: + # Ansible-style format: {"gpu_nodes": {...}} + for node_data in inventory_data["gpu_nodes"]: + nodes.append(NodeConfig(**node_data)) + else: + # Auto-detect format + for key, value in inventory_data.items(): + if isinstance(value, list): + for node_data in value: + if isinstance(node_data, dict) and "hostname" in node_data: + nodes.append(NodeConfig(**node_data)) + + if not nodes: + raise ValueError("No valid nodes found in inventory") + + return nodes + + def filter_nodes(self, node_selector: Dict[str, str]) -> List[NodeConfig]: + """Filter nodes based on selector criteria. + + Args: + node_selector: Key-value pairs for node selection + + Returns: + Filtered list of nodes + """ + if not node_selector: + return self.nodes + + filtered_nodes = [] + for node in self.nodes: + match = True + for key, value in node_selector.items(): + if key == "gpu_vendor" and node.gpu_vendor != value: + match = False + break + elif key in node.labels and node.labels[key] != value: + match = False + break + + if match: + filtered_nodes.append(node) + + return filtered_nodes + + def validate_workload(self, workload: WorkloadSpec) -> bool: + """Validate workload specification. + + Args: + workload: Workload specification to validate + + Returns: + True if valid, False otherwise + """ + try: + # Check manifest file exists + if not os.path.exists(workload.manifest_file): + self.logger.error(f"Manifest file not found: {workload.manifest_file}") + return False + + # Load and validate manifest + with open(workload.manifest_file, "r") as f: + manifest = json.load(f) + + if "built_images" not in manifest: + self.logger.error("Invalid manifest: missing built_images") + return False + + # Filter nodes based on selector + target_nodes = self.filter_nodes(workload.node_selector) + if not target_nodes: + self.logger.error("No nodes match the selector criteria") + return False + + return True + + except Exception as e: + self.logger.error(f"Workload validation failed: {e}") + return False + + def prepare_execution_context(self, workload: WorkloadSpec) -> Dict[str, Any]: + """Prepare execution context for distributed execution. + + Args: + workload: Workload specification + + Returns: + Execution context dictionary + """ + # Load manifest + with open(workload.manifest_file, "r") as f: + manifest = json.load(f) + + # Prepare context + context = { + "manifest": manifest, + "registry": workload.registry or manifest.get("registry", ""), + "timeout": workload.timeout, + "additional_context": workload.additional_context, + "model_tags": workload.model_tags, + "parallelism": workload.parallelism, + } + + return context + + @abstractmethod + def setup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Setup infrastructure for distributed execution. + + Args: + workload: Workload specification + + Returns: + True if setup successful, False otherwise + """ + pass + + @abstractmethod + def execute_workload(self, workload: WorkloadSpec) -> DistributedResult: + """Execute workload across distributed nodes. + + Args: + workload: Workload specification + + Returns: + Distributed execution result + """ + pass + + @abstractmethod + def cleanup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Cleanup infrastructure after execution. + + Args: + workload: Workload specification + + Returns: + True if cleanup successful, False otherwise + """ + pass + + def run(self, workload: WorkloadSpec) -> DistributedResult: + """Run the complete distributed execution workflow. + + Args: + workload: Workload specification + + Returns: + Distributed execution result + """ + import time + + start_time = time.time() + + try: + # Validate workload + if not self.validate_workload(workload): + raise ValueError("Invalid workload specification") + + # Setup infrastructure + if not self.setup_infrastructure(workload): + raise RuntimeError("Failed to setup infrastructure") + + # Execute workload + result = self.execute_workload(workload) + + # Cleanup infrastructure + self.cleanup_infrastructure(workload) + + # Update total duration + result.total_duration = time.time() - start_time + + return result + + except Exception as e: + self.logger.error(f"Distributed execution failed: {e}") + # Ensure cleanup even on failure + try: + self.cleanup_infrastructure(workload) + except Exception as cleanup_error: + self.logger.error(f"Cleanup failed: {cleanup_error}") + + # Return failure result + self.results.total_duration = time.time() - start_time + return self.results + + def generate_report(self, output_file: str = "distributed_report.json") -> str: + """Generate execution report. + + Args: + output_file: Output file path + + Returns: + Path to generated report + """ + report_data = self.results.to_dict() + + with open(output_file, "w") as f: + json.dump(report_data, f, indent=2) + + return output_file diff --git a/src/madengine/runners/factory.py b/src/madengine/runners/factory.py new file mode 100644 index 00000000..3637efe9 --- /dev/null +++ b/src/madengine/runners/factory.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Runner Factory for MADEngine + +This module provides a factory for creating distributed runners +based on the specified runner type. +""" + +import logging +from typing import Dict, Type + +from madengine.runners.base import BaseDistributedRunner + + +class RunnerFactory: + """Factory for creating distributed runners.""" + + _runners: Dict[str, Type[BaseDistributedRunner]] = {} + + @classmethod + def register_runner( + cls, runner_type: str, runner_class: Type[BaseDistributedRunner] + ): + """Register a runner class. + + Args: + runner_type: Type identifier for the runner + runner_class: Runner class to register + """ + cls._runners[runner_type] = runner_class + + @classmethod + def create_runner(cls, runner_type: str, **kwargs) -> BaseDistributedRunner: + """Create a runner instance. + + Args: + runner_type: Type of runner to create + **kwargs: Arguments to pass to runner constructor + + Returns: + Runner instance + + Raises: + ValueError: If runner type is not registered + """ + if runner_type not in cls._runners: + available_types = ", ".join(cls._runners.keys()) + raise ValueError( + f"Unknown runner type: {runner_type}. " + f"Available types: {available_types}" + ) + + runner_class = cls._runners[runner_type] + return runner_class(**kwargs) + + @classmethod + def get_available_runners(cls) -> list: + """Get list of available runner types. + + Returns: + List of registered runner types + """ + return list(cls._runners.keys()) + + +def register_default_runners(): + """Register default runners.""" + try: + from madengine.runners.ssh_runner import SSHDistributedRunner + + RunnerFactory.register_runner("ssh", SSHDistributedRunner) + except ImportError as e: + logging.warning(f"SSH runner not available: {e}") + + try: + from madengine.runners.ansible_runner import AnsibleDistributedRunner + + RunnerFactory.register_runner("ansible", AnsibleDistributedRunner) + except ImportError as e: + logging.warning(f"Ansible runner not available: {e}") + + try: + from madengine.runners.k8s_runner import KubernetesDistributedRunner + + RunnerFactory.register_runner("k8s", KubernetesDistributedRunner) + RunnerFactory.register_runner("kubernetes", KubernetesDistributedRunner) + except ImportError as e: + logging.warning(f"Kubernetes runner not available: {e}") + + try: + from madengine.runners.slurm_runner import SlurmDistributedRunner + + RunnerFactory.register_runner("slurm", SlurmDistributedRunner) + except ImportError as e: + logging.warning(f"SLURM runner not available: {e}") + + +# Auto-register default runners +register_default_runners() diff --git a/src/madengine/runners/k8s_runner.py b/src/madengine/runners/k8s_runner.py new file mode 100644 index 00000000..6ac9ce49 --- /dev/null +++ b/src/madengine/runners/k8s_runner.py @@ -0,0 +1,981 @@ +#!/usr/bin/env python3 +""" +Kubernetes Distributed Runner for MADEngine + +This module implements Kubernetes-based distributed execution using +the kubernetes Python client for orchestrated parallel execution. +""" + +import json +import os +import time +import yaml +from typing import Dict, List, Any, Optional +import contextlib +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass + +try: + from kubernetes import client, config + from kubernetes.client.rest import ApiException +except ImportError: + raise ImportError( + "Kubernetes runner requires kubernetes. Install with: pip install kubernetes" + ) + +from madengine.runners.base import ( + BaseDistributedRunner, + NodeConfig, + WorkloadSpec, + ExecutionResult, + DistributedResult, +) +from madengine.core.errors import ( + RunnerError, + ConfigurationError, + ConnectionError as MADConnectionError, + create_error_context +) + + +@dataclass +class KubernetesExecutionError(RunnerError): + """Kubernetes execution specific errors.""" + + resource_type: str + resource_name: str + + def __init__(self, message: str, resource_type: str, resource_name: str, **kwargs): + self.resource_type = resource_type + self.resource_name = resource_name + context = create_error_context( + operation="kubernetes_execution", + component="KubernetesRunner", + additional_info={ + "resource_type": resource_type, + "resource_name": resource_name + } + ) + super().__init__( + f"Kubernetes error in {resource_type}/{resource_name}: {message}", + context=context, + **kwargs + ) + + +class KubernetesDistributedRunner(BaseDistributedRunner): + """Distributed runner using Kubernetes with enhanced error handling.""" + + def __init__(self, inventory_path: str, manifests_dir: str, **kwargs): + """Initialize Kubernetes distributed runner. + + The runner only executes pre-generated Kubernetes manifests created by the generate command. + It does not create or modify any Kubernetes resources dynamically. + + Args: + inventory_path: Path to Kubernetes inventory/configuration file + manifests_dir: Directory containing pre-generated Kubernetes manifests + **kwargs: Additional arguments (kubeconfig_path, namespace, etc.) + """ + super().__init__(inventory_path, **kwargs) + self.manifests_dir = manifests_dir + self.kubeconfig_path = kwargs.get("kubeconfig_path") + self.namespace = kwargs.get("namespace", "default") + self.cleanup_handlers: List[callable] = [] + self.created_resources: List[Dict[str, str]] = [] + self.executor: Optional[ThreadPoolExecutor] = None + self.k8s_client = None + self.batch_client = None + self._connection_validated = False + + def _validate_kubernetes_connection(self) -> bool: + """Validate Kubernetes connection and permissions.""" + try: + if self._connection_validated: + return True + + # Test basic connectivity + version = self.k8s_client.get_version() + self.logger.info(f"Connected to Kubernetes cluster version: {version}") + + # Test namespace access + try: + self.k8s_client.read_namespace(name=self.namespace) + except client.exceptions.ApiException as e: + if e.status == 404: + self.logger.error(f"Namespace '{self.namespace}' not found") + return False + elif e.status == 403: + self.logger.error(f"No access to namespace '{self.namespace}'") + return False + raise + + # Test job creation permissions + try: + # Try to list jobs to check permissions + self.batch_client.list_namespaced_job(namespace=self.namespace, limit=1) + except client.exceptions.ApiException as e: + if e.status == 403: + self.logger.error("No permission to create jobs") + return False + raise + + self._connection_validated = True + return True + + except Exception as e: + self.logger.error(f"Kubernetes connection validation failed: {e}") + return False + + def _ensure_namespace_exists(self) -> bool: + """Ensure the target namespace exists.""" + try: + self.k8s_client.read_namespace(name=self.namespace) + return True + except client.exceptions.ApiException as e: + if e.status == 404: + # Try to create namespace + try: + namespace = client.V1Namespace( + metadata=client.V1ObjectMeta(name=self.namespace) + ) + self.k8s_client.create_namespace(body=namespace) + self.logger.info(f"Created namespace: {self.namespace}") + return True + except client.exceptions.ApiException as create_e: + self.logger.error(f"Failed to create namespace: {create_e}") + return False + else: + self.logger.error(f"Namespace access error: {e}") + return False + except Exception as e: + self.logger.error(f"Namespace validation failed: {e}") + return False + + def _init_kubernetes_client(self): + """Initialize Kubernetes client.""" + try: + if self.kubeconfig_path: + config.load_kube_config(config_file=self.kubeconfig_path) + else: + # Try in-cluster config first, fallback to default kubeconfig + try: + config.load_incluster_config() + except config.ConfigException: + config.load_kube_config() + + self.k8s_client = client.CoreV1Api() + self.batch_client = client.BatchV1Api() + + # Test connection + self.k8s_client.get_api_resources() + self.logger.info("Successfully connected to Kubernetes cluster") + + except Exception as e: + self.logger.error(f"Failed to initialize Kubernetes client: {e}") + raise + + def _parse_inventory(self, inventory_data: Dict[str, Any]) -> List[NodeConfig]: + """Parse Kubernetes inventory data. + + For Kubernetes, inventory represents node selectors and resource requirements + rather than individual nodes. + + Args: + inventory_data: Raw inventory data + + Returns: + List of NodeConfig objects (representing logical nodes/pods) + """ + nodes = [] + + # Support Kubernetes-specific inventory format + if "pods" in inventory_data: + for pod_spec in inventory_data["pods"]: + node = NodeConfig( + hostname=pod_spec.get("name", f"pod-{len(nodes)}"), + address=pod_spec.get("node_selector", {}).get( + "kubernetes.io/hostname", "" + ), + gpu_count=pod_spec.get("resources", {}) + .get("requests", {}) + .get("nvidia.com/gpu", 1), + gpu_vendor=pod_spec.get("gpu_vendor", "NVIDIA"), + labels=pod_spec.get("node_selector", {}), + environment=pod_spec.get("environment", {}), + ) + nodes.append(node) + elif "node_selectors" in inventory_data: + # Alternative format with explicit node selectors + for i, selector in enumerate(inventory_data["node_selectors"]): + node = NodeConfig( + hostname=f"pod-{i}", + address="", + gpu_count=selector.get("gpu_count", 1), + gpu_vendor=selector.get("gpu_vendor", "NVIDIA"), + labels=selector.get("labels", {}), + environment=selector.get("environment", {}), + ) + nodes.append(node) + else: + # Fallback to base class parsing + return super()._parse_inventory(inventory_data) + + return nodes + + def _create_namespace(self) -> bool: + """Create namespace if it doesn't exist. + + Returns: + True if namespace exists or was created, False otherwise + """ + try: + self.k8s_client.read_namespace(name=self.namespace) + self.logger.info(f"Namespace '{self.namespace}' already exists") + return True + except ApiException as e: + if e.status == 404: + # Namespace doesn't exist, create it + namespace = client.V1Namespace( + metadata=client.V1ObjectMeta(name=self.namespace) + ) + self.k8s_client.create_namespace(body=namespace) + self.logger.info(f"Created namespace '{self.namespace}'") + return True + else: + self.logger.error(f"Failed to check namespace: {e}") + return False + + def _create_configmap(self, workload: WorkloadSpec) -> bool: + """Create ConfigMap with manifest and configuration. + + Args: + workload: Workload specification + + Returns: + True if ConfigMap created successfully, False otherwise + """ + try: + # Read manifest file + with open(workload.manifest_file, "r") as f: + manifest_content = f.read() + + # Create ConfigMap data + config_data = { + "build_manifest.json": manifest_content, + "additional_context.json": json.dumps(workload.additional_context), + "config.json": json.dumps( + { + "timeout": workload.timeout, + "registry": workload.registry, + "model_tags": workload.model_tags, + } + ), + } + + # Add supporting files if they exist + supporting_files = ["credential.json", "data.json", "models.json"] + for file_name in supporting_files: + if os.path.exists(file_name): + try: + with open(file_name, "r") as f: + config_data[file_name] = f.read() + self.logger.info(f"Added {file_name} to ConfigMap") + except Exception as e: + self.logger.warning(f"Failed to read {file_name}: {e}") + + # Create ConfigMap + configmap = client.V1ConfigMap( + metadata=client.V1ObjectMeta( + name=self.configmap_name, namespace=self.namespace + ), + data=config_data, + ) + + # Delete existing ConfigMap if it exists + try: + self.k8s_client.delete_namespaced_config_map( + name=self.configmap_name, namespace=self.namespace + ) + except ApiException as e: + if e.status != 404: + self.logger.warning(f"Failed to delete existing ConfigMap: {e}") + + # Create new ConfigMap + self.k8s_client.create_namespaced_config_map( + namespace=self.namespace, body=configmap + ) + + self.created_resources.append(("ConfigMap", self.configmap_name)) + self.logger.info(f"Created ConfigMap '{self.configmap_name}'") + return True + + except Exception as e: + self.logger.error(f"Failed to create ConfigMap: {e}") + return False + + def _create_job( + self, node: NodeConfig, model_tag: str, workload: WorkloadSpec + ) -> str: + """Create Kubernetes Job for a specific model on a node. + + Args: + node: Node configuration + model_tag: Model tag to execute + workload: Workload specification + + Returns: + Job name if created successfully, None otherwise + """ + job_name = f"{self.job_name_prefix}-{node.hostname}-{model_tag}".replace( + "_", "-" + ).lower() + + try: + # Create container spec + container = client.V1Container( + name="madengine-runner", + image=self.container_image, + command=["sh", "-c"], + args=[ + f""" + # Setup MAD environment + if [ -d MAD ]; then + cd MAD && git pull origin main + else + git clone https://github.com/ROCm/MAD.git + fi + + cd MAD + python3 -m venv venv || true + source venv/bin/activate + pip install -r requirements.txt + pip install paramiko scp ansible-runner kubernetes PyYAML || true + + # Copy config files from mounted volume + cp /workspace/build_manifest.json . + cp /workspace/credential.json . 2>/dev/null || true + cp /workspace/data.json . 2>/dev/null || true + cp /workspace/models.json . 2>/dev/null || true + + # Execute madengine from MAD directory + madengine-cli run \\ + --manifest-file build_manifest.json \\ + --timeout {workload.timeout} \\ + --tags {model_tag} \\ + --registry {workload.registry or ''} \\ + --additional-context "$(cat /workspace/additional_context.json 2>/dev/null || echo '{{}}')" # noqa: E501 + """ + ], + volume_mounts=[ + client.V1VolumeMount(name="config-volume", mount_path="/workspace") + ], + env=[ + client.V1EnvVar(name=k, value=v) + for k, v in node.environment.items() + ], + resources=client.V1ResourceRequirements( + requests=( + {"nvidia.com/gpu": str(node.gpu_count)} + if node.gpu_vendor == "NVIDIA" + else ( + {"amd.com/gpu": str(node.gpu_count)} + if node.gpu_vendor == "AMD" + else {} + ) + ) + ), + ) + + # Create pod spec + pod_spec = client.V1PodSpec( + containers=[container], + restart_policy="Never", + volumes=[ + client.V1Volume( + name="config-volume", + config_map=client.V1ConfigMapVolumeSource( + name=self.configmap_name + ), + ) + ], + node_selector=node.labels if node.labels else None, + ) + + # Create job spec + job_spec = client.V1JobSpec( + template=client.V1PodTemplateSpec(spec=pod_spec), + backoff_limit=3, + ttl_seconds_after_finished=300, + ) + + # Create job + job = client.V1Job( + metadata=client.V1ObjectMeta(name=job_name, namespace=self.namespace), + spec=job_spec, + ) + + # Submit job + self.batch_client.create_namespaced_job(namespace=self.namespace, body=job) + + self.created_resources.append(("Job", job_name)) + self.logger.info(f"Created job '{job_name}'") + return job_name + + except Exception as e: + self.logger.error(f"Failed to create job '{job_name}': {e}") + return None + + def _wait_for_jobs( + self, job_names: List[str], timeout: int = 3600 + ) -> Dict[str, Any]: + """Wait for jobs to complete. + + Args: + job_names: List of job names to wait for + timeout: Timeout in seconds + + Returns: + Dictionary mapping job names to their results + """ + job_results = {} + start_time = time.time() + + while job_names and (time.time() - start_time) < timeout: + completed_jobs = [] + + for job_name in job_names: + try: + job = self.batch_client.read_namespaced_job( + name=job_name, namespace=self.namespace + ) + + if job.status.completion_time: + # Job completed successfully + job_results[job_name] = { + "status": "SUCCESS", + "completion_time": job.status.completion_time, + "start_time": job.status.start_time, + } + completed_jobs.append(job_name) + elif job.status.failed: + # Job failed + job_results[job_name] = { + "status": "FAILURE", + "failed_pods": job.status.failed, + "start_time": job.status.start_time, + } + completed_jobs.append(job_name) + + except ApiException as e: + self.logger.error(f"Failed to get job status for {job_name}: {e}") + job_results[job_name] = {"status": "FAILURE", "error": str(e)} + completed_jobs.append(job_name) + + # Remove completed jobs from the list + for job_name in completed_jobs: + job_names.remove(job_name) + + if job_names: + time.sleep(10) # Wait 10 seconds before checking again + + # Mark remaining jobs as timed out + for job_name in job_names: + job_results[job_name] = { + "status": "TIMEOUT", + "message": f"Job did not complete within {timeout} seconds", + } + + return job_results + + def _create_configmaps(self, workload: WorkloadSpec) -> bool: + """Create ConfigMaps for workload data with size validation.""" + try: + # Create ConfigMap for additional context + if workload.additional_context: + context_data = workload.additional_context + + # Validate ConfigMap size (1MB limit) + if len(json.dumps(context_data).encode("utf-8")) > 1024 * 1024: + self.logger.error("Additional context too large for ConfigMap") + return False + + configmap_name = f"{self.job_name_prefix}-context" + configmap = client.V1ConfigMap( + metadata=client.V1ObjectMeta( + name=configmap_name, namespace=self.namespace + ), + data={"additional_context.json": json.dumps(context_data)}, + ) + + try: + self.k8s_client.create_namespaced_config_map( + namespace=self.namespace, body=configmap + ) + self.created_resources.append( + { + "type": "configmap", + "name": configmap_name, + "namespace": self.namespace, + } + ) + self.logger.info(f"Created ConfigMap: {configmap_name}") + + except client.exceptions.ApiException as e: + if e.status == 409: # Already exists + self.logger.info(f"ConfigMap {configmap_name} already exists") + else: + self.logger.error(f"Failed to create ConfigMap: {e}") + return False + + # Create ConfigMap for manifest file + if workload.manifest_file and os.path.exists(workload.manifest_file): + with open(workload.manifest_file, "r") as f: + manifest_data = f.read() + + # Validate size + if len(manifest_data.encode("utf-8")) > 1024 * 1024: + self.logger.error("Manifest file too large for ConfigMap") + return False + + configmap_name = f"{self.job_name_prefix}-manifest" + configmap = client.V1ConfigMap( + metadata=client.V1ObjectMeta( + name=configmap_name, namespace=self.namespace + ), + data={"build_manifest.json": manifest_data}, + ) + + try: + self.k8s_client.create_namespaced_config_map( + namespace=self.namespace, body=configmap + ) + self.created_resources.append( + { + "type": "configmap", + "name": configmap_name, + "namespace": self.namespace, + } + ) + self.logger.info(f"Created ConfigMap: {configmap_name}") + + except client.exceptions.ApiException as e: + if e.status == 409: # Already exists + self.logger.info(f"ConfigMap {configmap_name} already exists") + else: + self.logger.error(f"Failed to create ConfigMap: {e}") + return False + + return True + + except Exception as e: + self.logger.error(f"ConfigMap creation failed: {e}") + return False + + def execute_workload(self, workload: WorkloadSpec = None) -> DistributedResult: + """Execute workload using pre-generated Kubernetes manifests. + + This method applies pre-generated Kubernetes manifests from the manifests_dir + and monitors the resulting jobs for completion. + + Args: + workload: Legacy parameter, not used in simplified workflow + + Returns: + Distributed execution result + """ + try: + self.logger.info( + "Starting Kubernetes distributed execution using pre-generated manifests" + ) + + # Initialize Kubernetes client + self._init_kubernetes_client() + + # Validate connection and permissions + if not self._validate_kubernetes_connection(): + return DistributedResult( + success=False, + node_results=[], + error_message="Failed to validate Kubernetes connection", + ) + + # Apply manifests + if not self._apply_manifests(): + return DistributedResult( + success=False, + node_results=[], + error_message="Failed to apply Kubernetes manifests", + ) + + # Monitor execution + results = self._monitor_execution() + + distributed_result = DistributedResult( + success=any(r.success for r in results) if results else False, + node_results=results, + ) + + self.logger.info("Kubernetes distributed execution completed") + return distributed_result + + except Exception as e: + self.logger.error(f"Distributed execution failed: {e}") + return DistributedResult( + success=False, node_results=[], error_message=str(e) + ) + + def _apply_manifests(self) -> bool: + """Apply pre-generated Kubernetes manifests from manifests_dir. + + Returns: + True if manifests applied successfully, False otherwise + """ + try: + if not os.path.exists(self.manifests_dir): + self.logger.error( + f"Manifests directory not found: {self.manifests_dir}" + ) + return False + + # Find all YAML manifest files + manifest_files = [] + for root, dirs, files in os.walk(self.manifests_dir): + for file in files: + if file.endswith((".yaml", ".yml")): + manifest_files.append(os.path.join(root, file)) + + if not manifest_files: + self.logger.error( + f"No YAML manifest files found in {self.manifests_dir}" + ) + return False + + self.logger.info(f"Applying {len(manifest_files)} manifest files") + + # Apply each manifest + for manifest_file in manifest_files: + if not self._apply_manifest_file(manifest_file): + return False + + self.logger.info("All manifests applied successfully") + return True + + except Exception as e: + self.logger.error(f"Failed to apply manifests: {e}") + return False + + def _apply_manifest_file(self, manifest_file: str) -> bool: + """Apply a single manifest file. + + Args: + manifest_file: Path to the manifest file + + Returns: + True if applied successfully, False otherwise + """ + try: + with open(manifest_file, "r") as f: + manifest_content = f.read() + + # Parse YAML documents (may contain multiple documents) + for document in yaml.safe_load_all(manifest_content): + if not document: + continue + + self._apply_manifest_object(document) + + self.logger.info(f"Applied manifest: {os.path.basename(manifest_file)}") + return True + + except Exception as e: + self.logger.error(f"Failed to apply manifest {manifest_file}: {e}") + return False + + def _apply_manifest_object(self, manifest: Dict[str, Any]) -> None: + """Apply a single Kubernetes manifest object. + + Args: + manifest: Kubernetes manifest as dictionary + """ + try: + kind = manifest.get("kind", "").lower() + api_version = manifest.get("apiVersion", "") + metadata = manifest.get("metadata", {}) + name = metadata.get("name", "unknown") + + # Track created resources for cleanup + resource_info = { + "kind": kind, + "name": name, + "namespace": metadata.get("namespace", self.namespace), + } + self.created_resources.append(resource_info) + + # Apply based on resource type + if kind == "job": + self.batch_client.create_namespaced_job( + namespace=resource_info["namespace"], body=manifest + ) + elif kind == "configmap": + self.k8s_client.create_namespaced_config_map( + namespace=resource_info["namespace"], body=manifest + ) + elif kind == "namespace": + self.k8s_client.create_namespace(body=manifest) + # Add more resource types as needed + else: + self.logger.warning(f"Unsupported resource type: {kind}") + + self.logger.debug(f"Applied {kind}/{name}") + + except ApiException as e: + if e.status == 409: # Already exists + self.logger.info(f"Resource {kind}/{name} already exists") + else: + raise + except Exception as e: + self.logger.error(f"Failed to apply {kind}/{name}: {e}") + raise + + def _monitor_execution(self) -> List[ExecutionResult]: + """Monitor execution of applied manifests. + + Returns: + List of execution results + """ + try: + results = [] + + # Find all job resources that were created + job_resources = [r for r in self.created_resources if r["kind"] == "job"] + + if not job_resources: + self.logger.warning("No jobs found to monitor") + return results + + self.logger.info(f"Monitoring {len(job_resources)} jobs") + + # Monitor each job + for job_resource in job_resources: + result = self._get_job_result( + job_resource["name"], + job_resource["name"], # Use job name as node_id + "unknown", # Model tag not available in simplified workflow + ) + results.append(result) + + return results + + except Exception as e: + self.logger.error(f"Failed to monitor execution: {e}") + return [] + + def _monitor_jobs(self, workload: WorkloadSpec) -> List[ExecutionResult]: + """Monitor job execution with timeout and error handling.""" + results = [] + + try: + # Get target nodes + target_nodes = self.filter_nodes(workload.node_selector) + + # Monitor jobs with timeout + start_time = time.time() + timeout = workload.timeout + 60 # Add buffer + + while (time.time() - start_time) < timeout: + all_completed = True + + for node in target_nodes: + for model_tag in workload.model_tags: + job_name = f"{self.job_name_prefix}-{node.hostname}-{model_tag}".replace( + "_", "-" + ).lower() + + try: + # Check if result already exists + if any( + r.node_id == node.hostname and r.model_tag == model_tag + for r in results + ): + continue + + # Get job status + job = self.batch_client.read_namespaced_job( + name=job_name, namespace=self.namespace + ) + + if job.status.succeeded: + # Job completed successfully + result = self._get_job_result( + job_name, node.hostname, model_tag + ) + results.append(result) + + elif job.status.failed: + # Job failed + result = ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message="Job failed", + ) + results.append(result) + + else: + # Job still running + all_completed = False + + except client.exceptions.ApiException as e: + if e.status == 404: + # Job not found + result = ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message="Job not found", + ) + results.append(result) + else: + self.logger.error(f"Error checking job {job_name}: {e}") + all_completed = False + + if all_completed: + break + + time.sleep(10) # Check every 10 seconds + + # Handle timeout + if (time.time() - start_time) >= timeout: + self.logger.warning("Job monitoring timed out") + # Add timeout results for missing jobs + for node in target_nodes: + for model_tag in workload.model_tags: + if not any( + r.node_id == node.hostname and r.model_tag == model_tag + for r in results + ): + result = ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message="Job timed out", + ) + results.append(result) + + return results + + except Exception as e: + self.logger.error(f"Job monitoring failed: {e}") + return results + + def _get_job_result( + self, job_name: str, node_id: str, model_tag: str + ) -> ExecutionResult: + """Get result from completed job.""" + try: + # Get pod logs + pods = self.k8s_client.list_namespaced_pod( + namespace=self.namespace, label_selector=f"job-name={job_name}" + ) + + if not pods.items: + return ExecutionResult( + node_id=node_id, + model_tag=model_tag, + success=False, + error_message="No pods found for job", + ) + + pod = pods.items[0] + + # Get pod logs + logs = self.k8s_client.read_namespaced_pod_log( + name=pod.metadata.name, namespace=self.namespace + ) + + # Parse result from logs + success = "SUCCESS" in logs + + return ExecutionResult( + node_id=node_id, + model_tag=model_tag, + success=success, + output=logs, + error_message=None if success else "Job failed", + ) + + except Exception as e: + self.logger.error(f"Error getting job result: {e}") + return ExecutionResult( + node_id=node_id, + model_tag=model_tag, + success=False, + error_message=str(e), + ) + + def cleanup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Cleanup infrastructure after execution. + + Args: + workload: Workload specification + + Returns: + True if cleanup successful, False otherwise + """ + try: + self.logger.info("Cleaning up Kubernetes infrastructure") + + # Run custom cleanup handlers + for cleanup_handler in self.cleanup_handlers: + try: + cleanup_handler() + except Exception as e: + self.logger.warning(f"Cleanup handler failed: {e}") + + # Clean up created resources + for resource in self.created_resources: + try: + if resource["type"] == "configmap": + self.k8s_client.delete_namespaced_config_map( + name=resource["name"], namespace=resource["namespace"] + ) + self.logger.info(f"Deleted ConfigMap: {resource['name']}") + elif resource["type"] == "job": + self.batch_client.delete_namespaced_job( + name=resource["name"], namespace=resource["namespace"] + ) + self.logger.info(f"Deleted Job: {resource['name']}") + except Exception as e: + self.logger.warning( + f"Failed to delete resource {resource['name']}: {e}" + ) + + self.created_resources.clear() + + # Shutdown executor + if self.executor: + self.executor.shutdown(wait=True) + self.executor = None + + self.logger.info("Kubernetes infrastructure cleanup completed") + return True + + except Exception as e: + self.logger.error(f"Cleanup failed: {e}") + return False + + def add_cleanup_handler(self, handler: callable): + """Add a cleanup handler to be called during cleanup.""" + self.cleanup_handlers.append(handler) + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit with cleanup.""" + self.cleanup_infrastructure(None) + + # ...existing methods remain the same... diff --git a/src/madengine/runners/orchestrator_generation.py b/src/madengine/runners/orchestrator_generation.py new file mode 100644 index 00000000..8e496731 --- /dev/null +++ b/src/madengine/runners/orchestrator_generation.py @@ -0,0 +1,781 @@ +"""Orchestrator generation module for MADEngine distributed execution. + +This module provides high-level interfaces for generating distributed +execution configurations using the template system. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import os +import json +from typing import Dict, Any, Optional, List +from pathlib import Path + +from .template_generator import TemplateGenerator + + +class OrchestatorGenerator: + """High-level interface for generating distributed execution configurations.""" + + def __init__( + self, template_dir: Optional[str] = None, values_dir: Optional[str] = None + ): + """Initialize the orchestrator generator. + + Args: + template_dir: Custom template directory path + values_dir: Custom values directory path + """ + self.template_generator = TemplateGenerator(template_dir, values_dir) + + def generate_complete_ansible_setup( + self, + manifest_file: str, + environment: str = "default", + output_dir: str = "ansible-setup", + ) -> Dict[str, str]: + """Generate complete Ansible setup including playbook, script, and inventory. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_dir: Output directory for generated files + + Returns: + dict: Dictionary mapping file types to generated file paths + """ + os.makedirs(output_dir, exist_ok=True) + + generated_files = {} + + # Generate playbook + playbook_file = os.path.join(output_dir, "madengine_playbook.yml") + self.template_generator.generate_ansible_playbook( + manifest_file, environment, playbook_file + ) + generated_files["playbook"] = playbook_file + + # Generate execution script + script_file = os.path.join(output_dir, "execute_models.py") + self.template_generator.generate_execution_script( + manifest_file, environment, script_file + ) + generated_files["script"] = script_file + + # Generate inventory file + inventory_file = os.path.join(output_dir, "inventory.yml") + self._generate_ansible_inventory(manifest_file, environment, inventory_file) + generated_files["inventory"] = inventory_file + + # Generate ansible.cfg + config_file = os.path.join(output_dir, "ansible.cfg") + self._generate_ansible_config(environment, config_file) + generated_files["config"] = config_file + + return generated_files + + def generate_complete_k8s_setup( + self, + manifest_file: str, + environment: str = "default", + output_dir: str = "k8s-setup", + ) -> Dict[str, List[str]]: + """Generate complete Kubernetes setup including manifests and deployment scripts. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_dir: Output directory for generated files + + Returns: + dict: Dictionary mapping resource types to generated file paths + """ + os.makedirs(output_dir, exist_ok=True) + + # Generate manifests + manifests_dir = os.path.join(output_dir, "manifests") + manifest_files = self.template_generator.generate_kubernetes_manifests( + manifest_file, environment, manifests_dir + ) + + # Generate deployment script + deploy_script = os.path.join(output_dir, "deploy.sh") + self._generate_k8s_deploy_script(environment, manifests_dir, deploy_script) + + # Generate cleanup script + cleanup_script = os.path.join(output_dir, "cleanup.sh") + self._generate_k8s_cleanup_script(environment, manifests_dir, cleanup_script) + + return { + "manifests": manifest_files, + "deploy_script": deploy_script, + "cleanup_script": cleanup_script, + } + + def generate_complete_slurm_setup( + self, + manifest_file: str, + environment: str = "default", + output_dir: str = "slurm-setup", + ) -> Dict[str, str]: + """Generate complete SLURM setup including job scripts and configuration. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_dir: Output directory for generated files + + Returns: + dict: Dictionary mapping file types to generated file paths + """ + os.makedirs(output_dir, exist_ok=True) + + generated_files = {} + + # Generate job array script + job_array_script = os.path.join(output_dir, "madengine_job_array.sh") + self.template_generator.generate_slurm_job_array( + manifest_file, environment, job_array_script + ) + generated_files["job_array"] = job_array_script + + # Generate environment setup script + setup_script = os.path.join(output_dir, "setup_environment.sh") + self.template_generator.generate_slurm_setup_script( + manifest_file, environment, setup_script + ) + generated_files["setup_script"] = setup_script + + # Generate SLURM inventory + inventory_file = os.path.join(output_dir, "inventory.yml") + self.template_generator.generate_slurm_inventory( + manifest_file, environment, inventory_file + ) + generated_files["inventory"] = inventory_file + + # Generate individual job scripts for each model + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Extract model tags + model_tags = [] + if "models" in manifest_data: + model_tags = list(manifest_data["models"].keys()) + elif "built_models" in manifest_data: + model_tags = list(manifest_data["built_models"].keys()) + elif "model_tags" in manifest_data: + model_tags = manifest_data["model_tags"] + + # Create job_scripts subdirectory + job_scripts_dir = os.path.join(output_dir, "job_scripts") + os.makedirs(job_scripts_dir, exist_ok=True) + + # Generate individual job script for each model + individual_jobs = [] + for model_tag in model_tags: + safe_tag = model_tag.replace(":", "-").replace("_", "-") + job_script_file = os.path.join(job_scripts_dir, f"madengine_{safe_tag}.sh") + self.template_generator.generate_slurm_single_job( + manifest_file, model_tag, environment, job_script_file + ) + individual_jobs.append(job_script_file) + + generated_files["individual_jobs"] = individual_jobs + + # Generate job submission helper script + submit_script = os.path.join(output_dir, "submit_jobs.py") + self._generate_slurm_submit_script( + manifest_file, environment, submit_script, output_dir + ) + generated_files["submit_script"] = submit_script + + return generated_files + + def _generate_slurm_submit_script( + self, manifest_file: str, environment: str, output_file: str, setup_dir: str + ): + """Generate Python script for SLURM job submission.""" + submit_script_content = f'''#!/usr/bin/env python3 +""" +SLURM Job Submission Script for MADEngine +Generated from manifest: {os.path.basename(manifest_file)} +Environment: {environment} +""" + +import subprocess +import time +import json +import os +from pathlib import Path + +class SlurmJobSubmitter: + def __init__(self, setup_dir="{setup_dir}"): + self.setup_dir = Path(setup_dir) + self.job_array_script = self.setup_dir / "madengine_job_array.sh" + self.setup_script = self.setup_dir / "setup_environment.sh" + self.inventory_file = self.setup_dir / "inventory.yml" + self.submitted_jobs = [] + + def submit_setup_job(self): + """Submit environment setup job first.""" + if not self.setup_script.exists(): + print(f"Setup script not found: {{self.setup_script}}") + return None + + cmd = ["sbatch", str(self.setup_script)] + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + job_id = result.stdout.strip().split()[-1] + print(f"Submitted setup job: {{job_id}}") + return job_id + else: + print(f"Failed to submit setup job: {{result.stderr}}") + return None + + def submit_job_array(self, dependency_job_id=None): + """Submit the main job array.""" + if not self.job_array_script.exists(): + print(f"Job array script not found: {{self.job_array_script}}") + return None + + cmd = ["sbatch"] + + # Add dependency if setup job was submitted + if dependency_job_id: + cmd.extend(["--dependency", f"afterok:{{dependency_job_id}}"]) + + cmd.append(str(self.job_array_script)) + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + job_id = result.stdout.strip().split()[-1] + print(f"Submitted job array: {{job_id}}") + self.submitted_jobs.append(job_id) + return job_id + else: + print(f"Failed to submit job array: {{result.stderr}}") + return None + + def monitor_jobs(self, job_ids, check_interval=30): + """Monitor job completion.""" + print(f"Monitoring jobs: {{job_ids}}") + + while job_ids: + time.sleep(check_interval) + + # Check job status + cmd = ["squeue", "--job", ",".join(job_ids), "--noheader", "--format=%i %T"] + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + running_jobs = [] + for line in result.stdout.strip().split("\\n"): + if line.strip(): + job_id, status = line.strip().split() + if status in ["PENDING", "RUNNING"]: + running_jobs.append(job_id) + else: + print(f"Job {{job_id}} completed with status: {{status}}") + + job_ids = running_jobs + else: + print("No running jobs found") + break + + print("All jobs completed") + + def run_full_workflow(self): + """Run the complete SLURM workflow.""" + print("Starting MADEngine SLURM execution workflow") + + # Submit setup job first + setup_job_id = self.submit_setup_job() + + if setup_job_id: + print(f"Waiting for setup job {{setup_job_id}} to complete...") + time.sleep(10) # Brief wait before submitting main jobs + + # Submit main job array + main_job_id = self.submit_job_array(setup_job_id) + + if main_job_id: + # Monitor the job array + self.monitor_jobs([main_job_id]) + else: + print("Failed to submit main job array") + +if __name__ == "__main__": + submitter = SlurmJobSubmitter() + submitter.run_full_workflow() +''' + + with open(output_file, "w") as f: + f.write(submit_script_content) + + # Make script executable + os.chmod(output_file, 0o755) + + def generate_execution_pipeline( + self, + manifest_file: str, + environment: str = "default", + output_dir: str = "pipeline", + ) -> Dict[str, str]: + """Generate a complete execution pipeline with monitoring. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_dir: Output directory for generated files + + Returns: + dict: Dictionary mapping component types to generated file paths + """ + os.makedirs(output_dir, exist_ok=True) + + generated_files = {} + + # Generate main execution script + main_script = os.path.join(output_dir, "run_pipeline.py") + self._generate_pipeline_script(manifest_file, environment, main_script) + generated_files["main_script"] = main_script + + # Generate monitoring script + monitor_script = os.path.join(output_dir, "monitor_execution.py") + self._generate_monitoring_script(manifest_file, environment, monitor_script) + generated_files["monitor_script"] = monitor_script + + # Generate configuration + config_file = os.path.join(output_dir, "pipeline_config.json") + self._generate_pipeline_config(manifest_file, environment, config_file) + generated_files["config"] = config_file + + return generated_files + + def validate_manifest(self, manifest_file: str) -> Dict[str, Any]: + """Validate build manifest for completeness. + + Args: + manifest_file: Path to build manifest JSON file + + Returns: + dict: Validation results + """ + if not os.path.exists(manifest_file): + return { + "valid": False, + "error": f"Manifest file not found: {manifest_file}", + } + + try: + with open(manifest_file, "r") as f: + manifest = json.load(f) + + validation_results = {"valid": True, "warnings": [], "errors": []} + + # Check required fields + required_fields = ["built_images", "context"] + for field in required_fields: + if field not in manifest: + validation_results["errors"].append( + f"Missing required field: {field}" + ) + validation_results["valid"] = False + + # Check for built images + if "built_images" in manifest: + if not manifest["built_images"]: + validation_results["warnings"].append( + "No built images found in manifest" + ) + else: + for image_name, image_info in manifest["built_images"].items(): + if "docker_image" not in image_info: + validation_results["warnings"].append( + f"Image {image_name} missing docker_image field" + ) + + # Check context + if "context" in manifest: + context = manifest["context"] + if "gpu_vendor" not in context: + validation_results["warnings"].append( + "GPU vendor not specified in context" + ) + + return validation_results + + except json.JSONDecodeError as e: + return {"valid": False, "error": f"Invalid JSON in manifest: {e}"} + except Exception as e: + return {"valid": False, "error": f"Error reading manifest: {e}"} + + def _generate_ansible_inventory( + self, manifest_file: str, environment: str, output_file: str + ): + """Generate Ansible inventory file.""" + # Load values to get host configuration + values = self.template_generator.load_values(environment) + + # Load manifest for additional context + with open(manifest_file, "r") as f: + manifest = json.load(f) + + gpu_vendor = manifest.get("context", {}).get("gpu_vendor", "") + + inventory_content = f"""# MADEngine Ansible Inventory +# Generated for environment: {environment} +# GPU Vendor: {gpu_vendor} + +[gpu_nodes] +# Add your GPU nodes here +# gpu-node-1 ansible_host=192.168.1.10 ansible_user=ubuntu +# gpu-node-2 ansible_host=192.168.1.11 ansible_user=ubuntu + +[gpu_nodes:vars] +madengine_environment={environment} +gpu_vendor={gpu_vendor} +madengine_registry={manifest.get('registry', '')} + +[all:vars] +ansible_python_interpreter=/usr/bin/python3 +ansible_ssh_common_args='-o StrictHostKeyChecking=no' +""" + + with open(output_file, "w") as f: + f.write(inventory_content) + + def _generate_ansible_config(self, environment: str, output_file: str): + """Generate Ansible configuration file.""" + config_content = f"""# MADEngine Ansible Configuration +# Generated for environment: {environment} + +[defaults] +inventory = inventory.yml +host_key_checking = False +stdout_callback = yaml +stderr_callback = yaml +remote_user = ubuntu +private_key_file = ~/.ssh/id_rsa +timeout = 30 +log_path = ./ansible.log + +[ssh_connection] +ssh_args = -o ForwardAgent=yes -o ControlMaster=auto -o ControlPersist=60s +pipelining = True +""" + + with open(output_file, "w") as f: + f.write(config_content) + + def _generate_k8s_deploy_script( + self, environment: str, manifests_dir: str, output_file: str + ): + """Generate Kubernetes deployment script.""" + script_content = f"""#!/bin/bash +# MADEngine Kubernetes Deployment Script +# Generated for environment: {environment} + +set -e + +MANIFESTS_DIR="{manifests_dir}" +NAMESPACE="madengine-{environment}" + +echo "Deploying MADEngine to Kubernetes..." +echo "Environment: {environment}" +echo "Namespace: $NAMESPACE" + +# Apply manifests in order +if [ -f "$MANIFESTS_DIR/namespace.yaml" ]; then + echo "Creating namespace..." + kubectl apply -f "$MANIFESTS_DIR/namespace.yaml" +fi + +if [ -f "$MANIFESTS_DIR/configmap.yaml" ]; then + echo "Creating configmap..." + kubectl apply -f "$MANIFESTS_DIR/configmap.yaml" +fi + +if [ -f "$MANIFESTS_DIR/service.yaml" ]; then + echo "Creating service..." + kubectl apply -f "$MANIFESTS_DIR/service.yaml" +fi + +if [ -f "$MANIFESTS_DIR/job.yaml" ]; then + echo "Creating job..." + kubectl apply -f "$MANIFESTS_DIR/job.yaml" +fi + +echo "Deployment complete!" +echo "Monitor the job with: kubectl get jobs -n $NAMESPACE" +echo "View logs with: kubectl logs -n $NAMESPACE -l app.kubernetes.io/name=madengine" +""" + + with open(output_file, "w") as f: + f.write(script_content) + + os.chmod(output_file, 0o755) + + def _generate_k8s_cleanup_script( + self, environment: str, manifests_dir: str, output_file: str + ): + """Generate Kubernetes cleanup script.""" + script_content = f"""#!/bin/bash +# MADEngine Kubernetes Cleanup Script +# Generated for environment: {environment} + +set -e + +MANIFESTS_DIR="{manifests_dir}" +NAMESPACE="madengine-{environment}" + +echo "Cleaning up MADEngine from Kubernetes..." +echo "Environment: {environment}" +echo "Namespace: $NAMESPACE" + +# Delete resources +if [ -f "$MANIFESTS_DIR/job.yaml" ]; then + echo "Deleting job..." + kubectl delete -f "$MANIFESTS_DIR/job.yaml" --ignore-not-found=true +fi + +if [ -f "$MANIFESTS_DIR/service.yaml" ]; then + echo "Deleting service..." + kubectl delete -f "$MANIFESTS_DIR/service.yaml" --ignore-not-found=true +fi + +if [ -f "$MANIFESTS_DIR/configmap.yaml" ]; then + echo "Deleting configmap..." + kubectl delete -f "$MANIFESTS_DIR/configmap.yaml" --ignore-not-found=true +fi + +if [ -f "$MANIFESTS_DIR/namespace.yaml" ]; then + echo "Deleting namespace..." + kubectl delete -f "$MANIFESTS_DIR/namespace.yaml" --ignore-not-found=true +fi + +echo "Cleanup complete!" +""" + + with open(output_file, "w") as f: + f.write(script_content) + + os.chmod(output_file, 0o755) + + def _generate_pipeline_script( + self, manifest_file: str, environment: str, output_file: str + ): + """Generate pipeline execution script.""" + script_content = f"""#!/usr/bin/env python3 +\"\"\" +MADEngine Execution Pipeline +Generated for environment: {environment} +\"\"\" + +import os +import sys +import json +import time +import subprocess +from datetime import datetime + +def main(): + \"\"\"Main pipeline execution function.\"\"\" + print("=" * 80) + print("MADEngine Execution Pipeline") + print("=" * 80) + print(f"Started: {{datetime.now().isoformat()}}") + print(f"Environment: {environment}") + + # Load configuration + with open('pipeline_config.json', 'r') as f: + config = json.load(f) + + # Execute based on orchestrator type + orchestrator_type = config.get('orchestrator_type', 'ansible') + + if orchestrator_type == 'ansible': + return run_ansible_pipeline(config) + elif orchestrator_type == 'k8s': + return run_k8s_pipeline(config) + else: + print(f"Unknown orchestrator type: {{orchestrator_type}}") + return 1 + +def run_ansible_pipeline(config): + \"\"\"Run Ansible-based pipeline.\"\"\" + print("Running Ansible pipeline...") + + # Run ansible playbook + cmd = [ + 'ansible-playbook', + '-i', 'inventory.yml', + 'madengine_playbook.yml' + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + print("Ansible execution completed successfully") + return 0 + else: + print(f"Ansible execution failed: {{result.stderr}}") + return 1 + +def run_k8s_pipeline(config): + \"\"\"Run Kubernetes-based pipeline.\"\"\" + print("Running Kubernetes pipeline...") + + # Deploy to Kubernetes + result = subprocess.run(['./deploy.sh'], capture_output=True, text=True) + + if result.returncode == 0: + print("Kubernetes deployment completed successfully") + return 0 + else: + print(f"Kubernetes deployment failed: {{result.stderr}}") + return 1 + +if __name__ == '__main__': + sys.exit(main()) +""" + + with open(output_file, "w") as f: + f.write(script_content) + + os.chmod(output_file, 0o755) + + def _generate_monitoring_script( + self, manifest_file: str, environment: str, output_file: str + ): + """Generate monitoring script.""" + script_content = f"""#!/usr/bin/env python3 +\"\"\" +MADEngine Execution Monitoring +Generated for environment: {environment} +\"\"\" + +import os +import sys +import json +import time +import subprocess +from datetime import datetime + +def main(): + \"\"\"Main monitoring function.\"\"\" + print("=" * 80) + print("MADEngine Execution Monitor") + print("=" * 80) + print(f"Started: {{datetime.now().isoformat()}}") + print(f"Environment: {environment}") + + # Load configuration + with open('pipeline_config.json', 'r') as f: + config = json.load(f) + + orchestrator_type = config.get('orchestrator_type', 'ansible') + + if orchestrator_type == 'k8s': + return monitor_k8s_execution(config) + else: + print("Monitoring not implemented for this orchestrator type") + return 0 + +def monitor_k8s_execution(config): + \"\"\"Monitor Kubernetes execution.\"\"\" + namespace = config.get('namespace', 'madengine-{environment}') + + print(f"Monitoring namespace: {{namespace}}") + + while True: + try: + # Check job status + result = subprocess.run([ + 'kubectl', 'get', 'jobs', '-n', namespace, + '-o', 'json' + ], capture_output=True, text=True) + + if result.returncode == 0: + jobs = json.loads(result.stdout) + for job in jobs.get('items', []): + name = job['metadata']['name'] + status = job.get('status', {{}}) + + if status.get('succeeded', 0) > 0: + print(f"Job {{name}} completed successfully") + return 0 + elif status.get('failed', 0) > 0: + print(f"Job {{name}} failed") + return 1 + else: + print(f"Job {{name}} still running...") + + time.sleep(30) + + except KeyboardInterrupt: + print("Monitoring interrupted by user") + return 0 + except Exception as e: + print(f"Error monitoring: {{e}}") + return 1 + +if __name__ == '__main__': + sys.exit(main()) +""" + + with open(output_file, "w") as f: + f.write(script_content) + + os.chmod(output_file, 0o755) + + def _generate_pipeline_config( + self, manifest_file: str, environment: str, output_file: str + ): + """Generate pipeline configuration.""" + # Load manifest for context + with open(manifest_file, "r") as f: + manifest = json.load(f) + + config = { + "environment": environment, + "orchestrator_type": "ansible", # Default to ansible + "namespace": f"madengine-{environment}", + "manifest_file": manifest_file, + "registry": manifest.get("registry", ""), + "gpu_vendor": manifest.get("context", {}).get("gpu_vendor", ""), + "monitoring": {"enabled": True, "interval": 30}, + "timeouts": {"execution": 7200, "monitoring": 14400}, + } + + with open(output_file, "w") as f: + json.dump(config, f, indent=2) + + +# Convenience functions for backward compatibility +def generate_ansible_setup( + manifest_file: str, environment: str = "default", output_dir: str = "ansible-setup" +) -> Dict[str, str]: + """Generate complete Ansible setup.""" + generator = OrchestatorGenerator() + return generator.generate_complete_ansible_setup( + manifest_file, environment, output_dir + ) + + +def generate_k8s_setup( + manifest_file: str, environment: str = "default", output_dir: str = "k8s-setup" +) -> Dict[str, List[str]]: + """Generate complete Kubernetes setup.""" + generator = OrchestatorGenerator() + return generator.generate_complete_k8s_setup(manifest_file, environment, output_dir) + + +def generate_slurm_setup( + manifest_file: str, environment: str = "default", output_dir: str = "slurm-setup" +) -> Dict[str, str]: + """Generate complete SLURM setup.""" + generator = OrchestatorGenerator() + return generator.generate_complete_slurm_setup(manifest_file, environment, output_dir) diff --git a/src/madengine/runners/slurm_runner.py b/src/madengine/runners/slurm_runner.py new file mode 100644 index 00000000..f6f73cf1 --- /dev/null +++ b/src/madengine/runners/slurm_runner.py @@ -0,0 +1,751 @@ +#!/usr/bin/env python3 +""" +SLURM Distributed Runner for MADEngine + +This module implements SLURM-based distributed execution using +SLURM workload manager for orchestrated parallel execution across HPC clusters. +""" + +import json +import logging +import os +import subprocess +import time +import yaml +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, Dict, Any, List, Tuple +from dataclasses import dataclass +from pathlib import Path + +try: + import paramiko + from scp import SCPClient +except ImportError: + raise ImportError( + "SLURM runner requires paramiko and scp for SSH connections. " + "Install with: pip install paramiko scp" + ) + +from madengine.runners.base import ( + BaseDistributedRunner, + NodeConfig, + WorkloadSpec, + ExecutionResult, + DistributedResult, +) +from madengine.core.errors import ( + ConnectionError as MADConnectionError, + AuthenticationError, + TimeoutError as MADTimeoutError, + RunnerError, + create_error_context +) + + +@dataclass +class SlurmNodeConfig(NodeConfig): + """SLURM-specific node configuration.""" + partition: str = "gpu" + qos: Optional[str] = None + account: Optional[str] = None + constraint: Optional[str] = None + exclusive: bool = False + mem_per_gpu: Optional[str] = None + max_time: str = "24:00:00" + + +@dataclass +class SlurmExecutionError(RunnerError): + """SLURM execution specific errors.""" + + job_id: str + + def __init__(self, message: str, job_id: str, **kwargs): + self.job_id = job_id + context = create_error_context( + operation="slurm_execution", + component="SlurmRunner", + additional_info={"job_id": job_id} + ) + super().__init__(f"SLURM job {job_id}: {message}", context=context, **kwargs) + + +class SlurmConnection: + """Manages SSH connection to SLURM login node.""" + + def __init__(self, login_node: Dict[str, Any], timeout: int = 30): + """Initialize SSH connection to SLURM login node. + + Args: + login_node: Login node configuration + timeout: Connection timeout in seconds + """ + self.login_node = login_node + self.timeout = timeout + self.ssh_client = None + self.sftp_client = None + self.logger = logging.getLogger(f"SlurmConnection.{login_node['hostname']}") + self._connected = False + + def connect(self) -> bool: + """Establish SSH connection to SLURM login node. + + Returns: + True if connection successful, False otherwise + """ + try: + self.ssh_client = paramiko.SSHClient() + self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # Connection parameters + connect_params = { + "hostname": self.login_node["address"], + "port": self.login_node.get("port", 22), + "username": self.login_node["username"], + "timeout": self.timeout, + } + + # Use SSH key if provided + if self.login_node.get("ssh_key_path"): + expanded_key_path = os.path.expanduser(self.login_node["ssh_key_path"]) + if os.path.exists(expanded_key_path): + connect_params["key_filename"] = expanded_key_path + os.chmod(expanded_key_path, 0o600) + + self.ssh_client.connect(**connect_params) + self.sftp_client = self.ssh_client.open_sftp() + + self._connected = True + self.logger.info(f"Successfully connected to SLURM login node {self.login_node['hostname']}") + return True + + except Exception as e: + self.logger.error(f"Failed to connect to SLURM login node: {e}") + return False + + def is_connected(self) -> bool: + """Check if connection is active.""" + return ( + self._connected + and self.ssh_client + and self.ssh_client.get_transport() + and self.ssh_client.get_transport().is_active() + ) + + def execute_command(self, command: str, timeout: int = 300) -> Tuple[int, str, str]: + """Execute command on SLURM login node. + + Args: + command: Command to execute + timeout: Command timeout in seconds + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + if not self.is_connected(): + raise MADConnectionError("Connection not established") + + try: + stdin, stdout, stderr = self.ssh_client.exec_command(command, timeout=timeout) + exit_code = stdout.channel.recv_exit_status() + stdout_str = stdout.read().decode("utf-8", errors="replace") + stderr_str = stderr.read().decode("utf-8", errors="replace") + + return exit_code, stdout_str, stderr_str + + except Exception as e: + self.logger.error(f"Command execution failed: {e}") + return 1, "", str(e) + + def copy_file(self, local_path: str, remote_path: str, create_dirs: bool = True) -> bool: + """Copy file to SLURM login node. + + Args: + local_path: Local file path + remote_path: Remote file path + create_dirs: Whether to create remote directories + + Returns: + True if copy successful, False otherwise + """ + if not self.is_connected(): + raise MADConnectionError("Connection not established") + + try: + if not os.path.exists(local_path): + raise FileNotFoundError(f"Local file not found: {local_path}") + + # Create directory if needed + if create_dirs: + remote_dir = os.path.dirname(remote_path) + if remote_dir: + self.execute_command(f"mkdir -p {remote_dir}") + + # Copy file + self.sftp_client.put(local_path, remote_path) + self.sftp_client.chmod(remote_path, 0o644) + + self.logger.debug(f"Successfully copied {local_path} to {remote_path}") + return True + + except Exception as e: + self.logger.error(f"File copy failed: {e}") + return False + + def close(self): + """Close SSH connection.""" + try: + if self.sftp_client: + self.sftp_client.close() + self.sftp_client = None + if self.ssh_client: + self.ssh_client.close() + self.ssh_client = None + self._connected = False + self.logger.debug(f"Closed connection to {self.login_node['hostname']}") + except Exception as e: + self.logger.warning(f"Error closing connection: {e}") + + def __enter__(self): + """Context manager entry.""" + if not self.connect(): + raise MADConnectionError("Failed to establish SLURM connection") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() + + +class SlurmDistributedRunner(BaseDistributedRunner): + """Distributed runner using SLURM workload manager.""" + + def __init__(self, inventory_path: str, job_scripts_dir: str = None, **kwargs): + """Initialize SLURM distributed runner. + + Args: + inventory_path: Path to SLURM inventory configuration file + job_scripts_dir: Directory containing pre-generated job scripts + **kwargs: Additional arguments passed to base class + """ + super().__init__(inventory_path, **kwargs) + self.job_scripts_dir = Path(job_scripts_dir) if job_scripts_dir else None + self.slurm_connection: Optional[SlurmConnection] = None + self.submitted_jobs: List[str] = [] + self.cleanup_handlers: List[callable] = [] + + # Load SLURM-specific configuration + self.slurm_config = self._load_slurm_config() + + def _load_slurm_config(self) -> Dict[str, Any]: + """Load SLURM-specific configuration from inventory.""" + if not os.path.exists(self.inventory_path): + raise FileNotFoundError(f"Inventory file not found: {self.inventory_path}") + + with open(self.inventory_path, "r") as f: + if self.inventory_path.endswith(".json"): + inventory_data = json.load(f) + else: + inventory_data = yaml.safe_load(f) + + if "slurm_cluster" not in inventory_data: + raise ValueError("Invalid SLURM inventory: missing 'slurm_cluster' section") + + return inventory_data["slurm_cluster"] + + def _parse_inventory(self, inventory_data: Dict[str, Any]) -> List[NodeConfig]: + """Parse SLURM inventory data into NodeConfig objects. + + For SLURM, nodes represent logical execution units (partitions/resources) + rather than individual physical nodes. + + Args: + inventory_data: Raw inventory data + + Returns: + List of NodeConfig objects representing SLURM partitions + """ + nodes = [] + + if "slurm_cluster" in inventory_data: + slurm_config = inventory_data["slurm_cluster"] + + # Create logical nodes from partitions + for partition in slurm_config.get("partitions", []): + node = SlurmNodeConfig( + hostname=partition["name"], + address="slurm-partition", # Logical address + partition=partition["name"], + gpu_count=partition.get("default_gpu_count", 1), + gpu_vendor=partition.get("gpu_vendor", "AMD"), + labels={"partition": partition["name"]}, + qos=partition.get("qos"), + account=partition.get("account"), + max_time=partition.get("max_time", "24:00:00"), + ) + nodes.append(node) + + if not nodes: + raise ValueError("No SLURM partitions found in inventory") + + return nodes + + def setup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Setup SLURM infrastructure for distributed execution. + + Args: + workload: Workload specification + + Returns: + True if setup successful, False otherwise + """ + try: + self.logger.info("Setting up SLURM infrastructure for distributed execution") + + # Validate pre-generated job scripts exist + if not self._validate_job_scripts(): + self.logger.error("Pre-generated job scripts not found") + return False + + # Establish connection to SLURM login node + login_node = self.slurm_config["login_node"] + self.slurm_connection = SlurmConnection(login_node) + + if not self.slurm_connection.connect(): + self.logger.error("Failed to connect to SLURM login node") + return False + + # Validate SLURM cluster access + if not self._validate_slurm_access(): + self.logger.error("SLURM cluster validation failed") + return False + + # Copy job scripts to SLURM login node + if not self._copy_job_scripts(): + self.logger.error("Failed to copy job scripts to SLURM cluster") + return False + + self.logger.info("SLURM infrastructure setup completed successfully") + return True + + except Exception as e: + self.logger.error(f"SLURM infrastructure setup failed: {e}") + return False + + def _validate_job_scripts(self) -> bool: + """Validate that pre-generated job scripts exist.""" + if not self.job_scripts_dir or not self.job_scripts_dir.exists(): + self.logger.error(f"Job scripts directory not found: {self.job_scripts_dir}") + return False + + # Check for job array script + job_array_script = self.job_scripts_dir / "madengine_job_array.sh" + if not job_array_script.exists(): + self.logger.error(f"Job array script not found: {job_array_script}") + return False + + # Check for setup script + setup_script = self.job_scripts_dir / "setup_environment.sh" + if not setup_script.exists(): + self.logger.error(f"Setup script not found: {setup_script}") + return False + + return True + + def _validate_slurm_access(self) -> bool: + """Validate SLURM cluster access and permissions.""" + try: + # Test basic SLURM commands + exit_code, stdout, stderr = self.slurm_connection.execute_command("sinfo --version") + if exit_code != 0: + self.logger.error(f"SLURM not available: {stderr}") + return False + + # Check available partitions + exit_code, stdout, stderr = self.slurm_connection.execute_command("sinfo -h -o '%P'") + if exit_code != 0: + self.logger.error(f"Failed to query SLURM partitions: {stderr}") + return False + + available_partitions = [p.strip('*') for p in stdout.strip().split('\n') if p.strip()] + self.logger.info(f"Available SLURM partitions: {available_partitions}") + + return True + + except Exception as e: + self.logger.error(f"SLURM access validation failed: {e}") + return False + + def _copy_job_scripts(self) -> bool: + """Copy job scripts to SLURM login node.""" + try: + workspace_path = self.slurm_config.get("workspace", {}).get("shared_filesystem", "/shared/madengine") + scripts_dir = f"{workspace_path}/job_scripts" + + # Create remote scripts directory + self.slurm_connection.execute_command(f"mkdir -p {scripts_dir}") + + # Copy all job scripts + for script_file in self.job_scripts_dir.glob("*.sh"): + remote_path = f"{scripts_dir}/{script_file.name}" + if not self.slurm_connection.copy_file(str(script_file), remote_path): + return False + # Make scripts executable + self.slurm_connection.execute_command(f"chmod +x {remote_path}") + + # Copy Python submission script if exists + submit_script = self.job_scripts_dir / "submit_jobs.py" + if submit_script.exists(): + remote_path = f"{workspace_path}/submit_jobs.py" + if not self.slurm_connection.copy_file(str(submit_script), remote_path): + return False + self.slurm_connection.execute_command(f"chmod +x {remote_path}") + + self.logger.info("Successfully copied job scripts to SLURM cluster") + return True + + except Exception as e: + self.logger.error(f"Failed to copy job scripts: {e}") + return False + + def execute_workload(self, workload: WorkloadSpec) -> DistributedResult: + """Execute workload using pre-generated SLURM job scripts. + + Args: + workload: Workload specification (minimal, most config is in scripts) + + Returns: + Distributed execution result + """ + try: + self.logger.info("Starting SLURM distributed execution using pre-generated job scripts") + + # Validate job scripts exist + if not self._validate_job_scripts(): + return DistributedResult( + total_nodes=0, + successful_executions=0, + failed_executions=1, + total_duration=0.0, + node_results=[], + ) + + # Submit environment setup job first + setup_job_id = self._submit_setup_job() + if setup_job_id: + self.logger.info(f"Submitted setup job: {setup_job_id}") + self.submitted_jobs.append(setup_job_id) + + # Submit main job array with dependency on setup job + main_job_id = self._submit_job_array(setup_job_id) + if not main_job_id: + return DistributedResult( + total_nodes=0, + successful_executions=0, + failed_executions=1, + total_duration=0.0, + node_results=[], + ) + + self.logger.info(f"Submitted main job array: {main_job_id}") + self.submitted_jobs.append(main_job_id) + + # Monitor job execution + results = self._monitor_job_execution([main_job_id], workload.timeout) + + # Create distributed result + distributed_result = DistributedResult( + total_nodes=len(results), + successful_executions=sum(1 for r in results if r.status == "SUCCESS"), + failed_executions=sum(1 for r in results if r.status != "SUCCESS"), + total_duration=max([r.duration for r in results], default=0.0), + node_results=results, + ) + + self.logger.info("SLURM distributed execution completed") + return distributed_result + + except Exception as e: + self.logger.error(f"SLURM distributed execution failed: {e}") + return DistributedResult( + total_nodes=0, + successful_executions=0, + failed_executions=1, + total_duration=0.0, + node_results=[], + ) + + def _submit_setup_job(self) -> Optional[str]: + """Submit environment setup job.""" + try: + workspace_path = self.slurm_config.get("workspace", {}).get("shared_filesystem", "/shared/madengine") + setup_script = f"{workspace_path}/job_scripts/setup_environment.sh" + + # Submit setup job + cmd = f"sbatch {setup_script}" + exit_code, stdout, stderr = self.slurm_connection.execute_command(cmd) + + if exit_code == 0: + # Extract job ID from sbatch output + job_id = stdout.strip().split()[-1] + return job_id + else: + self.logger.error(f"Failed to submit setup job: {stderr}") + return None + + except Exception as e: + self.logger.error(f"Setup job submission failed: {e}") + return None + + def _submit_job_array(self, dependency_job_id: Optional[str] = None) -> Optional[str]: + """Submit main job array.""" + try: + workspace_path = self.slurm_config.get("workspace", {}).get("shared_filesystem", "/shared/madengine") + job_array_script = f"{workspace_path}/job_scripts/madengine_job_array.sh" + + # Build sbatch command + cmd = "sbatch" + if dependency_job_id: + cmd += f" --dependency=afterok:{dependency_job_id}" + cmd += f" {job_array_script}" + + # Submit job array + exit_code, stdout, stderr = self.slurm_connection.execute_command(cmd) + + if exit_code == 0: + # Extract job ID from sbatch output + job_id = stdout.strip().split()[-1] + return job_id + else: + self.logger.error(f"Failed to submit job array: {stderr}") + return None + + except Exception as e: + self.logger.error(f"Job array submission failed: {e}") + return None + + def _monitor_job_execution(self, job_ids: List[str], timeout: int) -> List[ExecutionResult]: + """Monitor SLURM job execution until completion.""" + results = [] + start_time = time.time() + + self.logger.info(f"Monitoring SLURM jobs: {job_ids}") + + while job_ids and (time.time() - start_time) < timeout: + completed_jobs = [] + + for job_id in job_ids: + try: + # Check job status + status = self._get_job_status(job_id) + + if status in ["COMPLETED", "FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL"]: + # Job completed, collect results + job_results = self._collect_job_results(job_id, status) + results.extend(job_results) + completed_jobs.append(job_id) + + self.logger.info(f"Job {job_id} completed with status: {status}") + + except Exception as e: + self.logger.error(f"Error checking job {job_id}: {e}") + # Create failed result + result = ExecutionResult( + node_id=job_id, + model_tag="unknown", + status="FAILURE", + duration=time.time() - start_time, + error_message=str(e), + ) + results.append(result) + completed_jobs.append(job_id) + + # Remove completed jobs + for job_id in completed_jobs: + job_ids.remove(job_id) + + if job_ids: + time.sleep(30) # Check every 30 seconds + + # Handle timeout for remaining jobs + for job_id in job_ids: + result = ExecutionResult( + node_id=job_id, + model_tag="timeout", + status="TIMEOUT", + duration=timeout, + error_message=f"Job monitoring timed out after {timeout} seconds", + ) + results.append(result) + + return results + + def _get_job_status(self, job_id: str) -> str: + """Get SLURM job status.""" + try: + cmd = f"squeue -j {job_id} -h -o '%T'" + exit_code, stdout, stderr = self.slurm_connection.execute_command(cmd) + + if exit_code == 0 and stdout.strip(): + return stdout.strip() + else: + # Job not in queue, check if completed + cmd = f"sacct -j {job_id} -n -o 'State' | head -1" + exit_code, stdout, stderr = self.slurm_connection.execute_command(cmd) + + if exit_code == 0 and stdout.strip(): + return stdout.strip() + else: + return "UNKNOWN" + + except Exception as e: + self.logger.error(f"Failed to get job status for {job_id}: {e}") + return "ERROR" + + def _collect_job_results(self, job_id: str, status: str) -> List[ExecutionResult]: + """Collect results from completed SLURM job.""" + results = [] + + try: + # For job arrays, get results for each array task + if "_" in job_id: # Job array format: jobid_arrayindex + # This is a single array task + result = self._get_single_job_result(job_id, status) + results.append(result) + else: + # This is a job array, get results for all tasks + cmd = f"sacct -j {job_id} -n -o 'JobID,State,ExitCode' | grep '{job_id}_'" + exit_code, stdout, stderr = self.slurm_connection.execute_command(cmd) + + if exit_code == 0: + for line in stdout.strip().split('\n'): + if line.strip(): + parts = line.strip().split() + array_job_id = parts[0] + array_status = parts[1] + + result = self._get_single_job_result(array_job_id, array_status) + results.append(result) + else: + # Fallback: create single result + result = self._get_single_job_result(job_id, status) + results.append(result) + + except Exception as e: + self.logger.error(f"Failed to collect results for job {job_id}: {e}") + result = ExecutionResult( + node_id=job_id, + model_tag="error", + status="FAILURE", + duration=0.0, + error_message=str(e), + ) + results.append(result) + + return results + + def _get_single_job_result(self, job_id: str, status: str) -> ExecutionResult: + """Get result for a single SLURM job.""" + try: + # Get job details + cmd = f"sacct -j {job_id} -n -o 'JobName,State,ExitCode,Elapsed,NodeList'" + exit_code, stdout, stderr = self.slurm_connection.execute_command(cmd) + + job_name = "unknown" + elapsed_time = 0.0 + node_list = "unknown" + exit_code_val = "0:0" + + if exit_code == 0 and stdout.strip(): + parts = stdout.strip().split() + if len(parts) >= 5: + job_name = parts[0] + exit_code_val = parts[2] + elapsed_str = parts[3] + node_list = parts[4] + + # Parse elapsed time (format: HH:MM:SS or MM:SS) + time_parts = elapsed_str.split(':') + if len(time_parts) == 3: + elapsed_time = int(time_parts[0]) * 3600 + int(time_parts[1]) * 60 + int(time_parts[2]) + elif len(time_parts) == 2: + elapsed_time = int(time_parts[0]) * 60 + int(time_parts[1]) + + # Extract model tag from job name + model_tag = job_name.replace("madengine-", "").replace("-", "_") + if not model_tag or model_tag == "unknown": + model_tag = f"task_{job_id.split('_')[-1] if '_' in job_id else '0'}" + + # Determine success based on SLURM status and exit code + success = status == "COMPLETED" and exit_code_val.startswith("0:") + + return ExecutionResult( + node_id=node_list, + model_tag=model_tag, + status="SUCCESS" if success else "FAILURE", + duration=elapsed_time, + performance_metrics={"slurm_job_id": job_id, "slurm_status": status}, + error_message=None if success else f"SLURM status: {status}, Exit code: {exit_code_val}", + ) + + except Exception as e: + self.logger.error(f"Failed to get job result for {job_id}: {e}") + return ExecutionResult( + node_id=job_id, + model_tag="error", + status="FAILURE", + duration=0.0, + error_message=str(e), + ) + + def cleanup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Cleanup SLURM infrastructure after execution. + + Args: + workload: Workload specification + + Returns: + True if cleanup successful, False otherwise + """ + try: + self.logger.info("Cleaning up SLURM infrastructure") + + # Cancel any remaining/running jobs + for job_id in self.submitted_jobs: + try: + cmd = f"scancel {job_id}" + self.slurm_connection.execute_command(cmd) + self.logger.info(f"Cancelled SLURM job: {job_id}") + except Exception as e: + self.logger.warning(f"Failed to cancel job {job_id}: {e}") + + # Run custom cleanup handlers + for cleanup_handler in self.cleanup_handlers: + try: + cleanup_handler() + except Exception as e: + self.logger.warning(f"Cleanup handler failed: {e}") + + # Close SLURM connection + if self.slurm_connection: + self.slurm_connection.close() + self.slurm_connection = None + + self.logger.info("SLURM infrastructure cleanup completed") + return True + + except Exception as e: + self.logger.error(f"SLURM cleanup failed: {e}") + return False + + def add_cleanup_handler(self, handler: callable): + """Add a cleanup handler to be called during cleanup.""" + self.cleanup_handlers.append(handler) + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit with cleanup.""" + self.cleanup_infrastructure(None) \ No newline at end of file diff --git a/src/madengine/runners/ssh_runner.py b/src/madengine/runners/ssh_runner.py new file mode 100644 index 00000000..6abcd448 --- /dev/null +++ b/src/madengine/runners/ssh_runner.py @@ -0,0 +1,935 @@ +#!/usr/bin/env python3 +""" +SSH Distributed Runner for MADEngine + +This module implements SSH-based distributed execution using paramiko +for secure remote execution across multiple nodes. +""" + +import json +import logging +import os +import time +import contextlib +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, Dict, Any, List, Tuple +from dataclasses import dataclass + +try: + import paramiko + from scp import SCPClient +except ImportError: + raise ImportError( + "SSH runner requires paramiko and scp. Install with: pip install paramiko scp" + ) + +from madengine.runners.base import ( + BaseDistributedRunner, + NodeConfig, + WorkloadSpec, + ExecutionResult, + DistributedResult, +) +from madengine.core.errors import ( + ConnectionError as MADConnectionError, + AuthenticationError, + TimeoutError as MADTimeoutError, + RunnerError, + create_error_context +) + + +# Legacy error classes - use unified error system instead +# Kept for backward compatibility but deprecated + +@dataclass +class SSHConnectionError(MADConnectionError): + """Deprecated: Use MADConnectionError instead.""" + + hostname: str + error_type: str + message: str + + def __init__(self, hostname: str, error_type: str, message: str): + self.hostname = hostname + self.error_type = error_type + self.message = message + context = create_error_context( + operation="ssh_connection", + component="SSHRunner", + node_id=hostname, + additional_info={"error_type": error_type} + ) + super().__init__(f"SSH {error_type} error on {hostname}: {message}", context=context) + + +class TimeoutError(MADTimeoutError): + """Deprecated: Use MADTimeoutError instead.""" + + def __init__(self, message: str, **kwargs): + context = create_error_context(operation="ssh_execution", component="SSHRunner") + super().__init__(message, context=context, **kwargs) + + +@contextlib.contextmanager +def timeout_context(seconds: int): + """Context manager for handling timeouts.""" + + def signal_handler(signum, frame): + raise TimeoutError(f"Operation timed out after {seconds} seconds") + + old_handler = signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(seconds) + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +class SSHConnection: + """Manages SSH connection to a single node with enhanced error handling.""" + + def __init__(self, node: NodeConfig, timeout: int = 30): + """Initialize SSH connection. + + Args: + node: Node configuration + timeout: Connection timeout in seconds + """ + self.node = node + self.timeout = timeout + self.ssh_client = None + self.sftp_client = None + self.logger = logging.getLogger(f"SSHConnection.{node.hostname}") + self._connected = False + self._connection_attempts = 0 + self._max_connection_attempts = 3 + + def connect(self) -> bool: + """Establish SSH connection to node with retry logic. + + Returns: + True if connection successful, False otherwise + """ + for attempt in range(self._max_connection_attempts): + try: + self._connection_attempts = attempt + 1 + self.ssh_client = paramiko.SSHClient() + self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # Connection parameters + connect_params = { + "hostname": self.node.address, + "port": self.node.port, + "username": self.node.username, + "timeout": self.timeout, + } + + # Use SSH key if provided - expand path + if self.node.ssh_key_path: + expanded_key_path = os.path.expanduser(self.node.ssh_key_path) + if os.path.exists(expanded_key_path): + connect_params["key_filename"] = expanded_key_path + # Ensure proper permissions + os.chmod(expanded_key_path, 0o600) + else: + self.logger.warning( + f"SSH key file not found: {expanded_key_path}" + ) + + # Test connection with timeout + with timeout_context(self.timeout): + self.ssh_client.connect(**connect_params) + self.sftp_client = self.ssh_client.open_sftp() + + self._connected = True + self.logger.info(f"Successfully connected to {self.node.hostname}") + return True + + except TimeoutError: + self.logger.warning(f"Connection attempt {attempt + 1} timed out") + if attempt < self._max_connection_attempts - 1: + time.sleep(2**attempt) # Exponential backoff + continue + + except paramiko.AuthenticationException as e: + raise SSHConnectionError( + self.node.hostname, "authentication", f"Authentication failed: {e}" + ) + + except paramiko.SSHException as e: + self.logger.warning(f"SSH error on attempt {attempt + 1}: {e}") + if attempt < self._max_connection_attempts - 1: + time.sleep(2**attempt) # Exponential backoff + continue + + except Exception as e: + self.logger.error(f"Unexpected error on attempt {attempt + 1}: {e}") + if attempt < self._max_connection_attempts - 1: + time.sleep(2**attempt) # Exponential backoff + continue + + self.logger.error( + f"Failed to connect to {self.node.hostname} after {self._max_connection_attempts} attempts" + ) + return False + + def is_connected(self) -> bool: + """Check if connection is active.""" + return ( + self._connected + and self.ssh_client + and self.ssh_client.get_transport().is_active() + ) + + def close(self): + """Close SSH connection safely.""" + try: + if self.sftp_client: + self.sftp_client.close() + self.sftp_client = None + if self.ssh_client: + self.ssh_client.close() + self.ssh_client = None + self._connected = False + self.logger.debug(f"Closed connection to {self.node.hostname}") + except Exception as e: + self.logger.warning(f"Error closing connection: {e}") + + def __enter__(self): + """Context manager entry.""" + if not self.connect(): + raise SSHConnectionError( + self.node.hostname, "connection", "Failed to establish connection" + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() + + def execute_command(self, command: str, timeout: int = 300) -> tuple: + """Execute command on remote node with enhanced error handling. + + Args: + command: Command to execute + timeout: Command timeout in seconds + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + if not self.is_connected(): + raise SSHConnectionError( + self.node.hostname, "connection", "Connection not established" + ) + + try: + with timeout_context(timeout): + stdin, stdout, stderr = self.ssh_client.exec_command( + command, timeout=timeout + ) + + # Wait for command completion + exit_code = stdout.channel.recv_exit_status() + + stdout_str = stdout.read().decode("utf-8", errors="replace") + stderr_str = stderr.read().decode("utf-8", errors="replace") + + return exit_code, stdout_str, stderr_str + + except TimeoutError: + raise SSHConnectionError( + self.node.hostname, + "timeout", + f"Command timed out after {timeout} seconds: {command}", + ) + except Exception as e: + self.logger.error(f"Command execution failed: {e}") + return 1, "", str(e) + + def copy_file( + self, local_path: str, remote_path: str, create_dirs: bool = True + ) -> bool: + """Copy file to remote node with enhanced error handling. + + Args: + local_path: Local file path + remote_path: Remote file path + create_dirs: Whether to create remote directories + + Returns: + True if copy successful, False otherwise + """ + if not self.is_connected(): + raise SSHConnectionError( + self.node.hostname, "connection", "Connection not established" + ) + + try: + # Validate local file exists + if not os.path.exists(local_path): + raise FileNotFoundError(f"Local file not found: {local_path}") + + # Create directory if needed + if create_dirs: + remote_dir = os.path.dirname(remote_path) + if remote_dir: + self.execute_command(f"mkdir -p {remote_dir}") + + # Copy file + self.sftp_client.put(local_path, remote_path) + + # Set proper permissions + self.sftp_client.chmod(remote_path, 0o644) + + self.logger.debug(f"Successfully copied {local_path} to {remote_path}") + return True + + except Exception as e: + self.logger.error(f"File copy failed: {e}") + return False + + def copy_directory(self, local_path: str, remote_path: str) -> bool: + """Copy directory to remote node with enhanced error handling. + + Args: + local_path: Local directory path + remote_path: Remote directory path + + Returns: + True if copy successful, False otherwise + """ + if not self.is_connected(): + raise SSHConnectionError( + self.node.hostname, "connection", "Connection not established" + ) + + try: + # Validate local directory exists + if not os.path.exists(local_path): + raise FileNotFoundError(f"Local directory not found: {local_path}") + + # Use SCP for directory transfer + with SCPClient(self.ssh_client.get_transport()) as scp: + scp.put(local_path, remote_path, recursive=True) + + self.logger.debug( + f"Successfully copied directory {local_path} to {remote_path}" + ) + return True + + except Exception as e: + self.logger.error(f"Directory copy failed: {e}") + return False + + +class SSHDistributedRunner(BaseDistributedRunner): + """Distributed runner using SSH connections with enhanced error handling.""" + + def __init__(self, inventory_path: str, **kwargs): + """Initialize SSH distributed runner. + + Args: + inventory_path: Path to inventory configuration file + **kwargs: Additional arguments passed to base class + """ + super().__init__(inventory_path, **kwargs) + self.connections: Dict[str, SSHConnection] = {} + self.connection_pool: Optional[ThreadPoolExecutor] = None + self.cleanup_handlers: List[callable] = [] + + def _create_connection(self, node: NodeConfig) -> Optional[SSHConnection]: + """Create SSH connection to node with proper error handling. + + Args: + node: Node configuration + + Returns: + SSH connection instance or None if failed + """ + try: + connection = SSHConnection(node, timeout=30) + if connection.connect(): + self.connections[node.hostname] = connection + return connection + return None + except SSHConnectionError as e: + self.logger.error(f"SSH connection error: {e}") + return None + except Exception as e: + self.logger.error( + f"Unexpected error creating connection to {node.hostname}: {e}" + ) + return None + + def setup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Setup SSH infrastructure for distributed execution with enhanced error handling. + + Args: + workload: Workload specification + + Returns: + True if setup successful, False otherwise + """ + try: + self.logger.info("Setting up SSH infrastructure for distributed execution") + + # Filter nodes based on workload requirements + target_nodes = self.filter_nodes(workload.node_selector) + if not target_nodes: + self.logger.error("No nodes match the workload requirements") + return False + + # Create connection pool + self.connection_pool = ThreadPoolExecutor(max_workers=len(target_nodes)) + + # Setup connections and environment in parallel + setup_futures = [] + + for node in target_nodes: + future = self.connection_pool.submit(self._setup_node, node, workload) + setup_futures.append((node, future)) + + # Collect results + success_count = 0 + failed_nodes = [] + + for node, future in setup_futures: + try: + if future.result(timeout=600): # 10 minute timeout per node + success_count += 1 + else: + failed_nodes.append(node.hostname) + except Exception as e: + self.logger.error(f"Setup failed for {node.hostname}: {e}") + failed_nodes.append(node.hostname) + + if failed_nodes: + self.logger.warning(f"Failed to setup nodes: {failed_nodes}") + + if success_count == 0: + self.logger.error("Failed to setup any nodes") + return False + + self.logger.info( + f"Successfully setup infrastructure on {success_count} nodes" + ) + return True + + except Exception as e: + self.logger.error(f"Infrastructure setup failed: {e}") + return False + + def _setup_node(self, node: NodeConfig, workload: WorkloadSpec) -> bool: + """Setup a single node for execution - simplified to focus on manifest distribution.""" + try: + # Create connection + connection = self._create_connection(node) + if not connection: + return False + + # Setup MAD environment (clone/update repository and install) + if not self._setup_mad_environment(connection, node.hostname): + return False + + # Copy build manifest - this is the key file we need + if not self._copy_build_manifest(connection, workload.manifest_file): + self.logger.error(f"Failed to copy manifest to {node.hostname}") + return False + + # Copy any supporting files that might be needed (credential.json, data.json, etc.) + if not self._copy_supporting_files(connection): + self.logger.warning( + f"Failed to copy some supporting files to {node.hostname}" + ) + # Don't fail for supporting files, just warn + + return True + + except Exception as e: + self.logger.error(f"Node setup failed for {node.hostname}: {e}") + return False + + def _copy_supporting_files(self, connection: SSHConnection) -> bool: + """Copy supporting files that might be needed for execution.""" + supporting_files = ["credential.json", "data.json", "models.json"] + success = True + + for file_name in supporting_files: + if os.path.exists(file_name): + try: + remote_path = f"MAD/{file_name}" + if not connection.copy_file(file_name, remote_path): + self.logger.warning(f"Failed to copy {file_name}") + success = False + except Exception as e: + self.logger.warning(f"Error copying {file_name}: {e}") + success = False + + return success + + def _setup_mad_environment(self, connection: SSHConnection, hostname: str) -> bool: + """Setup MAD repository and madengine-cli on a remote node with retry logic.""" + self.logger.info(f"Setting up MAD environment on {hostname}") + + max_retries = 3 + + # Enhanced setup commands for madengine-cli + setup_commands = [ + # Clone or update MAD repository + ( + "if [ -d MAD ]; then cd MAD && git pull origin main; " + "else git clone https://github.com/ROCm/MAD.git; fi" + ), + # Setup Python environment and install madengine + "cd MAD", + "python3 -m venv venv || true", + "source venv/bin/activate", + # Install dependencies and madengine + "pip install --upgrade pip", + "pip install -r requirements.txt", + "pip install -e .", + # Verify madengine-cli is installed and working + "which madengine-cli", + "madengine-cli --help > /dev/null", + ] + + for attempt in range(max_retries): + try: + for i, command in enumerate(setup_commands): + self.logger.debug( + f"Executing setup command {i+1}/{len(setup_commands)} on {hostname}" + ) + exit_code, stdout, stderr = connection.execute_command( + command, timeout=300 + ) + if exit_code != 0: + self.logger.warning( + f"MAD setup command failed on attempt {attempt + 1} " + f"on {hostname}: {command}\nStderr: {stderr}" + ) + if attempt == max_retries - 1: + self.logger.error( + f"Failed to setup MAD environment on {hostname} " + f"after {max_retries} attempts" + ) + return False + break + else: + # All commands succeeded + self.logger.info( + f"Successfully set up MAD environment on {hostname}" + ) + return True + + except SSHConnectionError as e: + self.logger.warning(f"SSH error during MAD setup on {hostname}: {e}") + if attempt == max_retries - 1: + return False + time.sleep(2**attempt) # Exponential backoff + + except Exception as e: + self.logger.warning( + f"MAD setup attempt {attempt + 1} exception on " f"{hostname}: {e}" + ) + if attempt == max_retries - 1: + self.logger.error( + f"Failed to setup MAD environment on {hostname} " + f"after {max_retries} attempts" + ) + return False + time.sleep(2**attempt) # Exponential backoff + + return False + + def _copy_build_manifest( + self, connection: SSHConnection, manifest_file: str + ) -> bool: + """Copy build manifest to remote node with error handling.""" + try: + if not manifest_file or not os.path.exists(manifest_file): + self.logger.error(f"Build manifest file not found: {manifest_file}") + return False + + remote_path = "MAD/build_manifest.json" + success = connection.copy_file(manifest_file, remote_path) + + if success: + self.logger.info( + f"Successfully copied build manifest to {connection.node.hostname}" + ) + + return success + + except Exception as e: + self.logger.error(f"Failed to copy build manifest: {e}") + return False + + def execute_workload(self, workload: WorkloadSpec) -> DistributedResult: + """Execute workload across distributed nodes using build manifest. + + This method distributes the pre-built manifest to remote nodes and + executes 'madengine-cli run' on each node. + + Args: + workload: Workload specification containing manifest file path + + Returns: + Distributed execution result + """ + try: + self.logger.info("Starting SSH distributed execution using build manifest") + + # Validate manifest file exists + if not workload.manifest_file or not os.path.exists(workload.manifest_file): + return DistributedResult( + success=False, + node_results=[], + error_message=f"Build manifest file not found: {workload.manifest_file}", + ) + + # Load manifest to get model tags and configuration + try: + with open(workload.manifest_file, "r") as f: + manifest_data = json.load(f) + + # Extract model tags from manifest + model_tags = [] + if "models" in manifest_data: + model_tags = list(manifest_data["models"].keys()) + elif "model_tags" in manifest_data: + model_tags = manifest_data["model_tags"] + + if not model_tags: + self.logger.warning("No model tags found in manifest") + model_tags = ["dummy"] # fallback + + except Exception as e: + return DistributedResult( + success=False, + node_results=[], + error_message=f"Failed to parse manifest: {e}", + ) + + # Get target nodes + target_nodes = self.filter_nodes(workload.node_selector) + if not target_nodes: + return DistributedResult( + success=False, + node_results=[], + error_message="No nodes match the workload requirements", + ) + + # Setup infrastructure + if not self.setup_infrastructure(workload): + return DistributedResult( + success=False, + node_results=[], + error_message="Failed to setup SSH infrastructure", + ) + + # Execute in parallel across nodes and models + execution_futures = [] + + for node in target_nodes: + # Execute all models on this node (or distribute models across nodes) + future = self.connection_pool.submit( + self._execute_models_on_node_safe, node, model_tags, workload + ) + execution_futures.append((node, future)) + + # Collect results + results = [] + + for node, future in execution_futures: + try: + node_results = future.result( + timeout=workload.timeout + 120 + ) # Extra buffer + results.extend(node_results) + except Exception as e: + self.logger.error(f"Execution failed on {node.hostname}: {e}") + # Create failed result for all models on this node + for model_tag in model_tags: + failed_result = ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + ) + results.append(failed_result) + + # Aggregate results + distributed_result = DistributedResult( + success=any(r.success for r in results), node_results=results + ) + + self.logger.info("SSH distributed execution completed") + return distributed_result + + except Exception as e: + self.logger.error(f"Distributed execution failed: {e}") + return DistributedResult( + success=False, node_results=[], error_message=str(e) + ) + + def _execute_models_on_node_safe( + self, node: NodeConfig, model_tags: List[str], workload: WorkloadSpec + ) -> List[ExecutionResult]: + """Execute all models on a specific node with comprehensive error handling.""" + try: + return self._execute_models_on_node(node, model_tags, workload) + except Exception as e: + self.logger.error(f"Models execution failed on {node.hostname}: {e}") + # Return failed results for all models + results = [] + for model_tag in model_tags: + results.append( + ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + ) + ) + return results + + def _execute_models_on_node( + self, node: NodeConfig, model_tags: List[str], workload: WorkloadSpec + ) -> List[ExecutionResult]: + """Execute models on a specific node using 'madengine-cli run'.""" + results = [] + + try: + connection = self.connections.get(node.hostname) + if not connection or not connection.is_connected(): + raise SSHConnectionError( + node.hostname, "connection", "Connection not available" + ) + + # Execute madengine-cli run with the manifest + start_time = time.time() + + # Build command to run madengine-cli with the manifest + command = self._build_execution_command(workload) + + self.logger.info(f"Executing on {node.hostname}: {command}") + + exit_code, stdout, stderr = connection.execute_command( + command, timeout=workload.timeout + ) + + execution_time = time.time() - start_time + + # Parse output to extract per-model results + # For now, create results for all models with the same status + for model_tag in model_tags: + result = ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=(exit_code == 0), + output=stdout, + error_message=stderr if exit_code != 0 else None, + execution_time=execution_time + / len(model_tags), # Distribute time across models + ) + results.append(result) + + if exit_code == 0: + self.logger.info( + f"Successfully executed {model_tag} on {node.hostname}" + ) + else: + self.logger.warning( + f"Execution failed for {model_tag} on {node.hostname}" + ) + + return results + + except SSHConnectionError as e: + # Return failed results for all models + for model_tag in model_tags: + results.append( + ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + execution_time=0, + ) + ) + return results + except Exception as e: + # Return failed results for all models + for model_tag in model_tags: + results.append( + ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + execution_time=0, + ) + ) + return results + + def _build_execution_command(self, workload: WorkloadSpec) -> str: + """Build the madengine-cli run command with the manifest file. + + Args: + workload: Workload specification containing manifest file + + Returns: + Command string to execute on remote node + """ + # The basic command structure + cmd_parts = [ + "cd MAD", + "source venv/bin/activate", + f"madengine-cli run --manifest-file build_manifest.json", + ] + + # Add timeout if specified (and not default) + if workload.timeout and workload.timeout > 0 and workload.timeout != 3600: + cmd_parts[-1] += f" --timeout {workload.timeout}" + + # Add registry if specified + if workload.registry: + cmd_parts[-1] += f" --registry {workload.registry}" + + # Add live output for better monitoring + cmd_parts[-1] += " --live-output" + + # Combine all commands + return " && ".join(cmd_parts) + + def _execute_model_on_node_safe( + self, node: NodeConfig, model_tag: str, workload: WorkloadSpec + ) -> ExecutionResult: + """Execute a model on a specific node with comprehensive error handling.""" + try: + return self._execute_model_on_node(node, model_tag, workload) + except Exception as e: + self.logger.error(f"Model execution failed on {node.hostname}: {e}") + return ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + ) + + def _execute_model_on_node( + self, node: NodeConfig, model_tag: str, workload: WorkloadSpec + ) -> ExecutionResult: + """Execute a model on a specific node with timeout and error handling.""" + start_time = time.time() + + try: + connection = self.connections.get(node.hostname) + if not connection or not connection.is_connected(): + raise SSHConnectionError( + node.hostname, "connection", "Connection not available" + ) + + # Build and execute command + command = self._build_execution_command(node, model_tag, workload) + + exit_code, stdout, stderr = connection.execute_command( + command, timeout=workload.timeout + ) + + execution_time = time.time() - start_time + + # Create execution result + result = ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=(exit_code == 0), + output=stdout, + error_message=stderr if exit_code != 0 else None, + execution_time=execution_time, + ) + + if exit_code == 0: + self.logger.info( + f"Successfully executed {model_tag} on {node.hostname}" + ) + else: + self.logger.warning( + f"Execution failed for {model_tag} on {node.hostname}" + ) + + return result + + except SSHConnectionError as e: + return ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + execution_time=time.time() - start_time, + ) + except Exception as e: + return ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + success=False, + error_message=str(e), + execution_time=time.time() - start_time, + ) + + def cleanup_infrastructure(self, workload: WorkloadSpec) -> bool: + """Cleanup infrastructure after execution with comprehensive cleanup. + + Args: + workload: Workload specification + + Returns: + True if cleanup successful, False otherwise + """ + try: + self.logger.info("Cleaning up SSH infrastructure") + + # Run custom cleanup handlers + for cleanup_handler in self.cleanup_handlers: + try: + cleanup_handler() + except Exception as e: + self.logger.warning(f"Cleanup handler failed: {e}") + + # Close all connections + for hostname, connection in self.connections.items(): + try: + connection.close() + except Exception as e: + self.logger.warning(f"Error closing connection to {hostname}: {e}") + + self.connections.clear() + + # Shutdown connection pool + if self.connection_pool: + self.connection_pool.shutdown(wait=True) + self.connection_pool = None + + self.logger.info("SSH infrastructure cleanup completed") + return True + + except Exception as e: + self.logger.error(f"Cleanup failed: {e}") + return False + + def add_cleanup_handler(self, handler: callable): + """Add a cleanup handler to be called during cleanup.""" + self.cleanup_handlers.append(handler) + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit with cleanup.""" + self.cleanup_infrastructure(None) + + # ...existing methods remain the same... diff --git a/src/madengine/runners/template_generator.py b/src/madengine/runners/template_generator.py new file mode 100644 index 00000000..63985bef --- /dev/null +++ b/src/madengine/runners/template_generator.py @@ -0,0 +1,461 @@ +"""Template generator for MADEngine distributed execution. + +This module provides Jinja2-based template generation for Ansible playbooks +and Kubernetes manifests, supporting environment-specific configurations. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import os +import json +import yaml +from typing import Dict, Any, Optional, List +from pathlib import Path +from jinja2 import Environment, FileSystemLoader, select_autoescape +from datetime import datetime + + +class TemplateGenerator: + """Template generator for distributed execution configurations.""" + + def __init__( + self, template_dir: Optional[str] = None, values_dir: Optional[str] = None + ): + """Initialize the template generator. + + Args: + template_dir: Path to template directory (defaults to runners/templates) + values_dir: Path to values directory (defaults to runners/values) + """ + self.base_dir = Path(__file__).parent + self.template_dir = ( + Path(template_dir) if template_dir else self.base_dir / "templates" + ) + self.values_dir = Path(values_dir) if values_dir else self.base_dir / "values" + + # Initialize Jinja2 environment + self.env = Environment( + loader=FileSystemLoader(str(self.template_dir)), + autoescape=select_autoescape(["html", "xml"]), + trim_blocks=True, + lstrip_blocks=True, + ) + + # Add custom filters + self.env.filters["to_yaml"] = self._to_yaml_filter + self.env.filters["to_json"] = self._to_json_filter + self.env.filters["basename"] = lambda x: os.path.basename(x) + self.env.filters["timestamp"] = lambda x: datetime.now().strftime( + "%Y%m%d_%H%M%S" + ) + + def _to_yaml_filter(self, value: Any) -> str: + """Convert value to YAML format.""" + return yaml.dump(value, default_flow_style=False) + + def _to_json_filter(self, value: Any) -> str: + """Convert value to JSON format.""" + return json.dumps(value, indent=2) + + def load_values(self, environment: str = "default") -> Dict[str, Any]: + """Load values from environment-specific YAML file. + + Args: + environment: Environment name (default, dev, prod, test) + + Returns: + dict: Loaded values + """ + values_file = self.values_dir / f"{environment}.yaml" + if not values_file.exists(): + raise FileNotFoundError(f"Values file not found: {values_file}") + + with open(values_file, "r") as f: + return yaml.safe_load(f) or {} + + def merge_values( + self, base_values: Dict[str, Any], manifest_data: Dict[str, Any] + ) -> Dict[str, Any]: + """Merge base values with manifest data. + + Args: + base_values: Base values from environment file + manifest_data: Data from build manifest + + Returns: + dict: Merged values + """ + merged = base_values.copy() + + # Extract relevant data from manifest + manifest_values = { + "manifest": manifest_data, + "images": manifest_data.get("built_images", {}), + "models": manifest_data.get("built_models", {}), + "context": manifest_data.get("context", {}), + "registry": manifest_data.get("registry", ""), + "build_timestamp": manifest_data.get("build_timestamp", ""), + "gpu_vendor": manifest_data.get("context", {}).get("gpu_vendor", ""), + "docker_build_args": manifest_data.get("context", {}).get( + "docker_build_arg", {} + ), + "docker_env_vars": manifest_data.get("context", {}).get( + "docker_env_vars", {} + ), + "docker_mounts": manifest_data.get("context", {}).get("docker_mounts", {}), + "docker_gpus": manifest_data.get("context", {}).get("docker_gpus", ""), + } + + # Deep merge the values + merged.update(manifest_values) + + # Add generation metadata + merged["generation"] = { + "timestamp": datetime.now().isoformat(), + "generator": "MADEngine Template Generator", + "version": "1.0.0", + } + + return merged + + def generate_ansible_playbook( + self, + manifest_file: str, + environment: str = "default", + output_file: str = "madengine_distributed.yml", + ) -> str: + """Generate Ansible playbook from template. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_file: Output playbook file path + + Returns: + str: Generated playbook content + """ + # Load manifest data + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Load and merge values + base_values = self.load_values(environment) + values = self.merge_values(base_values, manifest_data) + + # Load template + template = self.env.get_template("ansible/playbook.yml.j2") + + # Generate content + content = template.render(**values) + + # Write to file + with open(output_file, "w") as f: + f.write(content) + + return content + + def generate_kubernetes_manifests( + self, + manifest_file: str, + environment: str = "default", + output_dir: str = "k8s-manifests", + ) -> List[str]: + """Generate Kubernetes manifests from templates. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_dir: Output directory for manifests + + Returns: + list: List of generated manifest files + """ + # Load manifest data + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Load and merge values + base_values = self.load_values(environment) + values = self.merge_values(base_values, manifest_data) + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + generated_files = [] + + # Generate each manifest type + manifest_types = ["namespace", "configmap", "job", "service"] + + for manifest_type in manifest_types: + template_file = f"k8s/{manifest_type}.yaml.j2" + + try: + template = self.env.get_template(template_file) + content = template.render(**values) + + output_file = os.path.join(output_dir, f"{manifest_type}.yaml") + with open(output_file, "w") as f: + f.write(content) + + generated_files.append(output_file) + + except Exception as e: + print(f"Warning: Could not generate {manifest_type}.yaml: {e}") + + return generated_files + + def generate_slurm_job_array( + self, + manifest_file: str, + environment: str = "default", + output_file: str = "madengine_job_array.sh", + ) -> str: + """Generate SLURM job array script from template. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_file: Output job script file path + + Returns: + str: Generated job script content + """ + # Load manifest data + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Load and merge values + base_values = self.load_values(environment) + values = self.merge_values(base_values, manifest_data) + + # Extract model tags from manifest for job array + model_tags = [] + if "models" in manifest_data: + model_tags = list(manifest_data["models"].keys()) + elif "built_models" in manifest_data: + model_tags = list(manifest_data["built_models"].keys()) + elif "model_tags" in manifest_data: + model_tags = manifest_data["model_tags"] + + values["model_tags"] = model_tags + + # Load template + template = self.env.get_template("slurm/job_array.sh.j2") + + # Generate content + content = template.render(**values) + + # Write to file + with open(output_file, "w") as f: + f.write(content) + + # Make script executable + os.chmod(output_file, 0o755) + + return content + + def generate_slurm_single_job( + self, + manifest_file: str, + model_tag: str, + environment: str = "default", + output_file: str = None, + ) -> str: + """Generate SLURM single job script from template. + + Args: + manifest_file: Path to build manifest JSON file + model_tag: Specific model tag for this job + environment: Environment name for values + output_file: Output job script file path + + Returns: + str: Generated job script content + """ + if output_file is None: + safe_tag = model_tag.replace(":", "-").replace("_", "-") + output_file = f"madengine_{safe_tag}.sh" + + # Load manifest data + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Load and merge values + base_values = self.load_values(environment) + values = self.merge_values(base_values, manifest_data) + + # Add specific model tag + values["model_tag"] = model_tag + + # Load template + template = self.env.get_template("slurm/single_job.sh.j2") + + # Generate content + content = template.render(**values) + + # Write to file + with open(output_file, "w") as f: + f.write(content) + + # Make script executable + os.chmod(output_file, 0o755) + + return content + + def generate_slurm_setup_script( + self, + manifest_file: str, + environment: str = "default", + output_file: str = "setup_environment.sh", + ) -> str: + """Generate SLURM environment setup script from template. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_file: Output setup script file path + + Returns: + str: Generated setup script content + """ + # Load manifest data + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Load and merge values + base_values = self.load_values(environment) + values = self.merge_values(base_values, manifest_data) + + # Add config files that should be copied + config_files = [] + for file_name in ["credential.json", "data.json", "models.json"]: + if os.path.exists(file_name): + config_files.append(file_name) + values["config_files"] = config_files + + # Load template + template = self.env.get_template("slurm/setup_environment.sh.j2") + + # Generate content + content = template.render(**values) + + # Write to file + with open(output_file, "w") as f: + f.write(content) + + # Make script executable + os.chmod(output_file, 0o755) + + return content + + def generate_slurm_inventory( + self, + manifest_file: str, + environment: str = "default", + output_file: str = "inventory.yml", + ) -> str: + """Generate SLURM inventory file from template. + + Args: + manifest_file: Path to build manifest JSON file + environment: Environment name for values + output_file: Output inventory file path + + Returns: + str: Generated inventory content + """ + # Load manifest data + with open(manifest_file, "r") as f: + manifest_data = json.load(f) + + # Load and merge values + base_values = self.load_values(environment) + values = self.merge_values(base_values, manifest_data) + + # Load template + template = self.env.get_template("slurm/inventory.yml.j2") + + # Generate content + content = template.render(**values) + + # Write to file + with open(output_file, "w") as f: + f.write(content) + + return content + + def list_templates(self) -> Dict[str, List[str]]: + """List available templates. + + Returns: + dict: Dictionary of template types and their files + """ + templates = {} + + for template_type in ["ansible", "k8s", "slurm"]: + template_path = self.template_dir / template_type + if template_path.exists(): + templates[template_type] = [ + f.name + for f in template_path.iterdir() + if f.is_file() and f.suffix == ".j2" + ] + + return templates + + def validate_template(self, template_path: str) -> bool: + """Validate template syntax. + + Args: + template_path: Path to template file + + Returns: + bool: True if template is valid + """ + try: + template = self.env.get_template(template_path) + # Try to render with minimal context + template.render() + return True + except Exception as e: + print(f"Template validation failed: {e}") + return False + + +# Convenience functions for backward compatibility +def create_ansible_playbook( + manifest_file: str = "build_manifest.json", + environment: str = "default", + playbook_file: str = "madengine_distributed.yml", +) -> None: + """Create an Ansible playbook for distributed execution. + + Args: + manifest_file: Build manifest file + environment: Environment name for values + playbook_file: Output Ansible playbook file + """ + generator = TemplateGenerator() + generator.generate_ansible_playbook(manifest_file, environment, playbook_file) + print(f"Ansible playbook created: {playbook_file}") + + +def create_kubernetes_manifests( + manifest_file: str = "build_manifest.json", + environment: str = "default", + output_dir: str = "k8s-manifests", +) -> None: + """Create Kubernetes manifests for distributed execution. + + Args: + manifest_file: Build manifest file + environment: Environment name for values + output_dir: Output directory for manifests + """ + generator = TemplateGenerator() + generated_files = generator.generate_kubernetes_manifests( + manifest_file, environment, output_dir + ) + print(f"Kubernetes manifests created in {output_dir}:") + for file in generated_files: + print(f" - {file}") diff --git a/src/madengine/runners/templates/ansible/playbook.yml.j2 b/src/madengine/runners/templates/ansible/playbook.yml.j2 new file mode 100644 index 00000000..5454637a --- /dev/null +++ b/src/madengine/runners/templates/ansible/playbook.yml.j2 @@ -0,0 +1,189 @@ +--- +# MADEngine Distributed Execution Playbook +# Generated on: {{ generation.timestamp }} +# Environment: {{ environment | default('default') }} +# Manifest: {{ manifest_file | default('build_manifest.json') }} + +- name: MADEngine Distributed Model Execution + hosts: {{ ansible.target_hosts | default('gpu_nodes') }} + become: {{ ansible.become | default(true) }} + vars: + madengine_workspace: "{{ workspace.path | default('/tmp/madengine_distributed') }}" + manifest_file: "{{ manifest_file | default('build_manifest.json') }}" + registry: "{{ registry | default('') }}" + gpu_vendor: "{{ gpu_vendor | default('') }}" + timeout: {{ execution.timeout | default(7200) }} + + tasks: + - name: Create MADEngine workspace + file: + path: "{{ madengine_workspace }}" + state: directory + mode: '0755' + owner: "{{ workspace.owner | default('root') }}" + group: "{{ workspace.group | default('root') }}" + + - name: Copy build manifest to nodes + copy: + src: "{{ manifest_file }}" + dest: "{{ madengine_workspace }}/{{ manifest_file }}" + mode: '0644' + + {% if credentials %} + - name: Copy credentials to nodes + copy: + src: "{{ credentials.file | default('credential.json') }}" + dest: "{{ madengine_workspace }}/credential.json" + mode: '0600' + when: credentials.required | default(false) + {% endif %} + + {% if data_config %} + - name: Copy data configuration to nodes + copy: + src: "{{ data_config.file | default('data.json') }}" + dest: "{{ madengine_workspace }}/data.json" + mode: '0644' + when: data_config.required | default(false) + {% endif %} + + {% if registry %} + - name: Login to Docker registry + docker_login: + registry: "{{ registry }}" + username: "{{ docker_registry.username | default('') }}" + password: "{{ docker_registry.password | default('') }}" + when: docker_registry.login_required | default(false) + {% endif %} + + - name: Pull Docker images from registry + shell: | + cd {{ madengine_workspace }} + python3 -c " + import json + import subprocess + import sys + + try: + with open('{{ manifest_file }}', 'r') as f: + manifest = json.load(f) + + pulled_images = [] + for image_name, build_info in manifest.get('built_images', {}).items(): + if 'registry_image' in build_info: + registry_image = build_info['registry_image'] + docker_image = build_info['docker_image'] + + print(f'Pulling {registry_image}') + result = subprocess.run(['docker', 'pull', registry_image], + capture_output=True, text=True) + if result.returncode == 0: + print(f'Successfully pulled {registry_image}') + + # Tag the image + subprocess.run(['docker', 'tag', registry_image, docker_image], + check=True) + print(f'Tagged as {docker_image}') + pulled_images.append(image_name) + else: + print(f'Failed to pull {registry_image}: {result.stderr}') + + print(f'Successfully pulled {len(pulled_images)} images') + + except Exception as e: + print(f'Error pulling images: {e}') + sys.exit(1) + " + register: pull_result + when: registry != "" + + - name: Display image pull results + debug: + var: pull_result.stdout_lines + when: pull_result is defined + + - name: Install MADEngine dependencies + pip: + name: "{{ item }}" + state: present + loop: {{ python_dependencies | default(['jinja2', 'pyyaml']) | to_yaml }} + when: install_dependencies | default(false) + + - name: Create execution script + template: + src: execution_script.py.j2 + dest: "{{ madengine_workspace }}/execute_models.py" + mode: '0755' + + - name: Run MADEngine model execution + shell: | + cd {{ madengine_workspace }} + python3 execute_models.py + register: execution_results + async: {{ execution.async_timeout | default(14400) }} + poll: {{ execution.poll_interval | default(30) }} + environment: + PYTHONPATH: "{{ python_path | default('/usr/local/lib/python3.8/site-packages') }}" + {% for key, value in docker_env_vars.items() %} + {{ key }}: "{{ value }}" + {% endfor %} + + - name: Create execution results summary + copy: + content: | + # MADEngine Execution Results + ## Execution Summary + + **Timestamp:** {{ generation.timestamp }} + **Node:** {{ '{{ inventory_hostname }}' }} + **Environment:** {{ environment | default('default') }} + **Registry:** {{ registry | default('local') }} + **GPU Vendor:** {{ gpu_vendor | default('unknown') }} + + ## Models Executed + {% for model_name, model_info in models.items() %} + - **{{ model_name }}**: {{ model_info.get('status', 'unknown') }} + {% endfor %} + + ## Execution Output + ``` + {{ '{{ execution_results.stdout | default("No output captured") }}' }} + ``` + + ## Execution Errors + ``` + {{ '{{ execution_results.stderr | default("No errors") }}' }} + ``` + dest: "{{ '{{ madengine_workspace }}' }}/execution_summary.md" + mode: '0644' + + - name: Display execution results + debug: + var: execution_results.stdout_lines + when: execution_results is defined + + - name: Handle execution failures + fail: + msg: "MADEngine execution failed: {{ '{{ execution_results.stderr }}' }}" + when: execution_results is defined and execution_results.rc != 0 + + {% if post_execution.cleanup | default(false) %} + - name: Cleanup workspace + file: + path: "{{ madengine_workspace }}" + state: absent + when: post_execution.cleanup | default(false) + {% endif %} + + {% if post_execution.collect_logs | default(true) %} + - name: Collect execution logs + fetch: + src: "{{ madengine_workspace }}/{{ item }}" + dest: "{{ logs.local_path | default('./logs') }}/{{ inventory_hostname }}_{{ item }}" + flat: yes + loop: + - "execution_summary.md" + - "perf.csv" + - "madengine.log" + ignore_errors: yes + {% endif %} diff --git a/src/madengine/runners/templates/k8s/configmap.yaml.j2 b/src/madengine/runners/templates/k8s/configmap.yaml.j2 new file mode 100644 index 00000000..9cd01f36 --- /dev/null +++ b/src/madengine/runners/templates/k8s/configmap.yaml.j2 @@ -0,0 +1,143 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ k8s.configmap.name | default('madengine-config') }} + namespace: {{ k8s.namespace | default('madengine') }} + labels: + app.kubernetes.io/name: madengine + app.kubernetes.io/component: config + app.kubernetes.io/version: {{ generation.version | default('1.0.0') }} + annotations: + generated-on: "{{ generation.timestamp }}" + environment: "{{ environment | default('default') }}" +data: + # Build manifest data + manifest.json: | + {{ manifest | to_json | indent(4) }} + + # Execution configuration + execution-config.json: | + { + "timeout": {{ execution.timeout | default(7200) }}, + "keep_alive": {{ execution.keep_alive | default(false) | lower }}, + "live_output": {{ execution.live_output | default(true) | lower }}, + "output_file": "{{ execution.output_file | default('perf.csv') }}", + "results_file": "{{ execution.results_file | default('execution_results.json') }}", + "generate_sys_env_details": {{ execution.generate_sys_env_details | default(true) | lower }}, + "registry": "{{ registry | default('') }}", + "gpu_vendor": "{{ gpu_vendor | default('') }}" + } + + {% if credentials %} + # Credentials configuration + credential.json: | + {{ credentials | to_json | indent(4) }} + {% endif %} + + {% if data_config %} + # Data configuration + data.json: | + {{ data_config | to_json | indent(4) }} + {% endif %} + + # Execution script + execute_models.py: | + #!/usr/bin/env python3 + """ + MADEngine Kubernetes Execution Script + Generated on: {{ generation.timestamp }} + Environment: {{ environment | default('default') }} + """ + + import os + import sys + import json + import argparse + from datetime import datetime + + try: + from madengine.tools.distributed_orchestrator import DistributedOrchestrator + except ImportError as e: + print(f"Error importing MADEngine: {e}") + sys.exit(1) + + def main(): + """Main execution function.""" + print("=" * 80) + print("MADEngine Kubernetes Model Execution") + print("=" * 80) + print(f"Execution started: {datetime.now().isoformat()}") + print(f"Environment: {{ environment | default('default') }}") + print(f"Registry: {{ registry | default('local') }}") + print(f"GPU Vendor: {{ gpu_vendor | default('unknown') }}") + print("=" * 80) + + # Load configuration + with open('/config/execution-config.json', 'r') as f: + config = json.load(f) + + # Create args + args = argparse.Namespace() + args.live_output = config.get('live_output', True) + args.additional_context = None + args.additional_context_file = None + args.data_config_file_name = '/config/data.json' if os.path.exists('/config/data.json') else 'data.json' + args.force_mirror_local = False + args.output = config.get('output_file', 'perf.csv') + args.generate_sys_env_details = config.get('generate_sys_env_details', True) + args._separate_phases = True + + try: + # Initialize orchestrator + orchestrator = DistributedOrchestrator(args) + + # Execute run phase + execution_summary = orchestrator.run_phase( + manifest_file='/config/manifest.json', + registry=config.get('registry', ''), + timeout=config.get('timeout', 7200), + keep_alive=config.get('keep_alive', False) + ) + + # Save results + results_file = config.get('results_file', 'execution_results.json') + with open(results_file, 'w') as f: + json.dump(execution_summary, f, indent=2) + + print(f"Results saved to: {results_file}") + + # Return appropriate exit code + if execution_summary.get('failed_runs'): + return 1 + return 0 + + except Exception as e: + print(f"Error during execution: {e}") + import traceback + traceback.print_exc() + return 1 + + if __name__ == "__main__": + sys.exit(main()) + + # Additional configuration files + madengine.conf: | + # MADEngine Configuration + [general] + environment = {{ environment | default('default') }} + registry = {{ registry | default('') }} + gpu_vendor = {{ gpu_vendor | default('') }} + + [execution] + timeout = {{ execution.timeout | default(7200) }} + keep_alive = {{ execution.keep_alive | default(false) | lower }} + live_output = {{ execution.live_output | default(true) | lower }} + + [logging] + level = {{ logging.level | default('INFO') }} + format = {{ logging.format | default('%(asctime)s - %(name)s - %(levelname)s - %(message)s') }} + + [resources] + memory_limit = {{ resources.memory_limit | default('4Gi') }} + cpu_limit = {{ resources.cpu_limit | default('2') }} + gpu_limit = {{ resources.gpu_limit | default('1') }} diff --git a/src/madengine/runners/templates/k8s/job.yaml.j2 b/src/madengine/runners/templates/k8s/job.yaml.j2 new file mode 100644 index 00000000..520ed44a --- /dev/null +++ b/src/madengine/runners/templates/k8s/job.yaml.j2 @@ -0,0 +1,238 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ k8s.job.name | default('madengine-execution') }} + namespace: {{ k8s.namespace | default('madengine') }} + labels: + app.kubernetes.io/name: madengine + app.kubernetes.io/component: execution + app.kubernetes.io/version: {{ generation.version | default('1.0.0') }} + environment: {{ environment | default('default') }} + annotations: + generated-on: "{{ generation.timestamp }}" + registry: "{{ registry | default('local') }}" + gpu-vendor: "{{ gpu_vendor | default('unknown') }}" +spec: + parallelism: {{ k8s.job.parallelism | default(1) }} + completions: {{ k8s.job.completions | default(1) }} + backoffLimit: {{ k8s.job.backoff_limit | default(3) }} + activeDeadlineSeconds: {{ k8s.job.active_deadline_seconds | default(14400) }} + template: + metadata: + labels: + app.kubernetes.io/name: madengine + app.kubernetes.io/component: execution + job-name: {{ k8s.job.name | default('madengine-execution') }} + spec: + restartPolicy: {{ k8s.job.restart_policy | default('Never') }} + + {% if k8s.service_account %} + serviceAccountName: {{ k8s.service_account }} + {% endif %} + + {% if k8s.image_pull_secrets %} + imagePullSecrets: + {% for secret in k8s.image_pull_secrets %} + - name: {{ secret }} + {% endfor %} + {% endif %} + + containers: + - name: madengine-runner + image: {{ k8s.container.image | default('madengine/distributed-runner:latest') }} + imagePullPolicy: {{ k8s.container.image_pull_policy | default('IfNotPresent') }} + + command: ["/bin/bash"] + args: + - "-c" + - | + set -e + echo "Starting MADEngine execution..." + + # Set up environment + export PYTHONPATH=/usr/local/lib/python3.8/site-packages:$PYTHONPATH + + # Make script executable + chmod +x /config/execute_models.py + + # Execute the models + python3 /config/execute_models.py + + # Copy results to shared volume if available + if [ -d "/results" ]; then + cp -v *.csv *.json *.log /results/ 2>/dev/null || echo "No results to copy" + fi + + echo "MADEngine execution completed" + + volumeMounts: + - name: config-volume + mountPath: /config + readOnly: true + - name: docker-socket + mountPath: /var/run/docker.sock + {% if k8s.volumes.shared_storage %} + - name: shared-storage + mountPath: /results + {% endif %} + {% if k8s.volumes.data_storage %} + - name: data-storage + mountPath: /data + {% endif %} + + resources: + limits: + {% if gpu_vendor == 'nvidia' %} + nvidia.com/gpu: {{ resources.gpu_limit | default('1') }} + {% elif gpu_vendor == 'amd' %} + amd.com/gpu: {{ resources.gpu_limit | default('1') }} + {% endif %} + memory: {{ resources.memory_limit | default('4Gi') }} + cpu: {{ resources.cpu_limit | default('2') }} + requests: + memory: {{ resources.memory_request | default('2Gi') }} + cpu: {{ resources.cpu_request | default('1') }} + + env: + - name: MADENGINE_ENVIRONMENT + value: "{{ environment | default('default') }}" + - name: MADENGINE_REGISTRY + value: "{{ registry | default('') }}" + - name: MADENGINE_GPU_VENDOR + value: "{{ gpu_vendor | default('') }}" + - name: PYTHONPATH + value: "/usr/local/lib/python3.8/site-packages" + + {% if gpu_vendor == 'nvidia' %} + - name: NVIDIA_VISIBLE_DEVICES + value: "{{ nvidia.visible_devices | default('all') }}" + - name: NVIDIA_DRIVER_CAPABILITIES + value: "{{ nvidia.driver_capabilities | default('compute,utility') }}" + {% elif gpu_vendor == 'amd' %} + - name: ROC_ENABLE_PRE_VEGA + value: "{{ amd.enable_pre_vega | default('1') }}" + - name: HIP_VISIBLE_DEVICES + value: "{{ amd.visible_devices | default('all') }}" + {% endif %} + + {% for key, value in docker_env_vars.items() %} + - name: {{ key }} + value: "{{ value }}" + {% endfor %} + + {% if k8s.container.security_context %} + securityContext: + runAsUser: {{ k8s.container.security_context.run_as_user | default(0) }} + runAsGroup: {{ k8s.container.security_context.run_as_group | default(0) }} + privileged: {{ k8s.container.security_context.privileged | default(false) | lower }} + {% if k8s.container.security_context.capabilities %} + capabilities: + add: + {% for cap in k8s.container.security_context.capabilities.add %} + - {{ cap }} + {% endfor %} + {% endif %} + {% endif %} + + {% if k8s.container.health_checks %} + livenessProbe: + exec: + command: + - /bin/bash + - -c + - "ps aux | grep -v grep | grep python3 > /dev/null" + initialDelaySeconds: {{ k8s.container.health_checks.liveness.initial_delay | default(30) }} + periodSeconds: {{ k8s.container.health_checks.liveness.period | default(60) }} + timeoutSeconds: {{ k8s.container.health_checks.liveness.timeout | default(10) }} + failureThreshold: {{ k8s.container.health_checks.liveness.failure_threshold | default(3) }} + + readinessProbe: + exec: + command: + - /bin/bash + - -c + - "test -f /config/manifest.json" + initialDelaySeconds: {{ k8s.container.health_checks.readiness.initial_delay | default(5) }} + periodSeconds: {{ k8s.container.health_checks.readiness.period | default(10) }} + timeoutSeconds: {{ k8s.container.health_checks.readiness.timeout | default(5) }} + {% endif %} + + volumes: + - name: config-volume + configMap: + name: {{ k8s.configmap.name | default('madengine-config') }} + defaultMode: 0755 + - name: docker-socket + hostPath: + path: /var/run/docker.sock + type: Socket + + {% if k8s.volumes.shared_storage %} + - name: shared-storage + {% if k8s.volumes.shared_storage.type == 'pvc' %} + persistentVolumeClaim: + claimName: {{ k8s.volumes.shared_storage.claim_name }} + {% elif k8s.volumes.shared_storage.type == 'nfs' %} + nfs: + server: {{ k8s.volumes.shared_storage.server }} + path: {{ k8s.volumes.shared_storage.path }} + {% elif k8s.volumes.shared_storage.type == 'hostPath' %} + hostPath: + path: {{ k8s.volumes.shared_storage.path }} + type: {{ k8s.volumes.shared_storage.hostPath_type | default('DirectoryOrCreate') }} + {% endif %} + {% endif %} + + {% if k8s.volumes.data_storage %} + - name: data-storage + {% if k8s.volumes.data_storage.type == 'pvc' %} + persistentVolumeClaim: + claimName: {{ k8s.volumes.data_storage.claim_name }} + {% elif k8s.volumes.data_storage.type == 'nfs' %} + nfs: + server: {{ k8s.volumes.data_storage.server }} + path: {{ k8s.volumes.data_storage.path }} + {% elif k8s.volumes.data_storage.type == 'hostPath' %} + hostPath: + path: {{ k8s.volumes.data_storage.path }} + type: {{ k8s.volumes.data_storage.hostPath_type | default('DirectoryOrCreate') }} + {% endif %} + {% endif %} + + {% if k8s.node_selector %} + nodeSelector: + {% for key, value in k8s.node_selector.items() %} + {{ key }}: {{ value }} + {% endfor %} + {% endif %} + + {% if k8s.tolerations %} + tolerations: + {% for toleration in k8s.tolerations %} + - key: {{ toleration.key }} + operator: {{ toleration.operator | default('Equal') }} + {% if toleration.value %} + value: {{ toleration.value }} + {% endif %} + effect: {{ toleration.effect }} + {% if toleration.toleration_seconds %} + tolerationSeconds: {{ toleration.toleration_seconds }} + {% endif %} + {% endfor %} + {% endif %} + + {% if k8s.affinity %} + affinity: + {% if k8s.affinity.node_affinity %} + nodeAffinity: + {{ k8s.affinity.node_affinity | to_yaml | indent(10) }} + {% endif %} + {% if k8s.affinity.pod_affinity %} + podAffinity: + {{ k8s.affinity.pod_affinity | to_yaml | indent(10) }} + {% endif %} + {% if k8s.affinity.pod_anti_affinity %} + podAntiAffinity: + {{ k8s.affinity.pod_anti_affinity | to_yaml | indent(10) }} + {% endif %} + {% endif %} diff --git a/src/madengine/runners/templates/k8s/namespace.yaml.j2 b/src/madengine/runners/templates/k8s/namespace.yaml.j2 new file mode 100644 index 00000000..e4fabf01 --- /dev/null +++ b/src/madengine/runners/templates/k8s/namespace.yaml.j2 @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: {{ k8s.namespace | default('madengine') }} + labels: + name: {{ k8s.namespace | default('madengine') }} + app.kubernetes.io/name: madengine + app.kubernetes.io/version: {{ generation.version | default('1.0.0') }} + app.kubernetes.io/managed-by: {{ generation.generator | default('MADEngine Template Generator') }} + annotations: + generated-on: "{{ generation.timestamp }}" + environment: "{{ environment | default('default') }}" + registry: "{{ registry | default('local') }}" diff --git a/src/madengine/runners/templates/k8s/service.yaml.j2 b/src/madengine/runners/templates/k8s/service.yaml.j2 new file mode 100644 index 00000000..a714dfd3 --- /dev/null +++ b/src/madengine/runners/templates/k8s/service.yaml.j2 @@ -0,0 +1,78 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ k8s.service.name | default('madengine-service') }} + namespace: {{ k8s.namespace | default('madengine') }} + labels: + app.kubernetes.io/name: madengine + app.kubernetes.io/component: service + app.kubernetes.io/version: {{ generation.version | default('1.0.0') }} + annotations: + generated-on: "{{ generation.timestamp }}" + environment: "{{ environment | default('default') }}" +spec: + type: {{ k8s.service.type | default('ClusterIP') }} + + {% if k8s.service.type == 'LoadBalancer' and k8s.service.load_balancer_ip %} + loadBalancerIP: {{ k8s.service.load_balancer_ip }} + {% endif %} + + {% if k8s.service.type == 'LoadBalancer' and k8s.service.load_balancer_source_ranges %} + loadBalancerSourceRanges: + {% for range in k8s.service.load_balancer_source_ranges %} + - {{ range }} + {% endfor %} + {% endif %} + + {% if k8s.service.external_ips %} + externalIPs: + {% for ip in k8s.service.external_ips %} + - {{ ip }} + {% endfor %} + {% endif %} + + {% if k8s.service.cluster_ip %} + clusterIP: {{ k8s.service.cluster_ip }} + {% endif %} + + {% if k8s.service.external_name %} + externalName: {{ k8s.service.external_name }} + {% endif %} + + ports: + {% if k8s.service.ports %} + {% for port in k8s.service.ports %} + - name: {{ port.name | default('http') }} + port: {{ port.port }} + targetPort: {{ port.target_port | default(port.port) }} + {% if port.protocol %} + protocol: {{ port.protocol }} + {% endif %} + {% if port.node_port and k8s.service.type == 'NodePort' %} + nodePort: {{ port.node_port }} + {% endif %} + {% endfor %} + {% else %} + # Default ports for MADEngine monitoring/logging + - name: http + port: 8080 + targetPort: 8080 + protocol: TCP + - name: metrics + port: 9090 + targetPort: 9090 + protocol: TCP + {% endif %} + + selector: + app.kubernetes.io/name: madengine + app.kubernetes.io/component: execution + + {% if k8s.service.session_affinity %} + sessionAffinity: {{ k8s.service.session_affinity }} + {% if k8s.service.session_affinity == 'ClientIP' and k8s.service.session_affinity_config %} + sessionAffinityConfig: + clientIP: + timeoutSeconds: {{ k8s.service.session_affinity_config.timeout_seconds | default(10800) }} + {% endif %} + {% endif %} diff --git a/src/madengine/runners/templates/slurm/inventory.yml.j2 b/src/madengine/runners/templates/slurm/inventory.yml.j2 new file mode 100644 index 00000000..a31ffd22 --- /dev/null +++ b/src/madengine/runners/templates/slurm/inventory.yml.j2 @@ -0,0 +1,78 @@ +# SLURM Cluster Inventory for MADEngine +# Generated on {{ generation.timestamp }} + +slurm_cluster: + # SLURM login/head node configuration + login_node: + hostname: "{{ slurm.login_node.hostname | default('slurm-login') }}" + address: "{{ slurm.login_node.address | default('localhost') }}" + port: {{ slurm.login_node.port | default(22) }} + username: "{{ slurm.login_node.username | default('madengine') }}" + ssh_key_path: "{{ slurm.login_node.ssh_key_path | default('~/.ssh/id_rsa') }}" + + # SLURM cluster configuration + cluster_name: "{{ slurm.cluster_name | default('madengine-cluster') }}" + + # Available partitions + partitions: +{% for partition in slurm.partitions %} + - name: "{{ partition.name }}" + max_time: "{{ partition.max_time | default('24:00:00') }}" + max_nodes: {{ partition.max_nodes | default(32) }} + default_gpu_count: {{ partition.default_gpu_count | default(1) }} + gpu_types: {{ partition.gpu_types | default(['generic']) | to_yaml | indent(8) }} + memory_per_node: "{{ partition.memory_per_node | default('256G') }}" + {% if partition.qos %} + qos: "{{ partition.qos }}" + {% endif %} + {% if partition.account %} + account: "{{ partition.account }}" + {% endif %} +{% endfor %} + + # Workspace configuration + workspace: + shared_filesystem: "{{ workspace.shared_filesystem | default('/shared/madengine') }}" + results_dir: "{{ workspace.results_dir | default('/shared/results') }}" + logs_dir: "{{ workspace.logs_dir | default('logs') }}" + venv_path: "{{ workspace.venv_path | default('venv') }}" + + # Module system + modules: +{% for module in slurm.modules %} + - "{{ module }}" +{% endfor %} + + # Environment variables + environment: +{% for key, value in slurm.environment.items() %} + {{ key }}: "{{ value }}" +{% endfor %} + + # GPU vendor mapping + gpu_mapping: +{% for vendor, config in slurm.gpu_mapping.items() %} + {{ vendor }}: + gres_name: "{{ config.gres_name | default('gpu') }}" + constraint: "{{ config.constraint | default('') }}" + memory_per_gpu: "{{ config.memory_per_gpu | default('16G') }}" +{% endfor %} + + # Job execution settings + execution: + max_concurrent_jobs: {{ slurm.execution.max_concurrent_jobs | default(8) }} + job_array_strategy: {{ slurm.execution.job_array_strategy | default(true) }} + default_timeout: {{ slurm.execution.default_timeout | default(3600) }} + retry_failed_jobs: {{ slurm.execution.retry_failed_jobs | default(true) }} + max_retries: {{ slurm.execution.max_retries | default(3) }} + +# Model-specific overrides (if needed) +{% if model_overrides %} +model_overrides: +{% for model_tag, overrides in model_overrides.items() %} + "{{ model_tag }}": +{% for key, value in overrides.items() %} + {{ key }}: {{ value | to_yaml }} +{% endfor %} +{% endfor %} +{% endif %} \ No newline at end of file diff --git a/src/madengine/runners/templates/slurm/job_array.sh.j2 b/src/madengine/runners/templates/slurm/job_array.sh.j2 new file mode 100644 index 00000000..e79ff420 --- /dev/null +++ b/src/madengine/runners/templates/slurm/job_array.sh.j2 @@ -0,0 +1,101 @@ +#!/bin/bash +#SBATCH --job-name=madengine-array-{{ job_name | default("madengine") }} +#SBATCH --partition={{ partition | default("gpu") }} +#SBATCH --nodes={{ nodes_per_task | default(1) }} +#SBATCH --ntasks-per-node={{ tasks_per_node | default(1) }} +#SBATCH --gres=gpu:{{ gpu_count | default(1) }} +#SBATCH --time={{ time_limit | default("24:00:00") }} +#SBATCH --mem={{ memory | default("32G") }} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if exclusive %} +#SBATCH --exclusive +{% endif %} +#SBATCH --array=0-{{ (model_tags | length) - 1 }}%{{ max_concurrent_jobs | default(8) }} +#SBATCH --output={{ output_dir | default("logs") }}/madengine_array_%A_%a.out +#SBATCH --error={{ output_dir | default("logs") }}/madengine_array_%A_%a.err + +# Job configuration +echo "=== SLURM Job Array Information ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Array Task ID: $SLURM_ARRAY_TASK_ID" +echo "Node: $SLURMD_NODENAME" +echo "Partition: {{ partition | default('gpu') }}" +echo "GPUs: {{ gpu_count | default(1) }}" +echo "==================================" + +# Load required modules +{% for module in modules %} +module load {{ module }} +{% endfor %} + +# Set environment variables +export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID +export OMP_NUM_THREADS={{ omp_num_threads | default(1) }} +{% for key, value in environment.items() %} +export {{ key }}="{{ value }}" +{% endfor %} + +# Change to MAD workspace directory +cd {{ mad_workspace_path | default("/shared/madengine") }} + +# Activate Python virtual environment +source {{ venv_path | default("venv") }}/bin/activate + +# Create array of model tags +MODEL_TAGS=( +{% for tag in model_tags %} + "{{ tag }}" +{% endfor %} +) + +# Get the model tag for this array task +MODEL_TAG=${MODEL_TAGS[$SLURM_ARRAY_TASK_ID]} + +echo "Processing model tag: $MODEL_TAG" + +# Create output directory for this specific model +MODEL_OUTPUT_DIR="{{ results_dir | default('results') }}/${MODEL_TAG}_${SLURM_JOB_ID}_${SLURM_ARRAY_TASK_ID}" +mkdir -p "$MODEL_OUTPUT_DIR" + +# Execute madengine-cli with the specific model tag +echo "Starting madengine execution for $MODEL_TAG at $(date)" + +madengine-cli run \ + --manifest-file {{ manifest_file | default("build_manifest.json") }} \ + --tags "$MODEL_TAG" \ + --timeout {{ timeout | default(3600) }} \ + {% if registry %}--registry {{ registry }}{% endif %} \ + --live-output \ + --output-dir "$MODEL_OUTPUT_DIR" \ + {% if additional_args %}{{ additional_args }}{% endif %} + +# Capture exit code +EXIT_CODE=$? + +echo "Finished madengine execution for $MODEL_TAG at $(date) with exit code: $EXIT_CODE" + +# Create result summary +cat > "$MODEL_OUTPUT_DIR/job_summary.json" << EOF +{ + "job_id": "$SLURM_JOB_ID", + "array_task_id": "$SLURM_ARRAY_TASK_ID", + "model_tag": "$MODEL_TAG", + "node": "$SLURMD_NODENAME", + "start_time": "$(date -Iseconds)", + "exit_code": $EXIT_CODE, + "gpu_count": {{ gpu_count | default(1) }}, + "partition": "{{ partition | default('gpu') }}", + "output_dir": "$MODEL_OUTPUT_DIR" +} +EOF + +# Exit with the madengine exit code +exit $EXIT_CODE \ No newline at end of file diff --git a/src/madengine/runners/templates/slurm/setup_environment.sh.j2 b/src/madengine/runners/templates/slurm/setup_environment.sh.j2 new file mode 100644 index 00000000..34f59d44 --- /dev/null +++ b/src/madengine/runners/templates/slurm/setup_environment.sh.j2 @@ -0,0 +1,96 @@ +#!/bin/bash +#SBATCH --job-name=madengine-setup +#SBATCH --partition={{ setup_partition | default("cpu") }} +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --time={{ setup_time_limit | default("01:00:00") }} +#SBATCH --mem={{ setup_memory | default("8G") }} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +#SBATCH --output={{ output_dir | default("logs") }}/madengine_setup_%j.out +#SBATCH --error={{ output_dir | default("logs") }}/madengine_setup_%j.err + +# Environment setup job for MADEngine SLURM execution +echo "=== MADEngine Environment Setup ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Node: $SLURMD_NODENAME" +echo "Workspace: {{ mad_workspace_path | default('/shared/madengine') }}" +echo "==================================" + +# Load required modules +{% for module in modules %} +module load {{ module }} +{% endfor %} + +# Create workspace directory on shared filesystem +WORKSPACE="{{ mad_workspace_path | default('/shared/madengine') }}" +mkdir -p "$WORKSPACE" +mkdir -p "{{ results_dir | default('results') }}" +mkdir -p "{{ output_dir | default('logs') }}" + +cd "$WORKSPACE" + +# Clone or update MAD repository +if [ -d "MAD" ]; then + echo "Updating existing MAD repository..." + cd MAD + git pull origin main + cd .. +else + echo "Cloning MAD repository..." + git clone {{ mad_repo_url | default("https://github.com/ROCm/MAD.git") }} MAD +fi + +cd MAD + +# Create Python virtual environment +echo "Setting up Python virtual environment..." +python3 -m venv {{ venv_path | default("venv") }} +source {{ venv_path | default("venv") }}/bin/activate + +# Install dependencies +echo "Installing Python dependencies..." +pip install --upgrade pip +pip install -r requirements.txt + +# Install madengine with SLURM dependencies +pip install -e . + +# Copy manifest and configuration files to workspace +{% if manifest_file %} +cp {{ manifest_file }} build_manifest.json +{% endif %} + +{% for config_file in config_files %} +if [ -f "{{ config_file }}" ]; then + cp "{{ config_file }}" . + echo "Copied {{ config_file }}" +fi +{% endfor %} + +# Verify madengine installation +echo "Verifying madengine-cli installation..." +madengine-cli --version +madengine-cli --help > /dev/null + +if [ $? -eq 0 ]; then + echo "✅ MADEngine environment setup completed successfully" + + # Create setup completion marker + cat > setup_complete.json << EOF +{ + "setup_job_id": "$SLURM_JOB_ID", + "setup_node": "$SLURMD_NODENAME", + "setup_time": "$(date -Iseconds)", + "workspace_path": "$WORKSPACE", + "venv_path": "{{ venv_path | default('venv') }}", + "status": "completed" +} +EOF + + exit 0 +else + echo "❌ MADEngine environment setup failed" + exit 1 +fi \ No newline at end of file diff --git a/src/madengine/runners/templates/slurm/single_job.sh.j2 b/src/madengine/runners/templates/slurm/single_job.sh.j2 new file mode 100644 index 00000000..9b166565 --- /dev/null +++ b/src/madengine/runners/templates/slurm/single_job.sh.j2 @@ -0,0 +1,88 @@ +#!/bin/bash +#SBATCH --job-name=madengine-{{ model_tag | replace(":", "-") | replace("_", "-") }} +#SBATCH --partition={{ partition | default("gpu") }} +#SBATCH --nodes={{ nodes | default(1) }} +#SBATCH --ntasks-per-node={{ tasks_per_node | default(1) }} +#SBATCH --gres=gpu:{{ gpu_count | default(1) }} +#SBATCH --time={{ time_limit | default("24:00:00") }} +#SBATCH --mem={{ memory | default("32G") }} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if exclusive %} +#SBATCH --exclusive +{% endif %} +#SBATCH --output={{ output_dir | default("logs") }}/madengine_{{ model_tag | replace(":", "-") | replace("_", "-") }}_%j.out +#SBATCH --error={{ output_dir | default("logs") }}/madengine_{{ model_tag | replace(":", "-") | replace("_", "-") }}_%j.err + +# Job configuration +echo "=== SLURM Job Information ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Job Name: madengine-{{ model_tag | replace(":", "-") | replace("_", "-") }}" +echo "Node: $SLURMD_NODENAME" +echo "Partition: {{ partition | default('gpu') }}" +echo "GPUs: {{ gpu_count | default(1) }}" +echo "Model Tag: {{ model_tag }}" +echo "=============================" + +# Load required modules +{% for module in modules %} +module load {{ module }} +{% endfor %} + +# Set environment variables +export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID +export OMP_NUM_THREADS={{ omp_num_threads | default(1) }} +{% for key, value in environment.items() %} +export {{ key }}="{{ value }}" +{% endfor %} + +# Change to MAD workspace directory +cd {{ mad_workspace_path | default("/shared/madengine") }} + +# Activate Python virtual environment +source {{ venv_path | default("venv") }}/bin/activate + +# Create output directory for this specific model +MODEL_OUTPUT_DIR="{{ results_dir | default('results') }}/{{ model_tag | replace(":", "-") | replace("_", "-") }}_${SLURM_JOB_ID}" +mkdir -p "$MODEL_OUTPUT_DIR" + +# Execute madengine-cli with the specific model tag +echo "Starting madengine execution for {{ model_tag }} at $(date)" + +madengine-cli run \ + --manifest-file {{ manifest_file | default("build_manifest.json") }} \ + --tags "{{ model_tag }}" \ + --timeout {{ timeout | default(3600) }} \ + {% if registry %}--registry {{ registry }}{% endif %} \ + --live-output \ + --output-dir "$MODEL_OUTPUT_DIR" \ + {% if additional_args %}{{ additional_args }}{% endif %} + +# Capture exit code +EXIT_CODE=$? + +echo "Finished madengine execution for {{ model_tag }} at $(date) with exit code: $EXIT_CODE" + +# Create result summary +cat > "$MODEL_OUTPUT_DIR/job_summary.json" << EOF +{ + "job_id": "$SLURM_JOB_ID", + "model_tag": "{{ model_tag }}", + "node": "$SLURMD_NODENAME", + "start_time": "$(date -Iseconds)", + "exit_code": $EXIT_CODE, + "gpu_count": {{ gpu_count | default(1) }}, + "partition": "{{ partition | default('gpu') }}", + "output_dir": "$MODEL_OUTPUT_DIR" +} +EOF + +# Exit with the madengine exit code +exit $EXIT_CODE \ No newline at end of file diff --git a/src/madengine/runners/values/default.yaml b/src/madengine/runners/values/default.yaml new file mode 100644 index 00000000..77b50c6d --- /dev/null +++ b/src/madengine/runners/values/default.yaml @@ -0,0 +1,205 @@ +# Default configuration for MADEngine distributed execution +# This file contains the base configuration that can be overridden by environment-specific files + +# General configuration +environment: "default" +manifest_file: "build_manifest.json" + +# Workspace configuration +workspace: + path: "/tmp/madengine_distributed" + owner: "root" + group: "root" + +# Execution configuration +execution: + timeout: 7200 # 2 hours + keep_alive: false + live_output: true + output_file: "perf.csv" + results_file: "execution_results.json" + generate_sys_env_details: true + async_timeout: 14400 # 4 hours + poll_interval: 30 + additional_context: null + additional_context_file: null + +# Data configuration +data_config: + file: "data.json" + force_mirror_local: false + required: false + +# Credentials configuration +credentials: + file: "credential.json" + required: false + +# Docker registry configuration +docker_registry: + login_required: false + username: "" + password: "" + +# Python configuration +python_path: "/usr/local/lib/python3.8/site-packages" +python_dependencies: + - jinja2 + - pyyaml + - requests + +# Installation configuration +install_dependencies: false + +# Post-execution configuration +post_execution: + cleanup: false + collect_logs: true + +# Logging configuration +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +logs: + local_path: "./logs" + +# Ansible configuration +ansible: + target_hosts: "gpu_nodes" + become: true + +# Kubernetes configuration +k8s: + namespace: "madengine" + + # ConfigMap configuration + configmap: + name: "madengine-config" + + # Job configuration + job: + name: "madengine-execution" + parallelism: 1 + completions: 1 + backoff_limit: 3 + active_deadline_seconds: 14400 # 4 hours + restart_policy: "Never" + + # Container configuration + container: + image: "madengine/distributed-runner:latest" + image_pull_policy: "IfNotPresent" + security_context: + run_as_user: 0 + run_as_group: 0 + privileged: false + health_checks: + liveness: + initial_delay: 30 + period: 60 + timeout: 10 + failure_threshold: 3 + readiness: + initial_delay: 5 + period: 10 + timeout: 5 + + # Service configuration + service: + name: "madengine-service" + type: "ClusterIP" + ports: + - name: "http" + port: 8080 + target_port: 8080 + protocol: "TCP" + - name: "metrics" + port: 9090 + target_port: 9090 + protocol: "TCP" + + # Volume configuration + volumes: + shared_storage: + type: "hostPath" + path: "/tmp/madengine-results" + hostPath_type: "DirectoryOrCreate" + + # Node selector + node_selector: + accelerator: "gpu" + + # Tolerations for GPU nodes + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + +# Resource configuration +resources: + memory_limit: "4Gi" + memory_request: "2Gi" + cpu_limit: "2" + cpu_request: "1" + gpu_limit: "1" + +# GPU vendor specific configuration +nvidia: + visible_devices: "all" + driver_capabilities: "compute,utility" + +amd: + visible_devices: "all" + enable_pre_vega: "1" + +# SLURM configuration (basic defaults) +slurm: + # Login/head node configuration + login_node: + hostname: "slurm-login" + address: "localhost" + port: 22 + username: "madengine" + ssh_key_path: "~/.ssh/id_rsa" + + # Cluster identification + cluster_name: "madengine-cluster" + + # Basic partition configuration + partitions: + - name: "gpu" + max_time: "24:00:00" + max_nodes: 8 + default_gpu_count: 1 + gpu_types: ["gpu"] + memory_per_node: "64G" + gpu_vendor: "AMD" + + # Basic modules + modules: + - "python/3.9" + - "gcc/11.2.0" + + # Basic environment + environment: + OMP_NUM_THREADS: "1" + + # GPU mapping + gpu_mapping: + AMD: + gres_name: "gpu" + constraint: "" + memory_per_gpu: "16G" + NVIDIA: + gres_name: "gpu" + constraint: "" + memory_per_gpu: "16G" + + # Execution defaults + execution: + max_concurrent_jobs: 4 + job_array_strategy: true + default_timeout: 3600 + retry_failed_jobs: false + max_retries: 1 diff --git a/src/madengine/runners/values/dev.yaml b/src/madengine/runners/values/dev.yaml new file mode 100644 index 00000000..522c2718 --- /dev/null +++ b/src/madengine/runners/values/dev.yaml @@ -0,0 +1,169 @@ +# Development environment configuration +# Extends default.yaml with development-specific settings + +# General configuration +environment: "dev" + +# Workspace configuration +workspace: + path: "/tmp/madengine_dev" + owner: "developer" + group: "developer" + +# Execution configuration +execution: + timeout: 3600 # 1 hour for dev + keep_alive: true # Keep containers alive for debugging + live_output: true + output_file: "dev_perf.csv" + results_file: "dev_execution_results.json" + generate_sys_env_details: true + async_timeout: 7200 # 2 hours + poll_interval: 10 # More frequent polling + +# Data configuration +data_config: + file: "dev_data.json" + force_mirror_local: true # Use local data for dev + required: false + +# Credentials configuration +credentials: + file: "dev_credential.json" + required: false + +# Docker registry configuration +docker_registry: + login_required: false + username: "dev-user" + password: "" + +# Python configuration +python_dependencies: + - jinja2 + - pyyaml + - requests + - pytest + - black + - mypy + +# Installation configuration +install_dependencies: true + +# Post-execution configuration +post_execution: + cleanup: false # Don't cleanup in dev + collect_logs: true + +# Logging configuration +logging: + level: "DEBUG" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +logs: + local_path: "./dev_logs" + +# Ansible configuration +ansible: + target_hosts: "dev_nodes" + become: false + +# Kubernetes configuration +k8s: + namespace: "madengine-dev" + + # ConfigMap configuration + configmap: + name: "madengine-dev-config" + + # Job configuration + job: + name: "madengine-dev-execution" + parallelism: 1 + completions: 1 + backoff_limit: 1 # Fail fast in dev + active_deadline_seconds: 7200 # 2 hours + restart_policy: "Never" + + # Container configuration + container: + image: "madengine/distributed-runner:dev" + image_pull_policy: "Always" # Always pull latest dev image + security_context: + run_as_user: 1000 + run_as_group: 1000 + privileged: false + health_checks: + liveness: + initial_delay: 10 + period: 30 + timeout: 5 + failure_threshold: 2 + readiness: + initial_delay: 5 + period: 5 + timeout: 3 + + # Service configuration + service: + name: "madengine-dev-service" + type: "NodePort" + ports: + - name: "http" + port: 8080 + target_port: 8080 + protocol: "TCP" + node_port: 30080 + - name: "metrics" + port: 9090 + target_port: 9090 + protocol: "TCP" + node_port: 30090 + - name: "debug" + port: 5678 + target_port: 5678 + protocol: "TCP" + node_port: 30678 + + # Volume configuration + volumes: + shared_storage: + type: "hostPath" + path: "/tmp/madengine-dev-results" + hostPath_type: "DirectoryOrCreate" + data_storage: + type: "hostPath" + path: "/tmp/madengine-dev-data" + hostPath_type: "DirectoryOrCreate" + + # Node selector + node_selector: + environment: "dev" + accelerator: "gpu" + + # Tolerations for GPU nodes + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - key: "dev-environment" + operator: "Equal" + value: "true" + effect: "NoSchedule" + +# Resource configuration +resources: + memory_limit: "2Gi" # Lower limits for dev + memory_request: "1Gi" + cpu_limit: "1" + cpu_request: "0.5" + gpu_limit: "1" + +# GPU vendor specific configuration +nvidia: + visible_devices: "0" # Only use first GPU in dev + driver_capabilities: "compute,utility" + +amd: + visible_devices: "0" + enable_pre_vega: "1" diff --git a/src/madengine/runners/values/prod.yaml b/src/madengine/runners/values/prod.yaml new file mode 100644 index 00000000..7cfb0c6a --- /dev/null +++ b/src/madengine/runners/values/prod.yaml @@ -0,0 +1,179 @@ +# Production environment configuration +# Extends default.yaml with production-specific settings + +# General configuration +environment: "prod" + +# Workspace configuration +workspace: + path: "/opt/madengine/workspace" + owner: "madengine" + group: "madengine" + +# Execution configuration +execution: + timeout: 10800 # 3 hours for production + keep_alive: false # Don't keep containers alive in prod + live_output: false # Reduce output in prod + output_file: "prod_perf.csv" + results_file: "prod_execution_results.json" + generate_sys_env_details: true + async_timeout: 21600 # 6 hours + poll_interval: 60 # Less frequent polling + +# Data configuration +data_config: + file: "prod_data.json" + force_mirror_local: false + required: true + +# Credentials configuration +credentials: + file: "prod_credential.json" + required: true + +# Docker registry configuration +docker_registry: + login_required: true + username: "prod-service-account" + password: "" # Should be set via secret + +# Python configuration +python_dependencies: + - jinja2 + - pyyaml + - requests + +# Installation configuration +install_dependencies: false # Pre-installed in prod images + +# Post-execution configuration +post_execution: + cleanup: true # Clean up in prod + collect_logs: true + +# Logging configuration +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +logs: + local_path: "/var/log/madengine" + +# Ansible configuration +ansible: + target_hosts: "prod_gpu_nodes" + become: true + +# Kubernetes configuration +k8s: + namespace: "madengine-prod" + + # ConfigMap configuration + configmap: + name: "madengine-prod-config" + + # Job configuration + job: + name: "madengine-prod-execution" + parallelism: 2 # Higher parallelism in prod + completions: 2 + backoff_limit: 5 # More retries in prod + active_deadline_seconds: 21600 # 6 hours + restart_policy: "Never" + + # Container configuration + container: + image: "madengine/distributed-runner:stable" + image_pull_policy: "IfNotPresent" + security_context: + run_as_user: 1001 + run_as_group: 1001 + privileged: false + health_checks: + liveness: + initial_delay: 60 + period: 120 + timeout: 30 + failure_threshold: 5 + readiness: + initial_delay: 30 + period: 30 + timeout: 10 + + # Service configuration + service: + name: "madengine-prod-service" + type: "ClusterIP" + ports: + - name: "http" + port: 8080 + target_port: 8080 + protocol: "TCP" + - name: "metrics" + port: 9090 + target_port: 9090 + protocol: "TCP" + + # Volume configuration + volumes: + shared_storage: + type: "pvc" + claim_name: "madengine-prod-results" + data_storage: + type: "pvc" + claim_name: "madengine-prod-data" + + # Node selector + node_selector: + environment: "prod" + accelerator: "gpu" + instance-type: "high-performance" + + # Tolerations for GPU nodes + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - key: "prod-workload" + operator: "Equal" + value: "true" + effect: "NoSchedule" + + # Service account for prod + service_account: "madengine-prod-sa" + + # Image pull secrets + image_pull_secrets: + - "prod-registry-secret" + + # Affinity for better pod distribution + affinity: + pod_anti_affinity: + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 100 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: "app.kubernetes.io/name" + operator: In + values: + - "madengine" + topologyKey: "kubernetes.io/hostname" + +# Resource configuration +resources: + memory_limit: "8Gi" # Higher limits for prod + memory_request: "4Gi" + cpu_limit: "4" + cpu_request: "2" + gpu_limit: "2" + +# GPU vendor specific configuration +nvidia: + visible_devices: "all" + driver_capabilities: "compute,utility" + +amd: + visible_devices: "all" + enable_pre_vega: "1" diff --git a/src/madengine/runners/values/slurm.yaml b/src/madengine/runners/values/slurm.yaml new file mode 100644 index 00000000..c389f21f --- /dev/null +++ b/src/madengine/runners/values/slurm.yaml @@ -0,0 +1,122 @@ +# SLURM Configuration Values for MADEngine +# This file provides default configuration values for SLURM cluster execution + +# SLURM cluster configuration +slurm: + # Login/head node configuration + login_node: + hostname: "slurm-login" + address: "slurm-login.example.com" + port: 22 + username: "madengine" + ssh_key_path: "~/.ssh/id_rsa" + + # Cluster identification + cluster_name: "madengine-cluster" + + # Available partitions + partitions: + - name: "gpu" + max_time: "24:00:00" + max_nodes: 32 + default_gpu_count: 1 + gpu_types: ["MI250X", "A100"] + memory_per_node: "256G" + gpu_vendor: "AMD" + qos: "normal" + account: "madengine_proj" + + - name: "cpu" + max_time: "72:00:00" + max_nodes: 128 + default_gpu_count: 0 + gpu_types: [] + memory_per_node: "128G" + gpu_vendor: "" + + - name: "debug" + max_time: "02:00:00" + max_nodes: 4 + default_gpu_count: 1 + gpu_types: ["MI250X"] + memory_per_node: "64G" + gpu_vendor: "AMD" + qos: "debug" + + # Module system modules to load + modules: + - "rocm/5.7.0" + - "python/3.9" + - "gcc/11.2.0" + - "cmake/3.25.0" + + # Environment variables + environment: + ROCM_PATH: "/opt/rocm" + HCC_AMDGPU_TARGET: "gfx90a" + CUDA_VISIBLE_DEVICES: "0" + OMP_NUM_THREADS: "1" + PYTORCH_ROCM_ARCH: "gfx90a" + + # GPU vendor specific configuration + gpu_mapping: + AMD: + gres_name: "gpu" + constraint: "mi250x" + memory_per_gpu: "64G" + NVIDIA: + gres_name: "gpu" + constraint: "a100" + memory_per_gpu: "80G" + INTEL: + gres_name: "gpu" + constraint: "pvc" + memory_per_gpu: "48G" + + # Job execution settings + execution: + max_concurrent_jobs: 8 + job_array_strategy: true + default_timeout: 3600 + retry_failed_jobs: true + max_retries: 3 + +# Workspace configuration +workspace: + shared_filesystem: "/shared/madengine" + results_dir: "/shared/results" + logs_dir: "logs" + venv_path: "venv" + mad_repo_url: "https://github.com/ROCm/MAD.git" + +# Job script default settings +job_defaults: + partition: "gpu" + nodes: 1 + tasks_per_node: 1 + gpu_count: 1 + time_limit: "24:00:00" + memory: "32G" + exclusive: false + output_dir: "logs" + omp_num_threads: 1 + +# Model-specific overrides (example) +model_overrides: + "llama2:7b": + memory: "64G" + gpu_count: 2 + time_limit: "12:00:00" + partition: "gpu" + + "stable_diffusion:xl": + memory: "32G" + gpu_count: 1 + time_limit: "06:00:00" + partition: "gpu" + +# Generation metadata (filled automatically) +generation: + timestamp: "" + generator: "MADEngine Template Generator" + version: "1.0.0" \ No newline at end of file diff --git a/src/madengine/runners/values/test.yaml b/src/madengine/runners/values/test.yaml new file mode 100644 index 00000000..4a16200f --- /dev/null +++ b/src/madengine/runners/values/test.yaml @@ -0,0 +1,158 @@ +# Test environment configuration +# Extends default.yaml with test-specific settings + +# General configuration +environment: "test" + +# Workspace configuration +workspace: + path: "/tmp/madengine_test" + owner: "test" + group: "test" + +# Execution configuration +execution: + timeout: 1800 # 30 minutes for tests + keep_alive: false + live_output: true + output_file: "test_perf.csv" + results_file: "test_execution_results.json" + generate_sys_env_details: false # Skip for faster tests + async_timeout: 3600 # 1 hour + poll_interval: 5 # Fast polling for tests + +# Data configuration +data_config: + file: "test_data.json" + force_mirror_local: true + required: false + +# Credentials configuration +credentials: + file: "test_credential.json" + required: false + +# Docker registry configuration +docker_registry: + login_required: false + username: "test-user" + password: "" + +# Python configuration +python_dependencies: + - jinja2 + - pyyaml + - requests + - pytest + - pytest-cov + - mock + +# Installation configuration +install_dependencies: true + +# Post-execution configuration +post_execution: + cleanup: true # Clean up after tests + collect_logs: true + +# Logging configuration +logging: + level: "DEBUG" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +logs: + local_path: "./test_logs" + +# Ansible configuration +ansible: + target_hosts: "test_nodes" + become: false + +# Kubernetes configuration +k8s: + namespace: "madengine-test" + + # ConfigMap configuration + configmap: + name: "madengine-test-config" + + # Job configuration + job: + name: "madengine-test-execution" + parallelism: 1 + completions: 1 + backoff_limit: 0 # No retries in test + active_deadline_seconds: 3600 # 1 hour + restart_policy: "Never" + + # Container configuration + container: + image: "madengine/distributed-runner:test" + image_pull_policy: "Always" + security_context: + run_as_user: 1000 + run_as_group: 1000 + privileged: false + health_checks: + liveness: + initial_delay: 5 + period: 10 + timeout: 3 + failure_threshold: 1 + readiness: + initial_delay: 2 + period: 5 + timeout: 2 + + # Service configuration + service: + name: "madengine-test-service" + type: "ClusterIP" + ports: + - name: "http" + port: 8080 + target_port: 8080 + protocol: "TCP" + - name: "test-metrics" + port: 9091 + target_port: 9091 + protocol: "TCP" + + # Volume configuration + volumes: + shared_storage: + type: "hostPath" + path: "/tmp/madengine-test-results" + hostPath_type: "DirectoryOrCreate" + + # Node selector + node_selector: + environment: "test" + accelerator: "gpu" + + # Tolerations for GPU nodes + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - key: "test-environment" + operator: "Equal" + value: "true" + effect: "NoSchedule" + +# Resource configuration +resources: + memory_limit: "1Gi" # Minimal resources for tests + memory_request: "512Mi" + cpu_limit: "0.5" + cpu_request: "0.25" + gpu_limit: "1" + +# GPU vendor specific configuration +nvidia: + visible_devices: "0" # Only use first GPU for tests + driver_capabilities: "compute,utility" + +amd: + visible_devices: "0" + enable_pre_vega: "1" diff --git a/src/madengine/scripts/common/post_scripts/gpu_info_post.sh b/src/madengine/scripts/common/post_scripts/gpu_info_post.sh index 5582b986..c1a6e457 100644 --- a/src/madengine/scripts/common/post_scripts/gpu_info_post.sh +++ b/src/madengine/scripts/common/post_scripts/gpu_info_post.sh @@ -9,14 +9,21 @@ set -x tool=$1 +# Output filename is tool_output.csv (e.g., gpu_info_power_profiler_output.csv) OUTPUT=${tool}_output.csv SAVESPACE=/myworkspace/ cd $SAVESPACE -if [ -d "$OUTPUT" ]; then - mkdir "$OUTPUT" + +# Check if prof.csv exists (generated by the profiler) +if [ ! -f "prof.csv" ]; then + echo "Error: prof.csv not found in $SAVESPACE" + exit 1 fi +# Move the profiler output to the final location mv prof.csv "$OUTPUT" -chmod -R a+rw "${SAVESPACE}/${OUTPUT}" +chmod a+rw "${SAVESPACE}/${OUTPUT}" + +echo "Profiler output saved to: ${SAVESPACE}/${OUTPUT}" diff --git a/src/madengine/scripts/common/pre_scripts/rocEnvTool/csv_parser.py b/src/madengine/scripts/common/pre_scripts/rocEnvTool/csv_parser.py index 66fb84ac..db504803 100644 --- a/src/madengine/scripts/common/pre_scripts/rocEnvTool/csv_parser.py +++ b/src/madengine/scripts/common/pre_scripts/rocEnvTool/csv_parser.py @@ -284,11 +284,23 @@ def dump_csv_output(self): fs.write(sys_config_info[j]) fs.write("\n") fs.close() - print ("OK: Dumped into {} file.".format(self.filename)) + print("\n" + "="*60) + print(f"✅ SUCCESS: System config data dumped to {self.filename}") + print("="*60 + "\n") def print_csv_output(self): - print ("Printing the sys config info env variables...") + print("\n" + "="*80) + print("📋 SYSTEM CONFIG INFO - ENVIRONMENT VARIABLES") + print("="*80) if self.sys_config_info_list: for j in range(len(self.sys_config_info_list)): line = self.sys_config_info_list[j] - print (line) + # Add some formatting for key-value pairs + if "|" in line and not line.startswith("Tag"): + key, value = line.split("|", 1) + print(f"🔹 {key:<30}: {value}") + else: + print(f"📌 {line}") + else: + print("❌ No system config information available") + print("="*80 + "\n") diff --git a/src/madengine/tools/container_runner.py b/src/madengine/tools/container_runner.py new file mode 100644 index 00000000..72fa2d93 --- /dev/null +++ b/src/madengine/tools/container_runner.py @@ -0,0 +1,1046 @@ +#!/usr/bin/env python3 +""" +Docker Container Runner Module for MADEngine + +This module handles the Docker container execution phase separately from building, +enabling distributed workflows where containers are run on remote nodes +using pre-built images. +""" + +import os +import time +import json +import typing +import warnings +import re +from rich.console import Console as RichConsole +from contextlib import redirect_stdout, redirect_stderr +from madengine.core.console import Console +from madengine.core.context import Context +from madengine.core.docker import Docker +from madengine.core.timeout import Timeout +from madengine.core.dataprovider import Data +from madengine.utils.ops import PythonicTee, file_print +from madengine.tools.update_perf_csv import update_perf_csv, flatten_tags + + +class ContainerRunner: + """Class responsible for running Docker containers with models.""" + + def __init__( + self, + context: Context = None, + data: Data = None, + console: Console = None, + live_output: bool = False, + ): + """Initialize the Container Runner. + + Args: + context: The MADEngine context + data: The data provider instance + console: Optional console instance + live_output: Whether to show live output + """ + self.context = context + self.data = data + self.console = console or Console(live_output=live_output) + self.live_output = live_output + self.rich_console = RichConsole() + self.credentials = None + self.perf_csv_path = "perf.csv" # Default output path + + # Ensure runtime context is initialized for container operations + if self.context: + self.context.ensure_runtime_context() + + def set_perf_csv_path(self, path: str): + """Set the path for the performance CSV output file. + + Args: + path: Path to the performance CSV file + """ + self.perf_csv_path = path + + def ensure_perf_csv_exists(self): + """Ensure the performance CSV file exists with proper headers.""" + if not os.path.exists(self.perf_csv_path): + file_print( + "model,n_gpus,training_precision,pipeline,args,tags,docker_file,base_docker,docker_sha,docker_image,git_commit,machine_name,gpu_architecture,performance,metric,relative_change,status,build_duration,test_duration,dataname,data_provider_type,data_size,data_download_duration,build_number,additional_docker_run_options", + filename=self.perf_csv_path, + mode="w", + ) + print(f"Created performance CSV file: {self.perf_csv_path}") + + def create_run_details_dict( + self, model_info: typing.Dict, build_info: typing.Dict, run_results: typing.Dict + ) -> typing.Dict: + """Create a run details dictionary similar to RunDetails class in run_models.py. + + Args: + model_info: Model information dictionary + build_info: Build information from manifest + run_results: Container execution results + + Returns: + dict: Run details dictionary for CSV generation + """ + import os + + # Create run details dict with all required fields + run_details = { + "model": model_info["name"], + "n_gpus": model_info.get("n_gpus", ""), + "training_precision": model_info.get("training_precision", ""), + "pipeline": os.environ.get("pipeline", ""), + "args": model_info.get("args", ""), + "tags": model_info.get("tags", ""), + "docker_file": build_info.get("dockerfile", ""), + "base_docker": build_info.get("base_docker", ""), + "docker_sha": build_info.get("docker_sha", ""), + "docker_image": build_info.get("docker_image", ""), + "git_commit": run_results.get("git_commit", ""), + "machine_name": run_results.get("machine_name", ""), + "gpu_architecture": ( + self.context.ctx["docker_env_vars"]["MAD_SYSTEM_GPU_ARCHITECTURE"] + if self.context + else "" + ), + "performance": run_results.get("performance", ""), + "metric": run_results.get("metric", ""), + "relative_change": "", + "status": run_results.get("status", "FAILURE"), + "build_duration": build_info.get("build_duration", ""), + "test_duration": run_results.get("test_duration", ""), + "dataname": run_results.get("dataname", ""), + "data_provider_type": run_results.get("data_provider_type", ""), + "data_size": run_results.get("data_size", ""), + "data_download_duration": run_results.get("data_download_duration", ""), + "build_number": os.environ.get("BUILD_NUMBER", "0"), + "additional_docker_run_options": model_info.get( + "additional_docker_run_options", "" + ), + } + + # Flatten tags if they are in list format + flatten_tags(run_details) + + return run_details + + def load_build_manifest( + self, manifest_file: str = "build_manifest.json" + ) -> typing.Dict: + """Load build manifest from file. + + Args: + manifest_file: Path to build manifest file + + Returns: + dict: Build manifest data + """ + with open(manifest_file, "r") as f: + manifest = json.load(f) + + print(f"Loaded build manifest from: {manifest_file}") + return manifest + + def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> None: + """Login to a Docker registry for pulling images. + + Args: + registry: Registry URL (e.g., "localhost:5000", "docker.io") + credentials: Optional credentials dictionary containing username/password + """ + if not credentials: + self.rich_console.print("[yellow]No credentials provided for registry login[/yellow]") + return + + # Check if registry credentials are available + registry_key = registry if registry else "dockerhub" + + # Handle docker.io as dockerhub + if registry and registry.lower() == "docker.io": + registry_key = "dockerhub" + + if registry_key not in credentials: + error_msg = f"No credentials found for registry: {registry_key}" + if registry_key == "dockerhub": + error_msg += f"\nPlease add dockerhub credentials to credential.json:\n" + error_msg += "{\n" + error_msg += ' "dockerhub": {\n' + error_msg += ' "repository": "your-repository",\n' + error_msg += ' "username": "your-dockerhub-username",\n' + error_msg += ' "password": "your-dockerhub-password-or-token"\n' + error_msg += " }\n" + error_msg += "}" + else: + error_msg += ( + f"\nPlease add {registry_key} credentials to credential.json:\n" + ) + error_msg += "{\n" + error_msg += f' "{registry_key}": {{\n' + error_msg += f' "repository": "your-repository",\n' + error_msg += f' "username": "your-{registry_key}-username",\n' + error_msg += f' "password": "your-{registry_key}-password"\n' + error_msg += " }\n" + error_msg += "}" + print(error_msg) + raise RuntimeError(error_msg) + + creds = credentials[registry_key] + + if "username" not in creds or "password" not in creds: + error_msg = f"Invalid credentials format for registry: {registry_key}" + error_msg += f"\nCredentials must contain 'username' and 'password' fields" + print(error_msg) + raise RuntimeError(error_msg) + + # Ensure credential values are strings + username = str(creds["username"]) + password = str(creds["password"]) + + # Perform docker login + login_command = f"echo '{password}' | docker login" + + if registry and registry.lower() not in ["docker.io", "dockerhub"]: + login_command += f" {registry}" + + login_command += f" --username {username} --password-stdin" + + try: + self.console.sh(login_command, secret=True) + self.rich_console.print(f"[green]✅ Successfully logged in to registry: {registry or 'DockerHub'}[/green]") + except Exception as e: + self.rich_console.print(f"[red]❌ Failed to login to registry {registry}: {e}[/red]") + # Don't raise exception here, as public images might still be pullable + + def pull_image( + self, + registry_image: str, + local_name: str = None, + registry: str = None, + credentials: typing.Dict = None, + ) -> str: + """Pull an image from registry. + + Args: + registry_image: Full registry image name + local_name: Optional local name to tag the image + registry: Optional registry URL for authentication + credentials: Optional credentials dictionary for authentication + + Returns: + str: Local image name + """ + # Login to registry if credentials are provided + if registry and credentials: + self.login_to_registry(registry, credentials) + + self.rich_console.print(f"\n[bold blue]📥 Starting docker pull from registry...[/bold blue]") + print(f"📍 Registry: {registry or 'Default'}") + print(f"🏷️ Image: {registry_image}") + try: + self.console.sh(f"docker pull {registry_image}") + + if local_name: + self.console.sh(f"docker tag {registry_image} {local_name}") + print(f"🏷️ Tagged as: {local_name}") + self.rich_console.print(f"[bold green]✅ Successfully pulled and tagged image[/bold green]") + self.rich_console.print(f"[dim]{'='*80}[/dim]") + return local_name + + self.rich_console.print(f"[bold green]✅ Successfully pulled image:[/bold green] [cyan]{registry_image}[/cyan]") + self.rich_console.print(f"[dim]{'='*80}[/dim]") + return registry_image + + except Exception as e: + self.rich_console.print(f"[red]❌ Failed to pull image {registry_image}: {e}[/red]") + raise + + def get_gpu_arg(self, requested_gpus: str) -> str: + """Get the GPU arguments for docker run. + + Args: + requested_gpus: The requested GPUs. + + Returns: + str: The GPU arguments. + """ + gpu_arg = "" + gpu_vendor = self.context.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] + n_system_gpus = self.context.ctx["docker_env_vars"]["MAD_SYSTEM_NGPUS"] + gpu_strings = self.context.ctx["docker_gpus"].split(",") + + # Parse GPU string, example: '{0-4}' -> [0,1,2,3,4] + docker_gpus = [] + for gpu_string in gpu_strings: + if "-" in gpu_string: + gpu_range = gpu_string.split("-") + docker_gpus += [ + item for item in range(int(gpu_range[0]), int(gpu_range[1]) + 1) + ] + else: + docker_gpus.append(int(gpu_string)) + docker_gpus.sort() + + # Check GPU range is valid for system + if requested_gpus == "-1": + print("NGPUS requested is ALL (" + ",".join(map(str, docker_gpus)) + ").") + requested_gpus = len(docker_gpus) + + print( + "NGPUS requested is " + + str(requested_gpus) + + " out of " + + str(n_system_gpus) + ) + + if int(requested_gpus) > int(n_system_gpus) or int(requested_gpus) > len( + docker_gpus + ): + raise RuntimeError( + f"Too many gpus requested({requested_gpus}). System has {n_system_gpus} gpus. Context has {len(docker_gpus)} gpus." + ) + + # Expose number of requested gpus + self.context.ctx["docker_env_vars"]["MAD_RUNTIME_NGPUS"] = str(requested_gpus) + + # Create docker arg to assign requested GPUs + if gpu_vendor.find("AMD") != -1: + gpu_arg = "--device=/dev/kfd " + gpu_renderDs = self.context.ctx["gpu_renderDs"] + if gpu_renderDs is not None: + for idx in range(0, int(requested_gpus)): + gpu_arg += ( + f"--device=/dev/dri/renderD{gpu_renderDs[docker_gpus[idx]]} " + ) + + elif gpu_vendor.find("NVIDIA") != -1: + gpu_str = "" + for idx in range(0, int(requested_gpus)): + gpu_str += str(docker_gpus[idx]) + "," + gpu_arg += f"--gpus '\"device={gpu_str}\"' " + else: + raise RuntimeError("Unable to determine gpu vendor.") + + print(f"GPU arguments: {gpu_arg}") + return gpu_arg + + def get_cpu_arg(self) -> str: + """Get the CPU arguments for docker run.""" + if "docker_cpus" not in self.context.ctx: + return "" + cpus = self.context.ctx["docker_cpus"].replace(" ", "") + return f"--cpuset-cpus {cpus} " + + def get_env_arg(self, run_env: typing.Dict) -> str: + """Get the environment arguments for docker run.""" + env_args = "" + + # Add custom environment variables + if run_env: + for env_arg in run_env: + env_args += f"--env {env_arg}='{str(run_env[env_arg])}' " + + # Add context environment variables + if "docker_env_vars" in self.context.ctx: + for env_arg in self.context.ctx["docker_env_vars"].keys(): + # Skip individual MAD_MULTI_NODE_* env vars (except MAD_MULTI_NODE_RUNNER) + # These are redundant since MAD_MULTI_NODE_RUNNER contains all necessary information + if ( + env_arg.startswith("MAD_MULTI_NODE_") + and env_arg != "MAD_MULTI_NODE_RUNNER" + ): + continue + env_args += f"--env {env_arg}='{str(self.context.ctx['docker_env_vars'][env_arg])}' " + + print(f"Env arguments: {env_args}") + return env_args + + def get_mount_arg(self, mount_datapaths: typing.List) -> str: + """Get the mount arguments for docker run.""" + mount_args = "" + + # Mount data paths + if mount_datapaths: + for mount_datapath in mount_datapaths: + if mount_datapath: + mount_args += ( + f"-v {mount_datapath['path']}:{mount_datapath['home']}" + ) + if ( + "readwrite" in mount_datapath + and mount_datapath["readwrite"] == "true" + ): + mount_args += " " + else: + mount_args += ":ro " + + # Mount context paths + if "docker_mounts" in self.context.ctx: + for mount_arg in self.context.ctx["docker_mounts"].keys(): + mount_args += ( + f"-v {self.context.ctx['docker_mounts'][mount_arg]}:{mount_arg} " + ) + + return mount_args + + def apply_tools( + self, + pre_encapsulate_post_scripts: typing.Dict, + run_env: typing.Dict, + tools_json_file: str, + ) -> None: + """Apply tools configuration to the runtime environment.""" + if "tools" not in self.context.ctx: + return + + # Read tool settings from tools.json + with open(tools_json_file) as f: + tool_file = json.load(f) + + # Iterate over tools in context, apply tool settings + for ctx_tool_config in self.context.ctx["tools"]: + tool_name = ctx_tool_config["name"] + tool_config = tool_file["tools"][tool_name] + + if "cmd" in ctx_tool_config: + tool_config.update({"cmd": ctx_tool_config["cmd"]}) + + if "env_vars" in ctx_tool_config: + for env_var in ctx_tool_config["env_vars"]: + tool_config["env_vars"].update( + {env_var: ctx_tool_config["env_vars"][env_var]} + ) + + print(f"Selected Tool, {tool_name}. Configuration : {str(tool_config)}.") + + # Setup tool before other existing scripts + if "pre_scripts" in tool_config: + pre_encapsulate_post_scripts["pre_scripts"] = ( + tool_config["pre_scripts"] + + pre_encapsulate_post_scripts["pre_scripts"] + ) + # Cleanup tool after other existing scripts + if "post_scripts" in tool_config: + pre_encapsulate_post_scripts["post_scripts"] += tool_config[ + "post_scripts" + ] + # Update environment variables + if "env_vars" in tool_config: + run_env.update(tool_config["env_vars"]) + if "cmd" in tool_config: + # Prepend encapsulate cmd + pre_encapsulate_post_scripts["encapsulate_script"] = ( + tool_config["cmd"] + + " " + + pre_encapsulate_post_scripts["encapsulate_script"] + ) + + def run_pre_post_script( + self, model_docker: Docker, model_dir: str, pre_post: typing.List + ) -> None: + """Run pre/post scripts in the container.""" + for script in pre_post: + script_path = script["path"].strip() + model_docker.sh( + f"cp -vLR --preserve=all {script_path} {model_dir}", timeout=600 + ) + script_name = os.path.basename(script_path) + script_args = "" + if "args" in script: + script_args = script["args"].strip() + model_docker.sh( + f"cd {model_dir} && bash {script_name} {script_args}", timeout=600 + ) + + def gather_system_env_details( + self, pre_encapsulate_post_scripts: typing.Dict, model_name: str + ) -> None: + """Gather system environment details. + + Args: + pre_encapsulate_post_scripts: The pre, encapsulate and post scripts. + model_name: The model name. + + Returns: + None + + Raises: + Exception: An error occurred while gathering system environment details. + + Note: + This function is used to gather system environment details. + """ + # initialize pre_env_details + pre_env_details = {} + pre_env_details["path"] = "scripts/common/pre_scripts/run_rocenv_tool.sh" + pre_env_details["args"] = model_name.replace("/", "_") + "_env" + pre_encapsulate_post_scripts["pre_scripts"].append(pre_env_details) + print(f"pre encap post scripts: {pre_encapsulate_post_scripts}") + + def run_container( + self, + model_info: typing.Dict, + docker_image: str, + build_info: typing.Dict = None, + keep_alive: bool = False, + timeout: int = 7200, + tools_json_file: str = "scripts/common/tools.json", + phase_suffix: str = "", + generate_sys_env_details: bool = True, + ) -> typing.Dict: + """Run a model in a Docker container. + + Args: + model_info: Model information dictionary + docker_image: Docker image name to run + build_info: Optional build information from manifest + keep_alive: Whether to keep container alive after execution + timeout: Execution timeout in seconds + tools_json_file: Path to tools configuration file + phase_suffix: Suffix for log file name (e.g., ".run" or "") + generate_sys_env_details: Whether to collect system environment details + + Returns: + dict: Execution results including performance metrics + """ + self.rich_console.print(f"[bold green]🏃 Running model:[/bold green] [bold cyan]{model_info['name']}[/bold cyan] [dim]in container[/dim] [yellow]{docker_image}[/yellow]") + + # Create log file for this run + # Extract dockerfile part from docker image name (remove "ci-" prefix and model name prefix) + image_name_without_ci = docker_image.replace("ci-", "") + model_name_clean = model_info["name"].replace("/", "_").lower() + + # Remove model name from the beginning to get the dockerfile part + if image_name_without_ci.startswith(model_name_clean + "_"): + dockerfile_part = image_name_without_ci[len(model_name_clean + "_") :] + else: + dockerfile_part = image_name_without_ci + + log_file_path = ( + model_info["name"].replace("/", "_") + + "_" + + dockerfile_part + + phase_suffix + + ".live.log" + ) + # Replace / with _ in log file path (already done above, but keeping for safety) + log_file_path = log_file_path.replace("/", "_") + + print(f"Run log will be written to: {log_file_path}") + + # get machine name + machine_name = self.console.sh("hostname") + print(f"MACHINE NAME is {machine_name}") + + # Initialize results + run_results = { + "model": model_info["name"], + "docker_image": docker_image, + "status": "FAILURE", + "performance": "", + "metric": "", + "test_duration": 0, + "machine_name": machine_name, + "log_file": log_file_path, + } + + # If build info provided, merge it + if build_info: + run_results.update(build_info) + + # Prepare docker run options + gpu_vendor = self.context.ctx["gpu_vendor"] + docker_options = "" + + if gpu_vendor.find("AMD") != -1: + docker_options = ( + "--network host -u root --group-add video " + "--cap-add=SYS_PTRACE --cap-add SYS_ADMIN --device /dev/fuse " + "--security-opt seccomp=unconfined --security-opt apparmor=unconfined --ipc=host " + ) + elif gpu_vendor.find("NVIDIA") != -1: + docker_options = ( + "--cap-add=SYS_PTRACE --cap-add SYS_ADMIN --cap-add SYS_NICE --device /dev/fuse " + "--security-opt seccomp=unconfined --security-opt apparmor=unconfined " + "--network host -u root --ipc=host " + ) + else: + raise RuntimeError("Unable to determine gpu vendor.") + + # Initialize scripts + pre_encapsulate_post_scripts = { + "pre_scripts": [], + "encapsulate_script": "", + "post_scripts": [], + } + + if "pre_scripts" in self.context.ctx: + pre_encapsulate_post_scripts["pre_scripts"] = self.context.ctx[ + "pre_scripts" + ] + if "post_scripts" in self.context.ctx: + pre_encapsulate_post_scripts["post_scripts"] = self.context.ctx[ + "post_scripts" + ] + if "encapsulate_script" in self.context.ctx: + pre_encapsulate_post_scripts["encapsulate_script"] = self.context.ctx[ + "encapsulate_script" + ] + + # Add environment variables + docker_options += f"--env MAD_MODEL_NAME='{model_info['name']}' " + docker_options += ( + f"--env JENKINS_BUILD_NUMBER='{os.environ.get('BUILD_NUMBER','0')}' " + ) + + # Gather data and environment + run_env = {} + mount_datapaths = None + + if "data" in model_info and model_info["data"] != "" and self.data: + mount_datapaths = self.data.get_mountpaths(model_info["data"]) + model_dataenv = self.data.get_env(model_info["data"]) + if model_dataenv is not None: + run_env.update(model_dataenv) + run_env["MAD_DATANAME"] = model_info["data"] + + # Add credentials to environment + if "cred" in model_info and model_info["cred"] != "" and self.credentials: + if model_info["cred"] not in self.credentials: + raise RuntimeError(f"Credentials({model_info['cred']}) not found") + for key_cred, value_cred in self.credentials[model_info["cred"]].items(): + run_env[model_info["cred"] + "_" + key_cred.upper()] = value_cred + + # Apply tools if configured + if os.path.exists(tools_json_file): + self.apply_tools(pre_encapsulate_post_scripts, run_env, tools_json_file) + + # Add system environment collection script to pre_scripts (equivalent to generate_sys_env_details) + # This ensures distributed runs have the same system environment logging as standard runs + if generate_sys_env_details or self.context.ctx.get("gen_sys_env_details"): + self.gather_system_env_details( + pre_encapsulate_post_scripts, model_info["name"] + ) + + # Build docker options + docker_options += self.get_gpu_arg(model_info["n_gpus"]) + docker_options += self.get_cpu_arg() + docker_options += self.get_env_arg(run_env) + docker_options += self.get_mount_arg(mount_datapaths) + docker_options += f" {model_info.get('additional_docker_run_options', '')}" + + # Generate container name + container_name = "container_" + re.sub( + ".*:", "", docker_image.replace("/", "_").replace(":", "_") + ) + + print(f"Docker options: {docker_options}") + + # set timeout + print(f"⏰ Setting timeout to {str(timeout)} seconds.") + + self.rich_console.print(f"\n[bold blue]🏃 Starting Docker container execution...[/bold blue]") + print(f"🏷️ Image: {docker_image}") + print(f"📦 Container: {container_name}") + print(f"📝 Log file: {log_file_path}") + print(f"🎮 GPU Vendor: {gpu_vendor}") + self.rich_console.print(f"[dim]{'='*80}[/dim]") + + # Run the container with logging + try: + with open(log_file_path, mode="w", buffering=1) as outlog: + with redirect_stdout( + PythonicTee(outlog, self.live_output) + ), redirect_stderr(PythonicTee(outlog, self.live_output)): + with Timeout(timeout): + model_docker = Docker( + docker_image, + container_name, + docker_options, + keep_alive=keep_alive, + console=self.console, + ) + + # Check user + whoami = model_docker.sh("whoami") + print(f"👤 Running as user: {whoami}") + + # Show GPU info + if gpu_vendor.find("AMD") != -1: + print(f"🎮 Checking AMD GPU status...") + model_docker.sh("/opt/rocm/bin/rocm-smi || true") + elif gpu_vendor.find("NVIDIA") != -1: + print(f"🎮 Checking NVIDIA GPU status...") + model_docker.sh("/usr/bin/nvidia-smi || true") + + # Prepare model directory + model_dir = "run_directory" + if "url" in model_info and model_info["url"] != "": + model_dir = model_info["url"].rstrip("/").split("/")[-1] + + # Validate model_dir + special_char = r"[^a-zA-Z0-9\-\_]" + if re.search(special_char, model_dir) is not None: + warnings.warn( + "Model url contains special character. Fix url." + ) + + model_docker.sh(f"rm -rf {model_dir}", timeout=240) + model_docker.sh( + "git config --global --add safe.directory /myworkspace" + ) + + # Clone model repo if needed + if "url" in model_info and model_info["url"] != "": + if ( + "cred" in model_info + and model_info["cred"] != "" + and self.credentials + ): + print(f"Using credentials for {model_info['cred']}") + + if model_info["url"].startswith("ssh://"): + model_docker.sh( + f"git -c core.sshCommand='ssh -l {self.credentials[model_info['cred']]['username']} " + f"-i {self.credentials[model_info['cred']]['ssh_key_file']} -o IdentitiesOnly=yes " + f"-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' " + f"clone {model_info['url']}", + timeout=240, + ) + else: # http or https + model_docker.sh( + f"git clone -c credential.helper='!f() {{ echo username={self.credentials[model_info['cred']]['username']}; " + f"echo password={self.credentials[model_info['cred']]['password']}; }};f' " + f"{model_info['url']}", + timeout=240, + secret=f"git clone {model_info['url']}", + ) + else: + model_docker.sh( + f"git clone {model_info['url']}", timeout=240 + ) + + model_docker.sh( + f"git config --global --add safe.directory /myworkspace/{model_dir}" + ) + run_results["git_commit"] = model_docker.sh( + f"cd {model_dir} && git rev-parse HEAD" + ) + print(f"MODEL GIT COMMIT is {run_results['git_commit']}") + model_docker.sh( + f"cd {model_dir}; git submodule update --init --recursive" + ) + else: + model_docker.sh(f"mkdir -p {model_dir}") + + # Run pre-scripts + if pre_encapsulate_post_scripts["pre_scripts"]: + self.run_pre_post_script( + model_docker, + model_dir, + pre_encapsulate_post_scripts["pre_scripts"], + ) + + # Prepare script execution + scripts_arg = model_info["scripts"] + if scripts_arg.endswith(".sh"): + dir_path = os.path.dirname(scripts_arg) + script_name = "bash " + os.path.basename(scripts_arg) + else: + dir_path = model_info["scripts"] + script_name = "bash run.sh" + + # Add script prepend command + script_name = ( + pre_encapsulate_post_scripts["encapsulate_script"] + + " " + + script_name + ) + + # print repo hash + commit = model_docker.sh( + f"cd {dir_path}; git rev-parse HEAD || true" + ) + print("======================================================") + print("MODEL REPO COMMIT: ", commit) + print("======================================================") + + # Copy scripts to model directory + model_docker.sh( + f"cp -vLR --preserve=all {dir_path}/. {model_dir}/" + ) + + # Prepare data if needed + if ( + "data" in model_info + and model_info["data"] != "" + and self.data + ): + self.data.prepare_data(model_info["data"], model_docker) + + # Set permissions + model_docker.sh(f"chmod -R a+rw {model_dir}") + + # Run the model + test_start_time = time.time() + self.rich_console.print("[bold blue]Running model...[/bold blue]") + + model_args = self.context.ctx.get( + "model_args", model_info["args"] + ) + model_docker.sh( + f"cd {model_dir} && {script_name} {model_args}", + timeout=None, + ) + + run_results["test_duration"] = time.time() - test_start_time + print(f"Test Duration: {run_results['test_duration']} seconds") + + # Run post-scripts + if pre_encapsulate_post_scripts["post_scripts"]: + self.run_pre_post_script( + model_docker, + model_dir, + pre_encapsulate_post_scripts["post_scripts"], + ) + + # Extract performance metrics from logs + # Look for performance data in the log output similar to original run_models.py + try: + # Check if multiple results file is specified in model_info + multiple_results = model_info.get("multiple_results", None) + + if multiple_results: + run_results["performance"] = multiple_results + # Validate multiple results file format + try: + with open(multiple_results, "r") as f: + header = f.readline().strip().split(",") + for line in f: + row = line.strip().split(",") + for col in row: + if col == "": + run_results["performance"] = None + print( + "Error: Performance metric is empty in multiple results file." + ) + break + except Exception as e: + self.rich_console.print( + f"[yellow]Warning: Could not validate multiple results file: {e}[/yellow]" + ) + run_results["performance"] = None + else: + # Match the actual output format: "performance: 14164 samples_per_second" + # Simple pattern to capture number and metric unit + + # Extract from log file + try: + # Extract performance number: capture digits (with optional decimal/scientific notation) + perf_cmd = ( + "cat " + + log_file_path + + " | grep 'performance:' | sed -n 's/.*performance:[[:space:]]*\\([0-9][0-9.eE+-]*\\)[[:space:]].*/\\1/p'" + ) + run_results["performance"] = self.console.sh( + perf_cmd + ) + + # Extract metric unit: capture the word after the number + metric_cmd = ( + "cat " + + log_file_path + + " | grep 'performance:' | sed -n 's/.*performance:[[:space:]]*[0-9][0-9.eE+-]*[[:space:]]*\\([a-zA-Z_][a-zA-Z0-9_]*\\).*/\\1/p'" + ) + run_results["metric"] = self.console.sh(metric_cmd) + except Exception: + pass # Performance extraction is optional + except Exception as e: + print( + f"Warning: Could not extract performance metrics: {e}" + ) + + # Set status based on performance and error patterns + # First check for obvious failure patterns in the logs + try: + # Check for common failure patterns in the log file + error_patterns = [ + "OutOfMemoryError", + "HIP out of memory", + "CUDA out of memory", + "RuntimeError", + "AssertionError", + "ValueError", + "SystemExit", + "failed (exitcode:", + "Error:", + "FAILED", + "Exception:", + ] + + has_errors = False + if log_file_path and os.path.exists(log_file_path): + try: + # Check for error patterns in the log (exclude our own grep commands and output messages) + for pattern in error_patterns: + # Use grep with -v to exclude our own commands and output to avoid false positives + error_check_cmd = f"grep -v -E '(grep -q.*{pattern}|Found error pattern.*{pattern})' {log_file_path} | grep -q '{pattern}' && echo 'FOUND' || echo 'NOT_FOUND'" + result = self.console.sh( + error_check_cmd, canFail=True + ) + if result.strip() == "FOUND": + has_errors = True + print( + f"Found error pattern '{pattern}' in logs" + ) + break + except Exception: + pass # Error checking is optional + + # Status logic: Must have performance AND no errors to be considered success + performance_value = run_results.get("performance") + has_performance = ( + performance_value + and performance_value.strip() + and performance_value.strip() != "N/A" + ) + + if has_errors: + run_results["status"] = "FAILURE" + self.rich_console.print( + f"[red]Status: FAILURE (error patterns detected in logs)[/red]" + ) + elif has_performance: + run_results["status"] = "SUCCESS" + self.rich_console.print( + f"[green]Status: SUCCESS (performance metrics found, no errors)[/green]" + ) + else: + run_results["status"] = "FAILURE" + self.rich_console.print(f"[red]Status: FAILURE (no performance metrics)[/red]") + + except Exception as e: + self.rich_console.print(f"[yellow]Warning: Error in status determination: {e}[/yellow]") + # Fallback to simple performance check + run_results["status"] = ( + "SUCCESS" + if run_results.get("performance") + else "FAILURE" + ) + + print( + f"{model_info['name']} performance is {run_results.get('performance', 'N/A')} {run_results.get('metric', '')}" + ) + + # Generate performance results and update perf.csv + self.ensure_perf_csv_exists() + try: + # Create run details dictionary for CSV generation + run_details_dict = self.create_run_details_dict( + model_info, build_info, run_results + ) + + # Handle multiple results if specified + multiple_results = model_info.get("multiple_results", None) + if ( + multiple_results + and run_results.get("status") == "SUCCESS" + ): + # Generate common info JSON for multiple results + common_info = run_details_dict.copy() + # Remove model-specific fields for common info + for key in ["model", "performance", "metric", "status"]: + common_info.pop(key, None) + + with open("common_info.json", "w") as f: + json.dump(common_info, f) + + # Update perf.csv with multiple results + update_perf_csv( + multiple_results=multiple_results, + perf_csv=self.perf_csv_path, + model_name=run_details_dict["model"], + common_info="common_info.json", + ) + print( + f"Updated perf.csv with multiple results for {model_info['name']}" + ) + else: + # Generate single result JSON + with open("perf_entry.json", "w") as f: + json.dump(run_details_dict, f) + + # Update perf.csv with single result + if run_results.get("status") == "SUCCESS": + update_perf_csv( + single_result="perf_entry.json", + perf_csv=self.perf_csv_path, + ) + else: + update_perf_csv( + exception_result="perf_entry.json", + perf_csv=self.perf_csv_path, + ) + print( + f"Updated perf.csv with result for {model_info['name']}" + ) + + except Exception as e: + self.rich_console.print(f"[yellow]Warning: Could not update perf.csv: {e}[/yellow]") + + # Cleanup if not keeping alive + if not keep_alive: + model_docker.sh(f"rm -rf {model_dir}", timeout=240) + else: + model_docker.sh(f"chmod -R a+rw {model_dir}") + print( + f"keep_alive specified; model_dir({model_dir}) is not removed" + ) + + # Explicitly delete model docker to stop the container + del model_docker + + except Exception as e: + self.rich_console.print("[bold red]===== EXCEPTION =====[/bold red]") + self.rich_console.print(f"[red]Exception: {e}[/red]") + import traceback + + traceback.print_exc() + self.rich_console.print("[bold red]=============== =====[/bold red]") + run_results["status"] = "FAILURE" + + # Also update perf.csv for failures + self.ensure_perf_csv_exists() + try: + # Create run details dictionary for failed runs + run_details_dict = self.create_run_details_dict( + model_info, build_info, run_results + ) + + # Generate exception result JSON + with open("perf_entry.json", "w") as f: + json.dump(run_details_dict, f) + + # Update perf.csv with exception result + update_perf_csv( + exception_result="perf_entry.json", + perf_csv=self.perf_csv_path, + ) + print( + f"Updated perf.csv with exception result for {model_info['name']}" + ) + + except Exception as csv_e: + self.rich_console.print(f"[yellow]Warning: Could not update perf.csv with exception: {csv_e}[/yellow]") + + return run_results + + def set_credentials(self, credentials: typing.Dict) -> None: + """Set credentials for model execution. + + Args: + credentials: Credentials dictionary + """ + self.credentials = credentials diff --git a/src/madengine/tools/create_table_db.py b/src/madengine/tools/create_table_db.py index 68aec9e2..bb06c2c9 100644 --- a/src/madengine/tools/create_table_db.py +++ b/src/madengine/tools/create_table_db.py @@ -10,9 +10,11 @@ import argparse import subprocess import typing + # third-party modules import paramiko import socket + # mad-engine modules from madengine.utils.ssh_to_db import SFTPClient, print_ssh_out from madengine.db.logger import setup_logger @@ -26,9 +28,10 @@ class CreateTable: """Class to create tables in the database. - + This class provides the functions to create tables in the database. """ + def __init__(self, args: argparse.Namespace): """Initialize the CreateTable class. @@ -48,10 +51,10 @@ def __init__(self, args: argparse.Namespace): # get the db folder self.db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../db") - LOGGER.info(f"DB path: {self.db_path}") - self.status = False + LOGGER.info(f"DB path: {self.db_path}") + self.status = False - def run(self, table_name: str='dlm_table') -> None: + def run(self, table_name: str = "dlm_table") -> None: """Create an empty table in the database. Args: @@ -65,7 +68,7 @@ def run(self, table_name: str='dlm_table') -> None: """ print(f"Creating table {table_name} in the database") - if 'localhost' in self.ssh_hostname or '127.0.0.1' in self.ssh_hostname: + if "localhost" in self.ssh_hostname or "127.0.0.1" in self.ssh_hostname: try: self.local_db() self.status = True @@ -81,10 +84,10 @@ def run(self, table_name: str='dlm_table') -> None: except Exception as error: LOGGER.error(f"Error creating table in remote database: {error}") return self.status - + def local_db(self) -> None: """Create a table in the local database. - + Returns: None @@ -97,15 +100,17 @@ def local_db(self) -> None: cmd_list = ["cp", "-r", self.db_path, "."] try: - ret = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ret = subprocess.Popen( + cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) out, err = ret.communicate() if ret.returncode == 0: if out: - LOGGER.info(out.decode('utf-8')) + LOGGER.info(out.decode("utf-8")) print("Copied scripts to current work path") else: if err: - LOGGER.error(err.decode('utf-8')) + LOGGER.error(err.decode("utf-8")) except Exception as e: LOGGER.error(f"An error occurred: {e}") @@ -117,16 +122,20 @@ def local_db(self) -> None: print(f"ENV_VARS: {env_vars}") try: - ret = subprocess.Popen(cmd_list, env=env_vars, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ret = subprocess.Popen( + cmd_list, env=env_vars, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) out, err = ret.communicate() if ret.returncode == 0: if out: - LOGGER.info(out.decode('utf-8')) + LOGGER.info(out.decode("utf-8")) else: if err: - LOGGER.error(err.decode('utf-8')) - raise Exception(f"Error updating table in the local database: {err.decode('utf-8')}") + LOGGER.error(err.decode("utf-8")) + raise Exception( + f"Error updating table in the local database: {err.decode('utf-8')}" + ) except Exception as e: LOGGER.error(f"An error occurred: {e}") @@ -134,10 +143,10 @@ def local_db(self) -> None: def remote_db(self) -> None: """Create a table in the remote database. - + Returns: None - + Raises: socket.error: An error occurred connecting to the database. """ @@ -166,7 +175,7 @@ def remote_db(self) -> None: except socket.error as error: print(f"Socket error: {error}") return - + print("SSH client created, connected to the host of database") # print remote dir layout @@ -178,8 +187,10 @@ def remote_db(self) -> None: print(upload_script_path_remote) # clean up previous uploads - print_ssh_out(ssh_client.exec_command("rm -rf {}".format(upload_script_path_remote))) - print_ssh_out(ssh_client.exec_command("ls -l")) + print_ssh_out( + ssh_client.exec_command("rm -rf {}".format(upload_script_path_remote)) + ) + print_ssh_out(ssh_client.exec_command("ls -l")) # upload file sftp_client = SFTPClient.from_transport(ssh_client.get_transport()) diff --git a/src/madengine/tools/csv_to_html.py b/src/madengine/tools/csv_to_html.py index 5a27952a..0af7a6ac 100644 --- a/src/madengine/tools/csv_to_html.py +++ b/src/madengine/tools/csv_to_html.py @@ -15,7 +15,7 @@ def convert_csv_to_html(file_path: str): """Convert the CSV file to an HTML file. - + Args: file_path: The path to the CSV file. """ @@ -30,7 +30,18 @@ def convert_csv_to_html(file_path: str): output_name += file_name + ".html" # read csv df = pd.read_csv(file_path) - print(df) + + # Use beautiful formatting for dataframe display + try: + from madengine.utils.log_formatting import print_dataframe_beautiful + + print_dataframe_beautiful(df, f"Converting CSV: {file_name}") + except ImportError: + # Fallback to basic formatting if utils not available + print(f"\n📊 Converting CSV: {file_name}") + print("=" * 80) + print(df.to_string(max_rows=20, max_cols=10)) + print("=" * 80) # Use the .to_html() to get your table in html df_html = df.to_html(index=False) @@ -67,7 +78,18 @@ def run(self): # read csv df = pd.read_csv(file_path) - print(df) + + # Use beautiful formatting for dataframe display + try: + from madengine.utils.log_formatting import print_dataframe_beautiful + + print_dataframe_beautiful(df, f"CSV Data from {file_name}") + except ImportError: + # Fallback to basic formatting if utils not available + print(f"\n📊 CSV Data from {file_name}") + print("=" * 80) + print(df.to_string(max_rows=20, max_cols=10)) + print("=" * 80) # Use the .to_html() to get your table in html df_html = df.to_html(index=False) diff --git a/src/madengine/tools/discover_models.py b/src/madengine/tools/discover_models.py index d6776740..9d47dbb1 100644 --- a/src/madengine/tools/discover_models.py +++ b/src/madengine/tools/discover_models.py @@ -2,6 +2,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import argparse import os @@ -9,6 +10,8 @@ import importlib.util import typing from dataclasses import dataclass, field, asdict +from rich.console import Console as RichConsole + @dataclass class CustomModel: @@ -46,11 +49,12 @@ class DiscoverModels: def __init__(self, args: argparse.Namespace): """Initialize the DiscoverModels class. - + Args: args (argparse.Namespace): Arguments passed to the script. """ self.args = args + self.rich_console = RichConsole() # list of models from models.json and scripts/model_dir/models.json self.models: typing.List[dict] = [] # list of custom models from scripts/model_dir/get_models_json.py @@ -60,9 +64,55 @@ def __init__(self, args: argparse.Namespace): # list of selected models parsed using --tags argument self.selected_models: typing.List[dict] = [] + # Setup MODEL_DIR if environment variable is set + self._setup_model_dir_if_needed() + + def _setup_model_dir_if_needed(self) -> None: + """Setup model directory if MODEL_DIR environment variable is set. + + This copies the contents of MODEL_DIR to the current working directory + to support the model discovery process. This operation is safe for + build-only (CPU) nodes as it only involves file operations. + """ + model_dir_env = os.environ.get("MODEL_DIR") + if model_dir_env: + import subprocess + + cwd_path = os.getcwd() + self.rich_console.print(f"[bold cyan]📁 MODEL_DIR environment variable detected:[/bold cyan] [yellow]{model_dir_env}[/yellow]") + print(f"Copying contents to current working directory: {cwd_path}") + + try: + # Check if source directory exists + if not os.path.exists(model_dir_env): + self.rich_console.print(f"[yellow]⚠️ Warning: MODEL_DIR path does not exist: {model_dir_env}[/yellow]") + return + + # Use cp command similar to the original implementation + # cp -vLR --preserve=all source/* destination/ + cmd = f"cp -vLR --preserve=all {model_dir_env}/* {cwd_path}" + result = subprocess.run( + cmd, shell=True, capture_output=True, text=True, check=True + ) + self.rich_console.print(f"[green]✅ Successfully copied MODEL_DIR contents[/green]") + # Only show verbose output if there are not too many files + if result.stdout and len(result.stdout.splitlines()) < 20: + print(result.stdout) + elif result.stdout: + print(f"Copied {len(result.stdout.splitlines())} files/directories") + print(f"Model dir: {model_dir_env} → current dir: {cwd_path}") + except subprocess.CalledProcessError as e: + self.rich_console.print(f"[yellow]⚠️ Warning: Failed to copy MODEL_DIR contents: {e}[/yellow]") + if e.stderr: + print(f"Error details: {e.stderr}") + # Continue execution even if copy fails + except Exception as e: + self.rich_console.print(f"[yellow]⚠️ Warning: Unexpected error copying MODEL_DIR: {e}[/yellow]") + # Continue execution even if copy fails + def discover_models(self) -> None: """Discover models in models.json and models.json in model_dir under scripts directory. - + Raises: FileNotFoundError: models.json file not found. """ @@ -77,33 +127,45 @@ def discover_models(self) -> None: self.models = model_dict_list self.model_list = [model_dict["name"] for model_dict in model_dict_list] else: + self.rich_console.print("[red]❌ models.json file not found.[/red]") raise FileNotFoundError("models.json file not found.") - + # walk through the subdirs in model_dir/scripts directory to find the models.json file for dirname in os.listdir(os.path.join(model_dir, "scripts")): root = os.path.join(model_dir, "scripts", dirname) if os.path.isdir(root): files = os.listdir(root) - if 'models.json' in files and 'get_models_json.py' in files: - raise ValueError(f"Both models.json and get_models_json.py found in {root}.") + if "models.json" in files and "get_models_json.py" in files: + self.rich_console.print(f"[red]❌ Both models.json and get_models_json.py found in {root}.[/red]") + raise ValueError( + f"Both models.json and get_models_json.py found in {root}." + ) - if 'models.json' in files: + if "models.json" in files: with open(f"{root}/models.json") as f: model_dict_list: typing.List[dict] = json.load(f) for model_dict in model_dict_list: # Update model name using backslash-separated path - model_dict["name"] = dirname + '/' + model_dict["name"] + model_dict["name"] = dirname + "/" + model_dict["name"] # Update relative path for dockerfile and scripts - model_dict["dockerfile"] = os.path.normpath(os.path.join("scripts", dirname, model_dict["dockerfile"])) - model_dict["scripts"] = os.path.normpath(os.path.join("scripts", dirname, model_dict["scripts"])) + model_dict["dockerfile"] = os.path.normpath( + os.path.join( + "scripts", dirname, model_dict["dockerfile"] + ) + ) + model_dict["scripts"] = os.path.normpath( + os.path.join("scripts", dirname, model_dict["scripts"]) + ) self.models.append(model_dict) self.model_list.append(model_dict["name"]) - if 'get_models_json.py' in files: + if "get_models_json.py" in files: try: # load the module get_models_json.py - spec = importlib.util.spec_from_file_location("get_models_json", f"{root}/get_models_json.py") + spec = importlib.util.spec_from_file_location( + "get_models_json", f"{root}/get_models_json.py" + ) get_models_json = importlib.util.module_from_spec(spec) spec.loader.exec_module(get_models_json) assert hasattr( @@ -116,12 +178,14 @@ def discover_models(self) -> None: custom_model, CustomModel ), "Please use or subclass madengine.tools.discover_models.CustomModel to define your custom model." # Update model name using backslash-separated path - custom_model.name = dirname + '/' + custom_model.name + custom_model.name = dirname + "/" + custom_model.name # Defer updating script and dockerfile paths until update_model is called self.custom_models.append(custom_model) self.model_list.append(custom_model.name) except AssertionError: - print("See madengine/tests/fixtures/dummy/scripts/dummy3/get_models_json.py for an example.") + self.rich_console.print( + "[yellow]💡 See madengine/tests/fixtures/dummy/scripts/dummy3/get_models_json.py for an example.[/yellow]" + ) raise def select_models(self) -> None: @@ -136,11 +200,11 @@ def select_models(self) -> None: # models corresponding to the given tag tag_models = [] # split the tags by ':', strip the tags and remove empty tags. - tag_list = [tag_.strip() for tag_ in tag.split(':') if tag_.strip()] + tag_list = [tag_.strip() for tag_ in tag.split(":") if tag_.strip()] model_name = tag_list[0] - # if the length of tag_list is greater than 1, then the rest + # if the length of tag_list is greater than 1, then the rest # of the tags are extra args to be passed into the model script. if len(tag_list) > 1: extra_args = [tag_ for tag_ in tag_list[1:]] @@ -149,38 +213,54 @@ def select_models(self) -> None: extra_args = " --" + extra_args else: extra_args = "" - + for model in self.models: - if model["name"] == model_name or tag in model["tags"] or tag == "all": + if ( + model["name"] == model_name + or tag in model["tags"] + or tag == "all" + ): model_dict = model.copy() model_dict["args"] = model_dict["args"] + extra_args tag_models.append(model_dict) for custom_model in self.custom_models: - if custom_model.name == model_name or tag in custom_model.tags or tag == "all": + if ( + custom_model.name == model_name + or tag in custom_model.tags + or tag == "all" + ): custom_model.update_model() # Update relative path for dockerfile and scripts dirname = custom_model.name.split("/")[0] - custom_model.dockerfile = os.path.normpath(os.path.join("scripts", dirname, custom_model.dockerfile)) - custom_model.scripts = os.path.normpath(os.path.join("scripts", dirname, custom_model.scripts)) + custom_model.dockerfile = os.path.normpath( + os.path.join("scripts", dirname, custom_model.dockerfile) + ) + custom_model.scripts = os.path.normpath( + os.path.join("scripts", dirname, custom_model.scripts) + ) model_dict = custom_model.to_dict() model_dict["args"] = model_dict["args"] + extra_args tag_models.append(model_dict) if not tag_models: - raise ValueError(f"No models found corresponding to the given tag: {tag}") - + self.rich_console.print(f"[red]❌ No models found corresponding to the given tag: {tag}[/red]") + raise ValueError( + f"No models found corresponding to the given tag: {tag}" + ) + self.selected_models.extend(tag_models) def print_models(self) -> None: if self.selected_models: # print selected models using parsed tags and adding backslash-separated extra args + self.rich_console.print(f"[bold green]📋 Selected Models ({len(self.selected_models)} models):[/bold green]") print(json.dumps(self.selected_models, indent=4)) else: # print list of all model names - print(f"Number of models in total: {len(self.model_list)}") + self.rich_console.print(f"[bold cyan]📊 Available Models ({len(self.model_list)} total):[/bold cyan]") for model_name in self.model_list: - print(f"{model_name}") + print(f" {model_name}") def run(self, live_output: bool = True): @@ -188,7 +268,5 @@ def run(self, live_output: bool = True): self.select_models() if live_output: self.print_models() - - return self.selected_models - + return self.selected_models diff --git a/src/madengine/tools/distributed_orchestrator.py b/src/madengine/tools/distributed_orchestrator.py new file mode 100644 index 00000000..df0d8d61 --- /dev/null +++ b/src/madengine/tools/distributed_orchestrator.py @@ -0,0 +1,952 @@ +#!/usr/bin/env python3 +""" +Distributed Runner Orchestrator for MADEngine + +This module provides orchestration capabilities for distributed execution +scenarios like Ansible or Kubernetes, where Docker image building and +container execution are separated across different nodes. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import os +import json +import typing +from rich.console import Console as RichConsole +from madengine.core.console import Console +from madengine.core.context import Context +from madengine.core.dataprovider import Data +from madengine.core.errors import ( + handle_error, create_error_context, ConfigurationError, + BuildError, DiscoveryError, RuntimeError as MADRuntimeError +) +from madengine.tools.discover_models import DiscoverModels +from madengine.tools.docker_builder import DockerBuilder +from madengine.tools.container_runner import ContainerRunner + + +class DistributedOrchestrator: + """Orchestrator for distributed MADEngine workflows.""" + + def __init__(self, args, build_only_mode: bool = False): + """Initialize the distributed orchestrator. + + Args: + args: Command-line arguments + build_only_mode: Whether running in build-only mode (no GPU detection) + """ + self.args = args + self.console = Console(live_output=getattr(args, "live_output", True)) + self.rich_console = RichConsole() + + # Initialize context with appropriate mode + self.context = Context( + additional_context=getattr(args, "additional_context", None), + additional_context_file=getattr(args, "additional_context_file", None), + build_only_mode=build_only_mode, + ) + + # Initialize data provider if data config exists + data_json_file = getattr(args, "data_config_file_name", "data.json") + if os.path.exists(data_json_file): + self.data = Data( + self.context, + filename=data_json_file, + force_mirrorlocal=getattr(args, "force_mirror_local", False), + ) + else: + self.data = None + + # Load credentials + self.credentials = None + try: + credential_file = "credential.json" + if os.path.exists(credential_file): + with open(credential_file) as f: + self.credentials = json.load(f) + print(f"Credentials: {list(self.credentials.keys())}") + except Exception as e: + context = create_error_context( + operation="load_credentials", + component="DistributedOrchestrator", + file_path=credential_file + ) + handle_error( + ConfigurationError( + f"Could not load credentials: {e}", + context=context, + suggestions=["Check if credential.json exists and has valid JSON format"] + ) + ) + + # Check for Docker Hub environment variables and override credentials + docker_hub_user = None + docker_hub_password = None + docker_hub_repo = None + + if "MAD_DOCKERHUB_USER" in os.environ: + docker_hub_user = os.environ["MAD_DOCKERHUB_USER"] + if "MAD_DOCKERHUB_PASSWORD" in os.environ: + docker_hub_password = os.environ["MAD_DOCKERHUB_PASSWORD"] + if "MAD_DOCKERHUB_REPO" in os.environ: + docker_hub_repo = os.environ["MAD_DOCKERHUB_REPO"] + + if docker_hub_user and docker_hub_password: + print("Found Docker Hub credentials in environment variables") + if self.credentials is None: + self.credentials = {} + + # Override or add Docker Hub credentials + self.credentials["dockerhub"] = { + "repository": docker_hub_repo, + "username": docker_hub_user, + "password": docker_hub_password, + } + print("Docker Hub credentials updated from environment variables") + print(f"Docker Hub credentials: {self.credentials['dockerhub']}") + + def build_phase( + self, + registry: str = None, + clean_cache: bool = False, + manifest_output: str = "build_manifest.json", + batch_build_metadata: typing.Optional[dict] = None, + ) -> typing.Dict: + """Execute the build phase - build all Docker images. + + This method supports both build-only mode (for dedicated build nodes) + and full workflow mode. In build-only mode, GPU detection is skipped + and docker build args should be provided via --additional-context. + + Args: + registry: Optional registry to push images to + clean_cache: Whether to use --no-cache for builds + manifest_output: Output file for build manifest + batch_build_metadata: Optional batch build metadata for batch builds + + Returns: + dict: Build summary + """ + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold blue]🔨 STARTING BUILD PHASE[/bold blue]") + if self.context._build_only_mode: + self.rich_console.print("[yellow](Build-only mode - no GPU detection)[/yellow]") + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + + # Print the arguments as a dictionary for better readability + print( + f"Building models with args: {vars(self.args) if hasattr(self.args, '__dict__') else self.args}" + ) + + # Discover models + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold cyan]🔍 DISCOVERING MODELS[/bold cyan]") + discover_models = DiscoverModels(args=self.args) + models = discover_models.run() + + print(f"Discovered {len(models)} models to build") + + # Copy scripts for building + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold cyan]📋 COPYING SCRIPTS[/bold cyan]") + self._copy_scripts() + + # Validate build context for build-only mode + if self.context._build_only_mode: + if ( + "MAD_SYSTEM_GPU_ARCHITECTURE" + not in self.context.ctx["docker_build_arg"] + ): + self.rich_console.print( + "[yellow]⚠️ Warning: MAD_SYSTEM_GPU_ARCHITECTURE not provided in build context.[/yellow]" + ) + print( + "For build-only nodes, please provide GPU architecture via --additional-context:" + ) + print( + ' --additional-context \'{"docker_build_arg": {"MAD_SYSTEM_GPU_ARCHITECTURE": "gfx908"}}\'' + ) + + # Initialize builder + builder = DockerBuilder( + self.context, + self.console, + live_output=getattr(self.args, "live_output", False), + ) + + # Determine phase suffix for log files + phase_suffix = ( + ".build" + if hasattr(self.args, "_separate_phases") and self.args._separate_phases + else "" + ) + + # Get target architectures from args if provided + target_archs = getattr(self.args, "target_archs", []) + + # Handle comma-separated architectures in a single string + if target_archs: + processed_archs = [] + for arch_arg in target_archs: + # Split comma-separated values and add to list + processed_archs.extend([arch.strip() for arch in arch_arg.split(',') if arch.strip()]) + target_archs = processed_archs + + # If batch_build_metadata is provided, use it to set per-model registry/registry_image + build_summary = builder.build_all_models( + models, + self.credentials, + clean_cache, + registry, + phase_suffix, + batch_build_metadata=batch_build_metadata, + target_archs=target_archs, + ) + + # Export build manifest with registry information + builder.export_build_manifest(manifest_output, registry, batch_build_metadata) + + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold green]✅ BUILD PHASE COMPLETED[/bold green]") + self.rich_console.print(f" [green]Successful builds: {len(build_summary['successful_builds'])}[/green]") + self.rich_console.print(f" [red]Failed builds: {len(build_summary['failed_builds'])}[/red]") + self.rich_console.print(f" [blue]Total build time: {build_summary['total_build_time']:.2f} seconds[/blue]") + print(f" Manifest saved to: {manifest_output}") + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + + # Cleanup scripts + self.cleanup() + + return build_summary + + def generate_local_image_manifest( + self, + container_image: str, + manifest_output: str = "build_manifest.json", + ) -> typing.Dict: + """Generate a build manifest for a local container image. + + This method creates a build manifest that references a local container image, + skipping the build phase entirely. This is useful for legacy compatibility + when using MAD_CONTAINER_IMAGE. + + Args: + container_image: The local container image tag (e.g., 'model:tag') + manifest_output: Output file for build manifest + + Returns: + dict: Build summary compatible with regular build phase + """ + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold blue]🏠 GENERATING LOCAL IMAGE MANIFEST[/bold blue]") + self.rich_console.print(f"Container Image: [yellow]{container_image}[/yellow]") + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + + # Ensure runtime context is initialized for local image mode + self.context.ensure_runtime_context() + + # Discover models to get the model information + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold cyan]🔍 DISCOVERING MODELS[/bold cyan]") + discover_models = DiscoverModels(args=self.args) + models = discover_models.run() + + print(f"Discovered {len(models)} models for local image") + + # Copy scripts for running (even though we're skipping build) + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold cyan]📋 COPYING SCRIPTS[/bold cyan]") + self._copy_scripts() + + # Create manifest entries for all discovered models using the local image + built_images = {} + built_models = {} + successful_builds = [] + + for model in models: + model_name = model["name"] + # Generate a pseudo-image name for compatibility + image_name = f"ci-{model_name.replace('/', '_').lower()}_local" + + # Create build info entry for the local image + built_images[image_name] = { + "model_name": model_name, + "docker_image": container_image, # Use the provided local image + "dockerfile": model.get("dockerfile", ""), + "build_time": 0.0, # No build time for local image + "registry": None, # Local image, no registry + "local_image_mode": True, # Flag to indicate this is a local image + } + + # Create model info entry - use image_name as key for proper mapping + built_models[image_name] = { + "docker_image": container_image, + "image_name": image_name, + **model # Include all original model information + } + + successful_builds.append(model_name) + + # Extract credentials from models + credentials_required = list( + set( + [ + model.get("cred", "") + for model in models + if model.get("cred", "") != "" + ] + ) + ) + + # Create the manifest structure compatible with regular build phase + manifest = { + "built_images": built_images, + "built_models": built_models, + "context": { + "docker_env_vars": self.context.ctx.get("docker_env_vars", {}), + "docker_mounts": self.context.ctx.get("docker_mounts", {}), + "docker_build_arg": self.context.ctx.get("docker_build_arg", {}), + "gpu_vendor": self.context.ctx.get("gpu_vendor", ""), + "docker_gpus": self.context.ctx.get("docker_gpus", ""), + "MAD_CONTAINER_IMAGE": container_image, # Include the local image reference + }, + "credentials_required": credentials_required, + "local_image_mode": True, + "local_container_image": container_image, + } + + # Add multi-node args to context if present + if "build_multi_node_args" in self.context.ctx: + manifest["context"]["multi_node_args"] = self.context.ctx[ + "build_multi_node_args" + ] + + # Write the manifest file + with open(manifest_output, "w") as f: + json.dump(manifest, f, indent=2) + + # Create build summary compatible with regular build phase + build_summary = { + "successful_builds": successful_builds, + "failed_builds": [], + "total_build_time": 0.0, + "manifest_file": manifest_output, + "local_image_mode": True, + "container_image": container_image, + } + + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold green]✅ LOCAL IMAGE MANIFEST GENERATED[/bold green]") + self.rich_console.print(f" [green]Models configured: {len(successful_builds)}[/green]") + self.rich_console.print(f" [blue]Container Image: {container_image}[/blue]") + self.rich_console.print(f" [blue]Manifest saved to: {manifest_output}[/blue]") + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + + # Cleanup scripts (optional for local image mode) + self.cleanup() + + return build_summary + + def run_phase( + self, + manifest_file: str = "build_manifest.json", + registry: str = None, + timeout: int = 7200, + keep_alive: bool = False, + ) -> typing.Dict: + """Execute the run phase - run containers with models. + + This method requires GPU context and will initialize runtime context + if not already done. Should only be called on GPU nodes. + + Args: + manifest_file: Build manifest file from build phase + registry: Registry to pull images from (if different from build) + timeout: Execution timeout per model + keep_alive: Whether to keep containers alive after execution + + Returns: + dict: Execution summary + """ + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold blue]🏃 STARTING RUN PHASE[/bold blue]") + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + + # Ensure runtime context is initialized (GPU detection, env vars, etc.) + self.context.ensure_runtime_context() + + print(f"Running models with args {self.args}") + + self.console.sh("echo 'MAD Run Models'") + + # show node rocm info + host_os = self.context.ctx.get("host_os", "") + if host_os.find("HOST_UBUNTU") != -1: + print(self.console.sh("apt show rocm-libs -a", canFail=True)) + elif host_os.find("HOST_CENTOS") != -1: + print(self.console.sh("yum info rocm-libs", canFail=True)) + elif host_os.find("HOST_SLES") != -1: + print(self.console.sh("zypper info rocm-libs", canFail=True)) + elif host_os.find("HOST_AZURE") != -1: + print(self.console.sh("tdnf info rocm-libs", canFail=True)) + else: + self.rich_console.print("[red]❌ ERROR: Unable to detect host OS.[/red]") + + # Load build manifest + if not os.path.exists(manifest_file): + raise FileNotFoundError(f"Build manifest not found: {manifest_file}") + + with open(manifest_file, "r") as f: + manifest = json.load(f) + + print(f"Loaded manifest with {len(manifest['built_images'])} images") + + # Restore context from manifest if present (for tools, pre/post scripts, etc.) + if "context" in manifest: + manifest_context = manifest["context"] + + # Restore tools configuration if present in manifest + if "tools" in manifest_context: + self.context.ctx["tools"] = manifest_context["tools"] + print(f"Restored tools configuration from manifest: {manifest_context['tools']}") + + # Restore pre/post scripts if present in manifest + if "pre_scripts" in manifest_context: + self.context.ctx["pre_scripts"] = manifest_context["pre_scripts"] + print(f"Restored pre_scripts from manifest") + if "post_scripts" in manifest_context: + self.context.ctx["post_scripts"] = manifest_context["post_scripts"] + print(f"Restored post_scripts from manifest") + if "encapsulate_script" in manifest_context: + self.context.ctx["encapsulate_script"] = manifest_context["encapsulate_script"] + print(f"Restored encapsulate_script from manifest") + + # Filter images by GPU architecture compatibility + try: + runtime_gpu_arch = self.context.get_system_gpu_architecture() + print(f"Runtime GPU architecture detected: {runtime_gpu_arch}") + + # Filter manifest images by GPU architecture compatibility + compatible_images = self._filter_images_by_gpu_architecture( + manifest["built_images"], runtime_gpu_arch + ) + + if not compatible_images: + available_archs = list(set( + img.get('gpu_architecture', 'unknown') + for img in manifest['built_images'].values() + )) + available_archs = [arch for arch in available_archs if arch != 'unknown'] + + if available_archs: + error_msg = ( + f"No compatible Docker images found for runtime GPU architecture '{runtime_gpu_arch}'. " + f"Available image architectures: {available_archs}. " + f"Please build images for the target architecture using: " + f"--target-archs {runtime_gpu_arch}" + ) + else: + error_msg = ( + f"No compatible Docker images found for runtime GPU architecture '{runtime_gpu_arch}'. " + f"The manifest contains legacy images without architecture information. " + f"These will be treated as compatible for backward compatibility." + ) + + raise RuntimeError(error_msg) + + # Update manifest to only include compatible images + manifest["built_images"] = compatible_images + print(f"Filtered to {len(compatible_images)} compatible images for GPU architecture '{runtime_gpu_arch}'") + + except Exception as e: + # If GPU architecture detection fails, proceed with all images for backward compatibility + self.rich_console.print( + f"[yellow]Warning: GPU architecture filtering failed: {e}[/yellow]" + ) + self.rich_console.print( + "[yellow]Proceeding with all available images (backward compatibility mode)[/yellow]" + ) + + # Registry is now per-image; CLI registry is fallback + if registry: + print(f"Using registry from CLI: {registry}") + else: + self.rich_console.print( + "[yellow]No registry specified, will use per-image registry or local images only[/yellow]" + ) + + # Copy scripts for running + self._copy_scripts() + + # Initialize runner + runner = ContainerRunner( + self.context, + self.data, + self.console, + live_output=getattr(self.args, "live_output", False), + ) + runner.set_credentials(self.credentials) + + # Set perf.csv output path if specified in args + if hasattr(self.args, "output") and self.args.output: + runner.set_perf_csv_path(self.args.output) + + # Determine phase suffix for log files + phase_suffix = ( + ".run" + if hasattr(self.args, "_separate_phases") and self.args._separate_phases + else "" + ) + + # Use built models from manifest if available, otherwise discover models + if "built_models" in manifest and manifest["built_models"]: + self.rich_console.print("[cyan]Using model information from build manifest[/cyan]") + models = list(manifest["built_models"].values()) + else: + self.rich_console.print( + "[yellow]No model information in manifest, discovering models from current configuration[/yellow]" + ) + # Discover models (to get execution parameters) + discover_models = DiscoverModels(args=self.args) + models = discover_models.run() + + # Create execution summary + execution_summary = { + "successful_runs": [], + "failed_runs": [], + "total_execution_time": 0, + } + + # Map models to their built images + if "built_models" in manifest and manifest["built_models"]: + # Direct mapping from manifest - built_models maps image_name -> model_info + print("Using direct model-to-image mapping from manifest") + for image_name, build_info in manifest["built_images"].items(): + if image_name in manifest["built_models"]: + model_info = manifest["built_models"][image_name] + try: + print( + f"\nRunning model {model_info['name']} with image {image_name}" + ) + + # Check if MAD_CONTAINER_IMAGE is set in context (for local image mode) + if "MAD_CONTAINER_IMAGE" in self.context.ctx: + actual_image = self.context.ctx["MAD_CONTAINER_IMAGE"] + print(f"Using MAD_CONTAINER_IMAGE override: {actual_image}") + print("Warning: User override MAD_CONTAINER_IMAGE. Model support on image not guaranteed.") + else: + # Use per-image registry if present, else CLI registry + effective_registry = build_info.get("registry", registry) + registry_image = build_info.get("registry_image") + docker_image = build_info.get("docker_image") + if registry_image: + if effective_registry: + print(f"Pulling image from registry: {registry_image}") + try: + registry_image_str = ( + str(registry_image) if registry_image else "" + ) + docker_image_str = ( + str(docker_image) if docker_image else "" + ) + effective_registry_str = ( + str(effective_registry) + if effective_registry + else "" + ) + runner.pull_image( + registry_image_str, + docker_image_str, + effective_registry_str, + self.credentials, + ) + actual_image = docker_image_str + print( + f"Successfully pulled and tagged as: {docker_image_str}" + ) + except Exception as e: + print( + f"Failed to pull from registry, falling back to local image: {e}" + ) + actual_image = docker_image + else: + print( + f"Attempting to pull registry image as-is: {registry_image}" + ) + try: + registry_image_str = ( + str(registry_image) if registry_image else "" + ) + docker_image_str = ( + str(docker_image) if docker_image else "" + ) + runner.pull_image( + registry_image_str, docker_image_str + ) + actual_image = docker_image_str + print( + f"Successfully pulled and tagged as: {docker_image_str}" + ) + except Exception as e: + print( + f"Failed to pull from registry, falling back to local image: {e}" + ) + actual_image = docker_image + else: + # No registry_image key - run container directly using docker_image + actual_image = build_info["docker_image"] + print( + f"No registry image specified, using local image: {actual_image}" + ) + + # Run the container + run_results = runner.run_container( + model_info, + actual_image, + build_info, + keep_alive=keep_alive, + timeout=timeout, + phase_suffix=phase_suffix, + generate_sys_env_details=getattr( + self.args, "generate_sys_env_details", True + ), + ) + + # Add to appropriate list based on actual status + if run_results.get("status") == "SUCCESS": + execution_summary["successful_runs"].append(run_results) + self.rich_console.print( + f"[green]✅ Successfully completed: {model_info['name']} -> {run_results['status']}[/green]" + ) + else: + execution_summary["failed_runs"].append(run_results) + self.rich_console.print( + f"[red]❌ Failed to complete: {model_info['name']} -> {run_results['status']}[/red]" + ) + + execution_summary["total_execution_time"] += run_results.get( + "test_duration", 0 + ) + + except Exception as e: + self.rich_console.print( + f"[red]❌ Failed to run {model_info['name']} with image {image_name}: {e}[/red]" + ) + execution_summary["failed_runs"].append( + { + "model": model_info["name"], + "image": image_name, + "error": str(e), + } + ) + else: + self.rich_console.print(f"[yellow]⚠️ Warning: No model info found for built image: {image_name}[/yellow]") + else: + # Fallback to name-based matching for backward compatibility + self.rich_console.print("[yellow]Using name-based matching (fallback mode)[/yellow]") + for model_info in models: + model_name = model_info["name"] + + # Find matching built images for this model + matching_images = [] + for image_name, build_info in manifest["built_images"].items(): + if model_name.replace("/", "_").lower() in image_name: + matching_images.append((image_name, build_info)) + + if not matching_images: + self.rich_console.print(f"[red]❌ No built images found for model: {model_name}[/red]") + execution_summary["failed_runs"].append( + {"model": model_name, "error": "No built images found"} + ) + continue + + # Run each matching image + for image_name, build_info in matching_images: + try: + print(f"\nRunning model {model_name} with image {image_name}") + + # Handle registry image pulling and tagging according to manifest + if "registry_image" in build_info: + # Registry image exists - pull it and tag as docker_image, then run with docker_image + registry_image = build_info["registry_image"] + docker_image = build_info["docker_image"] + + # Extract registry from the registry_image format + effective_registry = registry + if not effective_registry and registry_image: + registry_parts = registry_image.split("/") + if len(registry_parts) > 1 and "." in registry_parts[0]: + effective_registry = registry_parts[0] + elif ( + registry_image.startswith("docker.io/") + or "/" in registry_image + ): + effective_registry = "docker.io" + + if effective_registry: + print(f"Pulling image from registry: {registry_image}") + try: + # Ensure all parameters are strings and credentials is properly formatted + registry_image_str = ( + str(registry_image) if registry_image else "" + ) + docker_image_str = ( + str(docker_image) if docker_image else "" + ) + effective_registry_str = ( + str(effective_registry) + if effective_registry + else "" + ) + + # Pull registry image and tag it as docker_image + runner.pull_image( + registry_image_str, + docker_image_str, + effective_registry_str, + self.credentials, + ) + actual_image = docker_image_str + print( + f"Successfully pulled and tagged as: {docker_image_str}" + ) + except Exception as e: + print( + f"Failed to pull from registry, falling back to local image: {e}" + ) + actual_image = docker_image + else: + # Registry image exists but no valid registry found, try to pull as-is and tag + print( + f"Attempting to pull registry image as-is: {registry_image}" + ) + try: + registry_image_str = ( + str(registry_image) if registry_image else "" + ) + docker_image_str = ( + str(docker_image) if docker_image else "" + ) + runner.pull_image( + registry_image_str, docker_image_str + ) + actual_image = docker_image_str + print( + f"Successfully pulled and tagged as: {docker_image_str}" + ) + except Exception as e: + print( + f"Failed to pull from registry, falling back to local image: {e}" + ) + actual_image = docker_image + else: + # No registry_image key - run container directly using docker_image + actual_image = build_info["docker_image"] + print( + f"No registry image specified, using local image: {actual_image}" + ) + + # Run the container + run_results = runner.run_container( + model_info, + actual_image, + build_info, + keep_alive=keep_alive, + timeout=timeout, + phase_suffix=phase_suffix, + generate_sys_env_details=getattr( + self.args, "generate_sys_env_details", True + ), + ) + + # Add to appropriate list based on actual status + if run_results.get("status") == "SUCCESS": + execution_summary["successful_runs"].append(run_results) + self.rich_console.print( + f"[green]✅ Successfully completed: {model_name} -> {run_results['status']}[/green]" + ) + else: + execution_summary["failed_runs"].append(run_results) + self.rich_console.print( + f"[red]❌ Failed to complete: {model_name} -> {run_results['status']}[/red]" + ) + + execution_summary["total_execution_time"] += run_results.get( + "test_duration", 0 + ) + + except Exception as e: + self.rich_console.print( + f"[red]❌ Failed to run {model_name} with image {image_name}: {e}[/red]" + ) + execution_summary["failed_runs"].append( + {"model": model_name, "image": image_name, "error": str(e)} + ) + + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold green]✅ RUN PHASE COMPLETED[/bold green]") + self.rich_console.print(f" [green]Successful runs: {len(execution_summary['successful_runs'])}[/green]") + self.rich_console.print(f" [red]Failed runs: {len(execution_summary['failed_runs'])}[/red]") + self.rich_console.print( + f" [blue]Total execution time: {execution_summary['total_execution_time']:.2f} seconds[/blue]" + ) + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + + # Convert output CSV to HTML like run_models.py does + try: + from madengine.tools.csv_to_html import convert_csv_to_html + + perf_csv_path = getattr(self.args, "output", "perf.csv") + if os.path.exists(perf_csv_path): + print("Converting output csv to html...") + convert_csv_to_html(file_path=perf_csv_path) + except Exception as e: + self.rich_console.print(f"[yellow]⚠️ Warning: Could not convert CSV to HTML: {e}[/yellow]") + + # Cleanup scripts + self.cleanup() + + return execution_summary + + def full_workflow( + self, + registry: str = None, + clean_cache: bool = False, + timeout: int = 7200, + keep_alive: bool = False, + ) -> typing.Dict: + """Execute the complete workflow: build then run. + + Args: + registry: Optional registry for image distribution + clean_cache: Whether to use --no-cache for builds + timeout: Execution timeout per model + keep_alive: Whether to keep containers alive after execution + + Returns: + dict: Complete workflow summary + """ + self.rich_console.print(f"\n[dim]{'=' * 80}[/dim]") + self.rich_console.print("[bold magenta]🚀 STARTING COMPLETE DISTRIBUTED WORKFLOW[/bold magenta]") + self.rich_console.print(f"\n[dim]{'=' * 80}[/dim]") + + # Build phase + build_summary = self.build_phase(registry, clean_cache) + + # Run phase + execution_summary = self.run_phase(timeout=timeout, keep_alive=keep_alive) + + # Combine summaries + workflow_summary = { + "build_phase": build_summary, + "run_phase": execution_summary, + "overall_success": ( + len(build_summary["failed_builds"]) == 0 + and len(execution_summary["failed_runs"]) == 0 + ), + } + + self.rich_console.print(f"\n[dim]{'=' * 80}[/dim]") + if workflow_summary['overall_success']: + self.rich_console.print("[bold green]🎉 COMPLETE WORKFLOW FINISHED SUCCESSFULLY[/bold green]") + self.rich_console.print(f" [green]Overall success: {workflow_summary['overall_success']}[/green]") + else: + self.rich_console.print("[bold red]❌ COMPLETE WORKFLOW FINISHED WITH ERRORS[/bold red]") + self.rich_console.print(f" [red]Overall success: {workflow_summary['overall_success']}[/red]") + self.rich_console.print(f"\n[dim]{'=' * 80}[/dim]") + + return workflow_summary + + def _copy_scripts(self) -> None: + """Copy scripts to the current directory.""" + scripts_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "scripts" + ) + print(f"Package path: {scripts_path}") + # copy the scripts to the model directory + self.console.sh(f"cp -vLR --preserve=all {scripts_path} .") + print(f"Scripts copied to {os.getcwd()}/scripts") + + def _filter_images_by_gpu_architecture(self, built_images: typing.Dict, runtime_arch: str) -> typing.Dict: + """Filter built images by GPU architecture compatibility. + + Args: + built_images: Dictionary of built images from manifest + runtime_arch: Runtime GPU architecture (e.g., 'gfx908') + + Returns: + dict: Filtered dictionary containing only compatible images + """ + compatible = {} + + self.rich_console.print(f"[cyan]Filtering images for runtime GPU architecture: {runtime_arch}[/cyan]") + + for image_name, image_info in built_images.items(): + image_arch = image_info.get("gpu_architecture") + + if not image_arch: + # Legacy images without architecture info - assume compatible for backward compatibility + self.rich_console.print( + f"[yellow] Warning: Image {image_name} has no architecture info, assuming compatible (legacy mode)[/yellow]" + ) + compatible[image_name] = image_info + elif image_arch == runtime_arch: + # Exact architecture match + self.rich_console.print( + f"[green] ✓ Compatible: {image_name} (architecture: {image_arch})[/green]" + ) + compatible[image_name] = image_info + else: + # Architecture mismatch + self.rich_console.print( + f"[red] ✗ Incompatible: {image_name} (architecture: {image_arch}, runtime: {runtime_arch})[/red]" + ) + + if not compatible: + self.rich_console.print(f"[red]No compatible images found for runtime architecture: {runtime_arch}[/red]") + else: + self.rich_console.print(f"[green]Found {len(compatible)} compatible image(s)[/green]") + + return compatible + + def cleanup(self) -> None: + """Cleanup the scripts/common directory.""" + # check the directory exists + if os.path.exists("scripts/common"): + # List of directories/files to clean up + cleanup_targets = [ + "scripts/common/tools", + "scripts/common/test_echo.sh", + "scripts/common/pre_scripts", + "scripts/common/post_scripts", + ] + + for target in cleanup_targets: + if os.path.exists(target): + try: + # Try normal removal first + self.console.sh(f"rm -rf {target}", canFail=True) + except Exception: + # If that fails, try to fix permissions and remove + try: + # Fix permissions recursively (ignore errors) + self.console.sh(f"chmod -R u+w {target} 2>/dev/null || true", canFail=True) + # Try removal again (allow failure) + self.console.sh(f"rm -rf {target} 2>/dev/null || true", canFail=True) + + # If directory still exists (e.g., __pycache__ with root-owned files), + # just warn the user instead of failing + if os.path.exists(target): + self.rich_console.print( + f"[yellow]⚠️ Warning: Could not fully remove {target} (permission denied for some files)[/yellow]" + ) + self.rich_console.print( + f"[dim]You may need to manually remove it with: sudo rm -rf {target}[/dim]" + ) + except Exception as e: + # Even permission fixing failed, just warn + self.rich_console.print( + f"[yellow]⚠️ Warning: Could not clean up {target}: {e}[/yellow]" + ) + + print(f"scripts/common directory cleanup attempted.") diff --git a/src/madengine/tools/docker_builder.py b/src/madengine/tools/docker_builder.py new file mode 100644 index 00000000..38f6ac38 --- /dev/null +++ b/src/madengine/tools/docker_builder.py @@ -0,0 +1,1083 @@ +#!/usr/bin/env python3 +""" +Docker Image Builder Module for MADEngine + +This module handles the Docker image building phase separately from execution, +enabling distributed workflows where images are built on a central host +and then distributed to remote nodes for execution. +""" + +import os +import time +import json +import re +import typing +from contextlib import redirect_stdout, redirect_stderr +from rich.console import Console as RichConsole +from madengine.core.console import Console +from madengine.core.context import Context +from madengine.utils.ops import PythonicTee + + +class DockerBuilder: + """Class responsible for building Docker images for models.""" + + # GPU architecture variables used in MAD/DLM Dockerfiles + GPU_ARCH_VARIABLES = [ + "MAD_SYSTEM_GPU_ARCHITECTURE", + "PYTORCH_ROCM_ARCH", + "GPU_TARGETS", + "GFX_COMPILATION_ARCH", + "GPU_ARCHS" + ] + + def __init__( + self, context: Context, console: Console = None, live_output: bool = False + ): + """Initialize the Docker Builder. + + Args: + context: The MADEngine context + console: Optional console instance + live_output: Whether to show live output + """ + self.context = context + self.console = console or Console(live_output=live_output) + self.live_output = live_output + self.rich_console = RichConsole() + self.built_images = {} # Track built images + self.built_models = {} # Track built models + + def get_context_path(self, info: typing.Dict) -> str: + """Get the context path for Docker build. + + Args: + info: The model info dict. + + Returns: + str: The context path. + """ + if "dockercontext" in info and info["dockercontext"] != "": + return info["dockercontext"] + else: + return "./docker" + + def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: + """Get the build arguments. + + Args: + run_build_arg: The run build arguments. + + Returns: + str: The build arguments. + """ + if not run_build_arg and "docker_build_arg" not in self.context.ctx: + return "" + + build_args = "" + for build_arg in self.context.ctx["docker_build_arg"].keys(): + build_args += ( + "--build-arg " + + build_arg + + "='" + + self.context.ctx["docker_build_arg"][build_arg] + + "' " + ) + + if run_build_arg: + for key, value in run_build_arg.items(): + build_args += "--build-arg " + key + "='" + value + "' " + + return build_args + + def build_image( + self, + model_info: typing.Dict, + dockerfile: str, + credentials: typing.Dict = None, + clean_cache: bool = False, + phase_suffix: str = "", + additional_build_args: typing.Dict[str, str] = None, + override_image_name: str = None, + ) -> typing.Dict: + """Build a Docker image for the given model. + + Args: + model_info: The model information dictionary + dockerfile: Path to the Dockerfile + credentials: Optional credentials dictionary + clean_cache: Whether to use --no-cache + phase_suffix: Suffix for log file name (e.g., ".build" or "") + additional_build_args: Additional build arguments to pass to Docker + override_image_name: Override the generated image name + + Returns: + dict: Build information including image name, build duration, etc. + """ + # Generate image name first + if override_image_name: + docker_image = override_image_name + else: + image_docker_name = ( + model_info["name"].replace("/", "_").lower() + + "_" + + os.path.basename(dockerfile).replace(".Dockerfile", "") + ) + docker_image = "ci-" + image_docker_name + + # Create log file for this build + cur_docker_file_basename = os.path.basename(dockerfile).replace( + ".Dockerfile", "" + ) + log_file_path = ( + model_info["name"].replace("/", "_") + + "_" + + cur_docker_file_basename + + phase_suffix + + ".live.log" + ) + # Replace / with _ in log file path (already done above, but keeping for safety) + log_file_path = log_file_path.replace("/", "_") + + self.rich_console.print(f"\n[bold green]🔨 Starting Docker build for model:[/bold green] [bold cyan]{model_info['name']}[/bold cyan]") + print(f"📁 Dockerfile: {dockerfile}") + print(f"🏷️ Target image: {docker_image}") + print(f"📝 Build log: {log_file_path}") + self.rich_console.print(f"[dim]{'='*80}[/dim]") + + # Get docker context + docker_context = self.get_context_path(model_info) + + # Prepare build args + run_build_arg = {} + if "cred" in model_info and model_info["cred"] != "" and credentials: + if model_info["cred"] not in credentials: + raise RuntimeError( + f"Credentials({model_info['cred']}) not found for model {model_info['name']}" + ) + # Add cred to build args + for key_cred, value_cred in credentials[model_info["cred"]].items(): + run_build_arg[model_info["cred"] + "_" + key_cred.upper()] = value_cred + + # Add additional build args if provided (for multi-architecture builds) + if additional_build_args: + run_build_arg.update(additional_build_args) + + build_args = self.get_build_arg(run_build_arg) + + use_cache_str = "--no-cache" if clean_cache else "" + + # Build the image with logging + build_start_time = time.time() + + build_command = ( + f"docker build {use_cache_str} --network=host " + f"-t {docker_image} --pull -f {dockerfile} " + f"{build_args} {docker_context}" + ) + + # Execute build with log redirection + with open(log_file_path, mode="w", buffering=1) as outlog: + with redirect_stdout( + PythonicTee(outlog, self.live_output) + ), redirect_stderr(PythonicTee(outlog, self.live_output)): + print(f"🔨 Executing build command...") + self.console.sh(build_command, timeout=None) + + build_duration = time.time() - build_start_time + + print(f"⏱️ Build Duration: {build_duration:.2f} seconds") + print(f"🏷️ MAD_CONTAINER_IMAGE is {docker_image}") + self.rich_console.print(f"[bold green]✅ Docker build completed successfully[/bold green]") + self.rich_console.print(f"[dim]{'='*80}[/dim]") + + # Get base docker info + base_docker = "" + if ( + "docker_build_arg" in self.context.ctx + and "BASE_DOCKER" in self.context.ctx["docker_build_arg"] + ): + base_docker = self.context.ctx["docker_build_arg"]["BASE_DOCKER"] + else: + base_docker = self.console.sh( + f"grep '^ARG BASE_DOCKER=' {dockerfile} | sed -E 's/ARG BASE_DOCKER=//g'" + ) + + print(f"BASE DOCKER is {base_docker}") + + # Get docker SHA + docker_sha = "" + try: + docker_sha = self.console.sh( + f'docker manifest inspect {base_docker} | grep digest | head -n 1 | cut -d \\" -f 4' + ) + print(f"BASE DOCKER SHA is {docker_sha}") + except Exception as e: + self.rich_console.print(f"[yellow]Warning: Could not get docker SHA: {e}[/yellow]") + + build_info = { + "model": model_info["name"], + "docker_image": docker_image, + "dockerfile": dockerfile, + "base_docker": base_docker, + "docker_sha": docker_sha, + "build_duration": build_duration, + "build_command": build_command, + "log_file": log_file_path, + } + + # Store built image info + self.built_images[docker_image] = build_info + + # Store model info linked to the built image + self.built_models[docker_image] = model_info + + self.rich_console.print(f"[bold green]Successfully built image:[/bold green] [cyan]{docker_image}[/cyan]") + + return build_info + + def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> None: + """Login to a Docker registry. + + Args: + registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty for DockerHub) + credentials: Optional credentials dictionary containing username/password + """ + if not credentials: + print("No credentials provided for registry login") + return + + # Check if registry credentials are available + registry_key = registry if registry else "dockerhub" + + # Handle docker.io as dockerhub + if registry and registry.lower() == "docker.io": + registry_key = "dockerhub" + + if registry_key not in credentials: + error_msg = f"No credentials found for registry: {registry_key}" + if registry_key == "dockerhub": + error_msg += f"\nPlease add dockerhub credentials to credential.json:\n" + error_msg += "{\n" + error_msg += ' "dockerhub": {\n' + error_msg += ' "repository": "your-repository",\n' + error_msg += ' "username": "your-dockerhub-username",\n' + error_msg += ' "password": "your-dockerhub-password-or-token"\n' + error_msg += " }\n" + error_msg += "}" + else: + error_msg += ( + f"\nPlease add {registry_key} credentials to credential.json:\n" + ) + error_msg += "{\n" + error_msg += f' "{registry_key}": {{\n' + error_msg += f' "repository": "your-repository",\n' + error_msg += f' "username": "your-{registry_key}-username",\n' + error_msg += f' "password": "your-{registry_key}-password"\n' + error_msg += " }\n" + error_msg += "}" + self.rich_console.print(f"[red]{error_msg}[/red]") + raise RuntimeError(error_msg) + + creds = credentials[registry_key] + + if "username" not in creds or "password" not in creds: + error_msg = f"Invalid credentials format for registry: {registry_key}" + error_msg += f"\nCredentials must contain 'username' and 'password' fields" + self.rich_console.print(f"[red]{error_msg}[/red]") + raise RuntimeError(error_msg) + + # Ensure credential values are strings + username = str(creds["username"]) + password = str(creds["password"]) + + # Perform docker login + login_command = f"echo '{password}' | docker login" + + if registry and registry.lower() not in ["docker.io", "dockerhub"]: + login_command += f" {registry}" + + login_command += f" --username {username} --password-stdin" + + try: + self.console.sh(login_command, secret=True) + self.rich_console.print(f"[green]✅ Successfully logged in to registry: {registry or 'DockerHub'}[/green]") + except Exception as e: + self.rich_console.print(f"[red]❌ Failed to login to registry {registry}: {e}[/red]") + raise + + def push_image( + self, + docker_image: str, + registry: str = None, + credentials: typing.Dict = None, + explicit_registry_image: str = None, + ) -> str: + """Push the built image to a registry. + + Args: + docker_image: The local docker image name + registry: Optional registry URL (e.g., "localhost:5000", "docker.io", or empty for DockerHub) + credentials: Optional credentials dictionary for registry authentication + + Returns: + str: The full registry image name + """ + if not registry: + print(f"No registry specified, image remains local: {docker_image}") + return docker_image + + # Login to registry if credentials are provided + if credentials: + self.login_to_registry(registry, credentials) + + # Determine registry image name (this should match what was already determined) + if explicit_registry_image: + registry_image = explicit_registry_image + else: + registry_image = self._determine_registry_image_name( + docker_image, registry, credentials + ) + + try: + # Tag the image if different from local name + if registry_image != docker_image: + print(f"Tagging image: docker tag {docker_image} {registry_image}") + tag_command = f"docker tag {docker_image} {registry_image}" + self.console.sh(tag_command) + else: + print( + f"No tag needed, docker_image and registry_image are the same: {docker_image}" + ) + + # Push the image + push_command = f"docker push {registry_image}" + self.rich_console.print(f"\n[bold blue]🚀 Starting docker push to registry...[/bold blue]") + print(f"📤 Registry: {registry}") + print(f"🏷️ Image: {registry_image}") + self.console.sh(push_command) + + self.rich_console.print(f"[bold green]✅ Successfully pushed image to registry:[/bold green] [cyan]{registry_image}[/cyan]") + self.rich_console.print(f"[dim]{'='*80}[/dim]") + return registry_image + + except Exception as e: + self.rich_console.print(f"[red]❌ Failed to push image {docker_image} to registry {registry}: {e}[/red]") + raise + + def export_build_manifest( + self, + output_file: str = "build_manifest.json", + registry: str = None, + batch_build_metadata: typing.Optional[dict] = None, + ) -> None: + """Export enhanced build information to a manifest file. + + This creates a comprehensive build manifest that includes all necessary + information for deployment, reducing the need for separate execution configs. + + Args: + output_file: Path to output manifest file + registry: Registry used for building (added to each image entry) + batch_build_metadata: Optional metadata for batch builds + """ + # Extract credentials from models + credentials_required = list( + set( + [ + model.get("cred", "") + for model in self.built_models.values() + if model.get("cred", "") != "" + ] + ) + ) + + # Set registry for each built image + for image_name, build_info in self.built_images.items(): + # If registry is not set in build_info, set it from argument + if registry: + build_info["registry"] = registry + + # If registry is set in batch_build_metadata, override it + docker_file = build_info.get("dockerfile", "") + truncated_docker_file = docker_file.split("/")[-1].split(".Dockerfile")[0] + model_name = ( + image_name.split("ci-")[1].split(truncated_docker_file)[0].rstrip("_") + ) + if batch_build_metadata and model_name in batch_build_metadata: + self.rich_console.print( + f"[yellow]Overriding registry for {model_name} from batch_build_metadata[/yellow]" + ) + build_info["registry"] = batch_build_metadata[model_name].get( + "registry" + ) + + manifest = { + "built_images": self.built_images, + "built_models": self.built_models, + "context": { + "docker_env_vars": self.context.ctx.get("docker_env_vars", {}), + "docker_mounts": self.context.ctx.get("docker_mounts", {}), + "docker_build_arg": self.context.ctx.get("docker_build_arg", {}), + "gpu_vendor": self.context.ctx.get("gpu_vendor", ""), + "docker_gpus": self.context.ctx.get("docker_gpus", ""), + }, + "credentials_required": credentials_required, + } + + # Preserve tools configuration if present in context + if "tools" in self.context.ctx: + manifest["context"]["tools"] = self.context.ctx["tools"] + + # Preserve pre/post scripts if present in context + if "pre_scripts" in self.context.ctx: + manifest["context"]["pre_scripts"] = self.context.ctx["pre_scripts"] + if "post_scripts" in self.context.ctx: + manifest["context"]["post_scripts"] = self.context.ctx["post_scripts"] + if "encapsulate_script" in self.context.ctx: + manifest["context"]["encapsulate_script"] = self.context.ctx["encapsulate_script"] + + # Add multi-node args to context if present + if "build_multi_node_args" in self.context.ctx: + manifest["context"]["multi_node_args"] = self.context.ctx[ + "build_multi_node_args" + ] + + # Add push failure summary if any pushes failed + push_failures = [] + for image_name, build_info in self.built_images.items(): + if "push_failed" in build_info and build_info["push_failed"]: + push_failures.append( + { + "image": image_name, + "intended_registry_image": build_info.get("registry_image"), + "error": build_info.get("push_error"), + } + ) + + if push_failures: + manifest["push_failures"] = push_failures + + with open(output_file, "w") as f: + json.dump(manifest, f, indent=2) + + self.rich_console.print(f"[green]Build manifest exported to:[/green] {output_file}") + if push_failures: + self.rich_console.print(f"[yellow]Warning: {len(push_failures)} image(s) failed to push to registry[/yellow]") + for failure in push_failures: + self.rich_console.print( + f"[red] - {failure['image']} -> {failure['intended_registry_image']}: {failure['error']}[/red]" + ) + + def build_all_models( + self, + models: typing.List[typing.Dict], + credentials: typing.Dict = None, + clean_cache: bool = False, + registry: str = None, + phase_suffix: str = "", + batch_build_metadata: typing.Optional[dict] = None, + target_archs: typing.List[str] = None, # New parameter + ) -> typing.Dict: + """Build images for all models, with optional multi-architecture support. + + Args: + models: List of model information dictionaries + credentials: Optional credentials dictionary + clean_cache: Whether to use --no-cache + registry: Optional registry to push images to + phase_suffix: Suffix for log file name (e.g., ".build" or "") + batch_build_metadata: Optional batch build metadata + target_archs: Optional list of target GPU architectures for multi-arch builds + + Returns: + dict: Summary of all built images + """ + self.rich_console.print(f"[bold blue]Building Docker images for {len(models)} models...[/bold blue]") + + if target_archs: + self.rich_console.print(f"[bold cyan]Multi-architecture build mode enabled for: {', '.join(target_archs)}[/bold cyan]") + else: + self.rich_console.print(f"[bold cyan]Single architecture build mode[/bold cyan]") + + build_summary = { + "successful_builds": [], + "failed_builds": [], + "total_build_time": 0, + "successful_pushes": [], + "failed_pushes": [], + } + + for model_info in models: + # Check if MAD_SYSTEM_GPU_ARCHITECTURE is provided in additional_context + # This overrides --target-archs and uses default flow + if ("docker_build_arg" in self.context.ctx and + "MAD_SYSTEM_GPU_ARCHITECTURE" in self.context.ctx["docker_build_arg"]): + self.rich_console.print(f"[yellow]Info: MAD_SYSTEM_GPU_ARCHITECTURE provided in additional_context, " + f"disabling --target-archs and using default flow for model {model_info['name']}[/yellow]") + # Use single architecture build mode regardless of target_archs + try: + single_build_info = self._build_model_single_arch( + model_info, credentials, clean_cache, + registry, phase_suffix, batch_build_metadata + ) + build_summary["successful_builds"].extend(single_build_info) + build_summary["total_build_time"] += sum( + info.get("build_duration", 0) for info in single_build_info + ) + except Exception as e: + build_summary["failed_builds"].append({ + "model": model_info["name"], + "error": str(e) + }) + elif target_archs: + # Multi-architecture build mode - always use architecture suffix + for arch in target_archs: + try: + # Always build with architecture suffix when --target-archs is used + arch_build_info = self._build_model_for_arch( + model_info, arch, credentials, clean_cache, + registry, phase_suffix, batch_build_metadata + ) + + build_summary["successful_builds"].extend(arch_build_info) + build_summary["total_build_time"] += sum( + info.get("build_duration", 0) for info in arch_build_info + ) + except Exception as e: + build_summary["failed_builds"].append({ + "model": model_info["name"], + "architecture": arch, + "error": str(e) + }) + else: + # Single architecture build mode (existing behavior - no validation needed) + try: + single_build_info = self._build_model_single_arch( + model_info, credentials, clean_cache, + registry, phase_suffix, batch_build_metadata + ) + build_summary["successful_builds"].extend(single_build_info) + build_summary["total_build_time"] += sum( + info.get("build_duration", 0) for info in single_build_info + ) + except Exception as e: + build_summary["failed_builds"].append({ + "model": model_info["name"], + "error": str(e) + }) + + return build_summary + + def _check_dockerfile_has_gpu_variables(self, model_info: typing.Dict) -> typing.Tuple[bool, str]: + """ + Check if model's Dockerfile contains GPU architecture variables. + Returns (has_gpu_vars, dockerfile_path) + """ + try: + # Find dockerfiles for this model + dockerfiles = self._get_dockerfiles_for_model(model_info) + + for dockerfile_path in dockerfiles: + with open(dockerfile_path, 'r') as f: + dockerfile_content = f.read() + + # Parse GPU architecture variables from Dockerfile + dockerfile_gpu_vars = self._parse_dockerfile_gpu_variables(dockerfile_content) + + if dockerfile_gpu_vars: + return True, dockerfile_path + else: + return False, dockerfile_path + + # No dockerfiles found + return False, "No Dockerfile found" + + except Exception as e: + self.rich_console.print(f"[yellow]Warning: Error checking GPU variables for model {model_info['name']}: {e}[/yellow]") + return False, "Error reading Dockerfile" + + def _get_dockerfiles_for_model(self, model_info: typing.Dict) -> typing.List[str]: + """Get dockerfiles for a model.""" + try: + all_dockerfiles = self.console.sh( + f"ls {model_info['dockerfile']}.*" + ).split("\n") + + dockerfiles = {} + for cur_docker_file in all_dockerfiles: + # Get context of dockerfile + dockerfiles[cur_docker_file] = self.console.sh( + f"head -n5 {cur_docker_file} | grep '# CONTEXT ' | sed 's/# CONTEXT //g'" + ) + + # Filter dockerfiles based on context + dockerfiles = self.context.filter(dockerfiles) + + return list(dockerfiles.keys()) + + except Exception as e: + self.rich_console.print(f"[yellow]Warning: Error finding dockerfiles for model {model_info['name']}: {e}[/yellow]") + return [] + + def _validate_target_arch_against_dockerfile(self, model_info: typing.Dict, target_arch: str) -> bool: + """ + Validate that target architecture is compatible with model's Dockerfile GPU variables. + Called during build phase when --target-archs is provided. + """ + try: + # Find dockerfiles for this model + dockerfiles = self._get_dockerfiles_for_model(model_info) + + for dockerfile_path in dockerfiles: + with open(dockerfile_path, 'r') as f: + dockerfile_content = f.read() + + # Parse GPU architecture variables from Dockerfile + dockerfile_gpu_vars = self._parse_dockerfile_gpu_variables(dockerfile_content) + + if not dockerfile_gpu_vars: + # No GPU variables found - target arch is acceptable + self.rich_console.print(f"[cyan]Info: No GPU architecture variables found in {dockerfile_path}, " + f"target architecture '{target_arch}' is acceptable[/cyan]") + continue + + # Validate target architecture against each GPU variable + for var_name, var_values in dockerfile_gpu_vars.items(): + if not self._is_target_arch_compatible_with_variable( + var_name, var_values, target_arch + ): + self.rich_console.print(f"[red]Error: Target architecture '{target_arch}' is not compatible " + f"with {var_name}={var_values} in {dockerfile_path}[/red]") + return False + + self.rich_console.print(f"[cyan]Info: Target architecture '{target_arch}' validated successfully " + f"against {dockerfile_path}[/cyan]") + + return True + + except FileNotFoundError as e: + self.rich_console.print(f"[yellow]Warning: Dockerfile not found for model {model_info['name']}: {e}[/yellow]") + return True # Assume compatible if Dockerfile not found + except Exception as e: + self.rich_console.print(f"[yellow]Warning: Error validating target architecture for model {model_info['name']}: {e}[/yellow]") + return True # Assume compatible on parsing errors + + def _parse_dockerfile_gpu_variables(self, dockerfile_content: str) -> typing.Dict[str, typing.List[str]]: + """Parse GPU architecture variables from Dockerfile content.""" + gpu_variables = {} + + for var_name in self.GPU_ARCH_VARIABLES: + # Look for ARG declarations + arg_pattern = rf"ARG\s+{var_name}=([^\s\n]+)" + arg_matches = re.findall(arg_pattern, dockerfile_content, re.IGNORECASE) + + # Look for ENV declarations + env_pattern = rf"ENV\s+{var_name}[=\s]+([^\s\n]+)" + env_matches = re.findall(env_pattern, dockerfile_content, re.IGNORECASE) + + # Process found values + all_matches = arg_matches + env_matches + if all_matches: + # Take the last defined value (in case of multiple definitions) + raw_value = all_matches[-1].strip('"\'') + parsed_values = self._parse_gpu_variable_value(var_name, raw_value) + if parsed_values: + gpu_variables[var_name] = parsed_values + + return gpu_variables + + def _parse_gpu_variable_value(self, var_name: str, raw_value: str) -> typing.List[str]: + """Parse GPU variable value based on variable type and format.""" + architectures = [] + + # Handle different variable formats + if var_name in ["GPU_TARGETS", "GPU_ARCHS", "PYTORCH_ROCM_ARCH"]: + # These often contain multiple architectures separated by semicolons or commas + if ";" in raw_value: + architectures = [arch.strip() for arch in raw_value.split(";") if arch.strip()] + elif "," in raw_value: + architectures = [arch.strip() for arch in raw_value.split(",") if arch.strip()] + else: + architectures = [raw_value.strip()] + else: + # Single architecture value (MAD_SYSTEM_GPU_ARCHITECTURE, GFX_COMPILATION_ARCH) + architectures = [raw_value.strip()] + + # Normalize architecture names + normalized_archs = [] + for arch in architectures: + normalized = self._normalize_architecture_name(arch) + if normalized: + normalized_archs.append(normalized) + + return normalized_archs + + def _normalize_architecture_name(self, arch: str) -> str: + """Normalize architecture name to standard format.""" + arch = arch.lower().strip() + + # Handle common variations and aliases + if arch.startswith("gfx"): + return arch + elif arch in ["mi100", "mi-100"]: + return "gfx908" + elif arch in ["mi200", "mi-200", "mi210", "mi250"]: + return "gfx90a" + elif arch in ["mi300", "mi-300", "mi300a"]: + return "gfx940" + elif arch in ["mi300x", "mi-300x"]: + return "gfx942" + elif arch.startswith("mi"): + # Unknown MI series - return as is for potential future support + return arch + + return arch if arch else None + + def _is_target_arch_compatible_with_variable( + self, + var_name: str, + var_values: typing.List[str], + target_arch: str + ) -> bool: + """ + Validate that target architecture is compatible with a specific GPU variable. + Used during build phase validation. + """ + if var_name == "MAD_SYSTEM_GPU_ARCHITECTURE": + # MAD_SYSTEM_GPU_ARCHITECTURE will be overridden by target_arch, so always compatible + return True + + elif var_name in ["PYTORCH_ROCM_ARCH", "GPU_TARGETS", "GPU_ARCHS"]: + # Multi-architecture variables - target arch must be in the list + return target_arch in var_values + + elif var_name == "GFX_COMPILATION_ARCH": + # Compilation architecture should be compatible with target arch + return len(var_values) == 1 and ( + var_values[0] == target_arch or + self._is_compilation_arch_compatible(var_values[0], target_arch) + ) + + # Unknown variable - assume compatible + return True + + def _is_compilation_arch_compatible(self, compile_arch: str, target_arch: str) -> bool: + """Check if compilation architecture is compatible with target architecture.""" + # Define compatibility rules for compilation + compatibility_matrix = { + "gfx908": ["gfx908"], # MI100 - exact match only + "gfx90a": ["gfx90a"], # MI200 - exact match only + "gfx940": ["gfx940"], # MI300A - exact match only + "gfx941": ["gfx941"], # MI300X - exact match only + "gfx942": ["gfx942"], # MI300X - exact match only + } + + compatible_archs = compatibility_matrix.get(compile_arch, [compile_arch]) + return target_arch in compatible_archs + + def _build_model_single_arch( + self, + model_info: typing.Dict, + credentials: typing.Dict, + clean_cache: bool, + registry: str, + phase_suffix: str, + batch_build_metadata: typing.Optional[dict] + ) -> typing.List[typing.Dict]: + """Build model using existing single architecture flow.""" + + # Use existing build logic - MAD_SYSTEM_GPU_ARCHITECTURE comes from additional_context + # or Dockerfile defaults + dockerfiles = self._get_dockerfiles_for_model(model_info) + + results = [] + for dockerfile in dockerfiles: + build_info = self.build_image( + model_info, + dockerfile, + credentials, + clean_cache, + phase_suffix + ) + + # Extract GPU architecture from build args or context for manifest + gpu_arch = self._get_effective_gpu_architecture(model_info, dockerfile) + if gpu_arch: + build_info["gpu_architecture"] = gpu_arch + + # Handle registry push (existing logic) + if registry: + try: + registry_image = self._create_registry_image_name( + build_info["docker_image"], registry, batch_build_metadata, model_info + ) + self.push_image(build_info["docker_image"], registry, credentials, registry_image) + build_info["registry_image"] = registry_image + except Exception as e: + build_info["push_error"] = str(e) + + results.append(build_info) + + return results + + def _get_effective_gpu_architecture(self, model_info: typing.Dict, dockerfile_path: str) -> str: + """Get effective GPU architecture for single arch builds.""" + # Check if MAD_SYSTEM_GPU_ARCHITECTURE is in build args from additional_context + if ("docker_build_arg" in self.context.ctx and + "MAD_SYSTEM_GPU_ARCHITECTURE" in self.context.ctx["docker_build_arg"]): + return self.context.ctx["docker_build_arg"]["MAD_SYSTEM_GPU_ARCHITECTURE"] + + # Try to extract from Dockerfile defaults + try: + with open(dockerfile_path, 'r') as f: + content = f.read() + + # Look for ARG or ENV declarations + patterns = [ + r"ARG\s+MAD_SYSTEM_GPU_ARCHITECTURE=([^\s\n]+)", + r"ENV\s+MAD_SYSTEM_GPU_ARCHITECTURE=([^\s\n]+)" + ] + + for pattern in patterns: + match = re.search(pattern, content, re.IGNORECASE) + if match: + return match.group(1).strip('"\'') + except Exception: + pass + + return None + + def _create_base_image_name(self, model_info: typing.Dict, dockerfile: str) -> str: + """Create base image name from model info and dockerfile.""" + # Extract dockerfile context suffix (e.g., "ubuntu.amd" from "dummy.ubuntu.amd.Dockerfile") + dockerfile_name = os.path.basename(dockerfile) + if '.' in dockerfile_name: + # Remove the .Dockerfile extension and get context + context_parts = dockerfile_name.replace('.Dockerfile', '').split('.')[1:] # Skip model name + context_suffix = '.'.join(context_parts) if context_parts else 'default' + else: + context_suffix = 'default' + + # Create base image name: ci-{model}_{model}.{context} + return f"ci-{model_info['name']}_{model_info['name']}.{context_suffix}" + + def _create_registry_image_name( + self, + image_name: str, + registry: str, + batch_build_metadata: typing.Optional[dict], + model_info: typing.Dict + ) -> str: + """Create registry image name.""" + if batch_build_metadata and model_info["name"] in batch_build_metadata: + meta = batch_build_metadata[model_info["name"]] + if meta.get("registry_image"): + return meta["registry_image"] + + # Default registry naming + return self._determine_registry_image_name(image_name, registry) + + def _create_arch_registry_image_name( + self, + image_name: str, + gpu_arch: str, + registry: str, + batch_build_metadata: typing.Optional[dict], + model_info: typing.Dict + ) -> str: + """Create architecture-specific registry image name.""" + # For multi-arch builds, add architecture to the tag + if batch_build_metadata and model_info["name"] in batch_build_metadata: + meta = batch_build_metadata[model_info["name"]] + if meta.get("registry_image"): + # Append architecture to existing registry image + return f"{meta['registry_image']}_{gpu_arch}" + + # Default arch-specific registry naming + base_registry_name = self._determine_registry_image_name(image_name, registry) + return f"{base_registry_name}" # Architecture already in image_name + + def _determine_registry_image_name( + self, docker_image: str, registry: str, credentials: typing.Dict = None + ) -> str: + """Determine the registry image name that would be used for pushing. + + Args: + docker_image: The local docker image name + registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty for DockerHub) + credentials: Optional credentials dictionary for registry authentication + + Returns: + str: The full registry image name that would be used + """ + if not registry: + return docker_image + + # Determine registry image name based on registry type + if registry.lower() in ["docker.io", "dockerhub"]: + # For DockerHub, always use format: repository:tag + # Try to get repository from credentials, fallback to default if not available + if ( + credentials + and "dockerhub" in credentials + and "repository" in credentials["dockerhub"] + ): + registry_image = ( + f"{credentials['dockerhub']['repository']}:{docker_image}" + ) + else: + registry_image = docker_image + else: + # For other registries (local, AWS ECR, etc.), use format: registry/repository:tag + registry_key = registry + if ( + credentials + and registry_key in credentials + and "repository" in credentials[registry_key] + ): + registry_image = f"{registry}/{credentials[registry_key]['repository']}:{docker_image}" + else: + # Fallback to just registry/imagename if no repository specified + registry_image = f"{registry}/{docker_image}" + + return registry_image + + def _is_compilation_arch_compatible(self, compile_arch: str, target_arch: str) -> bool: + """Check if compilation architecture is compatible with target architecture.""" + # Define compatibility rules for compilation + compatibility_matrix = { + "gfx908": ["gfx908"], # MI100 - exact match only + "gfx90a": ["gfx90a"], # MI200 - exact match only + "gfx940": ["gfx940"], # MI300A - exact match only + "gfx941": ["gfx941"], # MI300X - exact match only + "gfx942": ["gfx942"], # MI300X - exact match only + } + + compatible_archs = compatibility_matrix.get(compile_arch, [compile_arch]) + return target_arch in compatible_archs + + def _build_model_for_arch( + self, + model_info: typing.Dict, + gpu_arch: str, + credentials: typing.Dict, + clean_cache: bool, + registry: str, + phase_suffix: str, + batch_build_metadata: typing.Optional[dict] + ) -> typing.List[typing.Dict]: + """Build model for specific GPU architecture with smart image naming.""" + + # Find dockerfiles + dockerfiles = self._get_dockerfiles_for_model(model_info) + + arch_results = [] + for dockerfile in dockerfiles: + # When using --target-archs, always add architecture suffix regardless of GPU variables + # This ensures consistent naming for multi-architecture builds + base_image_name = self._create_base_image_name(model_info, dockerfile) + arch_image_name = f"{base_image_name}_{gpu_arch}" + + # Set MAD_SYSTEM_GPU_ARCHITECTURE for this build + arch_build_args = {"MAD_SYSTEM_GPU_ARCHITECTURE": gpu_arch} + + # Build the image + build_info = self.build_image( + model_info, + dockerfile, + credentials, + clean_cache, + phase_suffix, + additional_build_args=arch_build_args, + override_image_name=arch_image_name + ) + + # Add architecture metadata + build_info["gpu_architecture"] = gpu_arch + + # Handle registry push with architecture-specific tagging + if registry: + registry_image = self._determine_registry_image_name( + arch_image_name, registry, credentials + ) + try: + self.push_image(arch_image_name, registry, credentials, registry_image) + build_info["registry_image"] = registry_image + except Exception as e: + build_info["push_error"] = str(e) + + arch_results.append(build_info) + + return arch_results + + def _build_model_single_arch( + self, + model_info: typing.Dict, + credentials: typing.Dict, + clean_cache: bool, + registry: str, + phase_suffix: str, + batch_build_metadata: typing.Optional[dict] + ) -> typing.List[typing.Dict]: + """Build model using existing single architecture flow.""" + + # Find dockerfiles for this model + dockerfiles = self._get_dockerfiles_for_model(model_info) + + results = [] + for dockerfile in dockerfiles: + build_info = self.build_image( + model_info, + dockerfile, + credentials, + clean_cache, + phase_suffix + ) + + # Extract GPU architecture from build args or context for manifest + gpu_arch = self._get_effective_gpu_architecture(model_info, dockerfile) + if gpu_arch: + build_info["gpu_architecture"] = gpu_arch + + # Handle registry push (existing logic) + if registry: + registry_image = self._determine_registry_image_name( + build_info["docker_image"], registry, credentials + ) + try: + self.push_image(build_info["docker_image"], registry, credentials, registry_image) + build_info["registry_image"] = registry_image + except Exception as e: + build_info["push_error"] = str(e) + + results.append(build_info) + + return results + + def _get_effective_gpu_architecture(self, model_info: typing.Dict, dockerfile_path: str) -> str: + """Get effective GPU architecture for single arch builds.""" + # Check if MAD_SYSTEM_GPU_ARCHITECTURE is in build args from additional_context + if ("docker_build_arg" in self.context.ctx and + "MAD_SYSTEM_GPU_ARCHITECTURE" in self.context.ctx["docker_build_arg"]): + return self.context.ctx["docker_build_arg"]["MAD_SYSTEM_GPU_ARCHITECTURE"] + + # Try to extract from Dockerfile defaults + try: + with open(dockerfile_path, 'r') as f: + content = f.read() + + # Look for ARG or ENV declarations + patterns = [ + r"ARG\s+MAD_SYSTEM_GPU_ARCHITECTURE=([^\s\n]+)", + r"ENV\s+MAD_SYSTEM_GPU_ARCHITECTURE=([^\s\n]+)" + ] + + for pattern in patterns: + match = re.search(pattern, content, re.IGNORECASE) + if match: + return match.group(1).strip('"\'') + except Exception: + pass + + return None diff --git a/src/madengine/tools/run_models.py b/src/madengine/tools/run_models.py index a620d96f..500535e8 100644 --- a/src/madengine/tools/run_models.py +++ b/src/madengine/tools/run_models.py @@ -45,7 +45,12 @@ from madengine.core.context import Context from madengine.core.dataprovider import Data from madengine.core.docker import Docker -from madengine.utils.ops import PythonicTee, file_print, substring_found, find_and_replace_pattern +from madengine.utils.ops import ( + PythonicTee, + file_print, + substring_found, + find_and_replace_pattern, +) from madengine.core.constants import MAD_MINIO, MAD_AWS_S3 from madengine.core.constants import MODEL_DIR, PUBLIC_GITHUB_ROCM_KEY from madengine.core.timeout import Timeout @@ -118,7 +123,17 @@ def print_perf(self): Method to print stage perf results of a model. """ - print(f"{self.model} performance is {self.performance} {self.metric}") + print("\n" + "=" * 60) + print(f"📊 PERFORMANCE RESULTS") + print("=" * 60) + print(f"🏷️ Model: {self.model}") + print(f"⚡ Performance: {self.performance} {self.metric}") + print(f"📈 Status: {self.status}") + if self.machine_name: + print(f"🖥️ Machine: {self.machine_name}") + if self.gpu_architecture: + print(f"🎮 GPU Architecture: {self.gpu_architecture}") + print("=" * 60 + "\n") # Exports all info in json format to json_name # multiple_results excludes the "model,performance,metric,status" keys @@ -154,9 +169,11 @@ def __init__(self, args): self.return_status = True self.args = args self.console = Console(live_output=True) + # Initialize context in runtime mode (requires GPU detection) self.context = Context( additional_context=args.additional_context, additional_context_file=args.additional_context_file, + build_only_mode=False, # RunModels always needs full runtime context ) # check the data.json file exists data_json_file = args.data_config_file_name @@ -259,10 +276,8 @@ def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: return build_args def apply_tools( - self, - pre_encapsulate_post_scripts: typing.Dict, - run_env: typing.Dict - ) -> None: + self, pre_encapsulate_post_scripts: typing.Dict, run_env: typing.Dict + ) -> None: """Apply tools to the model. Args: @@ -290,32 +305,37 @@ def apply_tools( if "env_vars" in ctx_tool_config: for env_var in ctx_tool_config["env_vars"]: - tool_config["env_vars"].update({env_var: ctx_tool_config["env_vars"][env_var]}) + tool_config["env_vars"].update( + {env_var: ctx_tool_config["env_vars"][env_var]} + ) print(f"Selected Tool, {tool_name}. Configuration : {str(tool_config)}.") # setup tool before other existing scripts if "pre_scripts" in tool_config: pre_encapsulate_post_scripts["pre_scripts"] = ( - tool_config["pre_scripts"] + pre_encapsulate_post_scripts["pre_scripts"] + tool_config["pre_scripts"] + + pre_encapsulate_post_scripts["pre_scripts"] ) # cleanup tool after other existing scripts if "post_scripts" in tool_config: - pre_encapsulate_post_scripts["post_scripts"] += tool_config["post_scripts"] + pre_encapsulate_post_scripts["post_scripts"] += tool_config[ + "post_scripts" + ] # warning: this will update existing keys from env or other tools if "env_vars" in tool_config: run_env.update(tool_config["env_vars"]) if "cmd" in tool_config: # prepend encapsulate cmd pre_encapsulate_post_scripts["encapsulate_script"] = ( - tool_config["cmd"] + " " + pre_encapsulate_post_scripts["encapsulate_script"] + tool_config["cmd"] + + " " + + pre_encapsulate_post_scripts["encapsulate_script"] ) def gather_system_env_details( - self, - pre_encapsulate_post_scripts: typing.Dict, - model_name: str - ) -> None: + self, pre_encapsulate_post_scripts: typing.Dict, model_name: str + ) -> None: """Gather system environment details. Args: @@ -340,7 +360,9 @@ def gather_system_env_details( def copy_scripts(self) -> None: """Copy scripts to the model directory.""" - scripts_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "scripts") + scripts_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "scripts" + ) print(f"Package path: {scripts_path}") # copy the scripts to the model directory self.console.sh(f"cp -vLR --preserve=all {scripts_path} .") @@ -391,7 +413,7 @@ def get_gpu_arg(self, requested_gpus: str) -> str: gpu_arg = "" # get gpu vendor from context, if not raise exception. gpu_vendor = self.context.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] - n_system_gpus = self.context.ctx['docker_env_vars']['MAD_SYSTEM_NGPUS'] + n_system_gpus = self.context.ctx["docker_env_vars"]["MAD_SYSTEM_NGPUS"] gpu_strings = self.context.ctx["docker_gpus"].split(",") # parsing gpu string, example: '{0-4}' -> [0,1,2,3,4] @@ -399,9 +421,11 @@ def get_gpu_arg(self, requested_gpus: str) -> str: # iterate over the list of gpu strings, split range and append to docker_gpus. for gpu_string in gpu_strings: # check if gpu string has range, if so split and append to docker_gpus. - if '-' in gpu_string: - gpu_range = gpu_string.split('-') - docker_gpus += [item for item in range(int(gpu_range[0]),int(gpu_range[1])+1)] + if "-" in gpu_string: + gpu_range = gpu_string.split("-") + docker_gpus += [ + item for item in range(int(gpu_range[0]), int(gpu_range[1]) + 1) + ] else: docker_gpus.append(int(gpu_string)) # sort docker_gpus @@ -409,30 +433,49 @@ def get_gpu_arg(self, requested_gpus: str) -> str: # Check GPU range is valid for system if requested_gpus == "-1": - print("NGPUS requested is ALL (" + ','.join(map(str, docker_gpus)) + ")." ) + print("NGPUS requested is ALL (" + ",".join(map(str, docker_gpus)) + ").") requested_gpus = len(docker_gpus) - print("NGPUS requested is " + str(requested_gpus) + " out of " + str(n_system_gpus) ) + print( + "NGPUS requested is " + + str(requested_gpus) + + " out of " + + str(n_system_gpus) + ) - if int(requested_gpus) > int(n_system_gpus) or int(requested_gpus) > len(docker_gpus): - raise RuntimeError("Too many gpus requested(" + str(requested_gpus) + "). System has " + str(n_system_gpus) + " gpus. Context has " + str(len(docker_gpus)) + " gpus." ) + if int(requested_gpus) > int(n_system_gpus) or int(requested_gpus) > len( + docker_gpus + ): + raise RuntimeError( + "Too many gpus requested(" + + str(requested_gpus) + + "). System has " + + str(n_system_gpus) + + " gpus. Context has " + + str(len(docker_gpus)) + + " gpus." + ) # Exposing number of requested gpus - self.context.ctx['docker_env_vars']['MAD_RUNTIME_NGPUS'] = str(requested_gpus) + self.context.ctx["docker_env_vars"]["MAD_RUNTIME_NGPUS"] = str(requested_gpus) # Create docker arg to assign requested GPUs if gpu_vendor.find("AMD") != -1: - gpu_arg = '--device=/dev/kfd ' + gpu_arg = "--device=/dev/kfd " - gpu_renderDs = self.context.ctx['gpu_renderDs'] + gpu_renderDs = self.context.ctx["gpu_renderDs"] if gpu_renderDs is not None: for idx in range(0, int(requested_gpus)): - gpu_arg += "--device=/dev/dri/renderD" + str(gpu_renderDs[docker_gpus[idx]]) + " " + gpu_arg += ( + "--device=/dev/dri/renderD" + + str(gpu_renderDs[docker_gpus[idx]]) + + " " + ) elif gpu_vendor.find("NVIDIA") != -1: gpu_str = "" for idx in range(0, int(requested_gpus)): - gpu_str += str( docker_gpus[idx] ) + "," + gpu_str += str(docker_gpus[idx]) + "," gpu_arg += "--gpus '\"device=" + gpu_str + "\"' " else: raise RuntimeError("Unable to determine gpu vendor.") @@ -455,7 +498,7 @@ def get_cpu_arg(self) -> str: return "" # get docker_cpus from context, remove spaces and return cpu arguments. cpus = self.context.ctx["docker_cpus"] - cpus = cpus.replace(" ","") + cpus = cpus.replace(" ", "") return "--cpuset-cpus " + cpus + " " def get_env_arg(self, run_env: typing.Dict) -> str: @@ -481,7 +524,13 @@ def get_env_arg(self, run_env: typing.Dict) -> str: # get docker_env_vars from context, if not return env_args. if "docker_env_vars" in self.context.ctx: for env_arg in self.context.ctx["docker_env_vars"].keys(): - env_args += "--env " + env_arg + "='" + str(self.context.ctx["docker_env_vars"][env_arg]) + "' " + env_args += ( + "--env " + + env_arg + + "='" + + str(self.context.ctx["docker_env_vars"][env_arg]) + + "' " + ) print(f"Env arguments: {env_args}") return env_args @@ -506,8 +555,13 @@ def get_mount_arg(self, mount_datapaths: typing.List) -> str: for mount_datapath in mount_datapaths: if mount_datapath: # uses --mount to enforce existence of parent directory; data is mounted readonly by default - mount_args += "-v " + mount_datapath["path"] + ":" + mount_datapath["home"] - if "readwrite" in mount_datapath and mount_datapath["readwrite"] == 'true': + mount_args += ( + "-v " + mount_datapath["path"] + ":" + mount_datapath["home"] + ) + if ( + "readwrite" in mount_datapath + and mount_datapath["readwrite"] == "true" + ): mount_args += " " else: mount_args += ":ro " @@ -517,20 +571,31 @@ def get_mount_arg(self, mount_datapaths: typing.List) -> str: # get docker_mounts from context, if not return mount_args. for mount_arg in self.context.ctx["docker_mounts"].keys(): - mount_args += "-v " + self.context.ctx["docker_mounts"][mount_arg] + ":" + mount_arg + " " + mount_args += ( + "-v " + + self.context.ctx["docker_mounts"][mount_arg] + + ":" + + mount_arg + + " " + ) return mount_args def run_pre_post_script(self, model_docker, model_dir, pre_post): for script in pre_post: script_path = script["path"].strip() - model_docker.sh("cp -vLR --preserve=all " + script_path + " " + model_dir, timeout=600) + model_docker.sh( + "cp -vLR --preserve=all " + script_path + " " + model_dir, timeout=600 + ) script_name = os.path.basename(script_path) script_args = "" if "args" in script: script_args = script["args"] script_args.strip() - model_docker.sh("cd " + model_dir + " && bash " + script_name + " " + script_args , timeout=600) + model_docker.sh( + "cd " + model_dir + " && bash " + script_name + " " + script_args, + timeout=600, + ) def run_model_impl( self, info: typing.Dict, dockerfile: str, run_details: RunDetails @@ -548,7 +613,9 @@ def run_model_impl( if "MAD_CONTAINER_IMAGE" not in self.context.ctx: # build docker image image_docker_name = ( - info["name"].replace("/", "_").lower() # replace / with _ for models in scripts/somedir/ from madengine discover + info["name"] + .replace("/", "_") + .lower() # replace / with _ for models in scripts/somedir/ from madengine discover + "_" + os.path.basename(dockerfile).replace(".Dockerfile", "") ) @@ -584,7 +651,9 @@ def run_model_impl( # get docker image name run_details.docker_image = "ci-" + image_docker_name # get container name - container_name = "container_" + re.sub('.*:','', image_docker_name) # remove docker container hub details + container_name = "container_" + re.sub( + ".*:", "", image_docker_name + ) # remove docker container hub details ## Note: --network=host added to fix issue on CentOS+FBK kernel, where iptables is not available self.console.sh( @@ -611,7 +680,9 @@ def run_model_impl( "docker_build_arg" in self.context.ctx and "BASE_DOCKER" in self.context.ctx["docker_build_arg"] ): - run_details.base_docker = self.context.ctx["docker_build_arg"]["BASE_DOCKER"] + run_details.base_docker = self.context.ctx["docker_build_arg"][ + "BASE_DOCKER" + ] else: run_details.base_docker = self.console.sh( "grep '^ARG BASE_DOCKER=' " @@ -621,15 +692,23 @@ def run_model_impl( print(f"BASE DOCKER is {run_details.base_docker}") # print base docker image digest - run_details.docker_sha = self.console.sh("docker manifest inspect " + run_details.base_docker + " | grep digest | head -n 1 | cut -d \\\" -f 4") + run_details.docker_sha = self.console.sh( + "docker manifest inspect " + + run_details.base_docker + + ' | grep digest | head -n 1 | cut -d \\" -f 4' + ) print(f"BASE DOCKER SHA is {run_details.docker_sha}") else: - container_name = "container_" + self.context.ctx["MAD_CONTAINER_IMAGE"].replace("/", "_").replace(":", "_") + container_name = "container_" + self.context.ctx[ + "MAD_CONTAINER_IMAGE" + ].replace("/", "_").replace(":", "_") run_details.docker_image = self.context.ctx["MAD_CONTAINER_IMAGE"] print(f"MAD_CONTAINER_IMAGE is {run_details.docker_image}") - print(f"Warning: User override MAD_CONTAINER_IMAGE. Model support on image not guaranteed.") + print( + f"Warning: User override MAD_CONTAINER_IMAGE. Model support on image not guaranteed." + ) # prepare docker run options gpu_vendor = self.context.ctx["gpu_vendor"] @@ -644,24 +723,37 @@ def run_model_impl( raise RuntimeError("Unable to determine gpu vendor.") # initialize pre, encapsulate and post scripts - pre_encapsulate_post_scripts = {"pre_scripts": [], "encapsulate_script": "", "post_scripts": []} + pre_encapsulate_post_scripts = { + "pre_scripts": [], + "encapsulate_script": "", + "post_scripts": [], + } if "pre_scripts" in self.context.ctx: - pre_encapsulate_post_scripts["pre_scripts"] = self.context.ctx["pre_scripts"] + pre_encapsulate_post_scripts["pre_scripts"] = self.context.ctx[ + "pre_scripts" + ] if "post_scripts" in self.context.ctx: - pre_encapsulate_post_scripts["post_scripts"] = self.context.ctx["post_scripts"] + pre_encapsulate_post_scripts["post_scripts"] = self.context.ctx[ + "post_scripts" + ] if "encapsulate_script" in self.context.ctx: - pre_encapsulate_post_scripts["encapsulate_script"] = self.context.ctx["encapsulate_script"] + pre_encapsulate_post_scripts["encapsulate_script"] = self.context.ctx[ + "encapsulate_script" + ] # get docker run options docker_options += "--env MAD_MODEL_NAME='" + info["name"] + "' " # Since we are doing Jenkins level environment collection in the docker container, pass in the jenkins build number. - docker_options += f"--env JENKINS_BUILD_NUMBER='{os.environ.get('BUILD_NUMBER','0')}' " + docker_options += ( + f"--env JENKINS_BUILD_NUMBER='{os.environ.get('BUILD_NUMBER','0')}' " + ) - # gather data - # TODO: probably can use context.ctx instead of another dictionary like run_env here + # Gather data environment variables + # NOTE: run_env is a separate dictionary for model-specific environment variables. + # Consider refactoring to use context.ctx for better consistency across the codebase. run_env = {} mount_datapaths = None @@ -719,10 +811,16 @@ def run_model_impl( with Timeout(timeout): print(f"") - model_docker = Docker(run_details.docker_image, container_name, docker_options, keep_alive=self.args.keep_alive, console=self.console) + model_docker = Docker( + run_details.docker_image, + container_name, + docker_options, + keep_alive=self.args.keep_alive, + console=self.console, + ) # check that user is root whoami = model_docker.sh("whoami") - print( "USER is " + whoami ) + print("USER is " + whoami) # echo gpu smi info if gpu_vendor.find("AMD") != -1: @@ -737,10 +835,10 @@ def run_model_impl( if "url" in info and info["url"] != "": # model_dir is set to string after the last forwardslash in url field # adding for url field with and without trailing forwardslash (/) - model_dir = info['url'].rstrip('/').split('/')[-1] + model_dir = info["url"].rstrip("/").split("/")[-1] # Validate model_dir to make sure there are no special characters - special_char = r'[^a-zA-Z0-9\-\_]' # allow hyphen and underscore + special_char = r"[^a-zA-Z0-9\-\_]" # allow hyphen and underscore if re.search(special_char, model_dir) is not None: warnings.warn("Model url contains special character. Fix url.") @@ -755,84 +853,133 @@ def run_model_impl( print(f"Using cred for {info['cred']}") if info["cred"] not in self.creds: - raise RuntimeError("Credentials(" + info["cred"] + ") to run model not found in credential.json; Please contact the model owner, " + info["owner"] + ".") - - if info['url'].startswith('ssh://'): - model_docker.sh("git -c core.sshCommand='ssh -l " + self.creds[ info["cred"] ]["username"] + - " -i " + self.creds[ info["cred"] ]["ssh_key_file"] + " -o IdentitiesOnly=yes " + - " -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' " + - " clone " + info['url'], timeout=240 ) - else: # http or https - model_docker.sh("git clone -c credential.helper='!f() { echo username=" + self.creds[ info["cred"] ]["username"] + \ - "; echo password=" + self.creds[ info["cred"] ]["password"] + "; };f' " + \ - info['url'], timeout=240, secret="git clone " + info['url'] ) + raise RuntimeError( + "Credentials(" + + info["cred"] + + ") to run model not found in credential.json; Please contact the model owner, " + + info["owner"] + + "." + ) + + if info["url"].startswith("ssh://"): + model_docker.sh( + "git -c core.sshCommand='ssh -l " + + self.creds[info["cred"]]["username"] + + " -i " + + self.creds[info["cred"]]["ssh_key_file"] + + " -o IdentitiesOnly=yes " + + " -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' " + + " clone " + + info["url"], + timeout=240, + ) + else: # http or https + model_docker.sh( + "git clone -c credential.helper='!f() { echo username=" + + self.creds[info["cred"]]["username"] + + "; echo password=" + + self.creds[info["cred"]]["password"] + + "; };f' " + + info["url"], + timeout=240, + secret="git clone " + info["url"], + ) else: model_docker.sh("git clone " + info["url"], timeout=240) # set safe.directory for model directory - model_docker.sh("git config --global --add safe.directory /myworkspace/" + model_dir ) + model_docker.sh( + "git config --global --add safe.directory /myworkspace/" + model_dir + ) # echo git commit - run_details.git_commit = model_docker.sh("cd "+ model_dir + " && git rev-parse HEAD") + run_details.git_commit = model_docker.sh( + "cd " + model_dir + " && git rev-parse HEAD" + ) print(f"MODEL GIT COMMIT is {run_details.git_commit}") # update submodule - model_docker.sh("cd "+ model_dir + "; git submodule update --init --recursive") + model_docker.sh( + "cd " + model_dir + "; git submodule update --init --recursive" + ) else: model_docker.sh("mkdir -p " + model_dir) # add system environment collection script to pre_scripts - if self.args.generate_sys_env_details or self.context.ctx.get("gen_sys_env_details"): - self.gather_system_env_details(pre_encapsulate_post_scripts, info['name']) + if self.args.generate_sys_env_details or self.context.ctx.get( + "gen_sys_env_details" + ): + self.gather_system_env_details( + pre_encapsulate_post_scripts, info["name"] + ) # run pre_scripts if pre_encapsulate_post_scripts["pre_scripts"]: - self.run_pre_post_script(model_docker, model_dir, pre_encapsulate_post_scripts["pre_scripts"]) + self.run_pre_post_script( + model_docker, model_dir, pre_encapsulate_post_scripts["pre_scripts"] + ) - scripts_arg = info['scripts'] + scripts_arg = info["scripts"] dir_path = None script_name = None if scripts_arg.endswith(".sh"): dir_path = os.path.dirname(scripts_arg) script_name = "bash " + os.path.basename(scripts_arg) else: - dir_path = info['scripts'] + dir_path = info["scripts"] script_name = "bash run.sh" # add script_prepend_cmd - script_name = pre_encapsulate_post_scripts["encapsulate_script"] + " " + script_name + script_name = ( + pre_encapsulate_post_scripts["encapsulate_script"] + " " + script_name + ) # print repo hash - commit = model_docker.sh("cd "+ dir_path +"; git rev-parse HEAD || true ") + commit = model_docker.sh( + "cd " + dir_path + "; git rev-parse HEAD || true " + ) print("======================================================") - print("MODEL REPO COMMIT: ", commit ) + print("MODEL REPO COMMIT: ", commit) print("======================================================") # copy scripts to model directory - model_docker.sh("cp -vLR --preserve=all "+ dir_path +"/. "+ model_dir +"/") + model_docker.sh( + "cp -vLR --preserve=all " + dir_path + "/. " + model_dir + "/" + ) # prepare data inside container - if 'data' in info and info['data'] != "": - self.data.prepare_data( info['data'], model_docker ) + if "data" in info and info["data"] != "": + self.data.prepare_data(info["data"], model_docker) # Capture data provider information from selected_data_provider - if hasattr(self.data, 'selected_data_provider') and self.data.selected_data_provider: - if 'dataname' in self.data.selected_data_provider: - run_details.dataname = self.data.selected_data_provider['dataname'] - if 'data_provider_type' in self.data.selected_data_provider: - run_details.data_provider_type = self.data.selected_data_provider['data_provider_type'] - if 'duration' in self.data.selected_data_provider: - run_details.data_download_duration = self.data.selected_data_provider['duration'] - if 'size' in self.data.selected_data_provider: - run_details.data_size = self.data.selected_data_provider['size'] - print(f"Data Provider Details: {run_details.dataname}, {run_details.data_provider_type}, {run_details.data_size}, {run_details.data_download_duration}s") + if ( + hasattr(self.data, "selected_data_provider") + and self.data.selected_data_provider + ): + if "dataname" in self.data.selected_data_provider: + run_details.dataname = self.data.selected_data_provider[ + "dataname" + ] + if "data_provider_type" in self.data.selected_data_provider: + run_details.data_provider_type = ( + self.data.selected_data_provider["data_provider_type"] + ) + if "duration" in self.data.selected_data_provider: + run_details.data_download_duration = ( + self.data.selected_data_provider["duration"] + ) + if "size" in self.data.selected_data_provider: + run_details.data_size = self.data.selected_data_provider["size"] + print( + f"Data Provider Details: {run_details.dataname}, {run_details.data_provider_type}, {run_details.data_size}, {run_details.data_download_duration}s" + ) selected_data_provider = { "node_name": run_details.machine_name, - "build_number": os.environ.get('BUILD_NUMBER','0'), - "model_name": info["name"] if "name" in info else "" + "build_number": os.environ.get("BUILD_NUMBER", "0"), + "model_name": info["name"] if "name" in info else "", } # Set build number in run_details - run_details.build_number = os.environ.get('BUILD_NUMBER','0') + run_details.build_number = os.environ.get("BUILD_NUMBER", "0") print(f"Build Info::{selected_data_provider}") @@ -875,14 +1022,22 @@ def run_model_impl( # run post_scripts if pre_encapsulate_post_scripts["post_scripts"]: - self.run_pre_post_script(model_docker, model_dir, pre_encapsulate_post_scripts["post_scripts"]) + self.run_pre_post_script( + model_docker, + model_dir, + pre_encapsulate_post_scripts["post_scripts"], + ) # remove model directory if not self.args.keep_alive and not self.args.keep_model_dir: model_docker.sh("rm -rf " + model_dir, timeout=240) else: model_docker.sh("chmod -R a+rw " + model_dir) - print("keep_alive is specified; model_dir(" + model_dir + ") is not removed") + print( + "keep_alive is specified; model_dir(" + + model_dir + + ") is not removed" + ) # explicitly delete model docker to stop the container, without waiting for the in-built garbage collector del model_docker @@ -909,12 +1064,16 @@ def run_model(self, model_info: typing.Dict) -> bool: run_details.training_precision = model_info["training_precision"] run_details.args = model_info["args"] run_details.tags = model_info["tags"] - run_details.additional_docker_run_options = model_info.get("additional_docker_run_options", "") + run_details.additional_docker_run_options = model_info.get( + "additional_docker_run_options", "" + ) # gets pipeline variable from jenkinsfile, default value is none run_details.pipeline = os.environ.get("pipeline") # Taking gpu arch from context assumes the host image and container have the same gpu arch. # Environment variable updates for MAD Public CI - run_details.gpu_architecture = self.context.ctx["docker_env_vars"]["MAD_SYSTEM_GPU_ARCHITECTURE"] + run_details.gpu_architecture = self.context.ctx["docker_env_vars"][ + "MAD_SYSTEM_GPU_ARCHITECTURE" + ] # Check the setting of shared memory size if "SHM_SIZE" in self.context.ctx: @@ -927,7 +1086,9 @@ def run_model(self, model_info: typing.Dict) -> bool: if model_info.get("is_deprecated", False): print(f"WARNING: Model {model_info['name']} has been deprecated.") if self.args.ignore_deprecated_flag: - print(f"WARNING: Running deprecated model {model_info['name']} due to --ignore-deprecated-flag.") + print( + f"WARNING: Running deprecated model {model_info['name']} due to --ignore-deprecated-flag." + ) else: print(f"WARNING: Skipping execution. No bypass flags mentioned.") return True # exit early @@ -954,7 +1115,9 @@ def run_model(self, model_info: typing.Dict) -> bool: run_details.status = "SKIPPED" # generate exception for testing run_details.generate_json("perf_entry.json") - update_perf_csv(exception_result="perf_entry.json", perf_csv=self.args.output) + update_perf_csv( + exception_result="perf_entry.json", perf_csv=self.args.output + ) else: print( f"Running model {run_details.model} on {run_details.gpu_architecture} architecture." @@ -984,7 +1147,10 @@ def run_model(self, model_info: typing.Dict) -> bool: # check if dockerfiles are found, if not raise exception. if not dockerfiles: - raise Exception("No dockerfiles matching context found for model " + run_details.model) + raise Exception( + "No dockerfiles matching context found for model " + + run_details.model + ) # run dockerfiles for cur_docker_file in dockerfiles.keys(): @@ -1001,7 +1167,7 @@ def run_model(self, model_info: typing.Dict) -> bool: try: # generate exception for testing - if model_info['args'] == "--exception": + if model_info["args"] == "--exception": raise Exception("Exception test!") print(f"Processing Dockerfile: {cur_docker_file}") @@ -1018,40 +1184,64 @@ def run_model(self, model_info: typing.Dict) -> bool: log_file_path = log_file_path.replace("/", "_") with open(log_file_path, mode="w", buffering=1) as outlog: - with redirect_stdout(PythonicTee(outlog, self.args.live_output)), redirect_stderr(PythonicTee(outlog, self.args.live_output)): - self.run_model_impl(model_info, cur_docker_file, run_details) + with redirect_stdout( + PythonicTee(outlog, self.args.live_output) + ), redirect_stderr( + PythonicTee(outlog, self.args.live_output) + ): + self.run_model_impl( + model_info, cur_docker_file, run_details + ) if self.args.skip_model_run: # move to next dockerfile continue # Check if we are looking for a single result or multiple. - multiple_results = (None if "multiple_results" not in model_info else model_info["multiple_results"]) + multiple_results = ( + None + if "multiple_results" not in model_info + else model_info["multiple_results"] + ) # get performance metric from log if multiple_results: run_details.performance = multiple_results else: - perf_regex = ".*performance:\\s*\\([+|-]\?[0-9]*[.]\\?[0-9]*\(e[+|-]\?[0-9]\+\)\?\\)\\s*.*\\s*" - run_details.performance = self.console.sh("cat " + log_file_path + - " | sed -n 's/" + perf_regex + "/\\1/p'") + perf_regex = ".*performance:\\s*\\([+|-]\\?[0-9]*[.]\\?[0-9]*\\(e[+|-]\\?[0-9]\\+\\)\\?\\)\\s*.*\\s*" + run_details.performance = self.console.sh( + "cat " + + log_file_path + + " | sed -n 's/" + + perf_regex + + "/\\1/p'" + ) - metric_regex = ".*performance:\\s*[+|-]\?[0-9]*[.]\\?[0-9]*\(e[+|-]\?[0-9]\+\)\?\\s*\\(\\w*\\)\\s*" - run_details.metric = self.console.sh("cat " + log_file_path + - " | sed -n 's/" + metric_regex + "/\\2/p'") + metric_regex = ".*performance:\\s*[+|-]\\?[0-9]*[.]\\?[0-9]*\\(e[+|-]\\?[0-9]\\+\\)\\?\\s*\\(\\w*\\)\\s*" + run_details.metric = self.console.sh( + "cat " + + log_file_path + + " | sed -n 's/" + + metric_regex + + "/\\2/p'" + ) # check if model passed or failed - run_details.status = 'SUCCESS' if run_details.performance else 'FAILURE' + run_details.status = ( + "SUCCESS" if run_details.performance else "FAILURE" + ) # print stage perf results run_details.print_perf() # add result to output if multiple_results: - run_details.generate_json("common_info.json", multiple_results=True) + run_details.generate_json( + "common_info.json", multiple_results=True + ) update_perf_csv( - multiple_results=model_info['multiple_results'], + multiple_results=model_info["multiple_results"], perf_csv=self.args.output, model_name=run_details.model, common_info="common_info.json", @@ -1063,15 +1253,15 @@ def run_model(self, model_info: typing.Dict) -> bool: perf_csv=self.args.output, ) - self.return_status &= (run_details.status == 'SUCCESS') + self.return_status &= run_details.status == "SUCCESS" except Exception as e: self.return_status = False - print( "===== EXCEPTION =====") - print( "Exception: ", e ) + print("===== EXCEPTION =====") + print("Exception: ", e) traceback.print_exc() - print( "=============== =====") + print("=============== =====") run_details.status = "FAILURE" run_details.generate_json("perf_entry.json") update_perf_csv( @@ -1082,10 +1272,10 @@ def run_model(self, model_info: typing.Dict) -> bool: except Exception as e: self.return_status = False - print( "===== EXCEPTION =====") - print( "Exception: ", e ) + print("===== EXCEPTION =====") + print("Exception: ", e) traceback.print_exc() - print( "=============== =====") + print("=============== =====") run_details.status = "FAILURE" run_details.generate_json("perf_entry.json") update_perf_csv( @@ -1163,7 +1353,7 @@ def run(self) -> bool: if self.return_status: print("All models ran successfully.") else: - print( "===== EXCEPTION =====") + print("===== EXCEPTION =====") print("Some models failed to run.") return self.return_status diff --git a/src/madengine/tools/update_perf_csv.py b/src/madengine/tools/update_perf_csv.py index 5e32e3e2..08285dd1 100644 --- a/src/madengine/tools/update_perf_csv.py +++ b/src/madengine/tools/update_perf_csv.py @@ -9,16 +9,17 @@ import json import argparse import typing + # third-party imports import pandas as pd def df_strip_columns(df: pd.DataFrame) -> pd.DataFrame: """Strip the column names of a DataFrame. - + Args: df: The DataFrame to strip the column names of. - + Returns: The DataFrame with stripped column names. """ @@ -28,10 +29,10 @@ def df_strip_columns(df: pd.DataFrame) -> pd.DataFrame: def read_json(js: str) -> dict: """Read a JSON file. - + Args: js: The path to the JSON file. - + Returns: The JSON dictionary. """ @@ -42,7 +43,7 @@ def read_json(js: str) -> dict: def flatten_tags(perf_entry: dict): """Flatten the tags of a performance entry. - + Args: perf_entry: The performance entry. @@ -56,7 +57,7 @@ def flatten_tags(perf_entry: dict): def perf_entry_df_to_csv(perf_entry: pd.DataFrame) -> None: """Write the performance entry DataFrame to a CSV file. - + Args: perf_entry: The performance entry DataFrame. @@ -68,7 +69,7 @@ def perf_entry_df_to_csv(perf_entry: pd.DataFrame) -> None: def perf_entry_dict_to_csv(perf_entry: typing.Dict) -> None: """Write the performance entry dictionary to a CSV file. - + Args: perf_entry: The performance entry dictionary. """ @@ -78,22 +79,19 @@ def perf_entry_dict_to_csv(perf_entry: typing.Dict) -> None: def handle_multiple_results( - perf_csv_df: pd.DataFrame, - multiple_results: str, - common_info: str, - model_name: str - ) -> pd.DataFrame: + perf_csv_df: pd.DataFrame, multiple_results: str, common_info: str, model_name: str +) -> pd.DataFrame: """Handle multiple results. - + Args: perf_csv_df: The performance csv DataFrame. multiple_results: The path to the multiple results CSV file. common_info: The path to the common info JSON file. model_name: The model name. - + Returns: The updated performance csv DataFrame. - + Raises: AssertionError: If the number of columns in the performance csv DataFrame is not equal to the length of the row. """ @@ -141,16 +139,13 @@ def handle_multiple_results( return perf_csv_df -def handle_single_result( - perf_csv_df: pd.DataFrame, - single_result: str - ) -> pd.DataFrame: +def handle_single_result(perf_csv_df: pd.DataFrame, single_result: str) -> pd.DataFrame: """Handle a single result. - + Args: perf_csv_df: The performance csv DataFrame. single_result: The path to the single result JSON file. - + Returns: The updated performance csv DataFrame. @@ -169,15 +164,14 @@ def handle_single_result( def handle_exception_result( - perf_csv_df: pd.DataFrame, - exception_result: str - ) -> pd.DataFrame: + perf_csv_df: pd.DataFrame, exception_result: str +) -> pd.DataFrame: """Handle an exception result. - + Args: perf_csv_df: The performance csv DataFrame. exception_result: The path to the exception result JSON file. - + Returns: The updated performance csv DataFrame. @@ -196,20 +190,25 @@ def handle_exception_result( def update_perf_csv( - perf_csv: str, - multiple_results: typing.Optional[str] = None, - single_result: typing.Optional[str] = None, - exception_result: typing.Optional[str] = None, - common_info: typing.Optional[str] = None, - model_name: typing.Optional[str] = None, - ): + perf_csv: str, + multiple_results: typing.Optional[str] = None, + single_result: typing.Optional[str] = None, + exception_result: typing.Optional[str] = None, + common_info: typing.Optional[str] = None, + model_name: typing.Optional[str] = None, +): """Update the performance csv file with the latest performance data.""" - print(f"Attaching performance metrics of models to perf.csv") + print("\n" + "=" * 80) + print("📈 ATTACHING PERFORMANCE METRICS TO DATABASE") + print("=" * 80) + print(f"📂 Target file: {perf_csv}") + # read perf.csv perf_csv_df = df_strip_columns(pd.read_csv(perf_csv)) # handle multiple_results, single_result, and exception_result if multiple_results: + print("🔄 Processing multiple results...") perf_csv_df = handle_multiple_results( perf_csv_df, multiple_results, @@ -217,17 +216,20 @@ def update_perf_csv( model_name, ) elif single_result: + print("🔄 Processing single result...") perf_csv_df = handle_single_result(perf_csv_df, single_result) elif exception_result: - perf_csv_df = handle_exception_result( - perf_csv_df, exception_result - ) + print("⚠️ Processing exception result...") + perf_csv_df = handle_exception_result(perf_csv_df, exception_result) else: - print("No results to update in perf.csv") + print("ℹ️ No results to update in perf.csv") # write new perf.csv # Note that this file will also generate a perf_entry.csv regardless of the output file args. perf_csv_df.to_csv(perf_csv, index=False) + print(f"✅ Successfully updated: {perf_csv}") + print("=" * 80 + "\n") + perf_csv_df.to_csv(perf_csv, index=False) class UpdatePerfCsv: @@ -247,12 +249,17 @@ def __init__(self, args: argparse.Namespace): def run(self): """Update the performance csv file with the latest performance data.""" - print(f"Updating performance metrics of models perf.csv to database") + print("\n" + "=" * 80) + print("📊 UPDATING PERFORMANCE METRICS DATABASE") + print("=" * 80) + print(f"📂 Processing: {self.args.perf_csv}") + # read perf.csv perf_csv_df = df_strip_columns(pd.read_csv(self.args.perf_csv)) # handle multiple_results, single_result, and exception_result if self.args.multiple_results: + print("🔄 Processing multiple results...") perf_csv_df = handle_multiple_results( perf_csv_df, self.args.multiple_results, @@ -260,17 +267,22 @@ def run(self): self.args.model_name, ) elif self.args.single_result: + print("🔄 Processing single result...") perf_csv_df = handle_single_result(perf_csv_df, self.args.single_result) elif self.args.exception_result: + print("⚠️ Processing exception result...") perf_csv_df = handle_exception_result( perf_csv_df, self.args.exception_result ) else: - print("No results to update in perf.csv") + print("ℹ️ No results to update in perf.csv") # write new perf.csv # Note that this file will also generate a perf_entry.csv regardless of the output file args. perf_csv_df.to_csv(self.args.perf_csv, index=False) + print(f"✅ Successfully updated: {self.args.perf_csv}") + print("=" * 80 + "\n") + self.return_status = True return self.return_status diff --git a/src/madengine/tools/update_table_db.py b/src/madengine/tools/update_table_db.py index a71bde87..06c82be3 100644 --- a/src/madengine/tools/update_table_db.py +++ b/src/madengine/tools/update_table_db.py @@ -10,9 +10,11 @@ import argparse import subprocess import typing + # third-party modules import paramiko import socket + # MAD Engine modules from madengine.utils.ssh_to_db import SFTPClient, print_ssh_out from madengine.db.logger import setup_logger @@ -26,9 +28,10 @@ class UpdateTable: """Class to update tables in the database. - + This class provides the functions to update tables in the database. """ + def __init__(self, args: argparse.Namespace): """Initialize the UpdateTable class. @@ -44,14 +47,14 @@ def __init__(self, args: argparse.Namespace): self.ssh_user = ENV_VARS["ssh_user"] self.ssh_password = ENV_VARS["ssh_password"] self.ssh_hostname = ENV_VARS["ssh_hostname"] - self.ssh_port = ENV_VARS["ssh_port"] + self.ssh_port = ENV_VARS["ssh_port"] # get the db folder self.db_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../db") - LOGGER.info(f"DB path: {self.db_path}") + LOGGER.info(f"DB path: {self.db_path}") self.status = False - def run(self, table_name: str='dlm_table') -> None: + def run(self, table_name: str = "dlm_table") -> None: """Update a table in the database. Args: @@ -59,13 +62,13 @@ def run(self, table_name: str='dlm_table') -> None: Returns: None - + Raises: Exception: An error occurred updating the table. """ print(f"Updating table {table_name} in the database") - if 'localhost' in self.ssh_hostname or '127.0.0.1' in self.ssh_hostname: + if "localhost" in self.ssh_hostname or "127.0.0.1" in self.ssh_hostname: try: self.local_db() self.status = True @@ -75,18 +78,18 @@ def run(self, table_name: str='dlm_table') -> None: return self.status else: try: - self.remote_db() + self.remote_db() self.status = True - return self.status + return self.status except Exception as error: LOGGER.error(f"Error updating table in the remote database: {error}") return self.status def local_db(self) -> None: """Update a table in the local database. - + This function updates a table in the local database. - + Returns: None @@ -99,34 +102,45 @@ def local_db(self) -> None: cmd_list = ["cp", "-r", self.db_path, "."] try: - ret = subprocess.Popen(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ret = subprocess.Popen( + cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) out, err = ret.communicate() if ret.returncode == 0: if out: - LOGGER.info(out.decode('utf-8')) + LOGGER.info(out.decode("utf-8")) print("Copied scripts to current work path") else: if err: - LOGGER.error(err.decode('utf-8')) + LOGGER.error(err.decode("utf-8")) except Exception as e: LOGGER.error(f"An error occurred: {e}") # run upload_csv_to_db.py in the db folder with environment variables using subprocess Popen - cmd_list = ["python3", "./db/upload_csv_to_db.py", "--csv-file-path", self.args.csv_file_path] + cmd_list = [ + "python3", + "./db/upload_csv_to_db.py", + "--csv-file-path", + self.args.csv_file_path, + ] # Ensure ENV_VARS is a dictionary env_vars = dict(ENV_VARS) try: - ret = subprocess.Popen(cmd_list, env=env_vars, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + ret = subprocess.Popen( + cmd_list, env=env_vars, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) out, err = ret.communicate() if ret.returncode == 0: if out: - LOGGER.info(out.decode('utf-8')) + LOGGER.info(out.decode("utf-8")) else: if err: - LOGGER.error(err.decode('utf-8')) - raise Exception(f"Error updating table in the local database: {err.decode('utf-8')}") + LOGGER.error(err.decode("utf-8")) + raise Exception( + f"Error updating table in the local database: {err.decode('utf-8')}" + ) except Exception as e: LOGGER.error(f"An error occurred: {e}") @@ -134,9 +148,9 @@ def local_db(self) -> None: def remote_db(self) -> None: """Update a table in the remote database. - + This function updates a table in the remote database. - + Returns: None @@ -182,7 +196,9 @@ def remote_db(self) -> None: print(upload_script_path_remote, csv_file_path_remote, model_json_path_remote) # clean up previous uploads - print_ssh_out(ssh_client.exec_command("rm -rf {}".format(upload_script_path_remote))) + print_ssh_out( + ssh_client.exec_command("rm -rf {}".format(upload_script_path_remote)) + ) print_ssh_out(ssh_client.exec_command("rm -rf {}".format(csv_file_path_remote))) # upload file diff --git a/src/madengine/tools/upload_mongodb.py b/src/madengine/tools/upload_mongodb.py index 6766e3e2..9d375a32 100644 --- a/src/madengine/tools/upload_mongodb.py +++ b/src/madengine/tools/upload_mongodb.py @@ -22,9 +22,10 @@ # Create the logger LOGGER = setup_logger() + class MongoDBHandler: """Class to handle MongoDB operations.""" - + def __init__(self, args: argparse.Namespace) -> None: """Initialize the MongoDBHandler. @@ -56,7 +57,7 @@ def connect(self) -> None: def collection_exists(self) -> bool: """Check if a collection exists in the database. - + Returns: bool: True if the collection exists, False otherwise. """ @@ -69,7 +70,9 @@ def update_collection(self, data: pd.DataFrame) -> None: data (pd.DataFrame): DataFrame containing the data to update. """ if not self.collection_exists(): - LOGGER.info(f"Collection '{self.collection_name}' does not exist. Creating it.") + LOGGER.info( + f"Collection '{self.collection_name}' does not exist. Creating it." + ) self.db.create_collection(self.collection_name) collection = self.db[self.collection_name] @@ -77,11 +80,12 @@ def update_collection(self, data: pd.DataFrame) -> None: for record in records: # Use an appropriate unique identifier for upsert (e.g., "_id" or another field) collection.update_one(record, {"$set": record}, upsert=True) - LOGGER.info(f"Updated collection '{self.collection_name}' with {len(records)} records.") + LOGGER.info( + f"Updated collection '{self.collection_name}' with {len(records)} records." + ) def run(self) -> None: - """Run the process of updating a MongoDB collection with data from a CSV file. - """ + """Run the process of updating a MongoDB collection with data from a CSV file.""" self.connect() data = load_csv_to_dataframe(self.csv_file_path) @@ -97,7 +101,7 @@ def run(self) -> None: # Remove any leading or trailing whitespace from column names data.columns = data.columns.str.strip() - + self.update_collection(data) diff --git a/src/madengine/utils/log_formatting.py b/src/madengine/utils/log_formatting.py new file mode 100644 index 00000000..31673c93 --- /dev/null +++ b/src/madengine/utils/log_formatting.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +""" +Utility functions for formatting and displaying data in logs. + +This module provides enhanced formatting utilities for better log readability, +including dataframe formatting and other display utilities. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import pandas as pd +import typing +from rich.table import Table +from rich.console import Console as RichConsole +from rich.text import Text + + +def format_dataframe_for_log( + df: pd.DataFrame, title: str = "DataFrame", max_rows: int = 20, max_cols: int = 10 +) -> str: + """ + Format a pandas DataFrame for beautiful log output. + + Args: + df: The pandas DataFrame to format + title: Title for the dataframe display + max_rows: Maximum number of rows to display (if None, use all rows) + max_cols: Maximum number of columns to display + + Returns: + str: Beautifully formatted string representation of the DataFrame + """ + if df.empty: + return f"\n📊 {title}\n{'='*60}\n❌ DataFrame is empty\n{'='*60}\n" + + # Define key columns to display for performance results + key_columns = [ + "model", + "n_gpus", + "docker_file", + "machine_name", + "gpu_architecture", + "performance", + "metric", + "status", + "dataname", + ] + + # Filter DataFrame to show only key columns that exist + available_columns = [col for col in key_columns if col in df.columns] + if available_columns: + display_df = df[available_columns].copy() + total_columns_note = ( + f"(showing {len(available_columns)} of {len(df.columns)} columns)" + ) + else: + # If no key columns found, show all columns as fallback with truncation + display_df = df.copy() + total_columns_note = f"(showing all {len(df.columns)} columns)" + if len(df.columns) > max_cols: + display_df = display_df.iloc[:, :max_cols] + total_columns_note = ( + f"(showing first {max_cols} of {len(df.columns)} columns)" + ) + + # Use all rows if max_rows is None + if max_rows is None: + max_rows = len(display_df) + + # Truncate rows if necessary (show latest rows) + truncated_rows = False + if len(display_df) > max_rows: + display_df = display_df.tail(max_rows) + truncated_rows = True + + # Create header + header = f"\n📊 {title} {total_columns_note}\n" + header += f"{'='*80}\n" + if available_columns: + header += f"📏 Shape: {df.shape[0]} rows × {len(available_columns)} key columns (total: {df.shape[1]} columns)\n" + else: + header += f"📏 Shape: {df.shape[0]} rows × {df.shape[1]} columns\n" + + if truncated_rows: + header += f"⚠️ Display truncated: showing first {max_rows} rows\n" + + header += f"{'='*80}\n" + + # Format the DataFrame with nice styling + formatted_df = display_df.to_string( + index=True, max_rows=max_rows, width=None, float_format="{:.4f}".format + ) + + # Add some visual separators + footer = f"\n{'='*80}\n" + + return header + formatted_df + footer + + +def format_dataframe_rich( + df: pd.DataFrame, title: str = "DataFrame", max_rows: int = 20 +) -> None: + """ + Display a pandas DataFrame using Rich formatting for enhanced readability. + + Args: + df: The pandas DataFrame to display + title: Title for the table + max_rows: Maximum number of rows to display + """ + console = RichConsole() + + if df.empty: + console.print( + f"📊 [bold cyan]{title}[/bold cyan]: [red]DataFrame is empty[/red]" + ) + return + + # Define key columns to display for performance results + key_columns = [ + "model", + "n_gpus", + "machine_name", + "gpu_architecture", + "performance", + "metric", + "status", + "dataname", + ] + + # Filter DataFrame to show only key columns that exist + available_columns = [col for col in key_columns if col in df.columns] + if available_columns: + display_df = df[available_columns] + total_columns_note = ( + f"(showing {len(available_columns)} of {len(df.columns)} columns)" + ) + else: + # If no key columns found, show all columns as fallback + display_df = df + total_columns_note = f"(showing all {len(df.columns)} columns)" + + # Create Rich table + table = Table( + title=f"📊 {title} {total_columns_note}", + show_header=True, + header_style="bold magenta", + ) + + # Add index column + table.add_column("Index", style="dim", width=8) + + # Add data columns + for col in display_df.columns: + table.add_column(str(col), style="cyan") + + # Add rows (truncate if necessary, show latest rows) + if len(display_df) > max_rows: + truncated_df = display_df.tail(max_rows) + truncated_indices = truncated_df.index + display_rows = max_rows + else: + truncated_df = display_df + truncated_indices = truncated_df.index + display_rows = len(truncated_df) + + for i in range(display_rows): + row_data = [str(truncated_indices[i])] + for col in truncated_df.columns: + value = truncated_df.iloc[i][col] + if pd.isna(value): + row_data.append("[dim]NaN[/dim]") + elif isinstance(value, float): + row_data.append(f"{value:.4f}") + else: + row_data.append(str(value)) + table.add_row(*row_data) + + # Show truncation info + if len(display_df) > max_rows: + table.add_row(*["..." for _ in range(len(truncated_df.columns) + 1)]) + console.print( + f"[yellow]⚠️ Showing latest {max_rows} of {len(display_df)} rows[/yellow]" + ) + + console.print(table) + console.print( + f"[green]✨ DataFrame shape: {df.shape[0]} rows × {len(available_columns)} key columns (total: {df.shape[1]} columns)[/green]" + ) + + +def print_dataframe_beautiful( + df: pd.DataFrame, title: str = "Data", use_rich: bool = True +) -> None: + """ + Print a pandas DataFrame with beautiful formatting. + + Args: + df: The pandas DataFrame to print + title: Title for the display + use_rich: Whether to use Rich formatting (if available) or fall back to simple formatting + """ + try: + if use_rich: + format_dataframe_rich(df, title) + else: + raise ImportError("Fallback to simple formatting") + except (ImportError, Exception): + # Fallback to simple but nice formatting + formatted_output = format_dataframe_for_log(df, title) + print(formatted_output) + + +def highlight_log_section(title: str, content: str, style: str = "info") -> str: + """ + Create a highlighted log section with borders and styling. + + Args: + title: Section title + content: Section content + style: Style type ('info', 'success', 'warning', 'error') + + Returns: + str: Formatted log section + """ + styles = { + "info": {"emoji": "ℹ️", "border": "-"}, + "success": {"emoji": "✅", "border": "="}, + "warning": {"emoji": "⚠️", "border": "!"}, + "error": {"emoji": "❌", "border": "#"}, + } + + style_config = styles.get(style, styles["info"]) + emoji = style_config["emoji"] + border_char = style_config["border"] + + border = border_char * 80 + header = f"\n{border}\n{emoji} {title.upper()}\n{border}" + footer = f"{border}\n" + + return f"{header}\n{content}\n{footer}" diff --git a/src/madengine/utils/ops.py b/src/madengine/utils/ops.py index 4a0f6a45..7b32ec9f 100644 --- a/src/madengine/utils/ops.py +++ b/src/madengine/utils/ops.py @@ -54,17 +54,15 @@ def flush(self) -> None: def find_and_replace_pattern( - dictionary: typing.Dict, - substring: str, - replacement: str - ) -> typing.Dict: + dictionary: typing.Dict, substring: str, replacement: str +) -> typing.Dict: """Find and replace a substring in a dictionary. - + Args: dictionary: The dictionary. substring: The substring to find. replacement: The replacement string. - + Returns: The updated dictionary. """ @@ -78,16 +76,13 @@ def find_and_replace_pattern( return updated_dict -def substring_found( - dictionary: typing.Dict, - substring: str - ) -> bool: +def substring_found(dictionary: typing.Dict, substring: str) -> bool: """Check if a substring is found in the dictionary. - + Args: dictionary: The dictionary. substring: The substring to find. - + Returns: True if the substring is found, False otherwise. """ diff --git a/src/madengine/utils/ssh_to_db.py b/src/madengine/utils/ssh_to_db.py index c5f694fa..255ae58a 100644 --- a/src/madengine/utils/ssh_to_db.py +++ b/src/madengine/utils/ssh_to_db.py @@ -4,9 +4,11 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import socket + # third-party modules import paramiko @@ -65,10 +67,10 @@ def mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> No def print_ssh_out(client_output: tuple) -> None: """Print the output from the SSH client. - + Args: client_output (tuple): The output from the SSH client. - + Returns: None """ diff --git a/tests/fixtures/dummy/credential.json b/tests/fixtures/dummy/credential.json index 1b8a56df..b53e0597 100644 --- a/tests/fixtures/dummy/credential.json +++ b/tests/fixtures/dummy/credential.json @@ -17,5 +17,15 @@ "PASSWORD": "admin-secret-key", "MINIO_ENDPOINT": "http://127.0.1:9000", "AWS_ENDPOINT_URL_S3": "http://127.0.1:9000" - } + }, + "dockerhub": { + "repository": "your-repository", + "username": "your-dockerhub-username", + "password": "your-dockerhub-password-or-token" + }, + "localhost:5000": { + "repository": "your-repository", + "username": "your-local-registry-username", + "password": "your-local-registry-password" + } } \ No newline at end of file diff --git a/tests/fixtures/utils.py b/tests/fixtures/utils.py index 617c305d..1e9f7d49 100644 --- a/tests/fixtures/utils.py +++ b/tests/fixtures/utils.py @@ -13,11 +13,9 @@ import pytest from unittest.mock import MagicMock - MODEL_DIR = "tests/fixtures/dummy" BASE_DIR = os.path.join(os.path.dirname(__file__), "..", "..") sys.path.insert(1, BASE_DIR) -print(f'BASE DIR:: {BASE_DIR}') # Cache variables to avoid repeated system checks during collection _gpu_vendor_cache = None @@ -25,11 +23,60 @@ _num_gpus_cache = None _num_cpus_cache = None +# GPU detection cache to avoid multiple expensive calls +_has_gpu_cache = None + + +def has_gpu() -> bool: + """Simple function to check if GPU is available for testing. + + This is the primary function for test skipping decisions. + Uses caching to avoid repeated expensive detection calls. + + Returns: + bool: True if GPU is available, False if CPU-only machine + """ + global _has_gpu_cache + + if _has_gpu_cache is not None: + return _has_gpu_cache + + try: + # Ultra-simple file existence check (no subprocess calls) + # This is safe for pytest collection and avoids hanging + nvidia_exists = os.path.exists("/usr/bin/nvidia-smi") + amd_rocm_exists = os.path.exists("/opt/rocm/bin/rocm-smi") or os.path.exists( + "/usr/local/bin/rocm-smi" + ) + + _has_gpu_cache = nvidia_exists or amd_rocm_exists + + except Exception: + # If file checks fail, assume no GPU (safe default for tests) + _has_gpu_cache = False + + return _has_gpu_cache + + +def requires_gpu(reason: str = "test requires GPU functionality"): + """Simple decorator to skip tests that require GPU. + + This is the only decorator needed for GPU-dependent tests. + + Args: + reason: Custom reason for skipping + + Returns: + pytest.mark.skipif decorator + """ + return pytest.mark.skipif(not has_gpu(), reason=reason) + @pytest.fixture def global_data(): # Lazy import to avoid collection issues - from madengine.core.console import Console + if "Console" not in globals(): + from madengine.core.console import Console return {"console": Console(live_output=True)} @@ -47,6 +94,55 @@ def clean_test_temp_files(request): os.remove(file_path) +def generate_additional_context_for_machine() -> dict: + """Generate appropriate additional context based on detected machine capabilities. + + Returns: + dict: Additional context with gpu_vendor and guest_os suitable for current machine + """ + if has_gpu(): + # Simple vendor detection for GPU machines + vendor = "NVIDIA" if os.path.exists("/usr/bin/nvidia-smi") else "AMD" + return {"gpu_vendor": vendor, "guest_os": "UBUNTU"} + else: + # On CPU-only machines, use defaults suitable for build-only operations + return { + "gpu_vendor": "AMD", # Default for build-only nodes + "guest_os": "UBUNTU", # Default OS + } + + +def generate_additional_context_json() -> str: + """Generate JSON string of additional context for current machine. + + Returns: + str: JSON string representation of additional context + """ + return json.dumps(generate_additional_context_for_machine()) + + +def create_mock_args_with_auto_context(**kwargs) -> MagicMock: + """Create mock args with automatically generated additional context. + + Args: + **kwargs: Additional attributes to set on the mock args + + Returns: + MagicMock: Mock args object with auto-generated additional context + """ + mock_args = MagicMock() + + # Set auto-generated context + mock_args.additional_context = generate_additional_context_json() + mock_args.additional_context_file = None + + # Set any additional attributes + for key, value in kwargs.items(): + setattr(mock_args, key, value) + + return mock_args + + def is_nvidia() -> bool: """Check if the GPU is NVIDIA or not. diff --git a/tests/test_cli_error_integration.py b/tests/test_cli_error_integration.py new file mode 100644 index 00000000..f0601357 --- /dev/null +++ b/tests/test_cli_error_integration.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +""" +Unit tests for MADEngine CLI error handling integration. + +Tests the integration of unified error handling in mad_cli.py and +distributed_orchestrator.py components. +""" + +import pytest +import json +import os +import tempfile +from unittest.mock import Mock, patch, MagicMock, mock_open +from rich.console import Console + +# Add src to path for imports +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from madengine.core.errors import ( + ErrorHandler, + ConfigurationError, + set_error_handler, + get_error_handler, + create_error_context +) + + +class TestMadCLIErrorIntegration: + """Test mad_cli.py error handling integration.""" + + @patch('madengine.mad_cli.Console') + def test_setup_logging_creates_error_handler(self, mock_console_class): + """Test that setup_logging initializes the unified error handler.""" + from madengine.mad_cli import setup_logging + + mock_console = Mock(spec=Console) + mock_console_class.return_value = mock_console + + # Clear any existing global error handler + set_error_handler(None) + + # Call setup_logging + setup_logging(verbose=True) + + # Verify error handler was set + handler = get_error_handler() + assert handler is not None + assert isinstance(handler, ErrorHandler) + assert handler.verbose is True + + def test_setup_logging_verbose_flag(self): + """Test that verbose flag is properly passed to error handler.""" + from madengine.mad_cli import setup_logging + + # Test with verbose=False + setup_logging(verbose=False) + handler = get_error_handler() + assert handler.verbose is False + + # Test with verbose=True + setup_logging(verbose=True) + handler = get_error_handler() + assert handler.verbose is True + + def test_build_command_error_handling(self): + """Test that build command imports and can use unified error handling.""" + from madengine.mad_cli import ExitCode + + # Test that the import works and error handling is available + try: + # This tests the actual import in mad_cli.py + from madengine.mad_cli import setup_logging + + # Verify error handler can be set up + setup_logging(verbose=False) + + # Verify handle_error can be imported in the context where it's used + from madengine.core.errors import handle_error, create_error_context + + # Create a test error to ensure the system works + error = Exception("Test build error") + context = create_error_context( + operation="build", + phase="build", + component="CLI" + ) + + # This should not raise an exception + handle_error(error, context=context) + + except ImportError as e: + pytest.fail(f"Error handling integration failed: {e}") + + @patch('madengine.mad_cli.console') + def test_cli_error_display_consistency(self, mock_console): + """Test that CLI errors are displayed consistently through unified handler.""" + from madengine.mad_cli import setup_logging + + # Setup logging to initialize error handler + setup_logging(verbose=False) + + # Get the initialized error handler + handler = get_error_handler() + + # Create a test error + error = ConfigurationError( + "Invalid configuration", + context=create_error_context( + operation="cli_command", + component="CLI", + phase="validation" + ) + ) + + # Handle the error through the unified system + handler.handle_error(error) + + # The error should be displayed through Rich console + # (Note: The actual console calls depend on the handler implementation) + assert handler.console is not None + + +class TestDistributedOrchestratorErrorIntegration: + """Test distributed_orchestrator.py error handling integration.""" + + def test_orchestrator_imports_error_handling(self): + """Test that distributed_orchestrator imports unified error handling.""" + try: + from madengine.tools.distributed_orchestrator import ( + handle_error, create_error_context, ConfigurationError + ) + # If import succeeds, the integration is working + assert handle_error is not None + assert create_error_context is not None + assert ConfigurationError is not None + except ImportError as e: + pytest.fail(f"Error handling imports failed in distributed_orchestrator: {e}") + + @patch('madengine.tools.distributed_orchestrator.handle_error') + @patch('builtins.open', side_effect=FileNotFoundError("File not found")) + @patch('os.path.exists', return_value=True) + def test_orchestrator_credential_loading_error_handling(self, mock_exists, mock_open, mock_handle_error): + """Test that credential loading uses unified error handling.""" + from madengine.tools.distributed_orchestrator import DistributedOrchestrator + + # Mock args object + mock_args = Mock() + mock_args.tags = ["test"] + mock_args.registry = None + mock_args.additional_context = "{}" + mock_args.additional_context_file = None + mock_args.clean_docker_cache = False + mock_args.manifest_output = "test.json" + mock_args.live_output = False + mock_args.output = "test.csv" + mock_args.ignore_deprecated_flag = False + mock_args.data_config_file_name = "data.json" + mock_args.tools_json_file_name = "tools.json" + mock_args.generate_sys_env_details = True + mock_args.force_mirror_local = None + mock_args.disable_skip_gpu_arch = False + mock_args.verbose = False + mock_args._separate_phases = True + + # Create orchestrator (should trigger credential loading) + with patch('madengine.tools.distributed_orchestrator.Context'): + with patch('madengine.tools.distributed_orchestrator.Data'): + try: + orchestrator = DistributedOrchestrator(mock_args) + except Exception: + # Expected to fail due to mocking, but error handling should be called + pass + + # Verify that handle_error was called for credential loading failure + assert mock_handle_error.called + + def test_orchestrator_error_context_creation(self): + """Test that orchestrator creates proper error contexts.""" + from madengine.tools.distributed_orchestrator import create_error_context + + context = create_error_context( + operation="load_credentials", + component="DistributedOrchestrator", + file_path="credential.json" + ) + + assert context.operation == "load_credentials" + assert context.component == "DistributedOrchestrator" + assert context.file_path == "credential.json" + + @patch('madengine.tools.distributed_orchestrator.handle_error') + def test_orchestrator_configuration_error_handling(self, mock_handle_error): + """Test that configuration errors are properly handled with context.""" + from madengine.tools.distributed_orchestrator import ( + ConfigurationError, create_error_context + ) + + # Simulate configuration error handling in orchestrator + error_context = create_error_context( + operation="load_credentials", + component="DistributedOrchestrator", + file_path="credential.json" + ) + + config_error = ConfigurationError( + "Could not load credentials: File not found", + context=error_context, + suggestions=["Check if credential.json exists and has valid JSON format"] + ) + + # Handle the error + mock_handle_error(config_error) + + # Verify the error was handled + mock_handle_error.assert_called_once_with(config_error) + + # Verify error structure + called_error = mock_handle_error.call_args[0][0] + assert isinstance(called_error, ConfigurationError) + assert called_error.context.operation == "load_credentials" + assert called_error.context.component == "DistributedOrchestrator" + assert called_error.suggestions[0] == "Check if credential.json exists and has valid JSON format" + + +class TestErrorHandlingWorkflow: + """Test complete error handling workflow across components.""" + + @patch('madengine.mad_cli.console') + def test_end_to_end_error_flow(self, mock_console): + """Test complete error flow from CLI through orchestrator.""" + from madengine.mad_cli import setup_logging + from madengine.core.errors import ValidationError + + # Setup unified error handling + setup_logging(verbose=True) + handler = get_error_handler() + + # Create an error that might occur in the orchestrator + orchestrator_error = ValidationError( + "Invalid model tag format", + context=create_error_context( + operation="model_discovery", + component="DistributedOrchestrator", + phase="validation", + model_name="invalid::tag" + ), + suggestions=[ + "Use format: model_name:version", + "Check model name contains only alphanumeric characters" + ] + ) + + # Handle the error through the unified system + handler.handle_error(orchestrator_error) + + # Verify the error was processed + assert handler.console is not None + assert orchestrator_error.context.operation == "model_discovery" + assert orchestrator_error.context.component == "DistributedOrchestrator" + assert len(orchestrator_error.suggestions) == 2 + + def test_error_logging_integration(self): + """Test that errors are properly logged with structured data.""" + from madengine.mad_cli import setup_logging + from madengine.core.errors import BuildError + + # Setup logging + setup_logging(verbose=False) + handler = get_error_handler() + + # Create a build error with rich context + build_error = BuildError( + "Docker build failed", + context=create_error_context( + operation="docker_build", + component="DockerBuilder", + phase="build", + model_name="test_model", + additional_info={"dockerfile": "Dockerfile.ubuntu.amd"} + ), + suggestions=["Check Dockerfile syntax", "Verify base image availability"] + ) + + # Mock the logger to capture log calls + with patch.object(handler, 'logger') as mock_logger: + handler.handle_error(build_error) + + # Verify logging was called with structured data + mock_logger.error.assert_called_once() + log_call_args = mock_logger.error.call_args + + # Check the log message + assert "build: Docker build failed" in log_call_args[0][0] + + # Check the extra structured data + extra_data = log_call_args[1]['extra'] + assert extra_data['context']['operation'] == "docker_build" + assert extra_data['context']['component'] == "DockerBuilder" + assert extra_data['recoverable'] is False # BuildError is not recoverable + assert len(extra_data['suggestions']) == 2 + + def test_error_context_serialization(self): + """Test that error contexts can be serialized for logging and debugging.""" + from madengine.core.errors import RuntimeError + + context = create_error_context( + operation="model_execution", + component="ContainerRunner", + phase="runtime", + model_name="llama2", + node_id="worker-node-01", + file_path="/models/llama2/run.sh", + additional_info={ + "container_id": "abc123", + "gpu_count": 2, + "timeout": 3600 + } + ) + + error = RuntimeError( + "Model execution failed with exit code 1", + context=context + ) + + # Test that context can be serialized + context_dict = error.context.__dict__ + json_str = json.dumps(context_dict, default=str) + + # Verify all context information is in the serialized form + assert "model_execution" in json_str + assert "ContainerRunner" in json_str + assert "runtime" in json_str + assert "llama2" in json_str + assert "worker-node-01" in json_str + assert "abc123" in json_str + + +class TestErrorHandlingPerformance: + """Test performance aspects of error handling.""" + + def test_error_handler_initialization_performance(self): + """Test that error handler initialization is fast.""" + import time + from madengine.core.errors import ErrorHandler + from rich.console import Console + + start_time = time.time() + + # Create multiple error handlers + for _ in range(100): + console = Console() + handler = ErrorHandler(console=console, verbose=False) + + end_time = time.time() + + # Should be able to create 100 handlers in under 1 second + assert end_time - start_time < 1.0 + + def test_error_context_creation_performance(self): + """Test that error context creation is efficient.""" + import time + + start_time = time.time() + + # Create many error contexts + for i in range(1000): + context = create_error_context( + operation=f"operation_{i}", + component=f"Component_{i}", + phase="test", + model_name=f"model_{i}", + additional_info={"iteration": i} + ) + + end_time = time.time() + + # Should be able to create 1000 contexts in under 0.1 seconds + assert end_time - start_time < 0.1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_cli_features.py b/tests/test_cli_features.py new file mode 100644 index 00000000..1a20fa7b --- /dev/null +++ b/tests/test_cli_features.py @@ -0,0 +1,133 @@ +"""Test various CLI features and command-line arguments. + +This module tests various command-line argument behaviors including: +- Output file path specification (-o flag) +- GPU architecture checking and skip flags +- Multiple results output handling + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +# built-in modules +import os +import sys +import csv +import pandas as pd + +# 3rd party modules +import pytest + +# project modules +from .fixtures.utils import BASE_DIR, MODEL_DIR +from .fixtures.utils import global_data +from .fixtures.utils import clean_test_temp_files + + +class TestCLIFeatures: + """Test various CLI features and command-line argument behaviors.""" + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf_test.csv", "perf_test.html"]], indirect=True + ) + def test_output_commandline_argument_writes_csv_correctly( + self, global_data, clean_test_temp_files + ): + """ + Test that -o/--output command-line argument writes CSV file to specified path. + """ + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy -o perf_test.csv" + ) + success = False + with open(os.path.join(BASE_DIR, "perf_test.csv"), "r") as csv_file: + csv_reader = csv.DictReader(csv_file) + for row in csv_reader: + if row["model"] == "dummy": + if row["status"] == "SUCCESS": + success = True + break + else: + pytest.fail("model in perf_test.csv did not run successfully.") + if not success: + pytest.fail("model, dummy, not found in perf_test.csv.") + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf_test.csv", "perf_test.html"]], indirect=True + ) + def test_commandline_argument_skip_gpu_arch( + self, global_data, clean_test_temp_files + ): + """ + Test that skip_gpu_arch command-line argument skips GPU architecture check. + """ + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_skip_gpu_arch" + ) + if "Skipping model" not in output: + pytest.fail("Enable skipping gpu arch for running model is failed.") + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf_test.csv", "perf_test.html"]], indirect=True + ) + def test_commandline_argument_disable_skip_gpu_arch_fail( + self, global_data, clean_test_temp_files + ): + """ + Test that --disable-skip-gpu-arch fails GPU architecture check as expected. + """ + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_skip_gpu_arch --disable-skip-gpu-arch" + ) + # Check if exception with message 'Skipping model' is thrown + if "Skipping model" in output: + pytest.fail("Disable skipping gpu arch for running model is failed.") + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf_test.csv", "perf_test.html"]], indirect=True + ) + def test_output_multi_results(self, global_data, clean_test_temp_files): + """ + Test that multiple results are correctly written and merged into output CSV. + """ + output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_multi") + # Check if multiple results are written to perf_dummy.csv + success = False + # Read the csv file to a dataframe using pandas + multi_df = pd.read_csv(os.path.join(BASE_DIR, 'perf_dummy.csv')) + # Check the number of rows in the dataframe is 4, and columns is 4 + if multi_df.shape == (4, 4): + success = True + if not success: + pytest.fail("The generated multi results is not correct.") + # Check if multiple results from perf_dummy.csv get copied over to perf.csv + perf_df = pd.read_csv(os.path.join(BASE_DIR, 'perf.csv')) + # Get the corresponding rows and columns from perf.csv + perf_df = perf_df[multi_df.columns] + perf_df = perf_df.iloc[-4:, :] + # Drop model columns from both dataframes; these will not match + # if multiple results csv has {model}, then perf csv has {tag_name}_{model} + multi_df = multi_df.drop('model', axis=1) + perf_df = perf_df.drop('model', axis=1) + if all(perf_df.columns == multi_df.columns): + success = True + if not success: + pytest.fail("The columns of the generated multi results do not match perf.csv.") + diff --git a/tests/test_console.py b/tests/test_console.py index 6ed0cb79..e6a700a0 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -4,25 +4,29 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import subprocess import typing + # third-party modules import pytest import typing_extensions + # project modules from madengine.core import console class TestConsole: """Test the console module. - + test_sh: Test the console.sh function with echo command. """ + def test_sh(self): obj = console.Console() assert obj.sh("echo MAD Engine") == "MAD Engine" - + def test_sh_fail(self): obj = console.Console() try: @@ -47,7 +51,9 @@ def test_sh_secret(self): def test_sh_env(self): obj = console.Console() - assert obj.sh("echo $MAD_ENGINE", env={"MAD_ENGINE": "MAD Engine"}) == "MAD Engine" + assert ( + obj.sh("echo $MAD_ENGINE", env={"MAD_ENGINE": "MAD Engine"}) == "MAD Engine" + ) def test_sh_verbose(self): obj = console.Console(shellVerbose=False) diff --git a/tests/test_container_runner.py b/tests/test_container_runner.py new file mode 100644 index 00000000..0df2831f --- /dev/null +++ b/tests/test_container_runner.py @@ -0,0 +1,445 @@ +"""Test the container runner module. + +This module tests the Docker container execution functionality for distributed execution. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +# built-in modules +import os +import json +import tempfile +import unittest.mock +from unittest.mock import patch, MagicMock, mock_open + +# third-party modules +import pytest + +# project modules +from madengine.tools.container_runner import ContainerRunner +from madengine.core.context import Context +from madengine.core.console import Console +from madengine.core.dataprovider import Data +from .fixtures.utils import BASE_DIR, MODEL_DIR + + +class TestContainerRunner: + """Test the container runner module.""" + + @patch("madengine.core.context.Context") + def test_container_runner_initialization(self, mock_context_class): + """Test ContainerRunner initialization.""" + mock_context = MagicMock() + mock_context_class.return_value = mock_context + context = mock_context_class() + console = Console() + data = MagicMock() + + runner = ContainerRunner(context, data, console) + + assert runner.context == context + assert runner.data == data + assert runner.console == console + assert runner.credentials is None + + def test_container_runner_initialization_minimal(self): + """Test ContainerRunner initialization with minimal parameters.""" + runner = ContainerRunner() + + assert runner.context is None + assert runner.data is None + assert isinstance(runner.console, Console) + assert runner.credentials is None + + def test_load_build_manifest(self): + """Test loading build manifest from file.""" + runner = ContainerRunner() + + manifest_data = { + "images": { + "model1": "localhost:5000/ci-model1:latest", + "model2": "localhost:5000/ci-model2:latest", + }, + "metadata": { + "build_time": "2023-01-01T12:00:00Z", + "registry": "localhost:5000", + }, + } + + with patch("builtins.open", mock_open(read_data=json.dumps(manifest_data))): + result = runner.load_build_manifest("test_manifest.json") + + assert result == manifest_data + assert "images" in result + assert "model1" in result["images"] + + @patch.object(Console, "sh") + def test_pull_image(self, mock_sh): + """Test pulling image from registry.""" + runner = ContainerRunner() + + mock_sh.return_value = "Pull successful" + + result = runner.pull_image("localhost:5000/test:latest") + + assert result == "localhost:5000/test:latest" + mock_sh.assert_called_with("docker pull localhost:5000/test:latest") + + @patch.object(Console, "sh") + def test_pull_image_with_local_name(self, mock_sh): + """Test pulling image with local name tagging.""" + runner = ContainerRunner() + + mock_sh.return_value = "Success" + + result = runner.pull_image("localhost:5000/test:latest", "local-test") + + assert result == "local-test" + # Should have called pull and tag + expected_calls = [ + unittest.mock.call("docker pull localhost:5000/test:latest"), + unittest.mock.call("docker tag localhost:5000/test:latest local-test"), + ] + mock_sh.assert_has_calls(expected_calls) + + @patch("madengine.core.context.Context") + def test_get_gpu_arg_all_gpus(self, mock_context_class): + """Test get_gpu_arg with all GPUs requested.""" + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "AMD", "MAD_SYSTEM_NGPUS": "4"}, + "docker_gpus": "0,1,2,3", + "gpu_renderDs": [128, 129, 130, 131], # Mock render device IDs for AMD GPUs + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + result = runner.get_gpu_arg("-1") + + # Should return GPU args for all available GPUs + assert "--device=/dev/kfd" in result and "renderD" in result + + @patch("madengine.core.context.Context") + def test_get_gpu_arg_specific_gpus(self, mock_context_class): + """Test get_gpu_arg with specific GPUs requested.""" + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "NVIDIA", "MAD_SYSTEM_NGPUS": "4"}, + "docker_gpus": "0,1,2,3", + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + result = runner.get_gpu_arg("2") + + # Should return GPU args for 2 GPUs + assert "gpu" in result.lower() + + @patch("madengine.core.context.Context") + def test_get_gpu_arg_range_format(self, mock_context_class): + """Test get_gpu_arg with range format.""" + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "NVIDIA", "MAD_SYSTEM_NGPUS": "4"}, + "docker_gpus": "0-3", + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + result = runner.get_gpu_arg("2") + + # Should handle range format correctly + assert isinstance(result, str) + + @patch("madengine.core.context.Context") + @patch.object(Console, "sh") + @patch("madengine.tools.container_runner.Docker") + def test_run_container_success( + self, mock_docker_class, mock_sh, mock_context_class + ): + """Test successful container run.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "NVIDIA", "MAD_SYSTEM_NGPUS": "2"}, + "docker_gpus": "0,1", + "gpu_vendor": "NVIDIA", + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + # Mock Docker instance + mock_docker = MagicMock() + mock_docker.sh.return_value = "Command output" + mock_docker_class.return_value = mock_docker + + mock_sh.return_value = "hostname" + + model_info = { + "name": "test_model", + "n_gpus": "1", + "scripts": "test_script.sh", + "args": "", + } + + with patch.object(runner, "get_gpu_arg", return_value="--gpus device=0"): + with patch.object(runner, "get_cpu_arg", return_value=""): + with patch.object(runner, "get_env_arg", return_value=""): + with patch.object(runner, "get_mount_arg", return_value=""): + result = runner.run_container( + model_info, "test-image", timeout=300 + ) + + assert result["status"] == "SUCCESS" + assert "test_duration" in result + assert mock_docker_class.called + + @patch("madengine.core.context.Context") + @patch.object(Console, "sh") + @patch("madengine.tools.container_runner.Docker") + def test_run_container_timeout( + self, mock_docker_class, mock_sh, mock_context_class + ): + """Test container run with timeout.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "NVIDIA", "MAD_SYSTEM_NGPUS": "2"}, + "docker_gpus": "0,1", + "gpu_vendor": "NVIDIA", + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + # Mock Docker instance that raises TimeoutError + mock_docker = MagicMock() + mock_docker.sh.side_effect = TimeoutError("Timeout occurred") + mock_docker_class.return_value = mock_docker + + mock_sh.return_value = "hostname" + + model_info = { + "name": "test_model", + "n_gpus": "1", + "scripts": "test_script.sh", + "args": "", + } + + with patch.object(runner, "get_gpu_arg", return_value="--gpus device=0"): + with patch.object(runner, "get_cpu_arg", return_value=""): + with patch.object(runner, "get_env_arg", return_value=""): + with patch.object(runner, "get_mount_arg", return_value=""): + # run_container catches exceptions and returns results with status + result = runner.run_container( + model_info, "test-image", timeout=10 + ) + assert result["status"] == "FAILURE" + + @patch("madengine.core.context.Context") + @patch.object(Console, "sh") + @patch("madengine.tools.container_runner.Docker") + def test_run_container_failure( + self, mock_docker_class, mock_sh, mock_context_class + ): + """Test container run failure.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "NVIDIA", "MAD_SYSTEM_NGPUS": "2"}, + "docker_gpus": "0,1", + "gpu_vendor": "NVIDIA", + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + # Mock Docker instance that raises RuntimeError + mock_docker = MagicMock() + mock_docker.sh.side_effect = RuntimeError("Container failed to start") + mock_docker_class.return_value = mock_docker + + mock_sh.return_value = "hostname" + + model_info = { + "name": "test_model", + "n_gpus": "1", + "scripts": "test_script.sh", + "args": "", + } + + with patch.object(runner, "get_gpu_arg", return_value="--gpus device=0"): + with patch.object(runner, "get_cpu_arg", return_value=""): + with patch.object(runner, "get_env_arg", return_value=""): + with patch.object(runner, "get_mount_arg", return_value=""): + # run_container catches exceptions and returns results with status + result = runner.run_container( + model_info, "test-image", timeout=300 + ) + assert result["status"] == "FAILURE" + + @patch("madengine.core.context.Context") + def test_load_credentials(self, mock_context_class): + """Test setting credentials for container runner.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + credentials = {"github": {"username": "testuser", "password": "testpass"}} + + runner.set_credentials(credentials) + + assert runner.credentials == credentials + + @patch("madengine.core.context.Context") + def test_login_to_registry(self, mock_context_class): + """Test login to Docker registry.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + credentials = { + "localhost:5000": {"username": "testuser", "password": "testpass"} + } + + with patch.object(runner.console, "sh") as mock_sh: + mock_sh.return_value = "Login Succeeded" + runner.login_to_registry("localhost:5000", credentials) + + # Verify login command was called + assert mock_sh.called + + @patch("madengine.core.context.Context") + def test_get_gpu_arg_specific_gpu(self, mock_context_class): + """Test getting GPU arguments for specific GPU count.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": {"MAD_GPU_VENDOR": "NVIDIA", "MAD_SYSTEM_NGPUS": "4"}, + "docker_gpus": "0,1,2,3", + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + result = runner.get_gpu_arg("2") + + # Should return GPU args for 2 GPUs + assert "gpu" in result.lower() or "device" in result.lower() + + @patch("madengine.core.context.Context") + def test_get_cpu_arg(self, mock_context_class): + """Test getting CPU arguments for docker run.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = {"docker_cpus": "0,1,2,3"} + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + result = runner.get_cpu_arg() + + assert "--cpuset-cpus" in result + assert "0,1,2,3" in result + + @patch("madengine.core.context.Context") + def test_get_env_arg(self, mock_context_class): + """Test getting environment variables for container.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = { + "docker_env_vars": { + "MAD_GPU_VENDOR": "NVIDIA", + "MAD_MODEL_NAME": "test_model", + "CUSTOM_VAR": "custom_value", + } + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + custom_env = {"EXTRA_VAR": "extra_value"} + result = runner.get_env_arg(custom_env) + + assert "--env MAD_GPU_VENDOR=" in result + assert "--env EXTRA_VAR=" in result + + @patch("madengine.core.context.Context") + def test_get_mount_arg(self, mock_context_class): + """Test getting mount arguments for container.""" + # Mock context to avoid GPU detection + mock_context = MagicMock() + mock_context.ctx = { + "docker_mounts": { + "/container/data": "/host/data", + "/container/output": "/host/output", + } + } + mock_context_class.return_value = mock_context + runner = ContainerRunner(mock_context) + + mount_datapaths = [ + {"path": "/host/input", "home": "/container/input", "readwrite": "false"} + ] + + result = runner.get_mount_arg(mount_datapaths) + + assert "-v /host/input:/container/input:ro" in result + assert "-v /host/data:/container/data" in result + + def test_apply_tools_without_tools_config(self): + """Test applying tools when no tools configuration exists.""" + runner = ContainerRunner() + + # Mock context without tools + runner.context = MagicMock() + runner.context.ctx = {} + + pre_encapsulate_post_scripts = { + "pre_scripts": [], + "encapsulate_script": "", + "post_scripts": [], + } + run_env = {} + + # Should not raise any exception + runner.apply_tools(pre_encapsulate_post_scripts, run_env, "nonexistent.json") + + # Scripts should remain unchanged + assert pre_encapsulate_post_scripts["pre_scripts"] == [] + assert pre_encapsulate_post_scripts["encapsulate_script"] == "" + assert run_env == {} + + def test_run_pre_post_script(self): + """Test running pre/post scripts.""" + runner = ContainerRunner() + + # Mock Docker instance + mock_docker = MagicMock() + mock_docker.sh = MagicMock() + + scripts = [ + {"path": "/path/to/script1.sh", "args": "arg1 arg2"}, + {"path": "/path/to/script2.sh"}, + ] + + runner.run_pre_post_script(mock_docker, "model_dir", scripts) + + # Verify scripts were copied and executed + assert mock_docker.sh.call_count == 4 # 2 copies + 2 executions + + # Check if copy commands were called + copy_calls = [ + call for call in mock_docker.sh.call_args_list if "cp -vLR" in str(call) + ] + assert len(copy_calls) == 2 + + def test_initialization_with_all_parameters(self): + """Test ContainerRunner initialization with all parameters.""" + context = MagicMock() + console = Console() + data = MagicMock() + + runner = ContainerRunner(context, data, console) + + assert runner.context == context + assert runner.data == data + assert runner.console == console + assert runner.credentials is None diff --git a/tests/test_contexts.py b/tests/test_contexts.py index 45ba117f..5942e24a 100644 --- a/tests/test_contexts.py +++ b/tests/test_contexts.py @@ -2,12 +2,15 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import sys import csv + # third-party modules import pytest + # project modules from .fixtures.utils import BASE_DIR, MODEL_DIR from .fixtures.utils import global_data @@ -15,237 +18,385 @@ from .fixtures.utils import get_gpu_nodeid_map from .fixtures.utils import get_num_gpus from .fixtures.utils import get_num_cpus +from .fixtures.utils import requires_gpu + from madengine.core.context import Context class TestContexts: - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_dockerfile_picked_on_detected_context_0(self, global_data, clean_test_temp_files): - """ + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_dockerfile_picked_on_detected_context_0( + self, global_data, clean_test_temp_files + ): + """ picks dockerfile based on detected context and only those """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '0': + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "0": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("model did not pick correct context.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'ctx_test']], indirect=True) - def test_dockerfile_picked_on_detected_context_1(self, global_data, clean_test_temp_files): - """ + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html", "ctx_test"]], indirect=True + ) + def test_dockerfile_picked_on_detected_context_1( + self, global_data, clean_test_temp_files + ): + """ picks dockerfile based on detected context and only those """ - with open(os.path.join(BASE_DIR, 'ctx_test'), 'w') as ctx_test_file: + with open(os.path.join(BASE_DIR, "ctx_test"), "w") as ctx_test_file: print("1", file=ctx_test_file) - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '1': + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "1": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("model did not pick correct context.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'ctx_test']], indirect=True) - def test_all_dockerfiles_matching_context_executed(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html", "ctx_test"]], indirect=True + ) + def test_all_dockerfiles_matching_context_executed( + self, global_data, clean_test_temp_files + ): """ All dockerfiles matching context is executed """ - with open(os.path.join(BASE_DIR, 'ctx_test'), 'w') as ctx_test_file: + with open(os.path.join(BASE_DIR, "ctx_test"), "w") as ctx_test_file: print("2", file=ctx_test_file) - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest " + ) foundDockerfiles = [] - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '2': - foundDockerfiles.append(row['docker_file'].replace(f'{MODEL_DIR}/', '')) + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "2": + foundDockerfiles.append( + row["docker_file"].replace(f"{MODEL_DIR}/", "") + ) else: pytest.fail("model in perf_test.csv did not run successfully.") - if not ("docker/dummy_ctxtest.ctx2a.ubuntu.amd.Dockerfile" in foundDockerfiles and - "docker/dummy_ctxtest.ctx2b.ubuntu.amd.Dockerfile" in foundDockerfiles ): - pytest.fail("All dockerfiles matching context is not executed. Executed dockerfiles are " + ' '.join(foundDockerfiles)) + if not ( + "docker/dummy_ctxtest.ctx2a.ubuntu.amd.Dockerfile" in foundDockerfiles + and "docker/dummy_ctxtest.ctx2b.ubuntu.amd.Dockerfile" in foundDockerfiles + ): + pytest.fail( + "All dockerfiles matching context is not executed. Executed dockerfiles are " + + " ".join(foundDockerfiles) + ) def test_dockerfile_executed_if_contexts_keys_are_not_common(self): """ - Dockerfile is executed even if all context keys are not common but common keys match + Dockerfile is executed even if all context keys are not common but common keys match """ # already tested in test_dockerfile_picked_on_detected_context_0 pass - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_can_override_context_with_additionalContext_commandline(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_can_override_context_with_additionalContext_commandline( + self, global_data, clean_test_temp_files + ): """ - Context can be overridden through additional-context command-line argument + Context can be overridden through additional-context command-line argument """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'ctx_test': '1'}\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'ctx_test': '1'}\" " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '1': + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "1": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("model did not pick correct context.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'ctx.json']], indirect=True) - def test_can_override_context_with_additionalContextFile_commandline(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html", "ctx.json"]], indirect=True + ) + def test_can_override_context_with_additionalContextFile_commandline( + self, global_data, clean_test_temp_files + ): """ - Context can be overridden through additional-context-file + Context can be overridden through additional-context-file """ - with open(os.path.join(BASE_DIR, 'ctx.json'), 'w') as ctx_json_file: - print("{ \"ctx_test\": \"1\" }", file=ctx_json_file) - - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context-file ctx.json ") + with open(os.path.join(BASE_DIR, "ctx.json"), "w") as ctx_json_file: + print('{ "ctx_test": "1" }', file=ctx_json_file) + + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context-file ctx.json " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '1': + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "1": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("model did not pick correct context.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'ctx.json']], indirect=True) - def test_additionalContext_commandline_overrides_additionalContextFile(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html", "ctx.json"]], indirect=True + ) + def test_additionalContext_commandline_overrides_additionalContextFile( + self, global_data, clean_test_temp_files + ): """ additional-context command-line argument has priority over additional-context-file """ - with open(os.path.join(BASE_DIR, 'ctx.json'), 'w') as ctx_json_file: - print("{ \"ctx_test\": \"2\" }", file=ctx_json_file) - - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context-file ctx.json --additional-context \"{'ctx_test': '1'}\" ") + with open(os.path.join(BASE_DIR, "ctx.json"), "w") as ctx_json_file: + print('{ "ctx_test": "2" }', file=ctx_json_file) + + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context-file ctx.json --additional-context \"{'ctx_test': '1'}\" " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '1': + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "1": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("model did not pick correct context.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_base_docker_override(self, global_data, clean_test_temp_files): """ BASE_DOCKER overrides base docker """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'docker_build_arg':{'BASE_DOCKER':'rocm/tensorflow' }}\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'docker_build_arg':{'BASE_DOCKER':'rocm/tensorflow' }}\" " + ) foundBaseDocker = [] - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '0': - foundBaseDocker.append(row['base_docker']) + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "0": + foundBaseDocker.append(row["base_docker"]) else: pytest.fail("model in perf_test.csv did not run successfully.") if not "rocm/tensorflow" in foundBaseDocker: - pytest.fail("BASE_DOCKER does not override base docker. Expected: rocm/tensorflow Found:" + foundBaseDocker) - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + pytest.fail( + "BASE_DOCKER does not override base docker. Expected: rocm/tensorflow Found:" + + foundBaseDocker + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_docker_image_override(self, global_data, clean_test_temp_files): """ Using user-provided image passed in with MAD_CONTAINER_IMAGE """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'docker_env_vars':{'ctxtest':'1'},'MAD_CONTAINER_IMAGE':'rocm/tensorflow:latest' }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'docker_env_vars':{'ctxtest':'1'},'MAD_CONTAINER_IMAGE':'rocm/tensorflow:latest' }\" " + ) foundLocalImage = None - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '1': - foundLocalImage = row['docker_image'] + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "1": + foundLocalImage = row["docker_image"] else: pytest.fail("model in perf_test.csv did not run successfully.") if not "rocm/tensorflow:latest" in foundLocalImage: - pytest.fail("MAD_CONTAINER_IMAGE does not override docker image. Expected: rocm/tensorflow:latest Found:" + foundLocalImage) - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + pytest.fail( + "MAD_CONTAINER_IMAGE does not override docker image. Expected: rocm/tensorflow:latest Found:" + + foundLocalImage + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_docker_env_vars_override(self, global_data, clean_test_temp_files): """ - docker_env_vars pass environment variables into docker container + docker_env_vars pass environment variables into docker container """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'docker_env_vars':{'ctxtest':'1'} }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_ctxtest --additional-context \"{'docker_env_vars':{'ctxtest':'1'} }\" " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_ctxtest': - if row['status'] == 'SUCCESS' and row['performance'] == '1': + if row["model"] == "dummy_ctxtest": + if row["status"] == "SUCCESS" and row["performance"] == "1": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: - pytest.fail("docker_env_vars did not pass environment variables into docker container.") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_docker_mounts_mount_host_paths_in_docker_container(self, global_data, clean_test_temp_files): + pytest.fail( + "docker_env_vars did not pass environment variables into docker container." + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_docker_mounts_mount_host_paths_in_docker_container( + self, global_data, clean_test_temp_files + ): """ - docker_mounts mount host paths inside docker containers + docker_mounts mount host paths inside docker containers """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_mountpath --additional-context \"{'docker_env_vars':{'MAD_DATAHOME':'/data'}, 'docker_mounts':{'/data':'/tmp'} }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_mountpath --additional-context \"{'docker_env_vars':{'MAD_DATAHOME':'/data'}, 'docker_mounts':{'/data':'/tmp'} }\" " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_mountpath': - if row['status'] == 'SUCCESS': + if row["model"] == "dummy_mountpath": + if row["status"] == "SUCCESS": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: - pytest.fail("docker_mounts did not mount host paths inside docker container.") - - @pytest.mark.skipif(get_num_gpus() < 8, reason="test requires atleast 8 gpus") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html','results_dummy_gpubind.csv']], indirect=True) + pytest.fail( + "docker_mounts did not mount host paths inside docker container." + ) + + @requires_gpu("docker gpus requires GPU hardware") + @pytest.mark.skipif( + lambda: get_num_gpus() < 8, reason="test requires atleast 8 gpus" + ) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "results_dummy_gpubind.csv"]], + indirect=True, + ) def test_docker_gpus(self, global_data, clean_test_temp_files): """ docker_gpus binds gpus to docker containers """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_gpubind --additional-context \"{'docker_gpus':'0,2-4,5-5,7'}\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_gpubind --additional-context \"{'docker_gpus':'0,2-4,5-5,7'}\" " + ) gpu_nodeid_map = get_gpu_nodeid_map() - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) gpu_node_ids = [] for row in csv_reader: - if 'dummy_gpubind' in row['model']: - if row['status'] == 'SUCCESS': - gpu_node_ids.append(row['performance']) + if "dummy_gpubind" in row["model"]: + if row["status"] == "SUCCESS": + gpu_node_ids.append(row["performance"]) else: pytest.fail("model in perf_test.csv did not run successfully.") @@ -263,21 +414,38 @@ def test_docker_gpus(self, global_data, clean_test_temp_files): if sorted_gpus != [0, 2, 3, 4, 5, 7]: pytest.fail(f"docker_gpus did not bind expected gpus in docker container. Expected: [0, 2, 3, 4, 5, 7], Got: {sorted_gpus}, Raw node IDs: {gpu_node_ids}, Mapping: {gpu_nodeid_map}") - @pytest.mark.skipif(get_num_cpus() < 64, reason="test requires atleast 64 cpus") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html','results_dummy_cpubind.csv']], indirect=True) + @pytest.mark.skipif( + lambda: get_num_cpus() < 64, reason="test requires atleast 64 cpus" + ) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "results_dummy_cpubind.csv"]], + indirect=True, + ) def test_docker_cpus(self, global_data, clean_test_temp_files): """ docker_cpus binds cpus to docker containers """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_cpubind --additional-context \"{'docker_cpus':'14-18,32,44-44,62'}\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_cpubind --additional-context \"{'docker_cpus':'14-18,32,44-44,62'}\" " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if 'dummy_cpubind' in row['model']: - if row['status'] == 'SUCCESS' and row['performance']=="14-18|32|44|62": + if "dummy_cpubind" in row["model"]: + if ( + row["status"] == "SUCCESS" + and row["performance"] == "14-18|32|44|62" + ): success = True else: pytest.fail("model in perf_test.csv did not run successfully.") diff --git a/tests/test_custom_timeouts.py b/tests/test_custom_timeouts.py index 09ba62ea..79a9ad61 100644 --- a/tests/test_custom_timeouts.py +++ b/tests/test_custom_timeouts.py @@ -2,6 +2,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + import pytest import os import re @@ -13,19 +14,38 @@ from .fixtures.utils import clean_test_temp_files from .fixtures.utils import is_nvidia + class TestCustomTimeoutsFunctionality: - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_default_model_timeout_2hrs(self, global_data, clean_test_temp_files): - """ + """ default model timeout is 2 hrs This test only checks if the timeout is set; it does not actually time the model. """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy") - - regexp = re.compile(r'Setting timeout to ([0-9]*) seconds.') + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy" + ) + + regexp = re.compile(r"Setting timeout to ([0-9]*) seconds.") foundTimeout = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -33,20 +53,38 @@ def test_default_model_timeout_2hrs(self, global_data, clean_test_temp_files): match = regexp.search(line) if match: foundTimeout = match.groups()[0] - if foundTimeout != '7200': + if foundTimeout != "7200": pytest.fail("default model timeout is not 2 hrs (" + foundTimeout + "s).") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_can_override_timeout_in_model(self, global_data, clean_test_temp_files): """ - timeout can be overridden in model + timeout can be overridden in model This test only checks if the timeout is set; it does not actually time the model. """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_timeout") - - regexp = re.compile(r'Setting timeout to ([0-9]*) seconds.') + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_timeout" + ) + + regexp = re.compile(r"Setting timeout to ([0-9]*) seconds.") foundTimeout = None - with open( os.path.join(BASE_DIR, "dummy_timeout_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_timeout_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -54,20 +92,44 @@ def test_can_override_timeout_in_model(self, global_data, clean_test_temp_files) match = regexp.search(line) if match: foundTimeout = match.groups()[0] - if foundTimeout != '360': - pytest.fail("timeout in models.json (360s) could not override actual timeout (" + foundTimeout + "s).") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_can_override_timeout_in_commandline(self, global_data, clean_test_temp_files): + if foundTimeout != "360": + pytest.fail( + "timeout in models.json (360s) could not override actual timeout (" + + foundTimeout + + "s)." + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_can_override_timeout_in_commandline( + self, global_data, clean_test_temp_files + ): """ timeout command-line argument overrides default timeout This test only checks if the timeout is set; it does not actually time the model. """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --timeout 120") - - regexp = re.compile(r'Setting timeout to ([0-9]*) seconds.') + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --timeout 120" + ) + + regexp = re.compile(r"Setting timeout to ([0-9]*) seconds.") foundTimeout = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -75,20 +137,44 @@ def test_can_override_timeout_in_commandline(self, global_data, clean_test_temp_ match = regexp.search(line) if match: foundTimeout = match.groups()[0] - if foundTimeout != '120': - pytest.fail("timeout command-line argument (120s) could not override actual timeout (" + foundTimeout + "s).") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_commandline_timeout_overrides_model_timeout(self, global_data, clean_test_temp_files): + if foundTimeout != "120": + pytest.fail( + "timeout command-line argument (120s) could not override actual timeout (" + + foundTimeout + + "s)." + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_commandline_timeout_overrides_model_timeout( + self, global_data, clean_test_temp_files + ): """ timeout command-line argument overrides model timeout This test only checks if the timeout is set; it does not actually time the model. """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_timeout --timeout 120") - - regexp = re.compile(r'Setting timeout to ([0-9]*) seconds.') + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_timeout --timeout 120" + ) + + regexp = re.compile(r"Setting timeout to ([0-9]*) seconds.") foundTimeout = None - with open( os.path.join(BASE_DIR, "dummy_timeout_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_timeout_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -96,31 +182,65 @@ def test_commandline_timeout_overrides_model_timeout(self, global_data, clean_te match = regexp.search(line) if match: foundTimeout = match.groups()[0] - if foundTimeout != '120': - pytest.fail("timeout in command-line argument (360s) could not override model.json timeout (" + foundTimeout + "s).") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) - def test_timeout_in_commandline_timesout_correctly(self, global_data, clean_test_temp_files): + if foundTimeout != "120": + pytest.fail( + "timeout in command-line argument (360s) could not override model.json timeout (" + + foundTimeout + + "s)." + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) + def test_timeout_in_commandline_timesout_correctly( + self, global_data, clean_test_temp_files + ): """ timeout command-line argument times model out correctly """ start_time = time.time() - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_sleep --timeout 60", canFail = True, timeout = 180) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_sleep --timeout 60", + canFail=True, + timeout=180, + ) test_duration = time.time() - start_time assert test_duration == pytest.approx(60, 10) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) - def test_timeout_in_model_timesout_correctly(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) + def test_timeout_in_model_timesout_correctly( + self, global_data, clean_test_temp_files + ): """ timeout in models.json times model out correctly """ start_time = time.time() - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_sleep", canFail = True, timeout = 180) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_sleep", + canFail=True, + timeout=180, + ) test_duration = time.time() - start_time assert test_duration == pytest.approx(120, 20) - - diff --git a/tests/test_data_provider.py b/tests/test_data_provider.py index ba45be5a..34d290a8 100644 --- a/tests/test_data_provider.py +++ b/tests/test_data_provider.py @@ -2,6 +2,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import sys @@ -9,8 +10,10 @@ import re import json import tempfile + # third-party modules import pytest + # project modules from .fixtures.utils import BASE_DIR, MODEL_DIR from .fixtures.utils import global_data @@ -25,86 +28,121 @@ def test_reorder_data_provider_config(self): Test the reorder_data_provider_config function to ensure it correctly orders data provider types """ # Create a temporary data.json file with shuffled data provider types - with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".json", delete=False + ) as temp_file: test_data = { "test_data": { "aws": {"path": "s3://bucket/path"}, "local": {"path": "/local/path"}, "nas": {"path": "/nas/path"}, "custom": {"path": "scripts/custom.sh"}, - "minio": {"path": "minio://bucket/path"} + "minio": {"path": "minio://bucket/path"}, } } json.dump(test_data, temp_file) temp_file_path = temp_file.name - + try: # Create Data object with the test file data_obj = Data(filename=temp_file_path) - + # Check the initial order (should be as defined in the test_data) original_keys = list(data_obj.data_provider_config["test_data"].keys()) - + # Call the reorder function data_obj.reorder_data_provider_config("test_data") - + # Check the order after reordering reordered_keys = list(data_obj.data_provider_config["test_data"].keys()) expected_order = ["custom", "local", "minio", "nas", "aws"] - + # Filter expected_order to only include keys that exist in original_keys expected_filtered = [k for k in expected_order if k in original_keys] - + # Assert that the reordering happened correctly - assert reordered_keys == expected_filtered, f"Expected order {expected_filtered}, got {reordered_keys}" - + assert ( + reordered_keys == expected_filtered + ), f"Expected order {expected_filtered}, got {reordered_keys}" + # Specifically check that custom comes first, if it exists if "custom" in original_keys: - assert reordered_keys[0] == "custom", "Custom should be first in the order" - + assert ( + reordered_keys[0] == "custom" + ), "Custom should be first in the order" + # Check that the order matches the expected priority for i, key in enumerate(reordered_keys): expected_index = expected_order.index(key) - for j, other_key in enumerate(reordered_keys[i+1:], i+1): + for j, other_key in enumerate(reordered_keys[i + 1 :], i + 1): other_expected_index = expected_order.index(other_key) - assert expected_index < other_expected_index, f"{key} should come before {other_key}" - + assert ( + expected_index < other_expected_index + ), f"{key} should come before {other_key}" + finally: # Clean up the temporary file os.unlink(temp_file_path) - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_local_data_provider_runs_successfully(self, global_data, clean_test_temp_files): + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_local_data_provider_runs_successfully( + self, global_data, clean_test_temp_files + ): """ - local data provider gets data from local disk + local data provider gets data from local disk """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_data_local ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_data_local " + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_data_local': - if row['status'] == 'SUCCESS': + if row["model"] == "dummy_data_local": + if row["status"] == "SUCCESS": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("local data provider test failed") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) - def test_model_executes_even_if_data_provider_fails(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) + def test_model_executes_even_if_data_provider_fails( + self, global_data, clean_test_temp_files + ): """ - model executes even if data provider fails + model executes even if data provider fails """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_data_local_fail --additional-context \"{'docker_env_vars':{'MAD_DATAHOME':'/data'} }\" --live-output ", canFail=True) + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_data_local_fail --additional-context \"{'docker_env_vars':{'MAD_DATAHOME':'/data'} }\" --live-output ", + canFail=True, + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_data_local_fail': - if row['status'] == 'FAILURE': + if row["model"] == "dummy_data_local_fail": + if row["status"] == "FAILURE": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") @@ -112,30 +150,43 @@ def test_model_executes_even_if_data_provider_fails(self, global_data, clean_tes pytest.fail("local data provider fail test passed") # Search for "/data is NOT mounted" to ensure model script ran - regexp = re.compile(r'is NOT mounted') + regexp = re.compile(r"is NOT mounted") if not regexp.search(output): pytest.fail("model did not execute after data provider failed") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'dataLocal']], indirect=True) - def test_local_data_provider_mirrorlocal_does_not_mirror_data(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html", "dataLocal"]], indirect=True + ) + def test_local_data_provider_mirrorlocal_does_not_mirror_data( + self, global_data, clean_test_temp_files + ): """ In local data provider, mirrorlocal field in data.json does not mirror data in local disk """ mirrorPath = os.path.join(BASE_DIR, "dataLocal") - os.mkdir( mirrorPath ) - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_data_local --force-mirror-local " + mirrorPath ) + os.mkdir(mirrorPath) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_data_local --force-mirror-local " + + mirrorPath + ) success = False - with open(os.path.join(BASE_DIR, 'perf.csv'), 'r') as csv_file: + with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row['model'] == 'dummy_data_local': - if row['status'] == 'SUCCESS': + if row["model"] == "dummy_data_local": + if row["status"] == "SUCCESS": success = True else: pytest.fail("model in perf_test.csv did not run successfully.") if not success: pytest.fail("local data provider test failed") - if os.path.exists( os.path.join(mirrorPath, "dummy_data_local") ): + if os.path.exists(os.path.join(mirrorPath, "dummy_data_local")): pytest.fail("custom data provider did mirror data locally") diff --git a/tests/test_debugging.py b/tests/test_debugging.py index 3eda2ba7..f20435e8 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -2,6 +2,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + import pytest import os import re @@ -15,75 +16,188 @@ class TestDebuggingFunctionality: """""" - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) def test_keepAlive_keeps_docker_alive(self, global_data, clean_test_temp_files): - """ - keep-alive command-line argument keeps the docker container alive """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --keep-alive") - output = global_data['console'].sh("docker ps -aqf 'name=container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + "'") - - if not output: + keep-alive command-line argument keeps the docker container alive + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --keep-alive" + ) + output = global_data["console"].sh( + "docker ps -aqf 'name=container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + "'" + ) + + if not output: pytest.fail("docker container not found after keep-alive argument.") - global_data['console'].sh("docker container stop --time=1 container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") ) - global_data['console'].sh("docker container rm -f container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") ) - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) - def test_no_keepAlive_does_not_keep_docker_alive(self, global_data, clean_test_temp_files): - """ + global_data["console"].sh( + "docker container stop --time=1 container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + ) + global_data["console"].sh( + "docker container rm -f container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) + def test_no_keepAlive_does_not_keep_docker_alive( + self, global_data, clean_test_temp_files + ): + """ without keep-alive command-line argument, the docker container is not kept alive """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy") - output = global_data['console'].sh("docker ps -aqf 'name=container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + "'") - - if output: - global_data['console'].sh("docker container stop --time=1 container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") ) - global_data['console'].sh("docker container rm -f container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") ) - pytest.fail("docker container found after not specifying keep-alive argument.") - - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy" + ) + output = global_data["console"].sh( + "docker ps -aqf 'name=container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + "'" + ) + + if output: + global_data["console"].sh( + "docker container stop --time=1 container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + ) + global_data["console"].sh( + "docker container rm -f container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + ) + pytest.fail( + "docker container found after not specifying keep-alive argument." + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) def test_keepAlive_preserves_model_dir(self, global_data, clean_test_temp_files): """ keep-alive command-line argument will keep model directory after run """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --keep-alive") - - global_data['console'].sh("docker container stop --time=1 container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") ) - global_data['console'].sh("docker container rm -f container_dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") ) - if not os.path.exists( os.path.join(BASE_DIR, "run_directory")): + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --keep-alive" + ) + + global_data["console"].sh( + "docker container stop --time=1 container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + ) + global_data["console"].sh( + "docker container rm -f container_dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + ) + if not os.path.exists(os.path.join(BASE_DIR, "run_directory")): pytest.fail("model directory not left over after keep-alive argument.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) def test_keepModelDir_keeps_model_dir(self, global_data, clean_test_temp_files): """ keep-model-dir command-line argument keeps model directory after run """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --keep-model-dir") - - if not os.path.exists( os.path.join(BASE_DIR, "run_directory")): + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --keep-model-dir" + ) + + if not os.path.exists(os.path.join(BASE_DIR, "run_directory")): pytest.fail("model directory not left over after keep-model-dir argument.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) - def test_no_keepModelDir_does_not_keep_model_dir(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) + def test_no_keepModelDir_does_not_keep_model_dir( + self, global_data, clean_test_temp_files + ): """ keep-model-dir command-line argument keeps model directory after run """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy") - - if os.path.exists( os.path.join(BASE_DIR, "run_directory")): - pytest.fail("model directory left over after not specifying keep-model-dir (or keep-alive) argument.") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'run_directory']], indirect=True) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy" + ) + + if os.path.exists(os.path.join(BASE_DIR, "run_directory")): + pytest.fail( + "model directory left over after not specifying keep-model-dir (or keep-alive) argument." + ) + + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "run_directory"]], + indirect=True, + ) def test_skipModelRun_does_not_run_model(self, global_data, clean_test_temp_files): """ - skip-model-run command-line argument does not run model + skip-model-run command-line argument does not run model """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --skip-model-run") - - regexp = re.compile(r'performance: [0-9]* samples_per_second') - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --skip-model-run" + ) + + regexp = re.compile(r"performance: [0-9]* samples_per_second") + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: diff --git a/tests/test_discover.py b/tests/test_discover.py index d0643985..617a506e 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -27,7 +27,15 @@ def test_static(self, global_data, clean_test_temp_files): """ test a tag from a models.json file """ - global_data["console"].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy2/model2 ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy2/model2 " + ) success = False with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: @@ -45,7 +53,15 @@ def test_dynamic(self, global_data, clean_test_temp_files): """ test a tag from a get_models_json.py file """ - global_data["console"].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy3/model4 ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy3/model4 " + ) success = False with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: @@ -63,13 +79,25 @@ def test_additional_args(self, global_data, clean_test_temp_files): """ passes additional args specified in the command line to the model """ - global_data["console"].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy2/model2:batch-size=32 ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy2/model2:batch-size=32 " + ) success = False with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row["model"] == "dummy2/model2" and row["status"] == "SUCCESS" and "--batch-size 32" in row["args"]: + if ( + row["model"] == "dummy2/model2" + and row["status"] == "SUCCESS" + and "--batch-size 32" in row["args"] + ): success = True if not success: pytest.fail("dummy2/model2:batch-size=32 did not run successfully.") @@ -81,7 +109,15 @@ def test_multiple(self, global_data, clean_test_temp_files): """ test multiple tags from top-level models.json, models.json in a script subdir, and get_models_json.py """ - global_data["console"].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_test_group_1 dummy_test_group_2 dummy_test_group_3 ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_test_group_1 dummy_test_group_2 dummy_test_group_3 " + ) success = False with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: @@ -103,4 +139,4 @@ def test_multiple(self, global_data, clean_test_temp_files): ]: success = True if not success: - pytest.fail("multiple tags did not run successfully.") \ No newline at end of file + pytest.fail("multiple tags did not run successfully.") diff --git a/tests/test_distributed_orchestrator.py b/tests/test_distributed_orchestrator.py new file mode 100644 index 00000000..acb2e687 --- /dev/null +++ b/tests/test_distributed_orchestrator.py @@ -0,0 +1,316 @@ +"""Test the distributed orchestrator module. + +This module tests the distributed orchestrator functionality. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +# built-in modules +import os +import json +import tempfile +import unittest.mock +from unittest.mock import patch, MagicMock, mock_open + +# third-party modules +import pytest + +# project modules +from madengine.tools.distributed_orchestrator import DistributedOrchestrator +from madengine.core.context import Context +from madengine.core.console import Console +from .fixtures.utils import BASE_DIR, MODEL_DIR + + +class TestDistributedOrchestrator: + """Test the distributed orchestrator module.""" + + @patch("madengine.tools.distributed_orchestrator.Context") + def test_orchestrator_initialization(self, mock_context): + """Test orchestrator initialization with minimal args.""" + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.data_config_file_name = "data.json" + mock_args.force_mirror_local = False + mock_args.live_output = True + + # Mock context instance + mock_context_instance = MagicMock() + mock_context.return_value = mock_context_instance + + with patch("os.path.exists", return_value=False): + orchestrator = DistributedOrchestrator(mock_args) + + assert orchestrator.args == mock_args + assert isinstance(orchestrator.console, Console) + assert orchestrator.context == mock_context_instance + assert orchestrator.data is None + assert orchestrator.credentials is None + + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"registry": "test", "token": "abc123"}', + ) + @patch("os.path.exists") + @patch("madengine.tools.distributed_orchestrator.Context") + def test_orchestrator_with_credentials(self, mock_context, mock_exists, mock_file): + """Test orchestrator initialization with credentials.""" + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.data_config_file_name = "data.json" + mock_args.force_mirror_local = False + mock_args.live_output = True + + # Mock context instance + mock_context_instance = MagicMock() + mock_context.return_value = mock_context_instance + + # Mock credential.json exists + def exists_side_effect(path): + return path == "credential.json" + + mock_exists.side_effect = exists_side_effect + + orchestrator = DistributedOrchestrator(mock_args) + + assert orchestrator.credentials == {"registry": "test", "token": "abc123"} + + @patch("madengine.tools.distributed_orchestrator.DiscoverModels") + @patch("madengine.tools.distributed_orchestrator.DockerBuilder") + @patch("madengine.tools.distributed_orchestrator.Context") + def test_build_phase( + self, mock_context_class, mock_docker_builder, mock_discover_models + ): + """Test the build phase functionality.""" + # Setup mocks + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.data_config_file_name = "data.json" + mock_args.force_mirror_local = False + mock_args.live_output = True + + # Mock context + mock_context = MagicMock() + mock_context_class.return_value = mock_context + + # Mock discover models + mock_discover_instance = MagicMock() + mock_discover_models.return_value = mock_discover_instance + mock_discover_instance.run.return_value = [ + {"name": "model1", "dockerfile": "Dockerfile1"}, + {"name": "model2", "dockerfile": "Dockerfile2"}, + ] + + # Mock docker builder + mock_builder_instance = MagicMock() + mock_docker_builder.return_value = mock_builder_instance + mock_builder_instance.build_all_models.return_value = { + "successful_builds": ["model1", "model2"], + "failed_builds": [], + "total_build_time": 120.5, + } + + with patch("os.path.exists", return_value=False): + orchestrator = DistributedOrchestrator(mock_args) + + with patch.object(orchestrator, "_copy_scripts"): + result = orchestrator.build_phase( + registry="localhost:5000", + clean_cache=True, + manifest_output="test_manifest.json", + ) + + # Verify the flow + mock_discover_models.assert_called_once_with(args=mock_args) + mock_discover_instance.run.assert_called_once() + mock_docker_builder.assert_called_once() + mock_builder_instance.build_all_models.assert_called_once() + mock_builder_instance.export_build_manifest.assert_called_once_with( + "test_manifest.json", "localhost:5000", unittest.mock.ANY + ) + + assert result["successful_builds"] == ["model1", "model2"] + assert result["failed_builds"] == [] + + @patch("madengine.tools.distributed_orchestrator.ContainerRunner") + @patch("madengine.tools.distributed_orchestrator.DiscoverModels") + @patch("madengine.tools.distributed_orchestrator.Context") + def test_run_phase(self, mock_context, mock_discover_models, mock_container_runner): + """Test the run phase functionality.""" + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.data_config_file_name = "data.json" + mock_args.force_mirror_local = False + mock_args.live_output = True + + # Mock context instance + mock_context_instance = MagicMock() + mock_context.return_value = mock_context_instance + + # Mock discover models + mock_discover_instance = MagicMock() + mock_discover_models.return_value = mock_discover_instance + mock_discover_instance.run.return_value = [ + { + "name": "dummy", + "dockerfile": "docker/dummy", + "scripts": "scripts/dummy/run.sh", + } + ] + + # Mock container runner + mock_runner_instance = MagicMock() + mock_container_runner.return_value = mock_runner_instance + mock_runner_instance.load_build_manifest.return_value = { + "images": {"dummy": "localhost:5000/dummy:latest"} + } + mock_runner_instance.run_container.return_value = { + "status": "completed", + "test_duration": 120.5, + "model": "dummy", + "exit_code": 0, + } + mock_runner_instance.run_all_containers.return_value = { + "successful_runs": ["dummy"], + "failed_runs": [], + } + + with patch("os.path.exists", return_value=False): + orchestrator = DistributedOrchestrator(mock_args) + + # Mock manifest file existence and content + manifest_content = '{"built_images": {"dummy": {"image": "localhost:5000/dummy:latest", "build_time": 120}}}' + + with patch.object(orchestrator, "_copy_scripts"), patch( + "os.path.exists" + ) as mock_exists, patch("builtins.open", mock_open(read_data=manifest_content)): + + # Mock manifest file exists but credential.json doesn't + def exists_side_effect(path): + return path == "manifest.json" + + mock_exists.side_effect = exists_side_effect + + result = orchestrator.run_phase( + manifest_file="manifest.json", + registry="localhost:5000", + timeout=1800, + keep_alive=False, + ) + + # Verify the flow + mock_discover_models.assert_called_once_with(args=mock_args) + mock_discover_instance.run.assert_called_once() + mock_container_runner.assert_called_once() + + assert "successful_runs" in result + assert "failed_runs" in result + + @patch("madengine.tools.distributed_orchestrator.DiscoverModels") + @patch("madengine.tools.distributed_orchestrator.DockerBuilder") + @patch("madengine.tools.distributed_orchestrator.ContainerRunner") + @patch("madengine.tools.distributed_orchestrator.Context") + def test_full_workflow( + self, + mock_context_class, + mock_container_runner, + mock_docker_builder, + mock_discover_models, + ): + """Test the full workflow functionality.""" + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.data_config_file_name = "data.json" + mock_args.force_mirror_local = False + mock_args.live_output = True + + # Mock context + mock_context = MagicMock() + mock_context_class.return_value = mock_context + + # Mock discover models + mock_discover_instance = MagicMock() + mock_discover_models.return_value = mock_discover_instance + mock_discover_instance.run.return_value = [{"name": "model1"}] + + # Mock docker builder + mock_builder_instance = MagicMock() + mock_docker_builder.return_value = mock_builder_instance + mock_builder_instance.build_all_models.return_value = { + "successful_builds": ["model1"], + "failed_builds": [], + "total_build_time": 120.5, + } + mock_builder_instance.get_build_manifest.return_value = { + "images": {"model1": "ci-model1:latest"} + } + + # Mock container runner + mock_runner_instance = MagicMock() + mock_container_runner.return_value = mock_runner_instance + mock_runner_instance.run_container.return_value = { + "status": "SUCCESS", + "test_duration": 120.5, + "model": "model1", + "exit_code": 0, + } + mock_runner_instance.run_all_containers.return_value = { + "successful_runs": ["model1"], + "failed_runs": [], + } + + with patch("os.path.exists", return_value=False): + orchestrator = DistributedOrchestrator(mock_args) + + # Mock manifest file content for run phase + manifest_content = """{"built_images": {"model1": {"docker_image": "ci-model1", "build_time": 120}}, "built_models": {"model1": {"name": "model1", "scripts": "scripts/model1/run.sh"}}}""" + + with patch.object(orchestrator, "_copy_scripts"), patch( + "os.path.exists" + ) as mock_exists, patch("builtins.open", mock_open(read_data=manifest_content)): + + # Mock build_manifest.json exists for run phase + def exists_side_effect(path): + return path == "build_manifest.json" + + mock_exists.side_effect = exists_side_effect + + result = orchestrator.full_workflow( + registry="localhost:5000", + clean_cache=True, + timeout=3600, + keep_alive=False, + ) + + # Verify the complete flow + assert result["overall_success"] is True + assert "build_phase" in result + assert "run_phase" in result + + @patch("madengine.tools.distributed_orchestrator.Context") + def test_copy_scripts_method(self, mock_context): + """Test the _copy_scripts method.""" + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.data_config_file_name = "data.json" + mock_args.force_mirror_local = False + mock_args.live_output = True + + # Mock context instance + mock_context_instance = MagicMock() + mock_context.return_value = mock_context_instance + + with patch("os.path.exists", return_value=False): + orchestrator = DistributedOrchestrator(mock_args) + + with patch.object(orchestrator.console, "sh") as mock_sh: + with patch("os.path.exists", return_value=True): + orchestrator._copy_scripts() + mock_sh.assert_called_once() diff --git a/tests/test_docker_builder.py b/tests/test_docker_builder.py new file mode 100644 index 00000000..8b1338eb --- /dev/null +++ b/tests/test_docker_builder.py @@ -0,0 +1,821 @@ +"""Test the Docker builder module. + +This module tests the Docker image building functionality for distributed execution. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +# built-in modules +import os +import json +import tempfile +import unittest.mock +from unittest.mock import patch, MagicMock, mock_open + +# third-party modules +import pytest + +# project modules +from madengine.tools.docker_builder import DockerBuilder +from madengine.core.context import Context +from madengine.core.console import Console +from .fixtures.utils import BASE_DIR, MODEL_DIR + + +class TestDockerBuilder: + """Test the Docker builder module.""" + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_docker_builder_initialization( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test DockerBuilder initialization.""" + context = Context() + console = Console() + + builder = DockerBuilder(context, console) + + assert builder.context == context + assert builder.console == console + assert builder.built_images == {} + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_docker_builder_initialization_without_console( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test DockerBuilder initialization without console.""" + context = Context() + + builder = DockerBuilder(context) + + assert builder.context == context + assert isinstance(builder.console, Console) + assert builder.built_images == {} + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_context_path_with_dockercontext( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_context_path when dockercontext is specified.""" + context = Context() + builder = DockerBuilder(context) + + info = {"dockercontext": "/custom/context"} + result = builder.get_context_path(info) + + assert result == "/custom/context" + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_context_path_without_dockercontext( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_context_path when dockercontext is not specified.""" + context = Context() + builder = DockerBuilder(context) + + info = {} + result = builder.get_context_path(info) + + assert result == "./docker" + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_context_path_with_empty_dockercontext( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_context_path when dockercontext is empty.""" + context = Context() + builder = DockerBuilder(context) + + info = {"dockercontext": ""} + result = builder.get_context_path(info) + + assert result == "./docker" + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_build_arg_no_args( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_build_arg with no additional runtime build arguments.""" + context = Context() + builder = DockerBuilder(context) + + result = builder.get_build_arg() + + # Context automatically includes system GPU architecture + assert "MAD_SYSTEM_GPU_ARCHITECTURE" in result + assert "--build-arg" in result + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_build_arg_with_context_args( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_build_arg with context build arguments.""" + context = Context() + context.ctx = {"docker_build_arg": {"ARG1": "value1", "ARG2": "value2"}} + builder = DockerBuilder(context) + + result = builder.get_build_arg() + + assert "--build-arg ARG1='value1'" in result + assert "--build-arg ARG2='value2'" in result + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_build_arg_with_run_args( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_build_arg with runtime build arguments.""" + context = Context() + builder = DockerBuilder(context) + + run_build_arg = {"RUNTIME_ARG": "runtime_value"} + result = builder.get_build_arg(run_build_arg) + + assert "--build-arg RUNTIME_ARG='runtime_value'" in result + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_get_build_arg_with_both_args( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test get_build_arg with both context and runtime arguments.""" + context = Context() + context.ctx = {"docker_build_arg": {"CONTEXT_ARG": "context_value"}} + builder = DockerBuilder(context) + + run_build_arg = {"RUNTIME_ARG": "runtime_value"} + result = builder.get_build_arg(run_build_arg) + + assert "--build-arg CONTEXT_ARG='context_value'" in result + assert "--build-arg RUNTIME_ARG='runtime_value'" in result + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_build_image_success( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test successful Docker image build.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + # Mock the console.sh calls + mock_sh.return_value = "Build successful" + + model_info = {"name": "test/model", "dockercontext": "./docker"} + dockerfile = "./docker/Dockerfile" + + with patch.object(builder, "get_build_arg", return_value=""): + result = builder.build_image(model_info, dockerfile) + + # Verify the image name generation + expected_image_name = "ci-test_model_Dockerfile" + assert result["docker_image"] == expected_image_name + assert "build_duration" in result + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_build_image_with_registry_push( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test Docker image build with registry push.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + # Mock successful build and push + mock_sh.return_value = "Success" + + model_info = {"name": "test_model"} + dockerfile = "./docker/Dockerfile" + registry = "localhost:5000" + + with patch.object(builder, "get_build_arg", return_value=""): + with patch.object(builder, "get_context_path", return_value="./docker"): + with patch.object( + builder, "push_image", return_value="localhost:5000/ci-test_model" + ) as mock_push: + result = builder.build_image(model_info, dockerfile) + registry_image = builder.push_image( + result["docker_image"], registry + ) + + # Should have called docker build + build_calls = [ + call for call in mock_sh.call_args_list if "docker build" in str(call) + ] + assert len(build_calls) >= 1 + assert registry_image == "localhost:5000/ci-test_model" + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_build_image_failure( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test Docker image build failure.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + # Mock build failure + mock_sh.side_effect = RuntimeError("Build failed") + + model_info = {"name": "test_model"} + dockerfile = "./docker/Dockerfile" + + with patch.object(builder, "get_build_arg", return_value=""): + with patch.object(builder, "get_context_path", return_value="./docker"): + # Test that the exception is raised + with pytest.raises(RuntimeError, match="Build failed"): + builder.build_image(model_info, dockerfile) + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_build_all_models( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test building all models.""" + context = Context() + builder = DockerBuilder(context) + + models = [ + {"name": "model1", "dockerfile": "./docker/Dockerfile1"}, + {"name": "model2", "dockerfile": "./docker/Dockerfile2"}, + ] + + # Mock console.sh calls for dockerfile listing + def mock_sh_side_effect(command, **kwargs): + if "ls ./docker/Dockerfile1.*" in command: + return "./docker/Dockerfile1" + elif "ls ./docker/Dockerfile2.*" in command: + return "./docker/Dockerfile2" + elif "head -n5" in command: + return "# CONTEXT AMD" + else: + return "success" + + # Mock context filter to return only the specific dockerfile for each model + def mock_filter_side_effect(dockerfiles): + # Return only the dockerfile that was requested for each model + if "./docker/Dockerfile1" in dockerfiles: + return {"./docker/Dockerfile1": "AMD"} + elif "./docker/Dockerfile2" in dockerfiles: + return {"./docker/Dockerfile2": "AMD"} + return dockerfiles + + # Mock successful builds + with patch.object(builder.console, "sh", side_effect=mock_sh_side_effect): + with patch.object(context, "filter", side_effect=mock_filter_side_effect): + with patch.object(builder, "build_image") as mock_build: + mock_build.return_value = { + "docker_image": "test_image", + "build_duration": 30.0, + } + + result = builder.build_all_models(models) + + assert len(result["successful_builds"]) == 2 + assert len(result["failed_builds"]) == 0 + assert mock_build.call_count == 2 + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_build_all_models_with_failures( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test building all models with some failures.""" + context = Context() + builder = DockerBuilder(context) + + models = [ + {"name": "model1", "dockerfile": "./docker/Dockerfile1"}, + {"name": "model2", "dockerfile": "./docker/Dockerfile2"}, + ] + + # Mock console.sh calls for dockerfile listing + def mock_sh_side_effect(command, **kwargs): + if "ls ./docker/Dockerfile1.*" in command: + return "./docker/Dockerfile1" + elif "ls ./docker/Dockerfile2.*" in command: + return "./docker/Dockerfile2" + elif "head -n5" in command: + return "# CONTEXT AMD" + else: + return "success" + + # Mock context filter to return only the specific dockerfile for each model + def mock_filter_side_effect(dockerfiles): + # Return only the dockerfile that was requested for each model + if "./docker/Dockerfile1" in dockerfiles: + return {"./docker/Dockerfile1": "AMD"} + elif "./docker/Dockerfile2" in dockerfiles: + return {"./docker/Dockerfile2": "AMD"} + return dockerfiles + + # Mock one success, one failure + def mock_build_side_effect(model_info, dockerfile, *args, **kwargs): + if model_info["name"] == "model1" and "Dockerfile1" in dockerfile: + return {"docker_image": "model1_image", "build_duration": 30.0} + else: + raise RuntimeError("Build failed") + + with patch.object(builder.console, "sh", side_effect=mock_sh_side_effect): + with patch.object(context, "filter", side_effect=mock_filter_side_effect): + with patch.object( + builder, "build_image", side_effect=mock_build_side_effect + ): + result = builder.build_all_models(models) + + assert len(result["successful_builds"]) == 1 + assert len(result["failed_builds"]) == 1 # 1 failure: model2/Dockerfile2 + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_export_build_manifest( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test exporting build manifest.""" + context = Context() + builder = DockerBuilder(context) + + # Set up some built images (key should match real DockerBuilder output) + builder.built_images = { + "ci-model1": {"docker_image": "ci-model1", "dockerfile": "./docker/Dockerfile"} + } + + with patch("builtins.open", mock_open()) as mock_file: + with patch("json.dump") as mock_json_dump: + builder.export_build_manifest("manifest.json") + + # Verify file was opened and JSON was written + mock_file.assert_called_once_with("manifest.json", "w") + mock_json_dump.assert_called_once() + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_build_image_with_credentials( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test Docker image build with credentials.""" + context = Context() + builder = DockerBuilder(context) + + mock_sh.return_value = "Success" + + model_info = {"name": "test_model", "cred": "testcred"} + dockerfile = "./docker/Dockerfile" + credentials = {"testcred": {"username": "testuser", "password": "testpass"}} + + with patch.object(builder, "get_build_arg") as mock_get_build_arg: + with patch.object(builder, "get_context_path", return_value="./docker"): + result = builder.build_image( + model_info, dockerfile, credentials=credentials + ) + + # Verify credentials were passed to build args + mock_get_build_arg.assert_called_once() + call_args = mock_get_build_arg.call_args[0][0] + assert "testcred_USERNAME" in call_args + assert "testcred_PASSWORD" in call_args + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + def test_clean_cache_option( + self, mock_render, mock_docker_gpu, mock_hip, mock_arch, mock_ngpus, mock_vendor + ): + """Test clean cache option in build.""" + context = Context() + builder = DockerBuilder(context) + + model_info = {"name": "test_model"} + dockerfile = "./docker/Dockerfile" + + with patch.object(builder.console, "sh") as mock_sh: + with patch.object(builder, "get_build_arg", return_value=""): + with patch.object(builder, "get_context_path", return_value="./docker"): + builder.build_image(model_info, dockerfile, clean_cache=True) + + # Verify --no-cache was used + build_calls = [ + call for call in mock_sh.call_args_list if "docker build" in str(call) + ] + assert any("--no-cache" in str(call) for call in build_calls) + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_push_image_dockerhub_with_repository( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test pushing image to DockerHub with repository specified in credentials.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + docker_image = "ci-dummy_dummy.ubuntu.amd" + registry = "dockerhub" + credentials = { + "dockerhub": { + "repository": "your-repository", + "username": "your-dockerhub-username", + "password": "your-dockerhub-password-or-token", + } + } + + # Mock successful operations + mock_sh.return_value = "Success" + + result = builder.push_image(docker_image, registry, credentials) + + # Verify the correct tag and push commands were called + expected_tag = "your-repository:ci-dummy_dummy.ubuntu.amd" + tag_calls = [ + call for call in mock_sh.call_args_list if "docker tag" in str(call) + ] + push_calls = [ + call for call in mock_sh.call_args_list if "docker push" in str(call) + ] + + assert len(tag_calls) == 1 + assert expected_tag in str(tag_calls[0]) + assert len(push_calls) == 1 + assert expected_tag in str(push_calls[0]) + assert result == expected_tag + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_push_image_local_registry_with_repository( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test pushing image to local registry with repository specified in credentials.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + docker_image = "ci-dummy_dummy.ubuntu.amd" + registry = "localhost:5000" + credentials = { + "localhost:5000": { + "repository": "your-repository", + "username": "your-local-registry-username", + "password": "your-local-registry-password", + } + } + + # Mock successful operations + mock_sh.return_value = "Success" + + result = builder.push_image(docker_image, registry, credentials) + + # Verify the correct tag and push commands were called + expected_tag = "localhost:5000/your-repository:ci-dummy_dummy.ubuntu.amd" + tag_calls = [ + call for call in mock_sh.call_args_list if "docker tag" in str(call) + ] + push_calls = [ + call for call in mock_sh.call_args_list if "docker push" in str(call) + ] + + assert len(tag_calls) == 1 + assert expected_tag in str(tag_calls[0]) + assert len(push_calls) == 1 + assert expected_tag in str(push_calls[0]) + assert result == expected_tag + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_push_image_dockerhub_no_repository( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test pushing image to DockerHub without repository specified in credentials.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + docker_image = "ci-dummy_dummy.ubuntu.amd" + registry = "dockerhub" + credentials = { + "dockerhub": { + "username": "your-dockerhub-username", + "password": "your-dockerhub-password-or-token", + } + } + + # Mock successful operations + mock_sh.return_value = "Success" + + result = builder.push_image(docker_image, registry, credentials) + + # DockerHub without repository should just use the image name (no tagging needed) + push_calls = [ + call for call in mock_sh.call_args_list if "docker push" in str(call) + ] + assert len(push_calls) == 1 + assert docker_image in str(push_calls[0]) + assert result == docker_image + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_push_image_local_registry_no_repository( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test pushing image to local registry without repository specified in credentials.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + docker_image = "ci-dummy_dummy.ubuntu.amd" + registry = "localhost:5000" + credentials = { + "localhost:5000": { + "username": "your-local-registry-username", + "password": "your-local-registry-password", + } + } + + # Mock successful operations + mock_sh.return_value = "Success" + + result = builder.push_image(docker_image, registry, credentials) + + # Should fallback to registry/imagename format + expected_tag = "localhost:5000/ci-dummy_dummy.ubuntu.amd" + tag_calls = [ + call for call in mock_sh.call_args_list if "docker tag" in str(call) + ] + push_calls = [ + call for call in mock_sh.call_args_list if "docker push" in str(call) + ] + + assert len(tag_calls) == 1 + assert expected_tag in str(tag_calls[0]) + assert len(push_calls) == 1 + assert expected_tag in str(push_calls[0]) + assert result == expected_tag + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_push_image_no_registry( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test pushing image with no registry specified.""" + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + docker_image = "ci-dummy_dummy.ubuntu.amd" + + result = builder.push_image(docker_image) + + # Should not call docker tag or push commands and return the original image name + docker_calls = [ + call + for call in mock_sh.call_args_list + if "docker tag" in str(call) or "docker push" in str(call) + ] + assert len(docker_calls) == 0 + assert result == docker_image + + @patch.object(Context, "get_gpu_vendor", return_value="AMD") + @patch.object(Context, "get_system_ngpus", return_value=1) + @patch.object(Context, "get_system_gpu_architecture", return_value="gfx908") + @patch.object(Context, "get_system_hip_version", return_value="5.4") + @patch.object(Context, "get_docker_gpus", return_value="all") + @patch.object(Context, "get_gpu_renderD_nodes", return_value=["renderD128"]) + @patch.object(Console, "sh") + def test_build_manifest_with_tagged_image( + self, + mock_sh, + mock_render, + mock_docker_gpu, + mock_hip, + mock_arch, + mock_ngpus, + mock_vendor, + ): + """Test that build manifest includes registry_image when pushing to registry.""" + import tempfile + import os + + # Mock successful operations BEFORE creating Context + # to avoid MagicMock objects being stored during initialization + mock_sh.return_value = "Success" + + context = Context() + console = Console() + builder = DockerBuilder(context, console) + + model_info = {"name": "test_model"} + dockerfile = "./docker/Dockerfile" + registry = "localhost:5000" + credentials = { + "localhost:5000": { + "repository": "test-repository", + "username": "test-user", + "password": "test-password", + } + } + + with patch.object(builder, "get_build_arg", return_value=""): + with patch.object(builder, "get_context_path", return_value="./docker"): + # Build image + build_info = builder.build_image(model_info, dockerfile, credentials) + local_image = build_info["docker_image"] + + # Push to registry + registry_image = builder.push_image(local_image, registry, credentials) + + # Update built_images with tagged image (simulating what build_all_models does) + if local_image in builder.built_images: + builder.built_images[local_image]["registry_image"] = registry_image + + # Export manifest to temporary file + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: + builder.export_build_manifest(tmp_file.name, registry) + + # Read and verify the manifest + with open(tmp_file.name, "r") as f: + import json + + manifest = json.load(f) + + # Clean up + os.unlink(tmp_file.name) + + # Verify the manifest contains the tagged image + assert local_image in manifest["built_images"] + assert "registry_image" in manifest["built_images"][local_image] + assert manifest["built_images"][local_image]["registry_image"] == registry_image + assert manifest["built_images"][local_image]["registry"] == registry + + # Verify the tagged image format is correct + expected_tagged_image = f"localhost:5000/test-repository:{local_image}" + assert registry_image == expected_tagged_image diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py new file mode 100644 index 00000000..1b905657 --- /dev/null +++ b/tests/test_error_handling.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +""" +Unit tests for MADEngine unified error handling system. + +Tests the core error handling functionality including error types, +context management, Rich console integration, and error propagation. +""" + +import pytest +import json +import io +from unittest.mock import Mock, patch, MagicMock +from rich.console import Console +from rich.text import Text + +# Add src to path for imports +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from madengine.core.errors import ( + ErrorCategory, + ErrorContext, + MADEngineError, + ValidationError, + ConnectionError, + AuthenticationError, + RuntimeError, + BuildError, + DiscoveryError, + OrchestrationError, + RunnerError, + ConfigurationError, + TimeoutError, + ErrorHandler, + set_error_handler, + get_error_handler, + handle_error, + create_error_context +) + + +class TestErrorCategories: + """Test error category enumeration.""" + + def test_error_categories_exist(self): + """Test that all required error categories are defined.""" + expected_categories = [ + "validation", "connection", "authentication", "runtime", + "build", "discovery", "orchestration", "runner", + "configuration", "timeout" + ] + + for category in expected_categories: + assert hasattr(ErrorCategory, category.upper()) + assert ErrorCategory[category.upper()].value == category + + +class TestErrorContext: + """Test error context data structure.""" + + def test_error_context_creation(self): + """Test basic error context creation.""" + context = ErrorContext( + operation="test_operation", + phase="test_phase", + component="test_component" + ) + + assert context.operation == "test_operation" + assert context.phase == "test_phase" + assert context.component == "test_component" + assert context.model_name is None + assert context.node_id is None + assert context.file_path is None + assert context.additional_info is None + + def test_error_context_full(self): + """Test error context with all fields.""" + additional_info = {"key": "value", "number": 42} + context = ErrorContext( + operation="complex_operation", + phase="execution", + component="TestComponent", + model_name="test_model", + node_id="node-001", + file_path="/path/to/file.json", + additional_info=additional_info + ) + + assert context.operation == "complex_operation" + assert context.phase == "execution" + assert context.component == "TestComponent" + assert context.model_name == "test_model" + assert context.node_id == "node-001" + assert context.file_path == "/path/to/file.json" + assert context.additional_info == additional_info + + def test_create_error_context_function(self): + """Test create_error_context convenience function.""" + context = create_error_context( + operation="test_op", + phase="test_phase", + model_name="test_model" + ) + + assert isinstance(context, ErrorContext) + assert context.operation == "test_op" + assert context.phase == "test_phase" + assert context.model_name == "test_model" + + +class TestMADEngineErrorHierarchy: + """Test MADEngine error class hierarchy.""" + + def test_base_madengine_error(self): + """Test base MADEngine error functionality.""" + context = ErrorContext(operation="test") + error = MADEngineError( + message="Test error", + category=ErrorCategory.RUNTIME, + context=context, + recoverable=True, + suggestions=["Try again", "Check logs"] + ) + + assert str(error) == "Test error" + assert error.message == "Test error" + assert error.category == ErrorCategory.RUNTIME + assert error.context == context + assert error.recoverable is True + assert error.suggestions == ["Try again", "Check logs"] + assert error.cause is None + + def test_validation_error(self): + """Test ValidationError specific functionality.""" + error = ValidationError("Invalid input") + + assert isinstance(error, MADEngineError) + assert error.category == ErrorCategory.VALIDATION + assert error.recoverable is True + assert str(error) == "Invalid input" + + def test_connection_error(self): + """Test ConnectionError specific functionality.""" + context = create_error_context(operation="connect", node_id="node-1") + error = ConnectionError("Connection failed", context=context) + + assert isinstance(error, MADEngineError) + assert error.category == ErrorCategory.CONNECTION + assert error.recoverable is True + assert error.context.node_id == "node-1" + + def test_build_error(self): + """Test BuildError specific functionality.""" + error = BuildError("Build failed") + + assert isinstance(error, MADEngineError) + assert error.category == ErrorCategory.BUILD + assert error.recoverable is False + + def test_runner_error(self): + """Test RunnerError specific functionality.""" + error = RunnerError("Runner execution failed") + + assert isinstance(error, MADEngineError) + assert error.category == ErrorCategory.RUNNER + assert error.recoverable is True + + def test_error_with_cause(self): + """Test error with underlying cause.""" + original_error = ValueError("Original error") + mad_error = RuntimeError("Runtime failure", cause=original_error) + + assert mad_error.cause == original_error + assert str(mad_error) == "Runtime failure" + + +class TestErrorHandler: + """Test ErrorHandler functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_console = Mock(spec=Console) + self.error_handler = ErrorHandler(console=self.mock_console, verbose=False) + + def test_error_handler_creation(self): + """Test ErrorHandler initialization.""" + assert self.error_handler.console == self.mock_console + assert self.error_handler.verbose is False + assert self.error_handler.logger is not None + + def test_handle_madengine_error(self): + """Test handling of MADEngine structured errors.""" + context = create_error_context( + operation="test_operation", + component="TestComponent", + model_name="test_model" + ) + error = ValidationError( + "Test validation error", + context=context, + suggestions=["Check input", "Verify format"] + ) + + self.error_handler.handle_error(error) + + # Verify console.print was called for the error panel + self.mock_console.print.assert_called() + call_args = self.mock_console.print.call_args[0] + + # Check that a Rich Panel was created + assert len(call_args) > 0 + panel = call_args[0] + assert hasattr(panel, 'title') + assert "Validation Error" in panel.title + + def test_handle_generic_error(self): + """Test handling of generic Python exceptions.""" + error = ValueError("Generic Python error") + context = create_error_context(operation="test_op") + + self.error_handler.handle_error(error, context=context) + + # Verify console.print was called + self.mock_console.print.assert_called() + call_args = self.mock_console.print.call_args[0] + + # Check that a Rich Panel was created + assert len(call_args) > 0 + panel = call_args[0] + assert hasattr(panel, 'title') + assert "ValueError" in panel.title + + def test_handle_error_verbose_mode(self): + """Test error handling in verbose mode.""" + verbose_handler = ErrorHandler(console=self.mock_console, verbose=True) + # Create error with a cause to trigger print_exception + original_error = ValueError("Original error") + error = RuntimeError("Test runtime error", cause=original_error) + + verbose_handler.handle_error(error, show_traceback=True) + + # Verify both print and print_exception were called + assert self.mock_console.print.call_count >= 2 + self.mock_console.print_exception.assert_called() + + def test_error_categorization_display(self): + """Test that different error categories display with correct styling.""" + test_cases = [ + (ValidationError("Validation failed"), "⚠️", "Validation Error"), + (ConnectionError("Connection failed"), "🔌", "Connection Error"), + (BuildError("Build failed"), "🔨", "Build Error"), + (RunnerError("Runner failed"), "🚀", "Runner Error"), + ] + + for error, expected_emoji, expected_title in test_cases: + self.mock_console.reset_mock() + self.error_handler.handle_error(error) + + # Verify console.print was called + self.mock_console.print.assert_called() + call_args = self.mock_console.print.call_args[0] + panel = call_args[0] + + assert expected_emoji in panel.title + assert expected_title in panel.title + + +class TestGlobalErrorHandler: + """Test global error handler functionality.""" + + def test_set_and_get_error_handler(self): + """Test setting and getting global error handler.""" + mock_console = Mock(spec=Console) + handler = ErrorHandler(console=mock_console) + + set_error_handler(handler) + retrieved_handler = get_error_handler() + + assert retrieved_handler == handler + + def test_handle_error_function(self): + """Test global handle_error function.""" + mock_console = Mock(spec=Console) + handler = ErrorHandler(console=mock_console) + set_error_handler(handler) + + error = ValidationError("Test error") + context = create_error_context(operation="test") + + handle_error(error, context=context) + + # Verify the handler was used + mock_console.print.assert_called() + + def test_handle_error_no_global_handler(self): + """Test handle_error function when no global handler is set.""" + # Clear global handler + set_error_handler(None) + + with patch('madengine.core.errors.logging') as mock_logging: + error = ValueError("Test error") + handle_error(error) + + # Should fallback to logging + mock_logging.error.assert_called_once() + + +class TestErrorContextPropagation: + """Test error context propagation through call stack.""" + + def test_context_preservation_through_hierarchy(self): + """Test that context is preserved when creating derived errors.""" + original_context = create_error_context( + operation="original_op", + component="OriginalComponent", + model_name="test_model" + ) + + # Create a base error with context + base_error = MADEngineError( + "Base error", + ErrorCategory.RUNTIME, + context=original_context + ) + + # Create a derived error that should preserve context + derived_error = ValidationError( + "Derived error", + context=original_context, + cause=base_error + ) + + assert derived_error.context == original_context + assert derived_error.cause == base_error + assert derived_error.context.operation == "original_op" + assert derived_error.context.component == "OriginalComponent" + + def test_context_enrichment(self): + """Test adding additional context information.""" + base_context = create_error_context(operation="base_op") + + # Create enriched context + enriched_context = ErrorContext( + operation=base_context.operation, + phase="enriched_phase", + component="EnrichedComponent", + additional_info={"enriched": True} + ) + + error = RuntimeError("Test error", context=enriched_context) + + assert error.context.operation == "base_op" + assert error.context.phase == "enriched_phase" + assert error.context.component == "EnrichedComponent" + assert error.context.additional_info["enriched"] is True + + +class TestErrorRecoveryAndSuggestions: + """Test error recovery indicators and suggestions.""" + + def test_recoverable_errors(self): + """Test that certain error types are marked as recoverable.""" + recoverable_errors = [ + ValidationError("Validation error"), + ConnectionError("Connection error"), + AuthenticationError("Auth error"), + ConfigurationError("Config error"), + TimeoutError("Timeout error"), + ] + + for error in recoverable_errors: + assert error.recoverable is True, f"{type(error).__name__} should be recoverable" + + def test_non_recoverable_errors(self): + """Test that certain error types are marked as non-recoverable.""" + non_recoverable_errors = [ + RuntimeError("Runtime error"), + BuildError("Build error"), + OrchestrationError("Orchestration error"), + ] + + for error in non_recoverable_errors: + assert error.recoverable is False, f"{type(error).__name__} should not be recoverable" + + def test_suggestions_in_errors(self): + """Test that suggestions are properly included in errors.""" + suggestions = ["Check configuration", "Verify credentials", "Try again"] + error = ValidationError( + "Validation failed", + suggestions=suggestions + ) + + assert error.suggestions == suggestions + + # Test handling displays suggestions + mock_console = Mock(spec=Console) + handler = ErrorHandler(console=mock_console) + handler.handle_error(error) + + # Verify console.print was called and suggestions are in output + mock_console.print.assert_called() + + +class TestErrorIntegration: + """Test error handling integration scenarios.""" + + def test_error_serialization_context(self): + """Test that error context can be serialized for logging.""" + context = create_error_context( + operation="test_operation", + phase="test_phase", + component="TestComponent", + model_name="test_model", + additional_info={"key": "value"} + ) + + error = ValidationError("Test error", context=context) + + # Context should be serializable + context_dict = error.context.__dict__ + json_str = json.dumps(context_dict, default=str) + + assert "test_operation" in json_str + assert "test_phase" in json_str + assert "TestComponent" in json_str + assert "test_model" in json_str + + def test_nested_error_handling(self): + """Test handling of nested exceptions.""" + original_error = ConnectionError("Network timeout") + wrapped_error = RuntimeError("Operation failed", cause=original_error) + final_error = OrchestrationError("Orchestration failed", cause=wrapped_error) + + assert final_error.cause == wrapped_error + assert wrapped_error.cause == original_error + + # Test that the handler can display nested error information + mock_console = Mock(spec=Console) + handler = ErrorHandler(console=mock_console) + handler.handle_error(final_error) + + mock_console.print.assert_called() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_error_system_integration.py b/tests/test_error_system_integration.py new file mode 100644 index 00000000..96d70bb9 --- /dev/null +++ b/tests/test_error_system_integration.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +Integration tests for MADEngine unified error handling system. + +This test file focuses on testing the integration without requiring +optional dependencies like paramiko, ansible-runner, or kubernetes. +""" + +import pytest +import json +from unittest.mock import Mock, patch, MagicMock + +# Add src to path for imports +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from madengine.core.errors import ( + ErrorHandler, + MADEngineError, + ValidationError, + ConfigurationError, + RunnerError, + set_error_handler, + get_error_handler, + create_error_context +) + + +class TestUnifiedErrorSystem: + """Test the unified error handling system integration.""" + + def test_error_system_basic_functionality(self): + """Test basic error system functionality works.""" + # Create error handler + mock_console = Mock() + handler = ErrorHandler(console=mock_console, verbose=False) + + # Create error with context + context = create_error_context( + operation="test_operation", + component="TestComponent", + model_name="test_model" + ) + + error = ValidationError("Test validation error", context=context) + + # Handle the error + handler.handle_error(error) + + # Verify it was handled + mock_console.print.assert_called_once() + + # Verify error structure + assert error.context.operation == "test_operation" + assert error.context.component == "TestComponent" + assert error.recoverable is True + + def test_mad_cli_error_handler_setup(self): + """Test that mad_cli properly sets up error handling.""" + from madengine.mad_cli import setup_logging + + # Clear existing handler + set_error_handler(None) + + # Setup logging + setup_logging(verbose=True) + + # Verify handler was created + handler = get_error_handler() + assert handler is not None + assert isinstance(handler, ErrorHandler) + assert handler.verbose is True + + def test_distributed_orchestrator_error_imports(self): + """Test that distributed_orchestrator can import error handling.""" + try: + from madengine.tools.distributed_orchestrator import ( + handle_error, create_error_context, ConfigurationError + ) + + # Test that we can create and handle errors + context = create_error_context( + operation="test_import", + component="DistributedOrchestrator" + ) + + error = ConfigurationError("Test config error", context=context) + + # This should not raise an exception + assert error.context.operation == "test_import" + assert error.context.component == "DistributedOrchestrator" + + except ImportError as e: + pytest.fail(f"Error handling imports failed: {e}") + + def test_runner_error_base_class(self): + """Test that RunnerError base class works properly.""" + context = create_error_context( + operation="runner_test", + component="TestRunner" + ) + + error = RunnerError("Test runner error", context=context) + + assert isinstance(error, MADEngineError) + assert error.recoverable is True + assert error.context.operation == "runner_test" + assert error.context.component == "TestRunner" + + def test_error_context_serialization(self): + """Test that error contexts can be serialized.""" + context = create_error_context( + operation="serialization_test", + component="TestComponent", + model_name="test_model", + node_id="test_node", + additional_info={"key": "value", "number": 42} + ) + + error = ValidationError("Test error", context=context) + + # Test serialization + context_dict = error.context.__dict__ + json_str = json.dumps(context_dict, default=str) + + # Verify content + assert "serialization_test" in json_str + assert "TestComponent" in json_str + assert "test_model" in json_str + assert "test_node" in json_str + assert "key" in json_str + assert "42" in json_str + + def test_error_hierarchy_consistency(self): + """Test that all error types maintain consistent behavior.""" + from madengine.core.errors import ( + ValidationError, ConnectionError, AuthenticationError, + RuntimeError, BuildError, DiscoveryError, OrchestrationError, + RunnerError, ConfigurationError, TimeoutError + ) + + error_classes = [ + ValidationError, ConnectionError, AuthenticationError, + RuntimeError, BuildError, DiscoveryError, OrchestrationError, + RunnerError, ConfigurationError, TimeoutError + ] + + for error_class in error_classes: + error = error_class("Test error message") + + # All should inherit from MADEngineError + assert isinstance(error, MADEngineError) + + # All should have context (even if default) + assert error.context is not None + + # All should have category + assert error.category is not None + + # All should have recoverable flag + assert isinstance(error.recoverable, bool) + + def test_global_error_handler_workflow(self): + """Test the complete global error handler workflow.""" + from madengine.core.errors import handle_error + + # Create and set global handler + mock_console = Mock() + handler = ErrorHandler(console=mock_console, verbose=False) + set_error_handler(handler) + + # Create error + error = ValidationError( + "Global handler test", + context=create_error_context( + operation="global_test", + component="TestGlobalHandler" + ) + ) + + # Use global handle_error function + handle_error(error) + + # Verify it was handled through the global handler + mock_console.print.assert_called_once() + + def test_error_suggestions_and_recovery(self): + """Test error suggestions and recovery information.""" + suggestions = [ + "Check your configuration file", + "Verify network connectivity", + "Try running with --verbose flag" + ] + + error = ConfigurationError( + "Configuration validation failed", + context=create_error_context( + operation="config_validation", + file_path="/path/to/config.json" + ), + suggestions=suggestions + ) + + assert error.suggestions == suggestions + assert error.recoverable is True + assert error.context.file_path == "/path/to/config.json" + + # Test error display includes suggestions + mock_console = Mock() + handler = ErrorHandler(console=mock_console) + handler.handle_error(error) + + # Should have been called to display the error + mock_console.print.assert_called_once() + + def test_nested_error_handling(self): + """Test handling of nested errors with causes.""" + from madengine.core.errors import RuntimeError as MADRuntimeError, OrchestrationError + + # Create a chain of errors + original_error = ConnectionError("Network timeout") + runtime_error = MADRuntimeError("Operation failed", cause=original_error) + final_error = OrchestrationError("Orchestration failed", cause=runtime_error) + + # Test the chain + assert final_error.cause == runtime_error + assert runtime_error.cause == original_error + + # Test handling preserves the chain + mock_console = Mock() + handler = ErrorHandler(console=mock_console, verbose=True) + handler.handle_error(final_error, show_traceback=True) + + # Should display error and potentially traceback + assert mock_console.print.call_count >= 1 + + def test_error_performance(self): + """Test that error handling is performant.""" + import time + + mock_console = Mock() + handler = ErrorHandler(console=mock_console) + + start_time = time.time() + + # Create and handle many errors + for i in range(100): + error = ValidationError( + f"Test error {i}", + context=create_error_context( + operation=f"test_op_{i}", + component="PerformanceTest" + ) + ) + handler.handle_error(error) + + end_time = time.time() + + # Should handle 100 errors in under 1 second + assert end_time - start_time < 1.0 + + # Verify all errors were handled + assert mock_console.print.call_count == 100 + + +class TestErrorSystemBackwardCompatibility: + """Test backward compatibility of the error system.""" + + def test_legacy_exception_handling_still_works(self): + """Test that legacy exception patterns still work.""" + try: + # Simulate old-style exception raising + raise ValueError("Legacy error") + except Exception as e: + # Should be able to handle with new system + mock_console = Mock() + handler = ErrorHandler(console=mock_console) + + context = create_error_context( + operation="legacy_handling", + component="LegacyTest" + ) + + handler.handle_error(e, context=context) + + # Should handle gracefully + mock_console.print.assert_called_once() + + def test_error_system_without_rich(self): + """Test error system fallback when Rich is not available.""" + # This test verifies the system degrades gracefully + # In practice, Rich is a hard dependency, but we test the concept + + with patch('madengine.core.errors.Console', side_effect=ImportError): + # Should still be able to create basic errors + error = ValidationError("Test without Rich") + assert str(error) == "Test without Rich" + assert error.recoverable is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_live_output.py b/tests/test_live_output.py index 76a0c4f4..bd04880f 100644 --- a/tests/test_live_output.py +++ b/tests/test_live_output.py @@ -2,9 +2,11 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import re import pytest + # project modules from .fixtures.utils import global_data from .fixtures.utils import BASE_DIR, MODEL_DIR @@ -13,29 +15,51 @@ class TestLiveOutputFunctionality: """Test the live output functionality.""" - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_default_silent_run(self, global_data, clean_test_temp_files): - """ + """ default run is silent """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy") + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy" + ) - regexp = re.compile(r'performance: [0-9]* samples_per_second') + regexp = re.compile(r"performance: [0-9]* samples_per_second") if regexp.search(output): pytest.fail("default run is not silent") if "ARG BASE_DOCKER=" in output: pytest.fail("default run is not silent") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_liveOutput_prints_output_to_screen(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_liveOutput_prints_output_to_screen( + self, global_data, clean_test_temp_files + ): """ - live_output prints output to screen + live_output prints output to screen """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --live-output") + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --live-output" + ) - regexp = re.compile(r'performance: [0-9]* samples_per_second') + regexp = re.compile(r"performance: [0-9]* samples_per_second") if not regexp.search(output): pytest.fail("default run is silent") diff --git a/tests/test_mad.py b/tests/test_mad.py index 055eb212..845de34f 100644 --- a/tests/test_mad.py +++ b/tests/test_mad.py @@ -1,67 +1,136 @@ -"""Test the mad module. +"""Test the legacy mad.py module (argparse-based CLI). + +This module tests the LEGACY argparse-based command-line interface for +backward compatibility. The legacy mad.py uses argparse and provides the +original MADEngine command structure. + +For NEW Typer-based CLI tests, see test_mad_cli.py. + +NOTE: Both interfaces are maintained for backward compatibility: +- mad.py (legacy) - argparse-based, original interface +- mad_cli.py (modern) - Typer-based, enhanced interface with Rich output Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import sys import subprocess import typing + # third-party modules import pytest + # project modules from madengine import mad -class TestMad: - """Test the mad module. +class TestLegacyMad: + """Test the legacy mad.py module (argparse-based). - test_run_model: run python3 mad.py --help + These tests ensure backward compatibility with the original + argparse-based CLI. All tests run the script directly via subprocess + to verify the entry point works correctly. """ + def test_mad_cli(self): + """Test legacy mad.py --help command.""" # Construct the path to the script - script_path = os.path.join(os.path.dirname(__file__), "../src/madengine", "mad.py") + script_path = os.path.join( + os.path.dirname(__file__), "../src/madengine", "mad.py" + ) # Run the script with arguments using subprocess.run - result = subprocess.run([sys.executable, script_path, "--help"], stdout=subprocess.PIPE) - print(result.stdout.decode("utf-8")) + result = subprocess.run( + [sys.executable, script_path, "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + output = result.stdout.decode("utf-8") + print(output) assert result.returncode == 0 + assert "Models automation and dashboarding" in output or "command-line tool" in output def test_mad_run_cli(self): - # Construct the path to the script - script_path = os.path.join(os.path.dirname(__file__), "../src/madengine", "mad.py") - # Run the script with arguments using subprocess.run - result = subprocess.run([sys.executable, script_path, "run", "--help"], stdout=subprocess.PIPE) - print(result.stdout.decode("utf-8")) + """Test legacy mad.py run --help command.""" + script_path = os.path.join( + os.path.dirname(__file__), "../src/madengine", "mad.py" + ) + result = subprocess.run( + [sys.executable, script_path, "run", "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + output = result.stdout.decode("utf-8") + print(output) assert result.returncode == 0 + assert "--tags" in output # Verify run command has expected options def test_mad_report_cli(self): - # Construct the path to the script - script_path = os.path.join(os.path.dirname(__file__), "../src/madengine", "mad.py") - # Run the script with arguments using subprocess.run - result = subprocess.run([sys.executable, script_path, "report", "--help"], stdout=subprocess.PIPE) - print(result.stdout.decode("utf-8")) + """Test legacy mad.py report --help command.""" + script_path = os.path.join( + os.path.dirname(__file__), "../src/madengine", "mad.py" + ) + result = subprocess.run( + [sys.executable, script_path, "report", "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + output = result.stdout.decode("utf-8") + print(output) assert result.returncode == 0 def test_mad_database_cli(self): - # Construct the path to the script - script_path = os.path.join(os.path.dirname(__file__), "../src/madengine", "mad.py") - # Run the script with arguments using subprocess.run - result = subprocess.run([sys.executable, script_path, "database", "--help"], stdout=subprocess.PIPE) - print(result.stdout.decode("utf-8")) + """Test legacy mad.py database --help command.""" + script_path = os.path.join( + os.path.dirname(__file__), "../src/madengine", "mad.py" + ) + result = subprocess.run( + [sys.executable, script_path, "database", "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + output = result.stdout.decode("utf-8") + print(output) assert result.returncode == 0 def test_mad_discover_cli(self): - # Construct the path to the script - script_path = os.path.join(os.path.dirname(__file__), "../src/madengine", "mad.py") - # Run the script with arguments using subprocess.run - result = subprocess.run([sys.executable, script_path, "discover", "--help"], stdout=subprocess.PIPE) - print(result.stdout.decode("utf-8")) - assert result.returncode == 0 + """Test legacy mad.py discover --help command.""" + script_path = os.path.join( + os.path.dirname(__file__), "../src/madengine", "mad.py" + ) + result = subprocess.run( + [sys.executable, script_path, "discover", "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + output = result.stdout.decode("utf-8") + print(output) + assert result.returncode == 0 def test_mad_version_cli(self): - # Construct the path to the script - script_path = os.path.join(os.path.dirname(__file__), "../src/madengine", "mad.py") - # Run the script with arguments using subprocess.run - result = subprocess.run([sys.executable, script_path, "--version"], stdout=subprocess.PIPE) - print(result.stdout.decode("utf-8")) + """Test legacy mad.py --version command.""" + script_path = os.path.join( + os.path.dirname(__file__), "../src/madengine", "mad.py" + ) + result = subprocess.run( + [sys.executable, script_path, "--version"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + output = result.stdout.decode("utf-8") + print(output) assert result.returncode == 0 + # Version should be printed (could be "dev" or actual version) + assert len(output.strip()) > 0 + + def test_legacy_and_modern_cli_both_work(self): + """Integration test: Verify both CLI interfaces are accessible.""" + # Test legacy can be imported + from madengine import mad + assert hasattr(mad, 'main') + + # Test modern can be imported + from madengine import mad_cli + assert hasattr(mad_cli, 'app') + assert hasattr(mad_cli, 'cli_main') diff --git a/tests/test_mad_cli.py b/tests/test_mad_cli.py new file mode 100644 index 00000000..cf3c89a7 --- /dev/null +++ b/tests/test_mad_cli.py @@ -0,0 +1,1161 @@ +"""Test the mad_cli module. + +This module tests the modern Typer-based command-line interface functionality. + +GPU Hardware Support: +- Tests automatically detect if the machine has GPU hardware +- GPU-dependent tests are skipped on CPU-only machines using @requires_gpu decorator +- Tests use auto-generated additional context appropriate for the current machine +- CPU-only machines default to AMD GPU vendor for build compatibility + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +# built-in modules +import json +import os +import sys +import tempfile +import unittest.mock +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch, mock_open + +# third-party modules +import pytest +import typer +from typer.testing import CliRunner + +# project modules +from madengine import mad_cli +from madengine.mad_cli import ( + app, + setup_logging, + create_args_namespace, + validate_additional_context, + save_summary_with_feedback, + display_results_table, + ExitCode, + VALID_GPU_VENDORS, + VALID_GUEST_OS, + DEFAULT_MANIFEST_FILE, + DEFAULT_PERF_OUTPUT, + DEFAULT_DATA_CONFIG, + DEFAULT_TOOLS_CONFIG, + DEFAULT_ANSIBLE_OUTPUT, + DEFAULT_TIMEOUT, +) +from .fixtures.utils import ( + BASE_DIR, + MODEL_DIR, + has_gpu, + requires_gpu, + generate_additional_context_for_machine, +) + + +class TestSetupLogging: + """Test the setup_logging function.""" + + @patch("madengine.mad_cli.logging.basicConfig") + def test_setup_logging_verbose(self, mock_basic_config): + """Test logging setup with verbose mode enabled.""" + setup_logging(verbose=True) + + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + assert call_args[1]["level"] == 10 # logging.DEBUG + + @patch("madengine.mad_cli.logging.basicConfig") + def test_setup_logging_normal(self, mock_basic_config): + """Test logging setup with normal mode.""" + setup_logging(verbose=False) + + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + assert call_args[1]["level"] == 20 # logging.INFO + + +class TestCreateArgsNamespace: + """Test the create_args_namespace function.""" + + def test_create_args_namespace_basic(self): + """Test creating args namespace with basic parameters.""" + args = create_args_namespace( + tags=["dummy"], registry="localhost:5000", verbose=True + ) + + assert args.tags == ["dummy"] + assert args.registry == "localhost:5000" + assert args.verbose is True + + def test_create_args_namespace_empty(self): + """Test creating args namespace with no parameters.""" + args = create_args_namespace() + + # Should create an object with no attributes + assert not hasattr(args, "tags") + + def test_create_args_namespace_complex(self): + """Test creating args namespace with complex parameters.""" + args = create_args_namespace( + tags=["model1", "model2"], + additional_context='{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}', + timeout=300, + keep_alive=True, + verbose=False, + ) + + assert args.tags == ["model1", "model2"] + assert args.additional_context == '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' + assert args.timeout == 300 + assert args.keep_alive is True + assert args.verbose is False + + +class TestValidateAdditionalContext: + """Test the validate_additional_context function.""" + + def test_validate_additional_context_valid_string(self): + """Test validation with valid additional context from string.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + with patch("madengine.mad_cli.console") as mock_console: + result = validate_additional_context(context_json) + + assert result == context + mock_console.print.assert_called() + + def test_validate_additional_context_valid_file(self): + """Test validation with valid additional context from file.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(context, f) + temp_file = f.name + + try: + with patch("madengine.mad_cli.console") as mock_console: + result = validate_additional_context("{}", temp_file) + + assert result == context + mock_console.print.assert_called() + finally: + os.unlink(temp_file) + + def test_validate_additional_context_string_overrides_file(self): + """Test that string context overrides file context.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Create file with different context + file_context = {"gpu_vendor": "NVIDIA", "guest_os": "CENTOS"} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(file_context, f) + temp_file = f.name + + try: + with patch("madengine.mad_cli.console") as mock_console: + result = validate_additional_context(context_json, temp_file) + + assert result == context + finally: + os.unlink(temp_file) + + def test_validate_additional_context_invalid_json(self): + """Test validation with invalid JSON.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context("invalid json") + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + def test_validate_additional_context_missing_gpu_vendor(self): + """Test validation with missing gpu_vendor.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context('{"guest_os": "UBUNTU"}') + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + def test_validate_additional_context_missing_guest_os(self): + """Test validation with missing guest_os.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context('{"gpu_vendor": "AMD"}') + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + def test_validate_additional_context_invalid_gpu_vendor(self): + """Test validation with invalid gpu_vendor.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context( + '{"gpu_vendor": "INVALID", "guest_os": "UBUNTU"}' + ) + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + def test_validate_additional_context_invalid_guest_os(self): + """Test validation with invalid guest_os.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context( + '{"gpu_vendor": "AMD", "guest_os": "INVALID"}' + ) + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + def test_validate_additional_context_case_insensitive(self): + """Test validation with case insensitive values.""" + with patch("madengine.mad_cli.console") as mock_console: + result = validate_additional_context( + '{"gpu_vendor": "amd", "guest_os": "ubuntu"}' + ) + + assert result == {"gpu_vendor": "amd", "guest_os": "ubuntu"} + mock_console.print.assert_called() + + def test_validate_additional_context_empty_context(self): + """Test validation with empty context.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context("{}") + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + def test_validate_additional_context_file_not_found(self): + """Test validation with non-existent file.""" + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + validate_additional_context("{}", "non_existent_file.json") + + assert exc_info.value.exit_code == ExitCode.INVALID_ARGS + mock_console.print.assert_called() + + +class TestSaveSummaryWithFeedback: + """Test the save_summary_with_feedback function.""" + + def test_save_summary_success(self): + """Test successful summary saving.""" + summary = {"successful_builds": ["model1", "model2"], "failed_builds": []} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_file = f.name + + try: + with patch("madengine.mad_cli.console") as mock_console: + save_summary_with_feedback(summary, temp_file, "Build") + + # Verify file was written + with open(temp_file, "r") as f: + saved_data = json.load(f) + assert saved_data == summary + + mock_console.print.assert_called() + finally: + os.unlink(temp_file) + + def test_save_summary_no_output_path(self): + """Test summary saving with no output path.""" + summary = {"successful_builds": ["model1"], "failed_builds": []} + + with patch("madengine.mad_cli.console") as mock_console: + save_summary_with_feedback(summary, None, "Build") + + # Should not call console.print for saving + mock_console.print.assert_not_called() + + def test_save_summary_io_error(self): + """Test summary saving with IO error.""" + summary = {"successful_builds": ["model1"], "failed_builds": []} + + with patch("madengine.mad_cli.console") as mock_console: + with pytest.raises(typer.Exit) as exc_info: + save_summary_with_feedback(summary, "/invalid/path/file.json", "Build") + + assert exc_info.value.exit_code == ExitCode.FAILURE + mock_console.print.assert_called() + + +class TestDisplayResultsTable: + """Test the display_results_table function.""" + + def test_display_results_table_build_success(self): + """Test displaying build results table with successes.""" + summary = {"successful_builds": ["model1", "model2"], "failed_builds": []} + + with patch("madengine.mad_cli.console") as mock_console: + display_results_table(summary, "Build Results") + + mock_console.print.assert_called() + + def test_display_results_table_build_failures(self): + """Test displaying build results table with failures.""" + summary = { + "successful_builds": ["model1"], + "failed_builds": ["model2", "model3"], + } + + with patch("madengine.mad_cli.console") as mock_console: + display_results_table(summary, "Build Results") + + mock_console.print.assert_called() + + def test_display_results_table_run_results(self): + """Test displaying run results table.""" + summary = { + "successful_runs": [ + {"model": "model1", "status": "success"}, + {"model": "model2", "status": "success"}, + ], + "failed_runs": [{"model": "model3", "status": "failed"}], + } + + with patch("madengine.mad_cli.console") as mock_console: + display_results_table(summary, "Run Results") + + mock_console.print.assert_called() + + def test_display_results_table_empty_results(self): + """Test displaying empty results table.""" + summary = {"successful_builds": [], "failed_builds": []} + + with patch("madengine.mad_cli.console") as mock_console: + display_results_table(summary, "Empty Results") + + mock_console.print.assert_called() + + def test_display_results_table_many_items(self): + """Test displaying results table with many items (truncation).""" + summary = { + "successful_builds": [f"model{i}" for i in range(10)], + "failed_builds": [], + } + + with patch("madengine.mad_cli.console") as mock_console: + display_results_table(summary, "Many Results") + + mock_console.print.assert_called() + + +class TestBuildCommand: + """Test the build command.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_build_command_success(self, mock_validate, mock_orchestrator_class): + """Test successful build command.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.build_phase.return_value = { + "successful_builds": ["model1"], + "failed_builds": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["build", "--tags", "dummy", "--additional-context", context_json] + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_validate.assert_called_once() + mock_orchestrator.build_phase.assert_called_once() + + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_build_command_failure(self, mock_validate, mock_orchestrator_class): + """Test build command with failures.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator with failures + mock_orchestrator = MagicMock() + mock_orchestrator.build_phase.return_value = { + "successful_builds": [], + "failed_builds": ["model1", "model2"], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["build", "--tags", "dummy", "--additional-context", context_json] + ) + + assert result.exit_code == ExitCode.BUILD_FAILURE + + def test_build_command_invalid_context(self): + """Test build command with invalid context.""" + result = self.runner.invoke( + app, ["build", "--tags", "dummy", "--additional-context", "invalid json"] + ) + + assert result.exit_code == ExitCode.INVALID_ARGS + + def test_build_command_missing_context(self): + """Test build command with missing context.""" + result = self.runner.invoke(app, ["build", "--tags", "dummy"]) + + assert result.exit_code == ExitCode.INVALID_ARGS + + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_build_command_with_registry(self, mock_validate, mock_orchestrator_class): + """Test build command with registry option.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.build_phase.return_value = { + "successful_builds": ["model1"], + "failed_builds": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, + [ + "build", + "--tags", + "dummy", + "--registry", + "localhost:5000", + "--additional-context", + context_json, + ], + ) + + assert result.exit_code == ExitCode.SUCCESS + # Verify registry was passed to build_phase + mock_orchestrator.build_phase.assert_called_once() + call_args = mock_orchestrator.build_phase.call_args + assert call_args[1]["registry"] == "localhost:5000" + + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_build_command_exception_handling( + self, mock_validate, mock_orchestrator_class + ): + """Test build command exception handling.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator to raise exception + mock_orchestrator_class.side_effect = Exception("Test error") + + result = self.runner.invoke( + app, ["build", "--tags", "dummy", "--additional-context", context_json] + ) + + assert result.exit_code == ExitCode.FAILURE + + +class TestRunCommand: + """Test the run command.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + def test_run_command_execution_only(self, mock_orchestrator_class, mock_exists): + """Test run command in execution-only mode (manifest exists).""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.run_phase.return_value = { + "successful_runs": [{"model": "model1"}], + "failed_runs": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_orchestrator.run_phase.assert_called_once() + + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_run_command_full_workflow( + self, mock_validate, mock_orchestrator_class, mock_exists + ): + """Test run command in full workflow mode (no manifest).""" + # Mock manifest file doesn't exist + mock_exists.return_value = False + + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.build_phase.return_value = { + "successful_builds": ["model1"], + "failed_builds": [], + } + mock_orchestrator.run_phase.return_value = { + "successful_runs": [{"model": "model1"}], + "failed_runs": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--tags", "dummy", "--additional-context", context_json] + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_orchestrator.build_phase.assert_called_once() + mock_orchestrator.run_phase.assert_called_once() + + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_run_command_build_failure( + self, mock_validate, mock_orchestrator_class, mock_exists + ): + """Test run command with build failure in full workflow.""" + # Mock manifest file doesn't exist + mock_exists.return_value = False + + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator with build failure + mock_orchestrator = MagicMock() + mock_orchestrator.build_phase.return_value = { + "successful_builds": [], + "failed_builds": ["model1"], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--tags", "dummy", "--additional-context", context_json] + ) + + assert result.exit_code == ExitCode.BUILD_FAILURE + mock_orchestrator.build_phase.assert_called_once() + # run_phase should not be called if build fails + mock_orchestrator.run_phase.assert_not_called() + + @requires_gpu("GPU execution tests require GPU hardware") + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + def test_run_command_execution_failure(self, mock_orchestrator_class, mock_exists): + """Test run command with execution failure.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock orchestrator with execution failure + mock_orchestrator = MagicMock() + mock_orchestrator.run_phase.return_value = { + "successful_runs": [], + "failed_runs": [{"model": "model1"}], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.RUN_FAILURE + + def test_run_command_invalid_timeout(self): + """Test run command with invalid timeout.""" + result = self.runner.invoke(app, ["run", "--timeout", "-5"]) + + assert result.exit_code == ExitCode.INVALID_ARGS + + @requires_gpu("GPU execution tests require GPU hardware") + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + def test_run_command_with_options(self, mock_orchestrator_class, mock_exists): + """Test run command with various options.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.run_phase.return_value = { + "successful_runs": [{"model": "model1"}], + "failed_runs": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, + [ + "run", + "--manifest-file", + "test_manifest.json", + "--timeout", + "300", + "--keep-alive", + "--keep-model-dir", + "--verbose", + ], + ) + + assert result.exit_code == ExitCode.SUCCESS + # Verify options were passed + call_args = mock_orchestrator.run_phase.call_args + assert call_args[1]["timeout"] == 300 + assert call_args[1]["keep_alive"] is True + + +class TestGenerateAnsibleCommand: + """Test the generate ansible command.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + @patch("madengine.mad_cli.generate_ansible_setup") + @patch("madengine.mad_cli.os.path.exists") + def test_generate_ansible_success(self, mock_exists, mock_generate_ansible): + """Test successful ansible generation.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock the return value of generate_ansible_setup + mock_generate_ansible.return_value = { + "playbook": "ansible-setup/madengine_playbook.yml" + } + + result = self.runner.invoke( + app, + [ + "generate", + "ansible", + "--manifest-file", + "test_manifest.json", + "--output", + "test_playbook.yml", + ], + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_generate_ansible.assert_called_once_with( + manifest_file="test_manifest.json", environment="default", output_dir="." + ) + + @patch("madengine.mad_cli.os.path.exists") + def test_generate_ansible_manifest_not_found(self, mock_exists): + """Test ansible generation with missing manifest.""" + # Mock manifest file doesn't exist + mock_exists.return_value = False + + result = self.runner.invoke( + app, ["generate", "ansible", "--manifest-file", "missing_manifest.json"] + ) + + assert result.exit_code == ExitCode.FAILURE + + @patch("madengine.mad_cli.generate_ansible_setup") + @patch("madengine.mad_cli.os.path.exists") + def test_generate_ansible_exception(self, mock_exists, mock_generate_ansible): + """Test ansible generation with exception.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock exception in ansible generation + mock_generate_ansible.side_effect = Exception("Test error") + + result = self.runner.invoke( + app, ["generate", "ansible", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.FAILURE + + @patch("madengine.mad_cli.generate_ansible_setup") + @patch("madengine.mad_cli.os.path.exists") + def test_generate_ansible_default_values(self, mock_exists, mock_generate_ansible): + """Test ansible generation with default values.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock the return value of generate_ansible_setup + mock_generate_ansible.return_value = { + "playbook": "ansible-setup/madengine_playbook.yml" + } + + result = self.runner.invoke(app, ["generate", "ansible"]) + + assert result.exit_code == ExitCode.SUCCESS + mock_generate_ansible.assert_called_once_with( + manifest_file=DEFAULT_MANIFEST_FILE, environment="default", output_dir="." + ) + + +class TestGenerateK8sCommand: + """Test the generate k8s command.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + @patch("madengine.mad_cli.generate_k8s_setup") + @patch("madengine.mad_cli.os.path.exists") + def test_generate_k8s_success(self, mock_exists, mock_generate_k8s): + """Test successful k8s generation.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock the return value of generate_k8s_setup + mock_generate_k8s.return_value = { + "deployment": ["k8s-setup/deployment.yml"], + "service": ["k8s-setup/service.yml"], + } + + result = self.runner.invoke( + app, + [ + "generate", + "k8s", + "--manifest-file", + "test_manifest.json", + "--output-dir", + "test-k8s", + ], + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_generate_k8s.assert_called_once_with( + manifest_file="test_manifest.json", + environment="default", + output_dir="test-k8s", + ) + + @patch("madengine.mad_cli.os.path.exists") + def test_generate_k8s_manifest_not_found(self, mock_exists): + """Test k8s generation with missing manifest.""" + # Mock manifest file doesn't exist + mock_exists.return_value = False + + result = self.runner.invoke( + app, ["generate", "k8s", "--manifest-file", "missing_manifest.json"] + ) + + assert result.exit_code == ExitCode.FAILURE + + @patch("madengine.mad_cli.generate_k8s_setup") + @patch("madengine.mad_cli.os.path.exists") + def test_generate_k8s_exception(self, mock_exists, mock_generate_k8s): + """Test k8s generation with exception.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock exception in k8s generation + mock_generate_k8s.side_effect = Exception("Test error") + + result = self.runner.invoke( + app, ["generate", "k8s", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.FAILURE + + @patch("madengine.mad_cli.generate_k8s_setup") + @patch("madengine.mad_cli.os.path.exists") + def test_generate_k8s_default_values(self, mock_exists, mock_generate_k8s): + """Test k8s generation with default values.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock the return value of generate_k8s_setup + mock_generate_k8s.return_value = { + "deployment": ["k8s-setup/deployment.yml"], + "service": ["k8s-setup/service.yml"], + } + + result = self.runner.invoke(app, ["generate", "k8s"]) + + assert result.exit_code == ExitCode.SUCCESS + mock_generate_k8s.assert_called_once_with( + manifest_file=DEFAULT_MANIFEST_FILE, + environment="default", + output_dir="k8s-setup", + ) + + +class TestMainCallback: + """Test the main callback function.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + def test_main_version_flag(self): + """Test main callback with version flag.""" + result = self.runner.invoke(app, ["--version"]) + + assert result.exit_code == ExitCode.SUCCESS + assert "madengine-cli" in result.stdout + assert "version" in result.stdout + + def test_main_help(self): + """Test main callback shows help when no command.""" + result = self.runner.invoke(app, []) + + # Should show help and exit + assert "madengine Distributed Orchestrator" in result.stdout + + +class TestConstants: + """Test module constants.""" + + def test_exit_codes(self): + """Test exit code constants.""" + assert ExitCode.SUCCESS == 0 + assert ExitCode.FAILURE == 1 + assert ExitCode.BUILD_FAILURE == 2 + assert ExitCode.RUN_FAILURE == 3 + assert ExitCode.INVALID_ARGS == 4 + + def test_valid_values(self): + """Test valid value constants.""" + assert "AMD" in VALID_GPU_VENDORS + assert "NVIDIA" in VALID_GPU_VENDORS + assert "INTEL" in VALID_GPU_VENDORS + + assert "UBUNTU" in VALID_GUEST_OS + assert "CENTOS" in VALID_GUEST_OS + assert "ROCKY" in VALID_GUEST_OS + + def test_default_values(self): + """Test default value constants.""" + assert DEFAULT_MANIFEST_FILE == "build_manifest.json" + assert DEFAULT_PERF_OUTPUT == "perf.csv" + assert DEFAULT_DATA_CONFIG == "data.json" + assert DEFAULT_TOOLS_CONFIG == "./scripts/common/tools.json" + assert DEFAULT_ANSIBLE_OUTPUT == "madengine_distributed.yml" + assert DEFAULT_TIMEOUT == -1 + + +class TestCliMain: + """Test the cli_main function.""" + + @patch("madengine.mad_cli.app") + def test_cli_main_success(self, mock_app): + """Test successful cli_main execution.""" + mock_app.return_value = None + + # Should not raise any exception + mad_cli.cli_main() + + mock_app.assert_called_once() + + @patch("madengine.mad_cli.app") + @patch("madengine.mad_cli.sys.exit") + def test_cli_main_keyboard_interrupt(self, mock_exit, mock_app): + """Test cli_main with keyboard interrupt.""" + mock_app.side_effect = KeyboardInterrupt() + + mad_cli.cli_main() + + mock_exit.assert_called_once_with(ExitCode.FAILURE) + + @patch("madengine.mad_cli.app") + @patch("madengine.mad_cli.sys.exit") + @patch("madengine.mad_cli.console") + def test_cli_main_unexpected_exception(self, mock_console, mock_exit, mock_app): + """Test cli_main with unexpected exception.""" + mock_app.side_effect = Exception("Test error") + + mad_cli.cli_main() + + mock_exit.assert_called_once_with(ExitCode.FAILURE) + mock_console.print.assert_called() + mock_console.print_exception.assert_called_once() + + +class TestIntegration: + """Integration tests for the CLI.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + def test_help_command(self): + """Test help command works.""" + result = self.runner.invoke(app, ["--help"]) + + assert result.exit_code == 0 + assert "madengine Distributed Orchestrator" in result.stdout + + def test_build_help(self): + """Test build command help.""" + result = self.runner.invoke(app, ["build", "--help"]) + + assert result.exit_code == 0 + assert "Build Docker images" in result.stdout + + def test_run_help(self): + """Test run command help.""" + result = self.runner.invoke(app, ["run", "--help"]) + + assert result.exit_code == 0 + assert "Run model containers" in result.stdout + + def test_generate_help(self): + """Test generate command help.""" + result = self.runner.invoke(app, ["generate", "--help"]) + + assert result.exit_code == 0 + assert "Generate orchestration files" in result.stdout + + def test_generate_ansible_help(self): + """Test generate ansible command help.""" + result = self.runner.invoke(app, ["generate", "ansible", "--help"]) + + assert result.exit_code == 0 + assert "Generate Ansible playbook" in result.stdout + + def test_generate_k8s_help(self): + """Test generate k8s command help.""" + result = self.runner.invoke(app, ["generate", "k8s", "--help"]) + + assert result.exit_code == 0 + assert "Generate Kubernetes manifests" in result.stdout + + +class TestCpuOnlyMachine: + """Tests specifically for CPU-only machines.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + def test_cpu_only_machine_detection(self): + """Test that GPU detection works.""" + # This test should always pass, regardless of hardware + has_gpu_available = has_gpu() + assert isinstance(has_gpu_available, bool) + + def test_auto_context_generation_cpu_only(self): + """Test that auto-generated context is appropriate for CPU-only machines.""" + context = generate_additional_context_for_machine() + + # Should always have required fields + assert "gpu_vendor" in context + assert "guest_os" in context + + # On CPU-only machines, should use default AMD for build compatibility + if not has_gpu(): + assert context["gpu_vendor"] == "AMD" + assert context["guest_os"] == "UBUNTU" + + @patch("madengine.mad_cli.DistributedOrchestrator") + @patch("madengine.mad_cli.validate_additional_context") + def test_build_on_cpu_only_machine(self, mock_validate, mock_orchestrator_class): + """Test build command works on CPU-only machines.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + # Mock validation + mock_validate.return_value = context + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.build_phase.return_value = { + "successful_builds": ["model1"], + "failed_builds": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["build", "--tags", "dummy", "--additional-context", context_json] + ) + + # Should work on CPU-only machines for build phase + assert result.exit_code == ExitCode.SUCCESS + mock_validate.assert_called_once() + mock_orchestrator.build_phase.assert_called_once() + + +class TestGpuRequiredTests: + """Tests that require GPU hardware.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + @requires_gpu("Test requires GPU hardware") + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + def test_run_with_gpu_required(self, mock_orchestrator_class, mock_exists): + """Test run command that requires GPU hardware.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.run_phase.return_value = { + "successful_runs": [{"model": "model1"}], + "failed_runs": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_orchestrator.run_phase.assert_called_once() + + @requires_gpu("Test requires AMD GPU hardware") + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + def test_run_with_amd_gpu_required(self, mock_orchestrator_class, mock_exists): + """Test run command that requires AMD GPU hardware.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.run_phase.return_value = { + "successful_runs": [{"model": "model1"}], + "failed_runs": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_orchestrator.run_phase.assert_called_once() + + @requires_gpu("Test requires NVIDIA GPU hardware") + @patch("madengine.mad_cli.os.path.exists") + @patch("madengine.mad_cli.DistributedOrchestrator") + def test_run_with_nvidia_gpu_required(self, mock_orchestrator_class, mock_exists): + """Test run command that requires NVIDIA GPU hardware.""" + # Mock manifest file exists + mock_exists.return_value = True + + # Mock orchestrator + mock_orchestrator = MagicMock() + mock_orchestrator.run_phase.return_value = { + "successful_runs": [{"model": "model1"}], + "failed_runs": [], + } + mock_orchestrator_class.return_value = mock_orchestrator + + result = self.runner.invoke( + app, ["run", "--manifest-file", "test_manifest.json"] + ) + + assert result.exit_code == ExitCode.SUCCESS + mock_orchestrator.run_phase.assert_called_once() + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def setup_method(self): + """Set up test fixtures.""" + self.runner = CliRunner() + + def test_build_empty_tags(self): + """Test build command with empty tags list.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + result = self.runner.invoke( + app, ["build", "--additional-context", context_json] + ) + + # Should handle empty tags gracefully + assert result.exit_code in [ + ExitCode.SUCCESS, + ExitCode.BUILD_FAILURE, + ExitCode.INVALID_ARGS, + ] + + def test_run_zero_timeout(self): + """Test run command with zero timeout.""" + result = self.runner.invoke(app, ["run", "--timeout", "0"]) + + # Zero timeout should be valid (no timeout) + # Exit code depends on other factors but shouldn't be INVALID_ARGS for timeout + assert ( + result.exit_code != ExitCode.INVALID_ARGS or "Timeout" not in result.stdout + ) + + @patch("madengine.mad_cli.validate_additional_context") + def test_context_file_and_string_both_provided(self, mock_validate): + """Test providing both context file and string.""" + # Use auto-generated context for current machine + context = generate_additional_context_for_machine() + context_json = json.dumps(context) + + mock_validate.return_value = context + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump({"gpu_vendor": "NVIDIA", "guest_os": "CENTOS"}, f) + temp_file = f.name + + try: + result = self.runner.invoke( + app, + [ + "build", + "--additional-context", + context_json, + "--additional-context-file", + temp_file, + ], + ) + + # Should call validate with both parameters + mock_validate.assert_called_once() + finally: + os.unlink(temp_file) diff --git a/tests/test_misc.py b/tests/test_misc.py deleted file mode 100644 index 11a6fa81..00000000 --- a/tests/test_misc.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Test the misc modules. - -Copyright (c) Advanced Micro Devices, Inc. All rights reserved. -""" -# built-in modules -import os -import sys -import csv -import pandas as pd -# 3rd party modules -import pytest -# project modules -from .fixtures.utils import BASE_DIR, MODEL_DIR -from .fixtures.utils import global_data -from .fixtures.utils import clean_test_temp_files - - -class TestMiscFunctionality: - - @pytest.mark.parametrize('clean_test_temp_files', [['perf_test.csv', 'perf_test.html']], indirect=True) - def test_output_commandline_argument_writes_csv_correctly(self, global_data, clean_test_temp_files): - """ - output command-line argument writes csv file to specified output path - """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy -o perf_test.csv") - success = False - with open(os.path.join(BASE_DIR, 'perf_test.csv'), 'r') as csv_file: - csv_reader = csv.DictReader(csv_file) - for row in csv_reader: - if row['model'] == 'dummy': - if row['status'] == 'SUCCESS': - success = True - break - else: - pytest.fail("model in perf_test.csv did not run successfully.") - if not success: - pytest.fail("model, dummy, not found in perf_test.csv.") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf_test.csv', 'perf_test.html']], indirect=True) - def test_commandline_argument_skip_gpu_arch(self, global_data, clean_test_temp_files): - """ - skip_gpu_arch command-line argument skips GPU architecture check - """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_skip_gpu_arch") - if 'Skipping model' not in output: - pytest.fail("Enable skipping gpu arch for running model is failed.") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf_test.csv', 'perf_test.html']], indirect=True) - def test_commandline_argument_disable_skip_gpu_arch_fail(self, global_data, clean_test_temp_files): - """ - skip_gpu_arch command-line argument fails GPU architecture check - """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_skip_gpu_arch --disable-skip-gpu-arch") - # Check if exception with message 'Skipping model' is thrown - if 'Skipping model' in output: - pytest.fail("Disable skipping gpu arch for running model is failed.") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf_test.csv', 'perf_test.html']], indirect=True) - def test_output_multi_results(self, global_data, clean_test_temp_files): - """ - test output multiple results - """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_multi") - # Check if multiple results are written to perf_dummy.csv - success = False - # Read the csv file to a dataframe using pandas - multi_df = pd.read_csv(os.path.join(BASE_DIR, 'perf_dummy.csv')) - # Check the number of rows in the dataframe is 4, and columns is 4 - if multi_df.shape == (4, 4): - success = True - if not success: - pytest.fail("The generated multi results is not correct.") - # Check if multiple results from perf_dummy.csv get copied over to perf.csv - perf_df = pd.read_csv(os.path.join(BASE_DIR, 'perf.csv')) - # Get the corresponding rows and columns from perf.csv - perf_df = perf_df[multi_df.columns] - perf_df = perf_df.iloc[-4:, :] - # Drop model columns from both dataframes; these will not match - # if multiple results csv has {model}, then perf csv has {tag_name}_{model} - multi_df = multi_df.drop('model', axis=1) - perf_df = perf_df.drop('model', axis=1) - if all(perf_df.columns == multi_df.columns): - success = True - if not success: - pytest.fail("The columns of the generated multi results do not match perf.csv.") - diff --git a/tests/test_multi_gpu_arch.py b/tests/test_multi_gpu_arch.py new file mode 100644 index 00000000..e46d8e10 --- /dev/null +++ b/tests/test_multi_gpu_arch.py @@ -0,0 +1,162 @@ +"""Comprehensive unit tests for multi-GPU architecture support in MADEngine. + +Covers: +- Multi-arch DockerBuilder logic (image naming, manifest, legacy/override) +- Dockerfile GPU variable parsing/validation +- Target architecture normalization and compatibility +- Run-phase manifest filtering by gpu_architecture + +All tests are logic/unit tests and do not require GPU hardware. +""" +import pytest +from unittest.mock import MagicMock, patch +from madengine.tools.docker_builder import DockerBuilder +from madengine.tools.distributed_orchestrator import DistributedOrchestrator + +class TestMultiGPUArch: + def setup_method(self): + self.context = MagicMock() + self.console = MagicMock() + self.builder = DockerBuilder(self.context, self.console) + + # Mock args for DistributedOrchestrator to avoid file reading issues + mock_args = MagicMock() + mock_args.additional_context = None + mock_args.additional_context_file = None + mock_args.live_output = True + mock_args.data_config_file_name = "data.json" + + # Create orchestrator with mocked args and build_only_mode to avoid GPU detection + self.orchestrator = DistributedOrchestrator(mock_args, build_only_mode=True) + + # --- DockerBuilder Multi-Arch Logic --- + @patch.object(DockerBuilder, "_get_dockerfiles_for_model") + @patch.object(DockerBuilder, "_check_dockerfile_has_gpu_variables") + @patch.object(DockerBuilder, "build_image") + def test_multi_arch_build_image_naming(self, mock_build_image, mock_check_gpu_vars, mock_get_dockerfiles): + model_info = {"name": "dummy", "dockerfile": "docker/dummy.Dockerfile"} + mock_get_dockerfiles.return_value = ["docker/dummy.Dockerfile"] + # GPU variable present + mock_check_gpu_vars.return_value = (True, "docker/dummy.Dockerfile") + mock_build_image.return_value = {"docker_image": "ci-dummy_dummy.ubuntu.amd_gfx908", "build_duration": 1.0} + result = self.builder._build_model_for_arch(model_info, "gfx908", None, False, None, "", None) + assert result[0]["docker_image"].endswith("_gfx908") + # GPU variable absent + mock_check_gpu_vars.return_value = (False, "docker/dummy.Dockerfile") + mock_build_image.return_value = {"docker_image": "ci-dummy_dummy.ubuntu.amd", "build_duration": 1.0} + result = self.builder._build_model_for_arch(model_info, "gfx908", None, False, None, "", None) + assert not result[0]["docker_image"].endswith("_gfx908") + + @patch.object(DockerBuilder, "_get_dockerfiles_for_model") + @patch.object(DockerBuilder, "_check_dockerfile_has_gpu_variables") + @patch.object(DockerBuilder, "build_image") + def test_multi_arch_manifest_fields(self, mock_build_image, mock_check_gpu_vars, mock_get_dockerfiles): + model_info = {"name": "dummy", "dockerfile": "docker/dummy.Dockerfile"} + mock_get_dockerfiles.return_value = ["docker/dummy.Dockerfile"] + mock_check_gpu_vars.return_value = (True, "docker/dummy.Dockerfile") + mock_build_image.return_value = {"docker_image": "ci-dummy_dummy.ubuntu.amd_gfx908", "build_duration": 1.0} + result = self.builder._build_model_for_arch(model_info, "gfx908", None, False, None, "", None) + assert result[0]["gpu_architecture"] == "gfx908" + + @patch.object(DockerBuilder, "_get_dockerfiles_for_model") + @patch.object(DockerBuilder, "build_image") + def test_legacy_single_arch_build(self, mock_build_image, mock_get_dockerfiles): + model_info = {"name": "dummy", "dockerfile": "docker/dummy.Dockerfile"} + mock_get_dockerfiles.return_value = ["docker/dummy.Dockerfile"] + mock_build_image.return_value = {"docker_image": "ci-dummy_dummy.ubuntu.amd", "build_duration": 1.0} + result = self.builder._build_model_single_arch(model_info, None, False, None, "", None) + assert result[0]["docker_image"] == "ci-dummy_dummy.ubuntu.amd" + + @patch.object(DockerBuilder, "_build_model_single_arch") + def test_additional_context_overrides_target_archs(self, mock_single_arch): + self.context.ctx = {"docker_build_arg": {"MAD_SYSTEM_GPU_ARCHITECTURE": "gfx908"}} + model_info = {"name": "dummy", "dockerfile": "docker/dummy.Dockerfile"} + mock_single_arch.return_value = [{"docker_image": "ci-dummy_dummy.ubuntu.amd", "build_duration": 1.0}] + result = self.builder.build_all_models([model_info], target_archs=["gfx908", "gfx90a"]) + assert result["successful_builds"][0]["docker_image"] == "ci-dummy_dummy.ubuntu.amd" + + # --- Dockerfile GPU Variable Parsing/Validation --- + def test_parse_dockerfile_gpu_variables(self): + dockerfile_content = """ + ARG MAD_SYSTEM_GPU_ARCHITECTURE=gfx908 + ENV PYTORCH_ROCM_ARCH=gfx908;gfx90a + ARG GPU_TARGETS=gfx908,gfx942 + ENV GFX_COMPILATION_ARCH=gfx908 + ARG GPU_ARCHS=gfx908;gfx90a;gfx942 + """ + result = self.builder._parse_dockerfile_gpu_variables(dockerfile_content) + assert result["MAD_SYSTEM_GPU_ARCHITECTURE"] == ["gfx908"] + assert result["PYTORCH_ROCM_ARCH"] == ["gfx908", "gfx90a"] + assert result["GPU_TARGETS"] == ["gfx908", "gfx942"] + assert result["GFX_COMPILATION_ARCH"] == ["gfx908"] + assert result["GPU_ARCHS"] == ["gfx908", "gfx90a", "gfx942"] + + def test_parse_dockerfile_gpu_variables_env_delimiter(self): + dockerfile_content = "ENV PYTORCH_ROCM_ARCH = gfx908,gfx90a" + result = self.builder._parse_dockerfile_gpu_variables(dockerfile_content) + assert result["PYTORCH_ROCM_ARCH"] == ["gfx908", "gfx90a"] + + def test_parse_malformed_dockerfile(self): + dockerfile_content = "ENV BAD_LINE\nARG MAD_SYSTEM_GPU_ARCHITECTURE=\nENV PYTORCH_ROCM_ARCH=\n" + result = self.builder._parse_dockerfile_gpu_variables(dockerfile_content) + assert isinstance(result, dict) + + # --- Target Architecture Normalization/Compatibility --- + def test_normalize_architecture_name(self): + cases = { + "gfx908": "gfx908", + "GFX908": "gfx908", + "mi100": "gfx908", + "mi-100": "gfx908", + "mi200": "gfx90a", + "mi-200": "gfx90a", + "mi210": "gfx90a", + "mi250": "gfx90a", + "mi300": "gfx940", + "mi-300": "gfx940", + "mi300a": "gfx940", + "mi300x": "gfx942", + "mi-300x": "gfx942", + "unknown": "unknown", + "": None, + } + for inp, expected in cases.items(): + assert self.builder._normalize_architecture_name(inp) == expected + + def test_is_target_arch_compatible_with_variable(self): + assert self.builder._is_target_arch_compatible_with_variable("MAD_SYSTEM_GPU_ARCHITECTURE", ["gfx908"], "gfx942") + assert self.builder._is_target_arch_compatible_with_variable("PYTORCH_ROCM_ARCH", ["gfx908", "gfx942"], "gfx942") + assert not self.builder._is_target_arch_compatible_with_variable("PYTORCH_ROCM_ARCH", ["gfx908"], "gfx942") + assert self.builder._is_target_arch_compatible_with_variable("GFX_COMPILATION_ARCH", ["gfx908"], "gfx908") + assert not self.builder._is_target_arch_compatible_with_variable("GFX_COMPILATION_ARCH", ["gfx908"], "gfx942") + assert self.builder._is_target_arch_compatible_with_variable("UNKNOWN_VAR", ["foo"], "bar") + + def test_is_compilation_arch_compatible(self): + assert self.builder._is_compilation_arch_compatible("gfx908", "gfx908") + assert not self.builder._is_compilation_arch_compatible("gfx908", "gfx942") + assert self.builder._is_compilation_arch_compatible("foo", "foo") + + # --- Run-Phase Manifest Filtering --- + def test_filter_images_by_gpu_architecture(self): + orch = self.orchestrator + + # Test exact match + built_images = {"img1": {"gpu_architecture": "gfx908"}, "img2": {"gpu_architecture": "gfx90a"}} + filtered = orch._filter_images_by_gpu_architecture(built_images, "gfx908") + assert "img1" in filtered and "img2" not in filtered + + # Test legacy image (no arch field) + built_images = {"img1": {}, "img2": {"gpu_architecture": "gfx90a"}} + filtered = orch._filter_images_by_gpu_architecture(built_images, "gfx908") + assert "img1" in filtered # Legacy images should be included for backward compatibility + assert "img2" not in filtered + + # Test no match case + built_images = {"img1": {"gpu_architecture": "gfx90a"}, "img2": {"gpu_architecture": "gfx942"}} + filtered = orch._filter_images_by_gpu_architecture(built_images, "gfx908") + assert len(filtered) == 0 + + # Test all matching case + built_images = {"img1": {"gpu_architecture": "gfx908"}, "img2": {"gpu_architecture": "gfx908"}} + filtered = orch._filter_images_by_gpu_architecture(built_images, "gfx908") + assert len(filtered) == 2 diff --git a/tests/test_pre_post_scripts.py b/tests/test_pre_post_scripts.py index 50d64b30..db396ed4 100644 --- a/tests/test_pre_post_scripts.py +++ b/tests/test_pre_post_scripts.py @@ -2,13 +2,16 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import re import csv import time + # 3rd party modules import pytest + # project modules from .fixtures.utils import BASE_DIR, MODEL_DIR from .fixtures.utils import global_data @@ -18,16 +21,34 @@ class TestPrePostScriptsFunctionality: - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_pre_scripts_run_before_model(self, global_data, clean_test_temp_files): - """ + """ pre_scripts are run in docker container before model execution """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh'}] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh'}] }\" " + ) - regexp = re.compile(r'Pre-Script test called ([0-9]*)') + regexp = re.compile(r"Pre-Script test called ([0-9]*)") foundLine = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -35,19 +56,39 @@ def test_pre_scripts_run_before_model(self, global_data, clean_test_temp_files): match = regexp.search(line) if match: foundLine = match.groups()[0] - if foundLine != '0': - pytest.fail("pre_scripts specification did not run the selected pre-script.") + if foundLine != "0": + pytest.fail( + "pre_scripts specification did not run the selected pre-script." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_post_scripts_run_after_model(self, global_data, clean_test_temp_files): """ post_scripts are run in docker container after model execution """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh'}] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh'}] }\" " + ) - regexp = re.compile(r'Post-Script test called ([0-9]*)') + regexp = re.compile(r"Post-Script test called ([0-9]*)") foundLine = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -55,19 +96,39 @@ def test_post_scripts_run_after_model(self, global_data, clean_test_temp_files): match = regexp.search(line) if match: foundLine = match.groups()[0] - if foundLine != '0': - pytest.fail("post_scripts specification did not run the selected post-script.") + if foundLine != "0": + pytest.fail( + "post_scripts specification did not run the selected post-script." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_pre_scripts_accept_arguments(self, global_data, clean_test_temp_files): - """ + """ pre_scripts are run in docker container before model execution and accept arguments """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh', 'args':'1'}] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh', 'args':'1'}] }\" " + ) - regexp = re.compile(r'Pre-Script test called ([0-9]*)') + regexp = re.compile(r"Pre-Script test called ([0-9]*)") foundLine = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -75,19 +136,39 @@ def test_pre_scripts_accept_arguments(self, global_data, clean_test_temp_files): match = regexp.search(line) if match: foundLine = match.groups()[0] - if foundLine != '1': - pytest.fail("pre_scripts specification did not run the selected pre-script.") + if foundLine != "1": + pytest.fail( + "pre_scripts specification did not run the selected pre-script." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_post_scripts_accept_arguments(self, global_data, clean_test_temp_files): """ post_scripts are run in docker container after model execution and accept arguments """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh', 'args':'1'}] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh', 'args':'1'}] }\" " + ) - regexp = re.compile(r'Post-Script test called ([0-9]*)') + regexp = re.compile(r"Post-Script test called ([0-9]*)") foundLine = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -95,19 +176,41 @@ def test_post_scripts_accept_arguments(self, global_data, clean_test_temp_files) match = regexp.search(line) if match: foundLine = match.groups()[0] - if foundLine != '1': - pytest.fail("post_scripts specification did not run the selected post-script.") + if foundLine != "1": + pytest.fail( + "post_scripts specification did not run the selected post-script." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_both_pre_and_post_scripts_run_before_and_after_model(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_both_pre_and_post_scripts_run_before_and_after_model( + self, global_data, clean_test_temp_files + ): """ post_scripts are run in docker container after model execution """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh'}], 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh'}] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh'}], 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh'}] }\" " + ) - regexp = re.compile(r'Pre-Script test called ([0-9]*)') + regexp = re.compile(r"Pre-Script test called ([0-9]*)") foundLine = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -115,12 +218,22 @@ def test_both_pre_and_post_scripts_run_before_and_after_model(self, global_data, match = regexp.search(line) if match: foundLine = match.groups()[0] - if foundLine != '0': - pytest.fail("pre_scripts specification did not run the selected pre-script.") + if foundLine != "0": + pytest.fail( + "pre_scripts specification did not run the selected pre-script." + ) - regexp = re.compile(r'Post-Script test called ([0-9]*)') + regexp = re.compile(r"Post-Script test called ([0-9]*)") foundLine = None - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -128,20 +241,40 @@ def test_both_pre_and_post_scripts_run_before_and_after_model(self, global_data, match = regexp.search(line) if match: foundLine = match.groups()[0] - if foundLine != '0': - pytest.fail("post_scripts specification did not run the selected post-script.") + if foundLine != "0": + pytest.fail( + "post_scripts specification did not run the selected post-script." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_all_pre_scripts_run_in_order(self, global_data, clean_test_temp_files): """ all pre_scripts are run in order """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh', 'args':'1'}, {'path':'scripts/common/pre_scripts/pre_test.sh', 'args':'2'} ] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'pre_scripts':[{'path':'scripts/common/pre_scripts/pre_test.sh', 'args':'1'}, {'path':'scripts/common/pre_scripts/pre_test.sh', 'args':'2'} ] }\" " + ) - regexp = re.compile(r'Pre-Script test called ([0-9]*)') + regexp = re.compile(r"Pre-Script test called ([0-9]*)") foundLine = None pre_post_script_count = 0 - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -151,22 +284,45 @@ def test_all_pre_scripts_run_in_order(self, global_data, clean_test_temp_files): foundLine = match.groups()[0] pre_post_script_count += 1 if foundLine != str(pre_post_script_count): - pytest.fail("pre_scripts run in order. Did not find " + str(pre_post_script_count) ) + pytest.fail( + "pre_scripts run in order. Did not find " + + str(pre_post_script_count) + ) - if foundLine != '2': - pytest.fail("pre_scripts specification did not run the selected pre-script.") + if foundLine != "2": + pytest.fail( + "pre_scripts specification did not run the selected pre-script." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_all_post_scripts_run_in_order(self, global_data, clean_test_temp_files): """ - all post_scripts are run in order + all post_scripts are run in order """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh', 'args':'1'}, {'path':'scripts/common/post_scripts/post_test.sh', 'args':'2'} ] }\" ") + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'post_scripts':[{'path':'scripts/common/post_scripts/post_test.sh', 'args':'1'}, {'path':'scripts/common/post_scripts/post_test.sh', 'args':'2'} ] }\" " + ) - regexp = re.compile(r'Post-Script test called ([0-9]*)') + regexp = re.compile(r"Post-Script test called ([0-9]*)") foundLine = None pre_post_script_count = 0 - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -176,7 +332,12 @@ def test_all_post_scripts_run_in_order(self, global_data, clean_test_temp_files) foundLine = match.groups()[0] pre_post_script_count += 1 if foundLine != str(pre_post_script_count): - pytest.fail("post_scripts run in order. Did not find " + str(pre_post_script_count) ) + pytest.fail( + "post_scripts run in order. Did not find " + + str(pre_post_script_count) + ) - if foundLine != '2': - pytest.fail("post_scripts specification did not run the selected post-script.") + if foundLine != "2": + pytest.fail( + "post_scripts specification did not run the selected post-script." + ) diff --git a/tests/test_profiling.py b/tests/test_profiling.py index 85aca389..5df1a6c7 100644 --- a/tests/test_profiling.py +++ b/tests/test_profiling.py @@ -2,78 +2,169 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + # built-in modules import os import re import sys import csv + # third-party modules import pytest + # project modules -from .fixtures.utils import BASE_DIR, MODEL_DIR -from .fixtures.utils import global_data -from .fixtures.utils import clean_test_temp_files -from .fixtures.utils import is_nvidia +from .fixtures.utils import ( + BASE_DIR, + MODEL_DIR, + global_data, + clean_test_temp_files, + requires_gpu, + is_nvidia, +) class TestProfilingFunctionality: @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'rocprof_output']], indirect=True) - def test_rocprof_profiling_tool_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "rocprof_output"]], + indirect=True, + ) + def test_rocprof_profiling_tool_runs_correctly( + self, global_data, clean_test_temp_files + ): + """ + specifying a profiling tool runs respective pre and post scripts """ # canFail is set to True because rocProf mode is failing the full DLM run; this test will test if the correct output files are generated - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rocprof' }] }\" ", canFail=True) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rocprof' }] }\" ", + canFail=True, + ) - if not os.path.exists( os.path.join(BASE_DIR, "rocprof_output", "results.csv") ): - pytest.fail("rocprof_output/results.csv not generated with rocprof profiling run.") + if not os.path.exists(os.path.join(BASE_DIR, "rocprof_output", "results.csv")): + pytest.fail( + "rocprof_output/results.csv not generated with rocprof profiling run." + ) @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'rpd_output']], indirect=True) - def test_rpd_profiling_tool_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "rpd_output"]], + indirect=True, + ) + def test_rpd_profiling_tool_runs_correctly( + self, global_data, clean_test_temp_files + ): + """ + specifying a profiling tool runs respective pre and post scripts """ # canFail is set to True because rpd mode is failing the full DLM run; this test will test if the correct output files are generated - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rpd' }] }\" ", canFail=True) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rpd' }] }\" ", + canFail=True, + ) - if not os.path.exists( os.path.join(BASE_DIR, "rpd_output", "trace.rpd") ): + if not os.path.exists(os.path.join(BASE_DIR, "rpd_output", "trace.rpd")): pytest.fail("rpd_output/trace.rpd not generated with rpd profiling run.") - - @pytest.mark.skip(reason="Skipping this test for debugging purposes") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'gpu_info_power_profiler_output.csv']], indirect=True) - def test_gpu_info_power_profiling_tool_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts + + @requires_gpu("gpu_info_power_profiler requires GPU hardware") + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "gpu_info_power_profiler_output.csv"]], + indirect=True, + ) + def test_gpu_info_power_profiling_tool_runs_correctly( + self, global_data, clean_test_temp_files + ): + """ + specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'gpu_info_power_profiler' }] }\" ", canFail=False) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'gpu_info_power_profiler' }] }\" ", + canFail=False, + ) - if not os.path.exists( os.path.join(BASE_DIR, "gpu_info_power_profiler_output.csv") ): - pytest.fail("gpu_info_power_profiler_output.csv not generated with gpu_info_power_profiler run.") - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'gpu_info_vram_profiler_output.csv']], indirect=True) - def test_gpu_info_vram_profiling_tool_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts + if not os.path.exists( + os.path.join(BASE_DIR, "gpu_info_power_profiler_output.csv") + ): + pytest.fail( + "gpu_info_power_profiler_output.csv not generated with gpu_info_power_profiler run." + ) + + @requires_gpu("gpu_info_vram_profiler requires GPU hardware") + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "gpu_info_vram_profiler_output.csv"]], + indirect=True, + ) + def test_gpu_info_vram_profiling_tool_runs_correctly( + self, global_data, clean_test_temp_files + ): + """ + specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'gpu_info_vram_profiler' }] }\" ", canFail=False) + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'gpu_info_vram_profiler' }] }\" ", + canFail=False, + ) - if not os.path.exists( os.path.join(BASE_DIR, "gpu_info_vram_profiler_output.csv") ): - pytest.fail("gpu_info_vram_profiler_output.csv not generated with gpu_info_vram_profiler run.") + if not os.path.exists( + os.path.join(BASE_DIR, "gpu_info_vram_profiler_output.csv") + ): + pytest.fail( + "gpu_info_vram_profiler_output.csv not generated with gpu_info_vram_profiler run." + ) @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'library_trace.csv']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "library_trace.csv"]], + indirect=True, + ) def test_rocblas_trace_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rocblas_trace' }] }\" ", canFail=False) + specifying a profiling tool runs respective pre and post scripts + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rocblas_trace' }] }\" ", + canFail=False, + ) - regexp = re.compile(r'rocblas-bench') + regexp = re.compile(r"rocblas-bench") foundMatch = None - with open( os.path.join(BASE_DIR, "library_trace.csv" ), 'r') as f: + with open(os.path.join(BASE_DIR, "library_trace.csv"), "r") as f: while True: line = f.readline() if not line: @@ -82,19 +173,34 @@ def test_rocblas_trace_runs_correctly(self, global_data, clean_test_temp_files): if match: foundMatch = True if not foundMatch: - pytest.fail("could not detect rocblas-bench in output log file with rocblas trace tool.") + pytest.fail( + "could not detect rocblas-bench in output log file with rocblas trace tool." + ) @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'library_trace.csv']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "library_trace.csv"]], + indirect=True, + ) def test_tensile_trace_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'tensile_trace' }] }\" ", canFail=False) + specifying a profiling tool runs respective pre and post scripts + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'tensile_trace' }] }\" ", + canFail=False, + ) - regexp = re.compile(r'tensile,Cijk') + regexp = re.compile(r"tensile,Cijk") foundMatch = None - with open( os.path.join(BASE_DIR, "library_trace.csv" ), 'r') as f: + with open(os.path.join(BASE_DIR, "library_trace.csv"), "r") as f: while True: line = f.readline() if not line: @@ -103,19 +209,34 @@ def test_tensile_trace_runs_correctly(self, global_data, clean_test_temp_files): if match: foundMatch = True if not foundMatch: - pytest.fail("could not detect tensile call in output log file with tensile trace tool.") + pytest.fail( + "could not detect tensile call in output log file with tensile trace tool." + ) @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'library_trace.csv']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "library_trace.csv"]], + indirect=True, + ) def test_miopen_trace_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'miopen_trace' }] }\" ", canFail=False) + specifying a profiling tool runs respective pre and post scripts + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'miopen_trace' }] }\" ", + canFail=False, + ) - regexp = re.compile(r'MIOpenDriver') + regexp = re.compile(r"MIOpenDriver") foundMatch = None - with open( os.path.join(BASE_DIR, "library_trace.csv" ), 'r') as f: + with open(os.path.join(BASE_DIR, "library_trace.csv"), "r") as f: while True: line = f.readline() if not line: @@ -124,19 +245,40 @@ def test_miopen_trace_runs_correctly(self, global_data, clean_test_temp_files): if match: foundMatch = True if not foundMatch: - pytest.fail("could not detect miopen call in output log file with miopen trace tool.") + pytest.fail( + "could not detect miopen call in output log file with miopen trace tool." + ) @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_rccl_trace_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof_rccl --additional-context \"{ 'tools': [{ 'name': 'rccl_trace' }] }\" ", canFail=False) + specifying a profiling tool runs respective pre and post scripts + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof_rccl --additional-context \"{ 'tools': [{ 'name': 'rccl_trace' }] }\" ", + canFail=False, + ) - regexp = re.compile(r'NCCL INFO AllReduce:') + regexp = re.compile(r"NCCL INFO AllReduce:") foundMatch = None - with open( os.path.join(BASE_DIR, "dummy_prof_rccl_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_prof_rccl_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: @@ -145,27 +287,48 @@ def test_rccl_trace_runs_correctly(self, global_data, clean_test_temp_files): if match: foundMatch = True if not foundMatch: - pytest.fail("could not detect rccl call in output log file with rccl trace tool.") + pytest.fail( + "could not detect rccl call in output log file with rccl trace tool." + ) - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_toolA_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'tools': [{ 'name': 'test_tools_A' }] }\" ", canFail=False) + specifying a profiling tool runs respective pre and post scripts + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'tools': [{ 'name': 'test_tools_A' }] }\" ", + canFail=False, + ) - match_str_array = ['^pre_script A$', '^cmd_A$', '^post_script A$'] + match_str_array = ["^pre_script A$", "^cmd_A$", "^post_script A$"] match_str_idx = 0 regexp = re.compile(match_str_array[match_str_idx]) - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: break match = regexp.search(line) if match: - print("MATCH = ", line ) + print("MATCH = ", line) match_str_idx = match_str_idx + 1 if match_str_idx == len(match_str_array): break @@ -174,44 +337,88 @@ def test_toolA_runs_correctly(self, global_data, clean_test_temp_files): print("Matched up to ", match_str_idx) pytest.fail("all strings were not matched in toolA test.") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) def test_stackable_design_runs_correctly(self, global_data, clean_test_temp_files): - """ - specifying a profiling tool runs respective pre and post scripts """ - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'tools': [{ 'name': 'test_tools_A' }, { 'name': 'test_tools_B' } ] }\" ", canFail=False) + specifying a profiling tool runs respective pre and post scripts + """ + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy --additional-context \"{ 'tools': [{ 'name': 'test_tools_A' }, { 'name': 'test_tools_B' } ] }\" ", + canFail=False, + ) - match_str_array = [ '^pre_script B$', '^pre_script A$', '^cmd_B$', '^cmd_A$', '^post_script A$', '^post_script B$'] + match_str_array = [ + "^pre_script B$", + "^pre_script A$", + "^cmd_B$", + "^cmd_A$", + "^post_script A$", + "^post_script B$", + ] match_str_idx = 0 regexp = re.compile(match_str_array[match_str_idx]) - with open( os.path.join(BASE_DIR, "dummy_dummy.ubuntu." + ("amd" if not is_nvidia() else "nvidia") + ".live.log" ), 'r') as f: + with open( + os.path.join( + BASE_DIR, + "dummy_dummy.ubuntu." + + ("amd" if not is_nvidia() else "nvidia") + + ".live.log", + ), + "r", + ) as f: while True: line = f.readline() if not line: break match = regexp.search(line) if match: - print("MATCH = ", line ) + print("MATCH = ", line) match_str_idx = match_str_idx + 1 if match_str_idx == len(match_str_array): break regexp = re.compile(match_str_array[match_str_idx]) if match_str_idx != len(match_str_array): print("Matched up to ", match_str_idx) - pytest.fail("all strings were not matched in the stacked test using toolA and toolB.") - + pytest.fail( + "all strings were not matched in the stacked test using toolA and toolB." + ) @pytest.mark.skipif(is_nvidia(), reason="test does not run on NVIDIA") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html', 'rocprof_output']], indirect=True) - def test_can_change_default_behavior_of_profiling_tool_with_additionalContext(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", + [["perf.csv", "perf.html", "rocprof_output"]], + indirect=True, + ) + def test_can_change_default_behavior_of_profiling_tool_with_additionalContext( + self, global_data, clean_test_temp_files + ): """ default behavior of a profiling tool can be changed from additional-context """ # canFail is set to True because rocProf is failing; this test will test if the correct output files are generated - global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rocprof', 'cmd': 'rocprof --hsa-trace' }] }\" ", canFail=True) - - if not os.path.exists( os.path.join(BASE_DIR, "rocprof_output", "results.hsa_stats.csv") ): - pytest.fail("rocprof_output/results.hsa_stats.csv not generated with rocprof --hsa-trace profiling run.") - + global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_prof --additional-context \"{ 'tools': [{ 'name': 'rocprof', 'cmd': 'rocprof --hsa-trace' }] }\" ", + canFail=True, + ) + if not os.path.exists( + os.path.join(BASE_DIR, "rocprof_output", "results.hsa_stats.csv") + ): + pytest.fail( + "rocprof_output/results.hsa_stats.csv not generated with rocprof --hsa-trace profiling run." + ) diff --git a/tests/test_runner_errors.py b/tests/test_runner_errors.py new file mode 100644 index 00000000..1a60b4a1 --- /dev/null +++ b/tests/test_runner_errors.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +""" +Unit tests for MADEngine runner error standardization. + +Tests the unified error handling across all distributed runners without +requiring optional dependencies. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +# Add src to path for imports +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from madengine.core.errors import ( + ErrorCategory, + ConnectionError as MADConnectionError, + RunnerError, + create_error_context +) + + +class TestRunnerErrorConcepts: + """Test runner error concepts without requiring optional dependencies.""" + + def test_runner_error_base_class(self): + """Test that RunnerError base class works correctly.""" + context = create_error_context( + operation="runner_test", + component="TestRunner", + node_id="test-node" + ) + + error = RunnerError("Test runner error", context=context) + + # Test inheritance + assert isinstance(error, RunnerError) + assert error.category == ErrorCategory.RUNNER + assert error.recoverable is True + + # Test context + assert error.context.operation == "runner_test" + assert error.context.component == "TestRunner" + assert error.context.node_id == "test-node" + + def test_connection_error_for_ssh_like_scenarios(self): + """Test connection error that SSH runner would use.""" + context = create_error_context( + operation="ssh_connection", + component="SSHRunner", + node_id="remote-host", + additional_info={"error_type": "timeout"} + ) + + error = MADConnectionError( + "SSH timeout error on remote-host: Connection timed out", + context=context + ) + + # Test structure + assert isinstance(error, MADConnectionError) + assert error.category == ErrorCategory.CONNECTION + assert error.recoverable is True + assert error.context.node_id == "remote-host" + assert error.context.additional_info["error_type"] == "timeout" + + def test_runner_error_for_ansible_like_scenarios(self): + """Test runner error that Ansible runner would use.""" + context = create_error_context( + operation="ansible_execution", + component="AnsibleRunner", + file_path="/path/to/playbook.yml" + ) + + error = RunnerError( + "Ansible execution error in playbook.yml: Playbook failed", + context=context, + suggestions=["Check playbook syntax", "Verify inventory file"] + ) + + # Test structure + assert isinstance(error, RunnerError) + assert error.category == ErrorCategory.RUNNER + assert error.recoverable is True + assert error.context.file_path == "/path/to/playbook.yml" + assert len(error.suggestions) == 2 + + def test_runner_error_for_k8s_like_scenarios(self): + """Test runner error that Kubernetes runner would use.""" + context = create_error_context( + operation="kubernetes_execution", + component="KubernetesRunner", + additional_info={ + "resource_type": "Pod", + "resource_name": "madengine-job-001" + } + ) + + error = RunnerError( + "Kubernetes error in Pod/madengine-job-001: Pod creation failed", + context=context + ) + + # Test structure + assert isinstance(error, RunnerError) + assert error.category == ErrorCategory.RUNNER + assert error.recoverable is True + assert error.context.additional_info["resource_type"] == "Pod" + assert error.context.additional_info["resource_name"] == "madengine-job-001" + + +class TestRunnerErrorHandling: + """Test unified error handling for runner scenarios.""" + + def test_all_runner_scenarios_use_unified_system(self): + """Test that all runner scenarios can use the unified error system.""" + from madengine.core.errors import ErrorHandler + from rich.console import Console + + mock_console = Mock(spec=Console) + handler = ErrorHandler(console=mock_console) + + # Create different runner-like errors + ssh_error = MADConnectionError( + "SSH connection failed", + context=create_error_context( + operation="ssh_connection", + component="SSHRunner", + node_id="host1" + ) + ) + + ansible_error = RunnerError( + "Ansible playbook failed", + context=create_error_context( + operation="ansible_execution", + component="AnsibleRunner", + file_path="/playbook.yml" + ) + ) + + k8s_error = RunnerError( + "Kubernetes pod failed", + context=create_error_context( + operation="kubernetes_execution", + component="KubernetesRunner" + ) + ) + + errors = [ssh_error, ansible_error, k8s_error] + + # All should be handleable by unified handler + for error in errors: + mock_console.reset_mock() + handler.handle_error(error) + + # Verify error was handled + mock_console.print.assert_called_once() + + # Verify Rich panel was created + call_args = mock_console.print.call_args[0] + panel = call_args[0] + assert hasattr(panel, 'title') + + def test_runner_error_context_consistency(self): + """Test that all runner errors have consistent context structure.""" + runner_scenarios = [ + ("ssh_connection", "SSHRunner", "host1"), + ("ansible_execution", "AnsibleRunner", "host2"), + ("kubernetes_execution", "KubernetesRunner", "cluster1") + ] + + for operation, component, node_id in runner_scenarios: + context = create_error_context( + operation=operation, + component=component, + node_id=node_id + ) + + if "connection" in operation: + error = MADConnectionError("Connection failed", context=context) + else: + error = RunnerError("Execution failed", context=context) + + # All should have consistent context structure + assert error.context.operation == operation + assert error.context.component == component + assert error.context.node_id == node_id + assert error.recoverable is True + + def test_runner_error_suggestions_work(self): + """Test that runner errors can include helpful suggestions.""" + suggestions = [ + "Check network connectivity", + "Verify authentication credentials", + "Try running with --verbose flag" + ] + + error = RunnerError( + "Distributed execution failed", + context=create_error_context( + operation="distributed_execution", + component="GenericRunner" + ), + suggestions=suggestions + ) + + assert error.suggestions == suggestions + + # Test that suggestions are displayed + from madengine.core.errors import ErrorHandler + mock_console = Mock() + handler = ErrorHandler(console=mock_console) + handler.handle_error(error) + + # Should have called print to display error with suggestions + mock_console.print.assert_called_once() + + +class TestActualRunnerIntegration: + """Test integration with actual runner modules where possible.""" + + def test_ssh_runner_error_class_if_available(self): + """Test SSH runner error class if the module can be imported.""" + try: + # Try to import without optional dependencies + with patch('paramiko.SSHClient'), patch('scp.SCPClient'): + from madengine.runners.ssh_runner import SSHConnectionError + + error = SSHConnectionError("test-host", "connection", "failed") + + # Should inherit from unified error system + assert isinstance(error, MADConnectionError) + assert error.hostname == "test-host" + assert error.error_type == "connection" + + except ImportError: + # Expected when dependencies aren't installed + pytest.skip("SSH runner dependencies not available") + + def test_ansible_runner_error_class_if_available(self): + """Test Ansible runner error class if the module can be imported.""" + try: + # Try to import without optional dependencies + with patch('ansible_runner.run'): + from madengine.runners.ansible_runner import AnsibleExecutionError + + error = AnsibleExecutionError("failed", "/playbook.yml") + + # Should inherit from unified error system + assert isinstance(error, RunnerError) + assert error.playbook_path == "/playbook.yml" + + except ImportError: + # Expected when dependencies aren't installed + pytest.skip("Ansible runner dependencies not available") + + def test_k8s_runner_error_class_if_available(self): + """Test Kubernetes runner error class if the module can be imported.""" + try: + # Try to import without optional dependencies + with patch('kubernetes.client'), patch('kubernetes.config'): + from madengine.runners.k8s_runner import KubernetesExecutionError + + error = KubernetesExecutionError("failed", "Pod", "test-pod") + + # Should inherit from unified error system + assert isinstance(error, RunnerError) + assert error.resource_type == "Pod" + assert error.resource_name == "test-pod" + + except ImportError: + # Expected when dependencies aren't installed + pytest.skip("Kubernetes runner dependencies not available") + + +class TestImportErrorHandling: + """Test that import errors are handled gracefully.""" + + def test_import_error_messages_are_informative(self): + """Test that import errors provide helpful information.""" + # Test the actual import behavior when dependencies are missing + + # SSH runner + with pytest.raises(ImportError) as exc_info: + import madengine.runners.ssh_runner + + error_msg = str(exc_info.value) + assert "SSH runner requires" in error_msg or "No module named" in error_msg + + # Ansible runner + with pytest.raises(ImportError) as exc_info: + import madengine.runners.ansible_runner + + error_msg = str(exc_info.value) + assert "Ansible runner requires" in error_msg or "No module named" in error_msg + + # Kubernetes runner + with pytest.raises(ImportError) as exc_info: + import madengine.runners.k8s_runner + + error_msg = str(exc_info.value) + assert "Kubernetes runner requires" in error_msg or "No module named" in error_msg + + def test_runner_factory_handles_missing_runners(self): + """Test that runner factory gracefully handles missing optional runners.""" + try: + from madengine.runners.factory import RunnerFactory + + # Should not crash even if optional runners aren't available + # This tests the import warnings but doesn't require the runners to work + assert RunnerFactory is not None + + except ImportError as e: + # If the factory itself can't be imported, that's a different issue + pytest.fail(f"Runner factory should be importable: {e}") + + +class TestErrorSystemRobustness: + """Test that the error system is robust to various scenarios.""" + + def test_error_system_works_without_optional_modules(self): + """Test that core error system works even without optional modules.""" + from madengine.core.errors import ( + ErrorHandler, RunnerError, ConnectionError, ValidationError + ) + + # Should work without any runner modules + mock_console = Mock() + handler = ErrorHandler(console=mock_console) + + error = ValidationError("Test error") + handler.handle_error(error) + + mock_console.print.assert_called_once() + + def test_error_context_serialization_robustness(self): + """Test that error context serialization handles various data types.""" + import json + + context = create_error_context( + operation="robust_test", + component="TestComponent", + additional_info={ + "string": "value", + "number": 42, + "boolean": True, + "none": None, + "list": [1, 2, 3], + "dict": {"nested": "value"} + } + ) + + error = RunnerError("Test error", context=context) + + # Should be serializable + context_dict = error.context.__dict__ + json_str = json.dumps(context_dict, default=str) + + # Should contain all the data + assert "robust_test" in json_str + assert "TestComponent" in json_str + assert "42" in json_str + assert "nested" in json_str + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_runners_base.py b/tests/test_runners_base.py new file mode 100644 index 00000000..c7c70b8f --- /dev/null +++ b/tests/test_runners_base.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +""" +Tests for the distributed runner base classes and factory. +""" + +import json +import os +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +import pytest + +from madengine.runners.base import ( + NodeConfig, + WorkloadSpec, + ExecutionResult, + DistributedResult, + BaseDistributedRunner, +) +from madengine.runners.factory import RunnerFactory + + +class TestNodeConfig: + """Test NodeConfig dataclass.""" + + def test_valid_node_config(self): + """Test valid node configuration.""" + node = NodeConfig( + hostname="test-node", + address="192.168.1.100", + port=22, + username="root", + gpu_count=4, + gpu_vendor="AMD", + ) + + assert node.hostname == "test-node" + assert node.address == "192.168.1.100" + assert node.port == 22 + assert node.username == "root" + assert node.gpu_count == 4 + assert node.gpu_vendor == "AMD" + + def test_invalid_gpu_vendor(self): + """Test invalid GPU vendor raises ValueError.""" + with pytest.raises(ValueError, match="Invalid gpu_vendor"): + NodeConfig( + hostname="test-node", address="192.168.1.100", gpu_vendor="INVALID" + ) + + def test_missing_required_fields(self): + """Test missing required fields raises ValueError.""" + with pytest.raises(ValueError, match="hostname and address are required"): + NodeConfig(hostname="", address="192.168.1.100") + + +class TestWorkloadSpec: + """Test WorkloadSpec dataclass.""" + + def test_valid_workload_spec(self): + """Test valid workload specification.""" + # Create temporary manifest file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump({"built_images": {}}, f) + manifest_file = f.name + + try: + workload = WorkloadSpec( + model_tags=["dummy"], + manifest_file=manifest_file, + timeout=3600, + registry="localhost:5000", + ) + + assert workload.model_tags == ["dummy"] + assert workload.manifest_file == manifest_file + assert workload.timeout == 3600 + assert workload.registry == "localhost:5000" + finally: + os.unlink(manifest_file) + + def test_empty_model_tags(self): + """Test empty model tags raises ValueError.""" + with pytest.raises(ValueError, match="model_tags cannot be empty"): + WorkloadSpec(model_tags=[], manifest_file="nonexistent.json") + + def test_missing_manifest_file(self): + """Test missing manifest file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="Manifest file not found"): + WorkloadSpec(model_tags=["dummy"], manifest_file="nonexistent.json") + + +class TestExecutionResult: + """Test ExecutionResult dataclass.""" + + def test_execution_result_to_dict(self): + """Test ExecutionResult to_dict method.""" + result = ExecutionResult( + node_id="test-node", + model_tag="dummy", + status="SUCCESS", + duration=123.45, + performance_metrics={"fps": 30.5}, + error_message=None, + ) + + result_dict = result.to_dict() + + assert result_dict["node_id"] == "test-node" + assert result_dict["model_tag"] == "dummy" + assert result_dict["status"] == "SUCCESS" + assert result_dict["duration"] == 123.45 + assert result_dict["performance_metrics"] == {"fps": 30.5} + assert result_dict["error_message"] is None + + +class TestDistributedResult: + """Test DistributedResult dataclass.""" + + def test_add_successful_result(self): + """Test adding successful result.""" + dist_result = DistributedResult( + total_nodes=2, + successful_executions=0, + failed_executions=0, + total_duration=0.0, + ) + + result = ExecutionResult( + node_id="test-node", model_tag="dummy", status="SUCCESS", duration=100.0 + ) + + dist_result.add_result(result) + + assert dist_result.successful_executions == 1 + assert dist_result.failed_executions == 0 + assert len(dist_result.node_results) == 1 + + def test_add_failed_result(self): + """Test adding failed result.""" + dist_result = DistributedResult( + total_nodes=2, + successful_executions=0, + failed_executions=0, + total_duration=0.0, + ) + + result = ExecutionResult( + node_id="test-node", + model_tag="dummy", + status="FAILURE", + duration=100.0, + error_message="Test error", + ) + + dist_result.add_result(result) + + assert dist_result.successful_executions == 0 + assert dist_result.failed_executions == 1 + assert len(dist_result.node_results) == 1 + + +class MockDistributedRunner(BaseDistributedRunner): + """Mock implementation of BaseDistributedRunner for testing.""" + + def setup_infrastructure(self, workload): + return True + + def execute_workload(self, workload): + result = DistributedResult( + total_nodes=len(self.nodes), + successful_executions=0, + failed_executions=0, + total_duration=0.0, + ) + + for node in self.nodes: + for model_tag in workload.model_tags: + result.add_result( + ExecutionResult( + node_id=node.hostname, + model_tag=model_tag, + status="SUCCESS", + duration=100.0, + ) + ) + + return result + + def cleanup_infrastructure(self, workload): + return True + + +class TestBaseDistributedRunner: + """Test BaseDistributedRunner abstract base class.""" + + def test_load_json_inventory(self): + """Test loading JSON inventory file.""" + inventory_data = { + "nodes": [ + {"hostname": "node1", "address": "192.168.1.101", "gpu_vendor": "AMD"}, + { + "hostname": "node2", + "address": "192.168.1.102", + "gpu_vendor": "NVIDIA", + }, + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(inventory_data, f) + inventory_file = f.name + + try: + runner = MockDistributedRunner(inventory_file) + + assert len(runner.nodes) == 2 + assert runner.nodes[0].hostname == "node1" + assert runner.nodes[0].gpu_vendor == "AMD" + assert runner.nodes[1].hostname == "node2" + assert runner.nodes[1].gpu_vendor == "NVIDIA" + finally: + os.unlink(inventory_file) + + def test_load_yaml_inventory(self): + """Test loading YAML inventory file.""" + inventory_content = """ + gpu_nodes: + - hostname: node1 + address: 192.168.1.101 + gpu_vendor: AMD + - hostname: node2 + address: 192.168.1.102 + gpu_vendor: NVIDIA + """ + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: + f.write(inventory_content) + inventory_file = f.name + + try: + runner = MockDistributedRunner(inventory_file) + + assert len(runner.nodes) == 2 + assert runner.nodes[0].hostname == "node1" + assert runner.nodes[0].gpu_vendor == "AMD" + assert runner.nodes[1].hostname == "node2" + assert runner.nodes[1].gpu_vendor == "NVIDIA" + finally: + os.unlink(inventory_file) + + def test_filter_nodes(self): + """Test node filtering functionality.""" + inventory_data = { + "nodes": [ + { + "hostname": "amd-node", + "address": "192.168.1.101", + "gpu_vendor": "AMD", + "labels": {"datacenter": "dc1"}, + }, + { + "hostname": "nvidia-node", + "address": "192.168.1.102", + "gpu_vendor": "NVIDIA", + "labels": {"datacenter": "dc2"}, + }, + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(inventory_data, f) + inventory_file = f.name + + try: + runner = MockDistributedRunner(inventory_file) + + # Test GPU vendor filtering + amd_nodes = runner.filter_nodes({"gpu_vendor": "AMD"}) + assert len(amd_nodes) == 1 + assert amd_nodes[0].hostname == "amd-node" + + # Test label filtering + dc1_nodes = runner.filter_nodes({"datacenter": "dc1"}) + assert len(dc1_nodes) == 1 + assert dc1_nodes[0].hostname == "amd-node" + finally: + os.unlink(inventory_file) + + def test_validate_workload(self): + """Test workload validation.""" + inventory_data = { + "nodes": [ + {"hostname": "node1", "address": "192.168.1.101", "gpu_vendor": "AMD"} + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(inventory_data, f) + inventory_file = f.name + + # Create manifest file + manifest_data = {"built_images": {"dummy": {}}} + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(manifest_data, f) + manifest_file = f.name + + try: + runner = MockDistributedRunner(inventory_file) + + workload = WorkloadSpec(model_tags=["dummy"], manifest_file=manifest_file) + + assert runner.validate_workload(workload) == True + finally: + os.unlink(inventory_file) + os.unlink(manifest_file) + + def test_run_workflow(self): + """Test complete run workflow.""" + inventory_data = { + "nodes": [ + {"hostname": "node1", "address": "192.168.1.101", "gpu_vendor": "AMD"} + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(inventory_data, f) + inventory_file = f.name + + # Create manifest file + manifest_data = {"built_images": {"dummy": {}}} + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(manifest_data, f) + manifest_file = f.name + + try: + runner = MockDistributedRunner(inventory_file) + + workload = WorkloadSpec(model_tags=["dummy"], manifest_file=manifest_file) + + result = runner.run(workload) + + assert result.total_nodes == 1 + assert result.successful_executions == 1 + assert result.failed_executions == 0 + assert len(result.node_results) == 1 + assert result.node_results[0].status == "SUCCESS" + finally: + os.unlink(inventory_file) + os.unlink(manifest_file) + + +class TestRunnerFactory: + """Test RunnerFactory class.""" + + def test_register_and_create_runner(self): + """Test registering and creating a runner.""" + # Register mock runner + RunnerFactory.register_runner("mock", MockDistributedRunner) + + # Create temporary inventory + inventory_data = { + "nodes": [ + {"hostname": "node1", "address": "192.168.1.101", "gpu_vendor": "AMD"} + ] + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(inventory_data, f) + inventory_file = f.name + + try: + # Create runner instance + runner = RunnerFactory.create_runner("mock", inventory_path=inventory_file) + + assert isinstance(runner, MockDistributedRunner) + assert len(runner.nodes) == 1 + assert runner.nodes[0].hostname == "node1" + finally: + os.unlink(inventory_file) + + def test_unknown_runner_type(self): + """Test creating unknown runner type raises ValueError.""" + with pytest.raises(ValueError, match="Unknown runner type"): + RunnerFactory.create_runner("unknown", inventory_path="test.json") + + def test_get_available_runners(self): + """Test getting available runner types.""" + available_runners = RunnerFactory.get_available_runners() + + # Should include default runners if dependencies are available + assert isinstance(available_runners, list) + assert len(available_runners) > 0 diff --git a/tests/test_tags.py b/tests/test_tags.py index 39eecaf3..df37a2fc 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -1,6 +1,7 @@ """ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ + import pytest import os import sys @@ -10,14 +11,27 @@ from .fixtures.utils import global_data from .fixtures.utils import clean_test_temp_files + class TestTagsFunctionality: - - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_can_select_model_subset_with_commandline_tag_argument(self, global_data, clean_test_temp_files): + + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_can_select_model_subset_with_commandline_tag_argument( + self, global_data, clean_test_temp_files + ): """ can select subset of models with tag with command-line argument """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_group_1") + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_group_1" + ) if "Running model dummy" not in output: pytest.fail("dummy tag not selected with commandline --tags argument") @@ -25,12 +39,24 @@ def test_can_select_model_subset_with_commandline_tag_argument(self, global_data if "Running model dummy2" not in output: pytest.fail("dummy2 tag not selected with commandline --tags argument") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_all_models_matching_any_tag_selected_with_multiple_tags(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_all_models_matching_any_tag_selected_with_multiple_tags( + self, global_data, clean_test_temp_files + ): """ if multiple tags are specified, all models that match any tag will be selected """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy_group_1 dummy_group_2") + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy_group_1 dummy_group_2" + ) if "Running model dummy" not in output: pytest.fail("dummy tag not selected with commandline --tags argument") @@ -41,13 +67,24 @@ def test_all_models_matching_any_tag_selected_with_multiple_tags(self, global_da if "Running model dummy3" not in output: pytest.fail("dummy3 tag not selected with commandline --tags argument") - @pytest.mark.parametrize('clean_test_temp_files', [['perf.csv', 'perf.html']], indirect=True) - def test_model_names_are_automatically_tags(self, global_data, clean_test_temp_files): + @pytest.mark.parametrize( + "clean_test_temp_files", [["perf.csv", "perf.html"]], indirect=True + ) + def test_model_names_are_automatically_tags( + self, global_data, clean_test_temp_files + ): """ - Each model name is automatically a tag + Each model name is automatically a tag """ - output = global_data['console'].sh("cd " + BASE_DIR + "; " + "MODEL_DIR=" + MODEL_DIR + " " + "python3 src/madengine/mad.py run --tags dummy") + output = global_data["console"].sh( + "cd " + + BASE_DIR + + "; " + + "MODEL_DIR=" + + MODEL_DIR + + " " + + "python3 src/madengine/mad.py run --tags dummy" + ) if "Running model dummy" not in output: pytest.fail("dummy tag not selected with commandline --tags argument") - diff --git a/tests/test_templates.py b/tests/test_templates.py new file mode 100644 index 00000000..d6c57f9b --- /dev/null +++ b/tests/test_templates.py @@ -0,0 +1,359 @@ +"""Tests for the template generator module. + +This module tests the Jinja2-based template generation functionality +for Ansible playbooks and Kubernetes manifests. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import os +import json +import tempfile +import shutil +import unittest +from unittest.mock import patch, mock_open, MagicMock +import pytest + +from madengine.runners.template_generator import ( + TemplateGenerator, + create_ansible_playbook, + create_kubernetes_manifests, +) + + +class TestTemplateGenerator(unittest.TestCase): + """Test the template generator functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.template_dir = os.path.join(self.temp_dir, "templates") + self.values_dir = os.path.join(self.temp_dir, "values") + + # Create template directories + os.makedirs(os.path.join(self.template_dir, "ansible")) + os.makedirs(os.path.join(self.template_dir, "k8s")) + os.makedirs(self.values_dir) + + # Create sample templates + self.create_sample_templates() + self.create_sample_values() + + # Create sample manifest + self.manifest_data = { + "built_images": { + "dummy_model": { + "docker_image": "dummy:latest", + "registry_image": "registry.example.com/dummy:latest", + "build_time": 120.5, + } + }, + "built_models": { + "dummy_model": { + "name": "dummy", + "dockerfile": "docker/dummy.Dockerfile", + "scripts": "scripts/dummy/run.sh", + } + }, + "context": { + "gpu_vendor": "nvidia", + "docker_build_arg": {"MAD_SYSTEM_GPU_ARCHITECTURE": "gfx908"}, + "docker_env_vars": {"CUDA_VISIBLE_DEVICES": "0"}, + "docker_mounts": {"/tmp": "/tmp"}, + "docker_gpus": "all", + }, + "registry": "registry.example.com", + "build_timestamp": "2023-01-01T00:00:00Z", + } + + self.manifest_file = os.path.join(self.temp_dir, "build_manifest.json") + with open(self.manifest_file, "w") as f: + json.dump(self.manifest_data, f) + + def tearDown(self): + """Clean up test fixtures.""" + shutil.rmtree(self.temp_dir) + + def create_sample_templates(self): + """Create sample template files.""" + # Ansible playbook template + ansible_template = """--- +- name: MADEngine Test Playbook + hosts: {{ ansible.target_hosts | default('test_nodes') }} + vars: + registry: "{{ registry | default('') }}" + gpu_vendor: "{{ gpu_vendor | default('') }}" + tasks: + - name: Test task + debug: + msg: "Environment: {{ environment | default('test') }}" +""" + + with open( + os.path.join(self.template_dir, "ansible", "playbook.yml.j2"), "w" + ) as f: + f.write(ansible_template) + + # K8s namespace template + k8s_namespace = """apiVersion: v1 +kind: Namespace +metadata: + name: {{ k8s.namespace | default('madengine-test') }} + labels: + environment: {{ environment | default('test') }} +""" + + with open( + os.path.join(self.template_dir, "k8s", "namespace.yaml.j2"), "w" + ) as f: + f.write(k8s_namespace) + + def create_sample_values(self): + """Create sample values files.""" + default_values = { + "environment": "test", + "ansible": {"target_hosts": "test_nodes", "become": False}, + "k8s": {"namespace": "madengine-test"}, + "execution": {"timeout": 1800, "keep_alive": False}, + } + + with open(os.path.join(self.values_dir, "default.yaml"), "w") as f: + import yaml + + yaml.dump(default_values, f) + + dev_values = { + "environment": "dev", + "ansible": {"target_hosts": "dev_nodes", "become": True}, + "k8s": {"namespace": "madengine-dev"}, + "execution": {"timeout": 3600, "keep_alive": True}, + } + + with open(os.path.join(self.values_dir, "dev.yaml"), "w") as f: + yaml.dump(dev_values, f) + + def test_template_generator_initialization(self): + """Test template generator initialization.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + assert str(generator.template_dir) == self.template_dir + assert str(generator.values_dir) == self.values_dir + assert generator.env is not None + + def test_load_values_default(self): + """Test loading default values.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + values = generator.load_values("default") + + assert values["environment"] == "test" + assert values["ansible"]["target_hosts"] == "test_nodes" + assert values["k8s"]["namespace"] == "madengine-test" + + def test_load_values_dev(self): + """Test loading dev values.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + values = generator.load_values("dev") + + assert values["environment"] == "dev" + assert values["ansible"]["target_hosts"] == "dev_nodes" + assert values["k8s"]["namespace"] == "madengine-dev" + + def test_load_values_nonexistent(self): + """Test loading non-existent values file.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + with pytest.raises(FileNotFoundError): + generator.load_values("nonexistent") + + def test_merge_values(self): + """Test merging values with manifest data.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + base_values = generator.load_values("default") + + merged = generator.merge_values(base_values, self.manifest_data) + + assert merged["environment"] == "test" + assert merged["registry"] == "registry.example.com" + assert merged["gpu_vendor"] == "nvidia" + assert merged["images"]["dummy_model"]["docker_image"] == "dummy:latest" + assert "generation" in merged + assert "timestamp" in merged["generation"] + + def test_generate_ansible_playbook(self): + """Test generating Ansible playbook.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + output_file = os.path.join(self.temp_dir, "test_playbook.yml") + content = generator.generate_ansible_playbook( + self.manifest_file, "default", output_file + ) + + assert os.path.exists(output_file) + assert "MADEngine Test Playbook" in content + assert "test_nodes" in content + assert "registry.example.com" in content + assert "nvidia" in content + + def test_generate_kubernetes_manifests(self): + """Test generating Kubernetes manifests.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + output_dir = os.path.join(self.temp_dir, "k8s_output") + generated_files = generator.generate_kubernetes_manifests( + self.manifest_file, "default", output_dir + ) + + assert os.path.exists(output_dir) + assert len(generated_files) > 0 + + # Check namespace file + namespace_file = os.path.join(output_dir, "namespace.yaml") + if os.path.exists(namespace_file): + with open(namespace_file, "r") as f: + content = f.read() + assert "madengine-test" in content + assert "environment: test" in content + + def test_list_templates(self): + """Test listing available templates.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + templates = generator.list_templates() + + assert "ansible" in templates + assert "k8s" in templates + assert "playbook.yml.j2" in templates["ansible"] + assert "namespace.yaml.j2" in templates["k8s"] + + def test_validate_template_valid(self): + """Test validating a valid template.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + # Create a simple valid template + template_content = "Hello {{ name | default('World') }}!" + template_file = os.path.join(self.template_dir, "test_template.j2") + with open(template_file, "w") as f: + f.write(template_content) + + is_valid = generator.validate_template("test_template.j2") + assert is_valid is True + + def test_validate_template_invalid(self): + """Test validating an invalid template.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + # Create an invalid template + template_content = "Hello {{ name | invalid_filter }}!" + template_file = os.path.join(self.template_dir, "invalid_template.j2") + with open(template_file, "w") as f: + f.write(template_content) + + is_valid = generator.validate_template("invalid_template.j2") + assert is_valid is False + + def test_custom_filters(self): + """Test custom Jinja2 filters.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + # Test to_yaml filter + template = generator.env.from_string("{{ data | to_yaml }}") + result = template.render(data={"key": "value"}) + assert "key: value" in result + + # Test to_json filter (check for JSON structure, allowing for HTML escaping) + template = generator.env.from_string("{{ data | to_json }}") + result = template.render(data={"key": "value"}) + assert "key" in result and "value" in result + + # Test basename filter + template = generator.env.from_string("{{ path | basename }}") + result = template.render(path="/path/to/file.txt") + assert result == "file.txt" + + def test_generate_with_dev_environment(self): + """Test generation with dev environment.""" + generator = TemplateGenerator(self.template_dir, self.values_dir) + + output_file = os.path.join(self.temp_dir, "dev_playbook.yml") + content = generator.generate_ansible_playbook( + self.manifest_file, "dev", output_file + ) + + assert "dev_nodes" in content + assert "registry.example.com" in content + + +class TestBackwardCompatibility(unittest.TestCase): + """Test backward compatibility functions.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.manifest_file = os.path.join(self.temp_dir, "build_manifest.json") + + # Create sample manifest + manifest_data = { + "built_images": {"dummy": {"docker_image": "dummy:latest"}}, + "context": {"gpu_vendor": "nvidia"}, + "registry": "localhost:5000", + } + + with open(self.manifest_file, "w") as f: + json.dump(manifest_data, f) + + def tearDown(self): + """Clean up test fixtures.""" + shutil.rmtree(self.temp_dir) + + @patch("madengine.runners.template_generator.TemplateGenerator") + def test_create_ansible_playbook_backward_compatibility(self, mock_generator_class): + """Test backward compatibility for create_ansible_playbook.""" + mock_generator = MagicMock() + mock_generator_class.return_value = mock_generator + + # Change to temp directory + original_cwd = os.getcwd() + os.chdir(self.temp_dir) + + try: + create_ansible_playbook( + manifest_file=self.manifest_file, + environment="test", + playbook_file="test.yml", + ) + + mock_generator_class.assert_called_once() + mock_generator.generate_ansible_playbook.assert_called_once_with( + self.manifest_file, "test", "test.yml" + ) + finally: + os.chdir(original_cwd) + + @patch("madengine.runners.template_generator.TemplateGenerator") + def test_create_kubernetes_manifests_backward_compatibility( + self, mock_generator_class + ): + """Test backward compatibility for create_kubernetes_manifests.""" + mock_generator = MagicMock() + mock_generator_class.return_value = mock_generator + + # Change to temp directory + original_cwd = os.getcwd() + os.chdir(self.temp_dir) + + try: + create_kubernetes_manifests( + manifest_file=self.manifest_file, + environment="test", + output_dir="test-k8s", + ) + + mock_generator_class.assert_called_once() + mock_generator.generate_kubernetes_manifests.assert_called_once_with( + self.manifest_file, "test", "test-k8s" + ) + finally: + os.chdir(original_cwd) + + +if __name__ == "__main__": + unittest.main()