diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 19d61d4..fb45114 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install black isort flake8 mypy + pip install "black>=26,<27" isort flake8 mypy - name: Check formatting with black run: black --check --diff . @@ -32,7 +32,7 @@ jobs: run: isort --check-only --diff . - name: Lint with flake8 - run: flake8 spkmc tests --max-line-length=100 --extend-ignore=E203,W503 + run: flake8 spkmc tests --max-line-length=100 --extend-ignore=E203,W503,E741,E501,E402,F401,F841,B007,E722,B001,F811,F541,B028,E266,F821 - name: Type check with mypy run: mypy spkmc --ignore-missing-imports --no-error-summary || true @@ -44,17 +44,13 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] exclude: # Reduce matrix for faster CI - test all versions on Ubuntu, latest on others - - os: macos-latest - python-version: "3.8" - os: macos-latest python-version: "3.9" - os: macos-latest python-version: "3.10" - - os: windows-latest - python-version: "3.8" - os: windows-latest python-version: "3.9" - os: windows-latest @@ -76,7 +72,7 @@ jobs: pip install -e ".[dev]" - name: Run tests with coverage - run: pytest --cov=spkmc --cov-report=xml --cov-report=term-missing -v + run: pytest --cov=spkmc --cov-report=xml --cov-report=term-missing -v --ignore=tests/e2e - name: Upload coverage to Codecov if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' @@ -111,3 +107,35 @@ jobs: - name: Check package with twine run: twine check dist/* + + e2e: + name: E2E Tests + runs-on: ubuntu-latest + needs: [lint, test] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,e2e]" + + - name: Install Playwright browsers + run: playwright install --with-deps chromium firefox + + - name: Run E2E tests + run: pytest tests/e2e/ -v --tb=short --browser chromium --browser firefox + + - name: Upload artifacts on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: playwright-artifacts + path: test-results/ diff --git a/.gitignore b/.gitignore index 37ece35..92732d8 100644 --- a/.gitignore +++ b/.gitignore @@ -130,8 +130,21 @@ results/ # Experiments (user-created, not tracked in git) experiments/ +!tests/e2e/fixtures/experiments/ # AI-generated analysis files cross_experiment_analysis.md **/.DS_Store + +# Claude Code +CLAUDE.md +.claude/ + +# Web interface runtime +.spkmc_web/ +.streamlit/ + +# Temporary files +tmp_*.html +tmp_*.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 778cfe9..639765e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: debug-statements - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 26.1.0 hooks: - id: black args: ['--line-length=100'] diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index f287640..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,228 +0,0 @@ -# CLAUDE.md - SPKMC Project Guide - -## Project Overview - -**SPKMC** (Shortest Path Kinetic Monte Carlo) is a Python implementation for simulating epidemic propagation on networks using the SIR (Susceptible-Infected-Recovered) model. The algorithm uses shortest path calculations on weighted graphs to efficiently model disease spread dynamics. - -**Version:** 1.0.0 -**Python:** 3.8+ -**License:** MIT - -## Quick Commands - -```bash -# Run tests -pytest - -# Run tests with coverage -pytest --cov=spkmc --cov-report=term-missing - -# Run a simulation -python -m spkmc.cli run -n er -d gamma --nodes 1000 --samples 50 - -# Run batch scenarios (interactive experiment menu) -python -m spkmc.cli batch - -# List saved results -python -m spkmc.cli info --list -``` - -## Project Architecture - -``` -spkmc/ -├── spkmc/ # Main package -│ ├── cli/ # Command-line interface (Click-based) -│ │ ├── commands.py # CLI commands: run, plot, info, compare, batch -│ │ ├── validators.py # Parameter validation callbacks -│ │ └── formatting.py # Rich terminal output formatting -│ ├── core/ # Core algorithm implementation -│ │ ├── simulation.py # SPKMC class - main algorithm -│ │ ├── distributions.py # Gamma & Exponential distributions -│ │ └── networks.py # NetworkFactory for graph creation -│ ├── io/ # Input/output operations -│ │ ├── export.py # Multi-format export (CSV, JSON, Excel, MD, HTML) -│ │ └── results.py # Result persistence and loading -│ ├── visualization/ # Plotting -│ │ └── plots.py # Matplotlib-based SIR curve visualization -│ └── utils/ # Utilities -│ └── numba_utils.py # Numba JIT-optimized functions -├── tests/ # Comprehensive test suite -├── docs/ # Documentation -├── examples/ # Usage examples -└── experiments/ # Structured experiments with data.json configs -``` - -## Core Concepts - -### Network Types -- `er` - Erdos-Renyi (random networks) -- `sf` - Scale-Free Networks (power-law degree distribution) -- `cg` - Complete Graph (fully connected) -- `rrn` - Random Regular Network (uniform degree) - -### Distributions -- `gamma` - Gamma distribution for recovery times (parameters: shape, scale, lambda) -- `exponential` - Exponential distribution (parameters: mu, lambda) - -### Algorithm Flow -1. Create network topology via `NetworkFactory` -2. Sample recovery times from distribution -3. Compute infection transmission times for edges -4. Run Dijkstra's algorithm on sparse weighted graph -5. Classify node states (S, I, R) at each timestep -6. Aggregate statistics across samples - -## Key Patterns - -### Factory Pattern -```python -# Network creation -network = NetworkFactory.create_network(network_type, nodes, k_avg, exponent) - -# Distribution creation -dist = create_distribution(dist_type, shape=shape, scale=scale, mu=mu, lambda_param=lambda_param) -``` - -### Abstract Base Class for Distributions -```python -class Distribution(ABC): - @abstractmethod - def get_recovery_weights(self, num_nodes: int) -> np.ndarray: ... - - @abstractmethod - def get_infection_times(self, weights: np.ndarray) -> np.ndarray: ... -``` - -### Numba JIT for Performance -Critical loops use `@njit(parallel=True)` decorators for speed: -```python -@njit(parallel=True) -def compute_infection_times_gamma(weights: np.ndarray, shape: float, scale: float) -> np.ndarray: - ... -``` - -## Coding Conventions - -### Style -- Python 3.8+ with type hints throughout -- Line length: 100 characters (Black formatter) -- snake_case for functions and variables -- PascalCase for classes -- Comprehensive docstrings on public methods - -### Import Organization (isort) -1. Standard library -2. Third-party packages -3. Local imports - -### Error Handling -- Use Click parameter callbacks for CLI validation -- Raise informative exceptions with context -- Graceful degradation for optional dependencies (openpyxl, pandas) - -### Testing -- pytest framework with fixtures -- Test files mirror source structure: `test_.py` -- Integration tests in `test_integration.py` -- Run with: `pytest -v` - -## Dependencies - -### Core -- `numpy>=1.20.0` - Numerical arrays -- `scipy>=1.7.0` - Sparse graphs, Dijkstra algorithm -- `networkx>=2.6.0` - Graph creation -- `matplotlib>=3.4.0` - Plotting -- `numba>=0.54.0` - JIT compilation - -### CLI -- `click>=8.0.0` - CLI framework -- `rich>=10.0.0` - Terminal formatting -- `tqdm>=4.60.0` - Progress bars - -### Data -- `pandas>=1.3.0` - DataFrame operations -- `openpyxl>=3.0.7` - Excel export (optional) -- `joblib>=1.0.1` - Parallel batch execution - -## CLI Commands - -### `run` - Execute simulation -```bash -spkmc run -n -d [OPTIONS] - -n, --network-type Network type (er|sf|cg|rrn) - -d, --dist-type Distribution (gamma|exponential) - -N, --nodes Number of nodes (default: 1000) - -s, --samples Samples per run (default: 50) - --shape, --scale Gamma parameters - --mu, --lambda Exponential/infection parameters - -o, --output Save results to file - --no-plot Skip visualization -``` - -### `plot` - Visualize results -```bash -spkmc plot [--save ] [--states S,I,R] -``` - -### `info` - List/inspect results -```bash -spkmc info --list # List all results -spkmc info --result-file # Show specific result -``` - -### `compare` - Compare multiple runs -```bash -spkmc compare ... [-o output] -``` - -### `batch` - Run multiple scenarios -```bash -spkmc batch # Interactive experiment menu -spkmc batch scenarios.json # File mode -``` - -## Result Storage - -Results are stored in standard locations: -``` -data/experiments// # Experiment results -data/runs/ # Individual run results -``` - -## Important Files - -| File | Purpose | -|------|---------| -| `spkmc/core/simulation.py` | Main SPKMC algorithm implementation | -| `spkmc/core/distributions.py` | Probability distribution classes | -| `spkmc/core/networks.py` | Network topology factory | -| `spkmc/cli/commands.py` | All CLI command definitions | -| `spkmc/utils/numba_utils.py` | Performance-critical JIT functions | -| `experiments/*/data.json` | Experiment configurations | - -## Development Guidelines - -1. **Never disable lint rules** - Fix the underlying issue instead -2. **Use established libraries** - Rely on NumPy, SciPy, NetworkX patterns -3. **Maintain type hints** - All public functions should be typed -4. **Write tests** - Add tests for new functionality -5. **Keep modules focused** - Separation of concerns between core/cli/io/visualization -6. **Preserve Numba compatibility** - JIT functions have restrictions on Python features - -## Performance Considerations - -- Use sparse matrices (SciPy CSR) for large networks -- Leverage Numba `@njit(parallel=True)` for loops over nodes/edges -- Avoid Python loops in hot paths - use NumPy vectorization -- NetworkX graphs convert to sparse adjacency matrices for computation - -## Validation Callbacks - -CLI validators in `spkmc/cli/validators.py`: -- `validate_percentage()` - Ensures 0 <= value <= 1 -- `validate_positive()` - Ensures value > 0 -- `validate_positive_int()` - Ensures positive integer -- `validate_network_type()` - Validates network type string -- `validate_distribution_type()` - Validates distribution string diff --git a/LICENSE b/LICENSE index a6623ca..87da325 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2025 mcaxtr +Copyright (c) 2025 Marcus Castro Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 27dfeaf..89e2b37 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,116 @@ The simulation tracks how the proportions of S, I, and R change over time as the - **High performance**: Uses Numba JIT compilation for speed, with optional GPU acceleration - **Publication-quality plots**: Generate professional visualizations of epidemic dynamics - **Multiple export formats**: Save results as JSON, CSV, Excel, Markdown, or HTML +- **Web interface**: Interactive Streamlit dashboard for managing experiments, viewing results, and running AI analysis in the browser + +## Web Interface + +SPKMC includes a full-featured web dashboard built with Streamlit for managing experiments, viewing results, and running AI analysis directly in the browser. + +### Quick Start + +```bash +# Launch the web interface (opens browser at http://localhost:8501) +spkmc web + +# Custom port +spkmc web --port 8080 + +# Headless mode (no browser auto-open) +spkmc web --no-browser + +# Bind to all interfaces (for remote access) +spkmc web --host 0.0.0.0 +``` + +### Features + +- **Experiment Dashboard** -- Browse all experiments with summary stats (total experiments, scenarios, completion rates). Create new experiments through a guided modal with network, distribution, and simulation parameter configuration. Click any experiment card to drill into details. + +- **Scenario Management** -- View scenario cards with live status badges (Pending / Running / Completed / Failed). Add new scenarios with parameter overrides, edit existing ones, or run simulations directly from the browser. Parameters that differ from global defaults are visually highlighted. + +- **Interactive Charts** -- Plotly-based SIR curve visualization with toggleable S/I/R traces, chart type switching (Line / Area / Scatter), error bands for multi-run results, and multi-scenario comparison overlays. All charts support zoom, pan, and image export. + +- **AI Analysis** -- Generate academic-style analysis reports for experiments and individual scenarios using OpenAI models. Reports include epidemic dynamics interpretation, key findings, and actionable insights. Requires an OpenAI API key configured in Preferences. + +- **Export** -- Download scenario results in JSON, CSV, Excel, Markdown, or HTML format directly from the scenario detail modal. + +- **Preferences** -- Configure chart colors, height, and template; set default simulation parameters; manage directory paths; select AI model; and store API keys. All settings auto-save on change. + +### Architecture + +``` +spkmc/web/ +├── app.py # Streamlit entry point, sidebar navigation, CSS injection +├── config.py # Configuration management (JSON prefs + Streamlit secrets) +├── state.py # Typed session state accessors (prevents st.session_state spaghetti) +├── plotting.py # Core Plotly figure builders (SIR curves, comparisons) +├── components.py # Reusable UI components (forms, metric cards, badges) +├── styles.py # Design system (CSS, card renderers, color tokens) +├── runner.py # Subprocess-based simulation runner +├── analysis_runner.py # AI analysis subprocess runner +└── pages/ + ├── dashboard.py # Experiments list, stats cards, create modal + ├── experiment_detail.py # Single experiment view, scenario cards, detail modal + └── settings.py # Preferences page (AI, chart, simulation, storage) +``` + +### Design Decisions + +**Everything is an experiment.** Even a single simulation run is treated as an experiment with one scenario. This unified model simplifies the codebase and provides consistent storage patterns. + +**Subprocess execution.** Simulations run in background subprocesses that survive browser refresh, page navigation, and UI interactions. Progress is tracked via filesystem-based IPC (`.spkmc_web/status/*.json`), not in-memory state. + +**Filesystem-first storage.** No database required. Experiments are stored as `experiments//data.json`, results as `experiments//.json`, and status as `.spkmc_web/status/.json`. All files are portable JSON. + +**Parameter inheritance.** Global parameters are defined at the experiment level. Each scenario only specifies what differs from the defaults, keeping configurations DRY and making overrides instantly visible in the UI. + +### Configuration + +**User preferences** are stored at `~/.spkmc/web_config.json`: +- Directory paths (data, experiments) +- Default simulation parameters +- Chart styling (height, colors, template) +- Export format preference + +Override the config file location with an environment variable: + +```bash +SPKMC_WEB_CONFIG_FILE=/path/to/config.json spkmc web +``` + +**API keys** are stored in `.streamlit/secrets.toml` (managed through the Preferences page): + +```toml +OPENAI_API_KEY = "sk-your-key-here" +``` + +### Workflow + +1. Open the Dashboard and create a new experiment with global parameters +2. Add scenarios -- each can override any parameter from the global defaults +3. Run individual scenarios or all at once from the experiment detail page +4. View interactive Plotly charts with toggleable S/I/R traces +5. Compare multiple scenarios with overlaid charts +6. Generate AI analysis reports (optional, requires API key) +7. Export results in your preferred format + +### Troubleshooting + +**Simulations not starting** -- Check `.spkmc_web/status/*.json` files for error messages. Common issues include Numba compilation errors, missing parameters, or invalid network/distribution combinations. + +**Browser doesn't open** -- Use `spkmc web --no-browser` and navigate to the URL shown in terminal output. + +**Charts not displaying** -- Ensure `plotly>=5.18.0` is installed: `pip install --upgrade plotly`. + +### Extending the Web Interface + +To add a new page: + +1. Create `spkmc/web/pages/my_page.py` with a `render()` function +2. Register the page in `spkmc/web/pages/__init__.py` +3. Add sidebar navigation in `spkmc/web/app.py` +4. Add routing logic in the `main()` function ## Installation @@ -174,7 +284,7 @@ The timing of infection and recovery events follows probability distributions: ### Parameter Reference -The following parameters apply to both `spkmc run` and batch experiment scenarios. +The following parameters apply to both `spkmc run` and experiment scenarios. #### Network Parameters @@ -243,7 +353,7 @@ spkmc run -n er -d gamma -o my_results.json # Save results as CSV instead of JSON spkmc run -n er -d gamma -o my_results --export csv -# Run without displaying the plot (useful for batch processing or servers) +# Run without displaying the plot (useful for automated processing or servers) spkmc run -n er -d gamma -o results.json --no-plot ``` @@ -742,6 +852,106 @@ The `_err` fields contain standard errors and are only present when `num_runs > --- +## Project Structure + +``` +spkmc/ +├── analysis/ # AI-powered analysis +│ ├── ai_analyzer.py # OpenAI integration for experiment and scenario analysis +│ ├── metrics.py # Metric extraction from simulation results +│ └── prompts.py # LLM prompt templates +├── cli/ # Command-line interface (Click-based) +│ ├── commands.py # CLI commands: run, plot, info, compare, experiments, web +│ ├── validators.py # Parameter validation callbacks +│ └── formatting.py # Rich terminal output formatting +├── core/ # Core algorithm implementation +│ ├── simulation.py # SPKMC class - main simulation algorithm +│ ├── distributions.py # Gamma & Exponential distribution classes +│ └── networks.py # NetworkFactory for graph creation +├── io/ # Input/output operations +│ ├── export.py # Multi-format export (CSV, JSON, Excel, MD, HTML) +│ ├── data_manager.py # Result persistence, loading, and report generation +│ └── results.py # Result file discovery and metadata +├── models/ # Data models +│ ├── experiment.py # Experiment and Scenario Pydantic models +│ └── scenario.py # Scenario configuration model +├── visualization/ # Plotting (Plotly-based) +│ └── plots.py # SIR curve visualization for CLI and programmatic use +├── web/ # Streamlit web interface +│ ├── app.py # Entry point, sidebar navigation, CSS injection +│ ├── config.py # Configuration management (JSON prefs + secrets) +│ ├── state.py # Typed session state accessors +│ ├── plotting.py # Core Plotly figure builders +│ ├── components.py # Reusable UI components (forms, badges, cards) +│ ├── styles.py # Design system (CSS, card renderers, color tokens) +│ ├── runner.py # Subprocess-based simulation runner +│ ├── analysis_runner.py # AI analysis subprocess runner +│ └── pages/ +│ ├── dashboard.py # Experiments list, stats cards, create modal +│ ├── experiment_detail.py # Experiment view, scenario cards, detail modal +│ └── settings.py # Preferences page +└── utils/ # Utilities + └── numba_utils.py # Numba JIT-optimized functions + +tests/ +├── test_web/ # Unit tests for web modules +│ ├── test_state.py # Session state management tests +│ ├── test_config.py # Configuration management tests +│ ├── test_runner.py # Simulation runner tests +│ ├── test_analysis_runner.py # Analysis runner tests +│ ├── test_plotting.py # Plotly figure builder tests +│ └── test_experiment_detail.py # Experiment detail logic tests +└── e2e/ # Playwright end-to-end tests + ├── conftest.py # Server lifecycle, page helpers, fixture seeding + ├── fixtures/ # Pre-seeded experiment data + ├── test_navigation.py # Sidebar nav, page routing, title + ├── test_dashboard.py # Stats cards, create modal, experiment cards + ├── test_experiment_detail.py # Params, scenario cards, modal, charts + └── test_settings.py # Preference sections, inputs, reset +``` + +--- + +## Development + +### Running Tests + +```bash +# Run all unit tests with coverage +pytest + +# Run web module tests only +pytest tests/test_web/ -v + +# Run E2E tests (requires playwright browsers installed) +pip install -e ".[e2e]" +playwright install chromium +pytest tests/e2e/ -v --browser chromium +``` + +### Dependencies + +**Core:** +- `numpy`, `scipy`, `networkx` -- Numerical computation and graph algorithms +- `numba` -- JIT compilation for performance-critical loops +- `plotly` -- Interactive visualization (CLI and web) +- `streamlit` -- Web interface framework +- `pydantic` -- Data validation and models + +**CLI:** +- `click` -- CLI framework +- `rich` -- Terminal formatting +- `tqdm` -- Progress bars + +**Data:** +- `pandas`, `openpyxl` -- DataFrame operations and Excel export +- `joblib` -- Parallel experiment execution +- `humanize` -- Human-readable formatting +- `psutil` -- Process management +- `openai` -- AI analysis integration + +--- + ## Performance Tips ### Choosing Sample Sizes @@ -846,6 +1056,7 @@ And optionally, this software implementation: ```bibtex @software{spkmc, title = {SPKMC: Shortest Path Kinetic Monte Carlo for Epidemic Simulation}, + author = {Castro, Marcus}, url = {https://github.com/mcaxtr/spkmc} } ``` diff --git a/docs/usage.md b/docs/usage.md index 1809d4e..026b58b 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -11,7 +11,7 @@ This document provides detailed information on how to use the SPKMC package for - NumPy: Efficient numerical operations - SciPy: Scientific and mathematical algorithms - NetworkX: Creation and manipulation of networks - - Matplotlib: Result visualization + - Plotly: Interactive result visualization - Numba: Python code acceleration - tqdm: Progress bars - Click: Command-line interface @@ -402,21 +402,20 @@ if has_error: #### Basic Visualization ```python -import matplotlib.pyplot as plt - -plt.figure(figsize=(10, 6)) -plt.plot(time_steps, S, 'b-', label='Susceptible') -plt.plot(time_steps, I, 'r-', label='Infected') -plt.plot(time_steps, R, 'g-', label='Recovered') - -plt.xlabel('Time') -plt.ylabel('Proportion of Individuals') -plt.title('SIR Model Dynamics Over Time') -plt.legend() -plt.grid(True, alpha=0.3) - -plt.tight_layout() -plt.show() +import plotly.graph_objects as go + +fig = go.Figure() +fig.add_trace(go.Scatter(x=time_steps, y=S, mode='lines', name='Susceptible')) +fig.add_trace(go.Scatter(x=time_steps, y=I, mode='lines', name='Infected')) +fig.add_trace(go.Scatter(x=time_steps, y=R, mode='lines', name='Recovered')) + +fig.update_layout( + title='SIR Model Dynamics Over Time', + xaxis_title='Time', + yaxis_title='Proportion of Individuals', + template='plotly_white', +) +fig.show() ``` #### Visualization with Error Bars diff --git a/pyproject.toml b/pyproject.toml index a654369..44ea434 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Shortest Path Kinetic Monte Carlo (SPKMC) for simulating epidemic spread on networks" readme = "README.md" authors = [ - {name = "mcaxtr", email = "mcaxtr@gmail.com"} + {name = "Marcus Castro", email = "mcaxtr@gmail.com"} ] license = {text = "MIT"} classifiers = [ @@ -19,13 +19,11 @@ classifiers = [ "Topic :: Scientific/Engineering :: Physics", "Topic :: Scientific/Engineering :: Information Analysis", ] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "numpy>=1.20.0", "scipy>=1.7.0", "networkx>=2.6.0", - "matplotlib>=3.4.0", - "seaborn>=0.12.0", "numba>=0.54.0", "tqdm>=4.60.0", "click>=8.0.0", @@ -38,6 +36,10 @@ dependencies = [ "psutil>=5.8.0", "openai>=1.0.0", "pydantic>=2.0.0", + "streamlit>=1.48.0", + "plotly>=5.18.0", + "kaleido>=0.2.1", + "humanize>=4.0.0", ] [project.optional-dependencies] @@ -56,12 +58,15 @@ gpu = [ "cudf-cu12>=24.0.0", "cugraph-cu12>=24.0.0", ] +e2e = [ + "pytest-playwright>=0.5.0", +] [project.scripts] spkmc = "spkmc.cli.commands:cli" [tool.setuptools] -packages = ["spkmc", "spkmc.analysis", "spkmc.cli", "spkmc.core", "spkmc.io", "spkmc.utils", "spkmc.visualization"] +packages = ["spkmc", "spkmc.analysis", "spkmc.cli", "spkmc.core", "spkmc.io", "spkmc.models", "spkmc.utils", "spkmc.visualization", "spkmc.web", "spkmc.web.pages"] [tool.setuptools_scm] write_to = "spkmc/_version.py" @@ -70,7 +75,7 @@ local_scheme = "no-local-version" [tool.black] line-length = 100 -target-version = ["py38"] +target-version = ["py39"] [tool.isort] profile = "black" @@ -86,3 +91,6 @@ disallow_incomplete_defs = true [tool.pytest.ini_options] testpaths = ["tests"] python_files = "test_*.py" +markers = [ + "e2e: end-to-end tests requiring a running Streamlit server", +] diff --git a/pytest.ini b/pytest.ini index 9ea0c4b..68015f6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,4 +3,6 @@ testpaths = tests python_files = test_*.py python_classes = Test* python_functions = test_* -addopts = -v --cov=spkmc --cov-report=term-missing +addopts = -v --cov=spkmc --cov-report=term-missing --ignore=tests/e2e +markers = + e2e: end-to-end tests requiring a running Streamlit server diff --git a/requirements.txt b/requirements.txt index 51b5c71..a33d259 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ numpy>=1.20.0 scipy>=1.7.0 networkx>=2.6.0 -matplotlib>=3.4.0 numba>=0.54.0 tqdm>=4.60.0 click>=8.0.0 @@ -9,6 +8,12 @@ colorama>=0.4.4 rich>=10.0.0 pandas>=1.3.0 openpyxl>=3.0.7 -pytest>=6.2.5 -pytest-cov>=2.12.1 joblib>=1.0.1 +questionary>=1.10.0 +psutil>=5.8.0 +openai>=1.0.0 +pydantic>=2.0.0 +streamlit>=1.48.0 +plotly>=5.18.0 +kaleido>=0.2.1 +humanize>=4.0.0 diff --git a/setup.cfg b/setup.cfg index 8eae0d5..95888dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ name = spkmc description = Shortest Path Kinetic Monte Carlo (SPKMC) for simulating epidemic spread on networks long_description = file: README.md long_description_content_type = text/markdown -author = mcaxtr +author = Marcus Castro author_email = mcaxtr@gmail.com license = MIT license_file = LICENSE @@ -18,12 +18,11 @@ classifiers = [options] packages = find: -python_requires = >=3.8 +python_requires = >=3.9 install_requires = numpy>=1.20.0 scipy>=1.7.0 networkx>=2.6.0 - matplotlib>=3.4.0 numba>=0.54.0 tqdm>=4.60.0 click>=8.0.0 @@ -32,6 +31,10 @@ install_requires = rich>=10.0.0 openpyxl>=3.0.7 joblib>=1.0.1 + plotly>=5.18.0 + streamlit>=1.35.0 + humanize>=4.0.0 + pydantic>=2.0.0 [options.entry_points] console_scripts = @@ -50,7 +53,7 @@ max-line-length = 100 exclude = .git,__pycache__,build,dist [mypy] -python_version = 3.8 +python_version = 3.9 warn_return_any = True warn_unused_configs = True disallow_untyped_defs = True diff --git a/spkmc/analysis/ai_analyzer.py b/spkmc/analysis/ai_analyzer.py index dcbb51a..78c0d48 100644 --- a/spkmc/analysis/ai_analyzer.py +++ b/spkmc/analysis/ai_analyzer.py @@ -10,13 +10,19 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from spkmc.analysis.metrics import ExperimentMetrics, extract_experiment_metrics +from spkmc.analysis.metrics import ( + ExperimentMetrics, + extract_experiment_metrics, + extract_scenario_metrics, +) from spkmc.analysis.prompts import ( CROSS_EXPERIMENT_SYSTEM_PROMPT, + SCENARIO_SYSTEM_PROMPT, SYSTEM_PROMPT, build_collection_prompt, build_cross_experiment_prompt, build_experiment_prompt, + build_scenario_prompt, ) @@ -126,7 +132,7 @@ def analyze_experiment( {"role": "user", "content": prompt}, ], temperature=0.3, # Lower for more consistent scientific writing - max_tokens=2000, + max_tokens=2500, ) analysis_text = response.choices[0].message.content @@ -142,6 +148,62 @@ def analyze_experiment( return str(analysis_path) + def analyze_scenario( + self, + scenario_label: str, + result: Dict[str, Any], + results_dir: Path, + ) -> Optional[str]: + """ + Generate AI analysis for a single scenario. + + Args: + scenario_label: Label of the scenario + result: The loaded result dictionary for this scenario + results_dir: Path to experiment results directory + + Returns: + Path to generated analysis file, or None if skipped/failed + """ + from spkmc.models import Scenario + + normalized = Scenario.normalize_label(scenario_label) + analysis_path = results_dir / f"{normalized}_analysis.md" + + # Skip if analysis already exists + if analysis_path.exists(): + return None + + # Extract metrics + scenario_metrics = extract_scenario_metrics(result) + + # Build prompt + prompt = build_scenario_prompt(scenario_metrics) + + # Call OpenAI API + client = self._get_client() + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": SCENARIO_SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + max_tokens=2000, + ) + + analysis_text = response.choices[0].message.content + + # Write analysis file + with open(analysis_path, "w", encoding="utf-8") as f: + f.write(f"# Scenario Analysis: {scenario_label}\n\n") + f.write("---\n\n") + f.write(str(analysis_text) if analysis_text else "") + f.write("\n\n---\n\n") + f.write(f"*Generated by AI analysis (model: {self.model})*\n") + + return str(analysis_path) + def generate_collection_summary( self, all_experiments_metrics: List[ExperimentMetrics] ) -> Optional[str]: @@ -217,7 +279,7 @@ def generate_cross_experiment_analysis( {"role": "user", "content": prompt}, ], temperature=0.3, - max_tokens=3000, + max_tokens=3500, ) analysis_text = response.choices[0].message.content diff --git a/spkmc/analysis/prompts.py b/spkmc/analysis/prompts.py index 8842b39..b49ddb9 100644 --- a/spkmc/analysis/prompts.py +++ b/spkmc/analysis/prompts.py @@ -25,6 +25,14 @@ 3. **Discussion** - Epidemiological interpretation of the patterns observed 4. **Conclusion** - Direct answer to the research question with main takeaways +Formatting Guidelines: +- Use emoji icons to mark section headers (e.g., "## 🔬 Results") +- **Bold** key numerical findings and critical terms +- Use > blockquotes for the single most important takeaway in each section +- Use bullet points for lists of findings +- End with a "💡 Key Takeaway" blockquote summarizing the most actionable insight +- Use --- horizontal rules between major sections for visual separation + Keep the analysis focused and concise (approximately 400-600 words).""" @@ -157,6 +165,73 @@ def build_collection_prompt(all_experiment_metrics: List[ExperimentMetrics]) -> return prompt +SCENARIO_SYSTEM_PROMPT = """You are a computational epidemiologist analyzing a single SIR model \ +simulation scenario on a complex network. Your task is to provide rigorous scientific analysis \ +of the results for this specific parameter configuration. + +Writing Style: +- Use formal academic/scientific language +- Be precise and quantitative - always cite specific numbers +- Use proper epidemiological terminology (basic reproduction number, epidemic threshold, \ +attack rate, herd immunity threshold, network topology, degree distribution, etc.) +- Focus on what these specific parameters reveal about epidemic dynamics + +Structure your analysis with these sections: +1. **Configuration Summary** - Brief overview of the network and distribution setup (2-3 sentences) +2. **Epidemic Dynamics** - Analysis of the SIR curves: peak timing, growth rate, decay behavior +3. **Key Findings** - What this parameter set reveals about epidemic spread on this network +4. **Implications** - Practical meaning of these results + +Formatting Guidelines: +- Use emoji icons to mark section headers (e.g., "## 🔬 Epidemic Dynamics") +- **Bold** key numerical findings and critical terms +- Use > blockquotes for the single most important takeaway in each section +- Use bullet points for lists of findings +- End with a "💡 Key Takeaway" blockquote summarizing the most actionable insight +- Use --- horizontal rules between major sections for visual separation + +Keep the analysis focused and concise (approximately 300-400 words).""" + + +def build_scenario_prompt(scenario: ScenarioMetrics) -> str: + """ + Build the prompt for single scenario analysis. + + Args: + scenario: Extracted scenario metrics + + Returns: + Formatted prompt string for the LLM + """ + network_type = scenario.network_type.upper() + network_names = { + "ER": "Erdos-Renyi", + "SF": "Scale-free (Power-law)", + "RRN": "Random Regular", + "CG": "Complete Graph", + } + network_name = network_names.get(network_type, network_type) + + return f"""Analyze the following single epidemic simulation scenario: + +## Scenario: {scenario.label} + +## Configuration +- **Network**: {network_name} ({_format_network_info(scenario)}) +- **Recovery Distribution**: {_format_distribution_info(scenario)} +- **Simulation**: {scenario.samples} samples, {scenario.num_runs} runs, \ +initial infected = {scenario.initial_perc:.1%} + +## Results +- **Peak Infection**: {scenario.peak_infection:.4f} (at t = {scenario.peak_infection_time:.2f}) +- **Final Outbreak Size**: {scenario.final_outbreak_size:.4f} +- **Attack Rate**: {scenario.attack_rate:.2%} +- **Epidemic Duration**: {scenario.epidemic_duration:.2f} time units + +Please analyze the epidemic dynamics for this specific scenario, focusing on what the \ +SIR curve shape and metrics reveal about disease spread on this network topology.""" + + CROSS_EXPERIMENT_SYSTEM_PROMPT = """You are a computational epidemiologist synthesizing findings \ from multiple epidemic modeling experiments. You are given the individual AI-generated analyses \ for each experiment. Your task is to create a comprehensive meta-analysis that identifies \ @@ -176,6 +251,14 @@ def build_collection_prompt(all_experiment_metrics: List[ExperimentMetrics]) -> 4. **Unified Conclusions** - What the collection of experiments tells us as a whole 5. **Implications** - Practical implications for understanding epidemic dynamics on networks +Formatting Guidelines: +- Use emoji icons to mark section headers (e.g., "## 🔬 Cross-Experiment Patterns") +- **Bold** key numerical findings and critical terms +- Use > blockquotes for the single most important takeaway in each section +- Use bullet points for lists of findings +- End with a "💡 Key Takeaway" blockquote summarizing the most actionable insight +- Use --- horizontal rules between major sections for visual separation + Keep the analysis focused and insightful (approximately 800-1200 words).""" diff --git a/spkmc/cli/commands.py b/spkmc/cli/commands.py index 84f902e..ee5410e 100644 --- a/spkmc/cli/commands.py +++ b/spkmc/cli/commands.py @@ -695,7 +695,7 @@ def _execute_single_scenario( if network_type == "sf": simulation_params["exponent"] = exponent - # Disable inner progress bars during batch execution + # Disable inner progress bars during experiment execution simulation_params["show_progress"] = False # Execute simulation with progress callback for per-sample updates @@ -2565,10 +2565,10 @@ def experiment( # FILE MODE: Traditional execution with a scenarios file # ============================================================ - # Record batch execution start + # Record experiment execution start start_time = time.time() log_debug( - f"Starting batch execution at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + f"Starting experiment execution at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", verbose_only=False, ) @@ -2794,3 +2794,73 @@ def clean( console.print() log_success(f"Cleanup completed. {cleaned_count} location(s) cleaned.") + + +@cli.command(help="Launch the web interface") +@click.option("--port", "-p", default=8501, type=int, help="Port to run the server on") +@click.option("--host", default="localhost", type=str, help="Host to bind to") +@click.option("--no-browser", is_flag=True, help="Do not open browser automatically") +def web(port: int, host: str, no_browser: bool) -> None: + """Launch the Streamlit web interface.""" + import subprocess + import sys + from pathlib import Path + + log_info("Starting SPKMC web interface...") + + # Find the app.py file + web_app = Path(__file__).parent.parent / "web" / "app.py" + + if not web_app.exists(): + log_error(f"Web app not found at {web_app}") + log_error("Web interface files are missing. Reinstall SPKMC: pip install --upgrade spkmc") + sys.exit(1) + + # Build streamlit command with all config as CLI flags + # (avoids requiring a .streamlit/config.toml file on disk) + cmd = [ + sys.executable, + "-m", + "streamlit", + "run", + str(web_app), + "--server.port", + str(port), + "--server.address", + host, + "--server.headless", + "true" if no_browser else "false", + "--server.fileWatcherType", + "none", + "--browser.gatherUsageStats", + "false", + "--client.toolbarMode", + "minimal", + "--runner.magicEnabled", + "false", + "--theme.base", + "light", + "--theme.primaryColor", + "#2D7A6E", + "--theme.backgroundColor", + "#F7F8FA", + "--theme.secondaryBackgroundColor", + "#FFFFFF", + "--theme.textColor", + "#111827", + "--theme.font", + "sans serif", + ] + + log_info(f"Launching at http://{host}:{port}") + + try: + result = subprocess.run(cmd) + if result.returncode != 0: + log_error(f"Web interface exited with code {result.returncode}") + sys.exit(result.returncode) + except KeyboardInterrupt: + log_info("Web interface stopped") + except Exception as e: + log_error(f"Failed to start web interface: {e}") + sys.exit(1) diff --git a/spkmc/io/data_manager.py b/spkmc/io/data_manager.py index 088b17b..a376a89 100644 --- a/spkmc/io/data_manager.py +++ b/spkmc/io/data_manager.py @@ -9,7 +9,7 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, cast +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, cast import numpy as np @@ -242,28 +242,24 @@ def _save_excel(cls, result: Dict[str, Any], path: str) -> None: df_stats.to_excel(writer, sheet_name="Statistics", index=False) @classmethod - def _save_markdown(cls, result: Dict[str, Any], path: str, include_plot: bool = True) -> None: - """Save result as Markdown report.""" - # Extract metadata + def _build_markdown_content(cls, result: Dict[str, Any]) -> str: + """Build Markdown report content string (without plot reference).""" metadata = result.get("metadata", {}) network_type = metadata.get("network", "").upper() dist_type = metadata.get("distribution", "").capitalize() N = metadata.get("N", "") - # Extract data time_steps = np.array(result.get("time", [])) s_vals = np.array(result.get("S_val", [])) i_vals = np.array(result.get("I_val", [])) r_vals = np.array(result.get("R_val", [])) - # Calculate statistics max_infected = np.max(i_vals) if len(i_vals) > 0 else 0 max_infected_time = ( time_steps[np.argmax(i_vals)] if len(i_vals) > 0 and len(time_steps) > 0 else 0 ) final_recovered = r_vals[-1] if len(r_vals) > 0 else 0 - # Build Markdown content timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") md_content = f"""# SPKMC Simulation Report @@ -278,13 +274,10 @@ def _save_markdown(cls, result: Dict[str, Any], path: str, include_plot: bool = | Distribution | {dist_type} | | Number of Nodes (N) | {N} | """ - - # Add specific parameters for key, value in metadata.items(): if key not in ["network", "distribution", "N"]: md_content += f"| {key} | {value} |\n" - # Add statistics md_content += f""" ## Statistics @@ -294,55 +287,39 @@ def _save_markdown(cls, result: Dict[str, Any], path: str, include_plot: bool = | Time to Infection Peak | {max_infected_time:.4f} | | Final Recovered | {final_recovered:.4f} | -""" - - # Add plot reference if requested - if include_plot: - plot_path = path.replace(".md", ".png") - cls._generate_plot(result, plot_path) - md_content += f""" -## Visualization - -![Simulation Plot]({os.path.basename(plot_path)}) - -""" - - # Add data tables (first and last 5 points) - md_content += """ ## Simulation Data -### First 5 points - | Time | Susceptible | Infected | Recovered | |-------|-------------|------------|-------------| """ - - for idx in range(min(5, len(time_steps))): + for idx in range(len(time_steps)): md_content += ( f"| {time_steps[idx]:.4f} | {s_vals[idx]:.4f} " f"| {i_vals[idx]:.4f} | {r_vals[idx]:.4f} |\n" ) - md_content += """ -### Last 5 points + return md_content -| Time | Susceptible | Infected | Recovered | -|-------|-------------|------------|-------------| -""" + @classmethod + def _save_markdown(cls, result: Dict[str, Any], path: str, include_plot: bool = True) -> None: + """Save result as Markdown report.""" + md_content = cls._build_markdown_content(result) - for idx in range(max(0, len(time_steps) - 5), len(time_steps)): - md_content += ( - f"| {time_steps[idx]:.4f} | {s_vals[idx]:.4f} " - f"| {i_vals[idx]:.4f} | {r_vals[idx]:.4f} |\n" - ) + if include_plot: + plot_path = path.replace(".md", ".png") + try: + cls._generate_plot(result, plot_path) + actual_plot = os.path.basename(plot_path) + md_content += f"\n## Visualization\n\n![Simulation Plot]({actual_plot})\n\n" + except RuntimeError: + pass # Plot generation failed (e.g. kaleido missing); skip image - # Save Markdown file with open(path, "w") as f: f.write(md_content) @classmethod - def _save_html(cls, result: Dict[str, Any], path: str, include_plot: bool = True) -> None: - """Save result as HTML report.""" + def _build_html_content(cls, result: Dict[str, Any]) -> str: + """Build HTML report content string.""" try: import pandas as pd except ImportError: @@ -350,20 +327,11 @@ def _save_html(cls, result: Dict[str, Any], path: str, include_plot: bool = True "Pandas is required for HTML export. Install with: pip install pandas" ) - # First export to Markdown - md_path = path.replace(".html", "_temp.md") - cls._save_markdown(result, md_path, include_plot) - - # Read markdown content - with open(md_path, "r") as f: - md_content = f.read() - - # Convert to simple HTML table + md_content = cls._build_markdown_content(result) df = pd.DataFrame({"markdown": [md_content]}) html = df.to_html(escape=False, index=False, header=False) - # Add CSS styles - html_content = f""" + return f""" @@ -408,12 +376,71 @@ def _save_html(cls, result: Dict[str, Any], path: str, include_plot: bool = True """ - # Save HTML file + @classmethod + def _save_html(cls, result: Dict[str, Any], path: str, include_plot: bool = True) -> None: + """Save result as HTML report.""" + # Build HTML (without embedded plot for simplicity) + html_content = cls._build_html_content(result) + + if include_plot: + # Generate plot alongside the HTML file + plot_path = path.replace(".html", ".png") + try: + cls._generate_plot(result, plot_path) + except RuntimeError: + pass # Plot generation failed (e.g. kaleido missing); skip image + with open(path, "w") as f: f.write(html_content) - # Remove temporary Markdown file - os.remove(md_path) + @classmethod + def to_bytes(cls, result: Dict[str, Any], fmt: str) -> Tuple[bytes, str, str]: + """ + Serialize a result dict to bytes for in-memory download. + + Args: + result: Result dictionary with SIR data and metadata. + fmt: One of "json", "csv", "excel", "md", "html". + + Returns: + Tuple of (data_bytes, mime_type, file_extension). + """ + import io as _io + + if fmt == "csv": + csv_buf = _io.StringIO() + cls._result_to_dataframe(result).to_csv(csv_buf, index=False) + return csv_buf.getvalue().encode("utf-8"), "text/csv", ".csv" + + if fmt == "excel": + try: + import openpyxl # noqa: F401 + import pandas as pd + except ImportError as exc: + raise ImportError( + "Excel export requires pandas and openpyxl: pip install pandas openpyxl" + ) from exc + excel_buf = _io.BytesIO() + df_data = cls._result_to_dataframe(result) + metadata = result.get("metadata", {}) + df_meta = pd.DataFrame([{"Parameter": k, "Value": v} for k, v in metadata.items()]) + with pd.ExcelWriter(excel_buf, engine="openpyxl") as writer: + df_data.to_excel(writer, sheet_name="Data", index=False) + df_meta.to_excel(writer, sheet_name="Metadata", index=False) + mime = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + return excel_buf.getvalue(), mime, ".xlsx" + + if fmt == "md": + content = cls._build_markdown_content(result) + return content.encode("utf-8"), "text/markdown", ".md" + + if fmt == "html": + content = cls._build_html_content(result) + return content.encode("utf-8"), "text/html", ".html" + + # Default: JSON + content = json.dumps(result, indent=2, cls=NumpyJSONEncoder) + return content.encode("utf-8"), "application/json", ".json" @classmethod def _generate_plot(cls, result: Dict[str, Any], output_path: str) -> None: diff --git a/spkmc/models/experiment.py b/spkmc/models/experiment.py index 0ec4a8c..91c2122 100644 --- a/spkmc/models/experiment.py +++ b/spkmc/models/experiment.py @@ -92,6 +92,7 @@ class Experiment(BaseModel): plot_config: Optional[PlotConfig] = Field( default=None, exclude=True ) # Alias for backward compat + parameters: Dict[str, Any] = Field(default_factory=dict) # Global default parameters scenarios: List[Any] = Field(min_length=1) # Accept raw dicts or Scenario objects path: Optional[Path] = None @@ -198,6 +199,7 @@ def from_config(cls, config: ExperimentConfig, path: Optional[Path] = None) -> " name=config.name, description=config.description, plot=config.plot, + parameters=config.parameters, scenarios=scenarios, path=path, ) diff --git a/spkmc/visualization/plots.py b/spkmc/visualization/plots.py index f4e7954..694834c 100644 --- a/spkmc/visualization/plots.py +++ b/spkmc/visualization/plots.py @@ -4,170 +4,89 @@ This module contains functions to visualize SPKMC simulation results, including time-evolution plots of SIR states and comparisons between simulations. -Uses seaborn and matplotlib with publication-quality styling suitable for -academic papers and presentations. +Uses Plotly for interactive visualizations that work in both CLI (opens browser) +and web interface (embedded in Streamlit). """ -import contextlib import os -import sys -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -import matplotlib.pyplot as plt -import networkx as nx import numpy as np -import seaborn as sns - -# Publication-quality color palettes (colorblind-friendly) -# Based on Paul Tol's colorblind-safe palette -COLORBLIND_PALETTE = [ - "#4477AA", # blue - "#EE6677", # red/pink - "#228833", # green - "#CCBB44", # yellow - "#66CCEE", # cyan - "#AA3377", # purple - "#BBBBBB", # grey -] - -# SIR-specific colors (semantically meaningful and colorblind-friendly) -SIR_COLORS = { - "S": "#4477AA", # blue for susceptible - "I": "#EE6677", # red/pink for infected - "R": "#228833", # green for recovered -} - -# Line styles for distinguishing curves -LINE_STYLES = { - "S": "-", # solid for susceptible - "I": "-", # solid for infected - "R": "--", # dashed for recovered -} +import plotly.graph_objects as go +import plotly.io as pio + +# Import from web plotting module to reuse code +from spkmc.web.plotting import ( + COLOR_I, + COLOR_R, + COLOR_S, + STATE_COLORS, + create_comparison_figure, + create_sir_figure, +) # Default DPI for saved figures (single source of truth) DEFAULT_PLOT_DPI = 300 - -def _setup_publication_style() -> None: - """Configure matplotlib and seaborn for publication-quality figures.""" - # Use seaborn's whitegrid style as base - sns.set_theme(style="whitegrid", context="paper", font_scale=1.2) - sns.set_palette(COLORBLIND_PALETTE) - - # Additional matplotlib customizations - plt.rcParams.update( - { - # Figure - "figure.facecolor": "white", - "figure.edgecolor": "white", - "figure.dpi": 150, - # Font - "font.family": "sans-serif", - "font.sans-serif": ["Arial", "DejaVu Sans", "Helvetica", "sans-serif"], - "font.size": 11, - # Axes - "axes.linewidth": 1.2, - "axes.labelsize": 12, - "axes.titlesize": 14, - "axes.titleweight": "bold", - "axes.spines.top": False, - "axes.spines.right": False, - "axes.grid": True, - "axes.axisbelow": True, - # Grid - "grid.alpha": 0.4, - "grid.linestyle": "-", - "grid.linewidth": 0.8, - # Legend - "legend.frameon": True, - "legend.framealpha": 0.9, - "legend.edgecolor": "0.8", - "legend.fontsize": 10, - "legend.title_fontsize": 11, - # Ticks - "xtick.labelsize": 10, - "ytick.labelsize": 10, - "xtick.major.width": 1.2, - "ytick.major.width": 1.2, - # Lines - "lines.linewidth": 2.0, - "lines.markersize": 6, - # Saving - "savefig.dpi": 300, - "savefig.bbox": "tight", - "savefig.facecolor": "white", - "savefig.edgecolor": "white", - } - ) - - -def _get_scenario_colors(n_scenarios: int) -> List[str]: - """Get a list of colorblind-friendly colors for scenarios.""" - if n_scenarios <= len(COLORBLIND_PALETTE): - return COLORBLIND_PALETTE[:n_scenarios] - - # If we need more colors, cycle through the palette - colors = [] - for i in range(n_scenarios): - colors.append(COLORBLIND_PALETTE[i % len(COLORBLIND_PALETTE)]) - return colors +# Configure Plotly defaults for publication quality +pio.templates.default = "plotly_white" -@contextlib.contextmanager -def _suppress_macos_warning() -> Generator[None, None, None]: +def _save_or_show( + fig: go.Figure, + save_path: Optional[str] = None, + format: str = "png", + dpi: int = DEFAULT_PLOT_DPI, + width: int = 800, + height: int = 500, +) -> None: """ - Context manager to suppress macOS ApplePersistenceIgnoreState warning. - - This warning is printed by macOS Cocoa layer, not Python, so we redirect - the file descriptor directly rather than using Python's sys.stderr. - """ - if sys.platform != "darwin": - yield - return - - # On macOS, redirect stderr at the file descriptor level - # to suppress Cocoa framework warnings - stderr_fd = sys.stderr.fileno() - try: - # Save the original stderr - saved_stderr = os.dup(stderr_fd) - # Open /dev/null - devnull = os.open(os.devnull, os.O_WRONLY) - # Replace stderr with /dev/null - os.dup2(devnull, stderr_fd) - os.close(devnull) - yield - finally: - # Restore original stderr - os.dup2(saved_stderr, stderr_fd) - os.close(saved_stderr) - - -def _show_plot() -> None: - """Show plot with suppressed macOS warnings.""" - with _suppress_macos_warning(): - plt.show() - - -def _create_figure( - figsize: Tuple[float, float] = (8, 5), **kwargs: Any -) -> Tuple[plt.Figure, plt.Axes]: - """ - Create a publication-quality figure with proper styling. + Save figure to file or open in browser. Args: - figsize: Figure size in inches (width, height) - **kwargs: Additional arguments passed to plt.subplots - - Returns: - Tuple of (figure, axes) + fig: Plotly figure + save_path: Path to save the figure (if None, opens in browser) + format: Output format ('png', 'jpg', 'svg', 'pdf', 'html') + dpi: Resolution for raster formats + width: Width in pixels + height: Height in pixels """ - _setup_publication_style() - - with _suppress_macos_warning(): - fig, ax = plt.subplots(figsize=figsize, **kwargs) + if save_path: + # Determine format from extension if not specified + if save_path.endswith(".html"): + format = "html" + elif save_path.endswith(".svg"): + format = "svg" + elif save_path.endswith(".pdf"): + format = "pdf" + elif save_path.endswith(".jpg") or save_path.endswith(".jpeg"): + format = "jpg" + else: + format = "png" - return fig, ax + if format == "html": + # Save as standalone HTML + fig.write_html(save_path, include_plotlyjs="cdn") + else: + # Save as static image (requires kaleido) + try: + scale = dpi / 96 # Convert DPI to scale factor (96 is default) + fig.write_image( + save_path, + format=format, + width=width, + height=height, + scale=scale, + ) + except (ValueError, ImportError) as e: + raise RuntimeError( + f"Failed to save plot as {format}: {e}. " + "Install kaleido for static image export: pip install kaleido" + ) from e + else: + # Open in browser + fig.show() if TYPE_CHECKING: @@ -175,7 +94,7 @@ def _create_figure( class Visualizer: - """Class for visualizing simulation results with publication-quality plots.""" + """Class for visualizing simulation results with interactive Plotly plots.""" @staticmethod def plot_result_with_error( @@ -188,11 +107,11 @@ def plot_result_with_error( time: np.ndarray, title: Optional[str] = None, save_path: Optional[str] = None, - states_to_plot: Optional[set] = None, + states_to_plot: Optional[Set[str]] = None, dpi: int = DEFAULT_PLOT_DPI, ) -> None: """ - Plot results with shaded error bands (publication-quality). + Plot results with shaded error bands (interactive Plotly). Args: S: Proportion of susceptible @@ -203,86 +122,37 @@ def plot_result_with_error( R_err: Standard error for recovered time: Time steps title: Plot title (optional) - save_path: Path to save the plot (optional) + save_path: Path to save the plot (optional, opens in browser if None) states_to_plot: Set of states to plot ('S', 'I', 'R') dpi: Resolution in dots per inch for saved figures (default: 300) """ if states_to_plot is None: states_to_plot = {"S", "I", "R"} - fig, ax = _create_figure(figsize=(8, 5)) - - # Plot with shaded error bands (more elegant than error bars) - if "S" in states_to_plot: - ax.plot( - time, - S, - color=SIR_COLORS["S"], - linestyle=LINE_STYLES["S"], - linewidth=2.0, - label="Susceptible", - ) - ax.fill_between( - time, - S - S_err, - S + S_err, - color=SIR_COLORS["S"], - alpha=0.2, - ) - - if "I" in states_to_plot: - ax.plot( - time, - I, - color=SIR_COLORS["I"], - linestyle=LINE_STYLES["I"], - linewidth=2.0, - label="Infected", - ) - ax.fill_between( - time, - I - I_err, - I + I_err, - color=SIR_COLORS["I"], - alpha=0.2, - ) - - if "R" in states_to_plot: - ax.plot( - time, - R, - color=SIR_COLORS["R"], - linestyle=LINE_STYLES["R"], - linewidth=2.0, - label="Recovered", - ) - ax.fill_between( - time, - R - R_err, - R + R_err, - color=SIR_COLORS["R"], - alpha=0.2, - ) - - ax.set_xlabel("Time", fontweight="medium") - ax.set_ylabel("Proportion of Population", fontweight="medium") - ax.set_ylim(0, 1.05) - ax.set_xlim(time[0], time[-1]) - - if title: - ax.set_title(title, pad=15) - else: - ax.set_title("SIR Dynamics with Confidence Bands", pad=15) - - ax.legend(loc="best", framealpha=0.9) + # Convert to list for Plotly + states_list = list(states_to_plot) + + # Build result dict + result_dict = { + "time": time.tolist() if isinstance(time, np.ndarray) else time, + "S_val": S.tolist() if isinstance(S, np.ndarray) else S, + "I_val": I.tolist() if isinstance(I, np.ndarray) else I, + "R_val": R.tolist() if isinstance(R, np.ndarray) else R, + "S_err": S_err.tolist() if isinstance(S_err, np.ndarray) else S_err, + "I_err": I_err.tolist() if isinstance(I_err, np.ndarray) else I_err, + "R_err": R_err.tolist() if isinstance(R_err, np.ndarray) else R_err, + } - plt.tight_layout() + # Create figure using web plotting module + fig = create_sir_figure( + result_dict, + title=title or "SIR Dynamics with Confidence Bands", + states=states_list, + show_error_bands=True, + height=500, + ) - if save_path: - fig.savefig(save_path, dpi=dpi, bbox_inches="tight", facecolor="white") - plt.close(fig) - else: - _show_plot() + _save_or_show(fig, save_path, dpi=dpi) @staticmethod def plot_result( @@ -292,11 +162,11 @@ def plot_result( time: np.ndarray, title: Optional[str] = None, save_path: Optional[str] = None, - states_to_plot: Optional[set] = None, + states_to_plot: Optional[Set[str]] = None, dpi: int = DEFAULT_PLOT_DPI, ) -> None: """ - Plot results without error bands (publication-quality). + Plot results without error bands (interactive Plotly). Args: S: Proportion of susceptible @@ -304,64 +174,34 @@ def plot_result( R: Proportion of recovered time: Time steps title: Plot title (optional) - save_path: Path to save the plot (optional) + save_path: Path to save the plot (optional, opens in browser if None) states_to_plot: Set of states to plot ('S', 'I', 'R') dpi: Resolution in dots per inch for saved figures (default: 300) """ if states_to_plot is None: states_to_plot = {"S", "I", "R"} - fig, ax = _create_figure(figsize=(8, 5)) - - if "S" in states_to_plot: - ax.plot( - time, - S, - color=SIR_COLORS["S"], - linestyle=LINE_STYLES["S"], - linewidth=2.0, - label="Susceptible", - ) - - if "I" in states_to_plot: - ax.plot( - time, - I, - color=SIR_COLORS["I"], - linestyle=LINE_STYLES["I"], - linewidth=2.0, - label="Infected", - ) - - if "R" in states_to_plot: - ax.plot( - time, - R, - color=SIR_COLORS["R"], - linestyle=LINE_STYLES["R"], - linewidth=2.0, - label="Recovered", - ) - - ax.set_xlabel("Time", fontweight="medium") - ax.set_ylabel("Proportion of Population", fontweight="medium") - ax.set_ylim(0, 1.05) - ax.set_xlim(time[0], time[-1]) + # Convert to list for Plotly + states_list = list(states_to_plot) - if title: - ax.set_title(title, pad=15) - else: - ax.set_title("SIR Model Dynamics", pad=15) - - ax.legend(loc="best", framealpha=0.9) + # Build result dict + result_dict = { + "time": time.tolist() if isinstance(time, np.ndarray) else time, + "S_val": S.tolist() if isinstance(S, np.ndarray) else S, + "I_val": I.tolist() if isinstance(I, np.ndarray) else I, + "R_val": R.tolist() if isinstance(R, np.ndarray) else R, + } - plt.tight_layout() + # Create figure using web plotting module + fig = create_sir_figure( + result_dict, + title=title or "SIR Model Dynamics", + states=states_list, + show_error_bands=False, + height=500, + ) - if save_path: - fig.savefig(save_path, dpi=dpi, bbox_inches="tight", facecolor="white") - plt.close(fig) - else: - _show_plot() + _save_or_show(fig, save_path, dpi=dpi) @staticmethod def compare_results( @@ -369,20 +209,17 @@ def compare_results( labels: List[str], title: Optional[str] = None, save_path: Optional[str] = None, - states_to_plot: Optional[set] = None, + states_to_plot: Optional[Set[str]] = None, dpi: int = DEFAULT_PLOT_DPI, ) -> None: """ - Compare results from multiple simulations (publication-quality). - - Uses distinct colors for each scenario and different line styles - for each SIR state. Colors are colorblind-friendly. + Compare results from multiple simulations (interactive Plotly). Args: results: List of dictionaries with results labels: List of labels for each result title: Plot title (optional) - save_path: Path to save the plot (optional) + save_path: Path to save the plot (optional, opens in browser if None) states_to_plot: Set of states to plot ('S', 'I', 'R') dpi: Resolution in dots per inch for saved figures (default: 300) """ @@ -395,94 +232,19 @@ def compare_results( if states_to_plot is None: states_to_plot = {"S", "I", "R"} - # Adjust figure size based on number of scenarios (need room for legend) - fig_width = 9 if len(results) <= 4 else 10 - fig, ax = _create_figure(figsize=(fig_width, 5.5)) - - # Get colorblind-friendly colors for scenarios - scenario_colors = _get_scenario_colors(len(results)) - - # Line styles for states (to distinguish S, I, R within same scenario) - state_styles = {"S": ":", "I": "-", "R": "--"} - state_widths = {"S": 1.8, "I": 2.2, "R": 1.8} - - for idx, (result, label) in enumerate(zip(results, labels)): - if not all(key in result for key in ["S_val", "I_val", "R_val", "time"]): - raise ValueError(f"Result {idx} does not contain all required data") - - s_vals = np.array(result["S_val"]) - i_vals = np.array(result["I_val"]) - r_vals = np.array(result["R_val"]) - time = np.array(result["time"]) - - color = scenario_colors[idx] - - if "S" in states_to_plot: - ax.plot( - time, - s_vals, - color=color, - linestyle=state_styles["S"], - linewidth=state_widths["S"], - alpha=0.85, - label=f"S — {label}", - ) - if "I" in states_to_plot: - ax.plot( - time, - i_vals, - color=color, - linestyle=state_styles["I"], - linewidth=state_widths["I"], - alpha=0.95, - label=f"I — {label}", - ) - if "R" in states_to_plot: - ax.plot( - time, - r_vals, - color=color, - linestyle=state_styles["R"], - linewidth=state_widths["R"], - alpha=0.85, - label=f"R — {label}", - ) - - ax.set_xlabel("Time", fontweight="medium") - ax.set_ylabel("Proportion of Population", fontweight="medium") - ax.set_ylim(0, 1.05) - - if title: - ax.set_title(title, pad=15) - else: - ax.set_title("Epidemic Dynamics Comparison", pad=15) - - # Position legend: outside for many scenarios, inside for few - if len(results) > 3: - ax.legend( - bbox_to_anchor=(1.02, 1), - loc="upper left", - fontsize=9, - framealpha=0.9, - title="State — Scenario", - title_fontsize=10, - ) - else: - ax.legend( - loc="best", - fontsize=9, - framealpha=0.9, - title="State — Scenario", - title_fontsize=10, - ) + # Convert to list for Plotly + states_list = list(states_to_plot) - plt.tight_layout() + # Create figure using web plotting module + fig = create_comparison_figure( + results, + labels, + title=title or "Epidemic Dynamics Comparison", + states=states_list, + height=600, + ) - if save_path: - fig.savefig(save_path, dpi=dpi, bbox_inches="tight", facecolor="white") - plt.close(fig) - else: - _show_plot() + _save_or_show(fig, save_path, dpi=dpi, height=600) @staticmethod def compare_results_with_config( @@ -498,7 +260,7 @@ def compare_results_with_config( results: List of dictionaries with results labels: List of labels for each result plot_config: Custom plot configuration - save_path: Path to save the plot (optional) + save_path: Path to save the plot (optional, opens in browser if None) """ if not results: raise ValueError("The results list is empty") @@ -508,111 +270,63 @@ def compare_results_with_config( # Use config values states_to_plot = ( - set(plot_config.states_to_plot) if plot_config.states_to_plot else {"S", "I", "R"} + plot_config.states_to_plot if plot_config.states_to_plot else ["S", "I", "R"] ) - figsize_tuple: Tuple[float, float] = (plot_config.figsize[0], plot_config.figsize[1]) - fig, ax = _create_figure(figsize=figsize_tuple) - - # Get colorblind-friendly colors for scenarios - scenario_colors = _get_scenario_colors(len(results)) - - # Line styles for states - state_styles = {"S": ":", "I": "-", "R": "--"} - state_widths = {"S": 1.8, "I": 2.2, "R": 1.8} - - for idx, (result, label) in enumerate(zip(results, labels)): - if not all(key in result for key in ["S_val", "I_val", "R_val", "time"]): - raise ValueError(f"Result {idx} does not contain all required data") - - s_vals = np.array(result["S_val"]) - i_vals = np.array(result["I_val"]) - r_vals = np.array(result["R_val"]) - time = np.array(result["time"]) - - color = scenario_colors[idx] - - if "S" in states_to_plot: - ax.plot( - time, - s_vals, - color=color, - linestyle=state_styles["S"], - linewidth=state_widths["S"], - alpha=0.85, - label=f"S — {label}", - ) - if "I" in states_to_plot: - ax.plot( - time, - i_vals, - color=color, - linestyle=state_styles["I"], - linewidth=state_widths["I"], - alpha=0.95, - label=f"I — {label}", - ) - if "R" in states_to_plot: - ax.plot( - time, - r_vals, - color=color, - linestyle=state_styles["R"], - linewidth=state_widths["R"], - alpha=0.85, - label=f"R — {label}", - ) - - ax.set_xlabel(plot_config.xlabel, fontweight="medium") - ax.set_ylabel(plot_config.ylabel, fontweight="medium") - ax.set_ylim(0, 1.05) + # Create figure using web plotting module + fig = create_comparison_figure( + results, + labels, + title=plot_config.title or "Epidemic Dynamics Comparison", + states=states_to_plot, + height=int(plot_config.figsize[1] * 100), # Convert to pixels + ) - if plot_config.title: - ax.set_title(plot_config.title, pad=15) - else: - ax.set_title("Epidemic Dynamics Comparison", pad=15) - - # Position legend based on number of scenarios - if len(results) > 4: - ax.legend( - bbox_to_anchor=(1.02, 1), - loc="upper left", - fontsize=9, - framealpha=0.9, - title="State — Scenario", - title_fontsize=10, - ) - else: - ax.legend( - loc=plot_config.legend_position, - fontsize=9, - framealpha=0.9, + # Apply labels, grid, and legend from config + grid_color = f"rgba(0,0,0,{plot_config.grid_alpha})" if plot_config.grid else None + fig.update_layout( + xaxis_title=plot_config.xlabel, + yaxis_title=plot_config.ylabel, + xaxis_showgrid=plot_config.grid, + yaxis_showgrid=plot_config.grid, + ) + if plot_config.grid and grid_color: + fig.update_layout( + xaxis_gridcolor=grid_color, + yaxis_gridcolor=grid_color, ) - - if plot_config.grid: - ax.grid(True, alpha=plot_config.grid_alpha, linestyle="-", linewidth=0.8) - - plt.tight_layout() - - if save_path: - fig.savefig(save_path, dpi=plot_config.dpi, bbox_inches="tight", facecolor="white") - plt.close(fig) - else: - _show_plot() + # Map matplotlib-style legend_position to Plotly + _LEGEND_MAP = { + "best": dict(x=1.02, y=1, orientation="v"), + "upper right": dict(x=1, y=1, xanchor="right"), + "upper left": dict(x=0, y=1, xanchor="left"), + "lower left": dict(x=0, y=0, xanchor="left", yanchor="bottom"), + "lower right": dict(x=1, y=0, xanchor="right", yanchor="bottom"), + "center": dict(x=0.5, y=0.5, xanchor="center", yanchor="middle"), + } + legend_kw = _LEGEND_MAP.get(plot_config.legend_position, {}) + if legend_kw: + fig.update_layout(legend=legend_kw) + + _save_or_show( + fig, + save_path, + dpi=plot_config.dpi, + width=int(plot_config.figsize[0] * 100), + height=int(plot_config.figsize[1] * 100), + ) @staticmethod - def plot_network( - G: nx.DiGraph, title: Optional[str] = None, save_path: Optional[str] = None - ) -> None: + def plot_network(G: Any, title: Optional[str] = None, save_path: Optional[str] = None) -> None: """ - Plot the network used in the simulation (publication-quality). + Plot the network used in the simulation (interactive Plotly). Args: - G: Network graph + G: Network graph (NetworkX) title: Plot title (optional) - save_path: Path to save the plot (optional) + save_path: Path to save the plot (optional, opens in browser if None) """ - fig, ax = _create_figure(figsize=(8, 7)) + import networkx as nx # Limit the number of nodes for visualization if G.number_of_nodes() > 100: @@ -625,55 +339,73 @@ def plot_network( ) G = nx.DiGraph(G.subgraph(list(G.nodes())[:100])) + # Get layout pos = nx.spring_layout(G, seed=42, k=1.5 / np.sqrt(G.number_of_nodes())) - # Draw edges first (behind nodes) - nx.draw_networkx_edges( - G, - pos, - ax=ax, - edge_color="#CCCCCC", - arrows=True, - arrowsize=8, - alpha=0.6, - width=0.8, - connectionstyle="arc3,rad=0.1", + # Create edge trace + edge_x = [] + edge_y = [] + for edge in G.edges(): + x0, y0 = pos[edge[0]] + x1, y1 = pos[edge[1]] + edge_x.extend([x0, x1, None]) + edge_y.extend([y0, y1, None]) + + edge_trace = go.Scatter( + x=edge_x, + y=edge_y, + line=dict(width=0.5, color="#888"), + hoverinfo="none", + mode="lines", ) - # Draw nodes - nx.draw_networkx_nodes( - G, - pos, - ax=ax, - node_size=80, - node_color=COLORBLIND_PALETTE[0], - edgecolors="white", - linewidths=1.0, - alpha=0.9, + # Create node trace + node_x = [] + node_y = [] + for node in G.nodes(): + x, y = pos[node] + node_x.append(x) + node_y.append(y) + + node_trace = go.Scatter( + x=node_x, + y=node_y, + mode="markers", + hoverinfo="text", + marker=dict( + showscale=False, + colorscale="YlGnBu", + size=10, + color=COLOR_S, + line_width=2, + ), ) - if title: - ax.set_title(title, pad=15) - else: - ax.set_title( - f"Network Structure ({G.number_of_nodes()} nodes, " f"{G.number_of_edges()} edges)", - pad=15, - ) - - ax.axis("off") - - plt.tight_layout() + # Create figure + fig = go.Figure( + data=[edge_trace, node_trace], + layout=go.Layout( + title=dict( + text=title + or f"Network Structure ({G.number_of_nodes()} nodes, {G.number_of_edges()} edges)", + x=0.5, + xanchor="center", + ), + showlegend=False, + hovermode="closest", + margin=dict(b=0, l=0, r=0, t=40), + xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), + yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), + height=700, + ), + ) - if save_path: - fig.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") - plt.close(fig) - else: - _show_plot() + _save_or_show(fig, save_path, width=800, height=700) @staticmethod def create_summary_plot(result_path: str, output_dir: Optional[str] = None) -> str: """ - Create a publication-quality summary plot from a results file. + Create an interactive summary plot from a results file. Args: result_path: Path to the results file @@ -700,10 +432,10 @@ def create_summary_plot(result_path: str, output_dir: Optional[str] = None) -> s # Create output directory if it doesn't exist if output_dir: os.makedirs(output_dir, exist_ok=True) - base_name = os.path.basename(result_path).replace(".json", ".png") + base_name = os.path.basename(result_path).replace(".json", ".html") save_path = os.path.join(output_dir, base_name) else: - save_path = result_path.replace(".json", ".png") + save_path = result_path.replace(".json", ".html") # Extract metadata for the title metadata = result.get("metadata", {}) diff --git a/spkmc/web/__init__.py b/spkmc/web/__init__.py new file mode 100644 index 0000000..556ae5b --- /dev/null +++ b/spkmc/web/__init__.py @@ -0,0 +1,31 @@ +""" +SPKMC Web Interface. + +This package provides a Streamlit-based web interface for managing and running +SPKMC epidemic simulations through a browser. +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, Dict + + +def atomic_json_write(path: Path, data: Dict[str, Any], indent: int = 2) -> None: + """Write a JSON file atomically via a temp-file + os.replace(). + + Prevents partial/corrupt files when the process is interrupted mid-write. + """ + tmp = path.with_suffix(".json.tmp") + try: + with open(tmp, "w", encoding="utf-8") as f: + json.dump(data, f, indent=indent) + os.replace(str(tmp), str(path)) + except BaseException: + tmp.unlink(missing_ok=True) + raise + + +__all__ = ["app", "config", "state", "plotting", "components", "runner"] diff --git a/spkmc/web/analysis_runner.py b/spkmc/web/analysis_runner.py new file mode 100644 index 0000000..11eeaab --- /dev/null +++ b/spkmc/web/analysis_runner.py @@ -0,0 +1,543 @@ +""" +Subprocess-based AI analysis runner for the web interface. + +Runs AI analyses in background subprocesses so they survive browser refresh +and UI interactions. Follows the same pattern as SimulationRunner. +""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, Optional, cast + +import psutil +import streamlit as st + + +class AnalysisRunner: + """Manages subprocess-based AI analysis execution.""" + + def __init__(self) -> None: + """Initialize the analysis runner.""" + self.status_dir = Path(".spkmc_web") / "status" + self.status_dir.mkdir(parents=True, exist_ok=True) + # Retain Popen handles so we can reap children and avoid zombies + self._processes: Dict[str, subprocess.Popen] = {} # type: ignore[type-arg] + + def run_experiment_analysis( + self, + experiment_path: Path, + experiment_name: str, + experiment_description: str, + model: str, + api_key: str, + ) -> Optional[str]: + """ + Launch a subprocess to run AI analysis on an entire experiment. + + Args: + experiment_path: Path to the experiment directory + experiment_name: Display name of the experiment + experiment_description: Research question / description + model: OpenAI model to use + api_key: OpenAI API key + + Returns: + Run ID if launched successfully, None otherwise + """ + run_id = f"exp_analysis--{experiment_path.name}--{time.time_ns()}" + + status_file = self.status_dir / f"{run_id}.json" + status_data = { + "run_id": run_id, + "type": "analysis", + "analysis_type": "experiment", + "experiment_name": experiment_path.name, + "scenario_normalized": "", + "status": "starting", + "start_time": time.time(), + } + + with open(status_file, "w") as f: + json.dump(status_data, f) + + script_content = self._build_experiment_script( + experiment_path, experiment_name, experiment_description, model, run_id + ) + + script_file = self.status_dir / f"{run_id}_script.py" + with open(script_file, "w") as f: + f.write(script_content) + + # Pass API key via environment so it never touches disk + child_env = {**os.environ, "OPENAI_API_KEY": api_key} + + try: + process = subprocess.Popen( + [sys.executable, str(script_file)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=child_env, + ) + + status_data["status"] = "running" + status_data["pid"] = process.pid + self._processes[run_id] = process + + with open(status_file, "w") as f: + json.dump(status_data, f) + + return run_id + + except Exception as e: + status_data["status"] = "failed" + status_data["error"] = str(e) + + with open(status_file, "w") as f: + json.dump(status_data, f) + + st.error(f"Failed to start analysis: {str(e)}") + return None + + def run_scenario_analysis( + self, + experiment_path: Path, + scenario_label: str, + scenario_normalized: str, + model: str, + api_key: str, + ) -> Optional[str]: + """ + Launch a subprocess to run AI analysis on a single scenario. + + Args: + experiment_path: Path to the experiment directory + scenario_label: Display label of the scenario + scenario_normalized: Normalized label for file naming + model: OpenAI model to use + api_key: OpenAI API key + + Returns: + Run ID if launched successfully, None otherwise + """ + run_id = f"sc_analysis--{experiment_path.name}--{scenario_normalized}--{time.time_ns()}" + + status_file = self.status_dir / f"{run_id}.json" + status_data = { + "run_id": run_id, + "type": "analysis", + "analysis_type": "scenario", + "experiment_name": experiment_path.name, + "scenario_normalized": scenario_normalized, + "status": "starting", + "start_time": time.time(), + } + + with open(status_file, "w") as f: + json.dump(status_data, f) + + script_content = self._build_scenario_script( + experiment_path, scenario_label, scenario_normalized, model, run_id + ) + + script_file = self.status_dir / f"{run_id}_script.py" + with open(script_file, "w") as f: + f.write(script_content) + + # Pass API key via environment so it never touches disk + child_env = {**os.environ, "OPENAI_API_KEY": api_key} + + try: + process = subprocess.Popen( + [sys.executable, str(script_file)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=child_env, + ) + + status_data["status"] = "running" + status_data["pid"] = process.pid + self._processes[run_id] = process + + with open(status_file, "w") as f: + json.dump(status_data, f) + + return run_id + + except Exception as e: + status_data["status"] = "failed" + status_data["error"] = str(e) + + with open(status_file, "w") as f: + json.dump(status_data, f) + + st.error(f"Failed to start analysis: {str(e)}") + return None + + def get_status(self, run_id: str) -> Optional[Dict[str, Any]]: + """Get the status of a running or completed analysis.""" + status_file = self.status_dir / f"{run_id}.json" + + if not status_file.exists(): + return None + + try: + with open(status_file, "r") as f: + return cast(Dict[str, Any], json.load(f)) + except (json.JSONDecodeError, IOError): + return None + + def cleanup_status(self, run_id: str) -> None: + """Clean up status files and reap child process for a completed run.""" + # Reap child process to prevent zombies + proc = self._processes.pop(run_id, None) + if proc is not None: + proc.poll() # Non-blocking reap + + status_file = self.status_dir / f"{run_id}.json" + script_file = self.status_dir / f"{run_id}_script.py" + + if status_file.exists(): + status_file.unlink() + if script_file.exists(): + script_file.unlink() + + def check_completion( + self, + experiment_name: str, + analysis_type: str, + scenario_normalized: str = "", + ) -> bool: + """ + Check if an analysis has completed by looking for the .md file. + + Args: + experiment_name: Name of the experiment + analysis_type: "experiment" or "scenario" + scenario_normalized: Normalized scenario label (for scenario type) + + Returns: + True if analysis file exists + """ + from spkmc.web.config import WebConfig + + config = WebConfig() + exp_path = config.get_experiments_path() / experiment_name + if analysis_type == "experiment": + return (exp_path / "analysis.md").exists() + return (exp_path / f"{scenario_normalized}_analysis.md").exists() + + def _build_experiment_script( + self, + experiment_path: Path, + experiment_name: str, + experiment_description: str, + model: str, + run_id: str, + ) -> str: + """Build a Python script to run experiment-level analysis.""" + # Use repr() for safe embedding — handles quotes, newlines, backslashes + exp_path_repr = repr(str(experiment_path)) + exp_name_repr = repr(experiment_name) + exp_desc_repr = repr(experiment_description) + model_repr = repr(model) + # Pass exact status file path so subprocess doesn't need prefix-glob discovery + status_file_repr = repr(str(self.status_dir / f"{run_id}.json")) + + return f""" +import sys +import json +import os +import re +import time +from pathlib import Path + +# Add package to path if needed +sys.path.insert(0, str(Path.cwd())) + +# OPENAI_API_KEY is passed via subprocess environment (never written to disk) + +# Exact status file path (set by runner before launching subprocess) +STATUS_FILE = {status_file_repr} + +def _write_status(status, error=None): + if STATUS_FILE is None: + return + try: + with open(STATUS_FILE, "r") as fh: + data = json.load(fh) + data["status"] = status + if error: + data["error"] = error + tmp = STATUS_FILE + ".tmp" + with open(tmp, "w") as fh: + json.dump(data, fh) + os.replace(tmp, STATUS_FILE) + except Exception: + pass + +def _normalize_label(label): + normalized = label.lower().strip() + normalized = re.sub(r"[\\s\\-]+", "_", normalized) + normalized = re.sub(r"[^\\w]", "", normalized) + return normalized + +_write_status("running") + +experiment_path = Path({exp_path_repr}) + +# Load all completed scenario results +results = [] +data_file = experiment_path / "data.json" +try: + with open(data_file, "r") as fh: + data = json.load(fh) + for sc in data.get("scenarios", []): + label = sc.get("label", "") + normalized = _normalize_label(label) + result_file = experiment_path / f"{{normalized}}.json" + if result_file.exists(): + with open(result_file, "r") as rfh: + results.append(json.load(rfh)) +except Exception as e: + _write_status("failed", error=f"Failed to load results: {{e}}") + sys.exit(1) + +if not results: + _write_status("failed", error="No completed scenarios to analyze") + sys.exit(1) + +# Initialize backup paths before the try block so the except handler +# can always reference them without risking UnboundLocalError. +old_analysis = experiment_path / "analysis.md" +backup_analysis = experiment_path / "analysis.md.bak" + +try: + from spkmc.analysis.ai_analyzer import AIAnalyzer + + # Preserve existing analysis as backup so a failed re-analysis doesn't + # permanently destroy the previous report. + if old_analysis.exists(): + old_analysis.rename(backup_analysis) + + analyzer = AIAnalyzer(model={model_repr}) + analysis_path = analyzer.analyze_experiment( + experiment_name={exp_name_repr}, + experiment_description={exp_desc_repr}, + results=results, + results_dir=experiment_path, + ) + + if analysis_path: + # Success — discard backup + if backup_analysis.exists(): + backup_analysis.unlink() + _write_status("completed") + print("Analysis completed successfully") + else: + # Restore backup when analysis returns None + if backup_analysis.exists(): + backup_analysis.rename(old_analysis) + _write_status("failed", error="Analysis returned None (may already exist)") + sys.exit(0) +except Exception as e: + # Restore backup on any failure + if backup_analysis.exists(): + backup_analysis.rename(old_analysis) + _write_status("failed", error=str(e)) + print(f"Analysis failed: {{e}}", file=sys.stderr) + sys.exit(1) +""" + + def _build_scenario_script( + self, + experiment_path: Path, + scenario_label: str, + scenario_normalized: str, + model: str, + run_id: str, + ) -> str: + """Build a Python script to run scenario-level analysis.""" + # Use repr() for safe embedding — handles quotes, newlines, backslashes + exp_path_repr = repr(str(experiment_path)) + label_repr = repr(scenario_label) + norm_repr = repr(scenario_normalized) + model_repr = repr(model) + # Pass exact status file path so subprocess doesn't need prefix-glob discovery + status_file_repr = repr(str(self.status_dir / f"{run_id}.json")) + + return f""" +import sys +import json +import os +import time +from pathlib import Path + +# Add package to path if needed +sys.path.insert(0, str(Path.cwd())) + +# OPENAI_API_KEY is passed via subprocess environment (never written to disk) + +# Exact status file path (set by runner before launching subprocess) +STATUS_FILE = {status_file_repr} + +def _write_status(status, error=None): + if STATUS_FILE is None: + return + try: + with open(STATUS_FILE, "r") as fh: + data = json.load(fh) + data["status"] = status + if error: + data["error"] = error + tmp = STATUS_FILE + ".tmp" + with open(tmp, "w") as fh: + json.dump(data, fh) + os.replace(tmp, STATUS_FILE) + except Exception: + pass + +_write_status("running") + +experiment_path = Path({exp_path_repr}) +result_file = experiment_path / ({norm_repr} + ".json") + +try: + with open(result_file, "r") as fh: + result_dict = json.load(fh) +except Exception as e: + _write_status("failed", error=f"Failed to load result: {{e}}") + sys.exit(1) + +# Initialize backup paths before the try block so the except handler +# can always reference them without risking UnboundLocalError. +old_analysis = experiment_path / ({norm_repr} + "_analysis.md") +backup_analysis = experiment_path / ({norm_repr} + "_analysis.md.bak") + +try: + from spkmc.analysis.ai_analyzer import AIAnalyzer + + # Preserve existing analysis as backup so a failed re-analysis doesn't + # permanently destroy the previous report. + if old_analysis.exists(): + old_analysis.rename(backup_analysis) + + analyzer = AIAnalyzer(model={model_repr}) + analysis_path = analyzer.analyze_scenario( + scenario_label={label_repr}, + result=result_dict, + results_dir=experiment_path, + ) + + if analysis_path: + # Success — discard backup + if backup_analysis.exists(): + backup_analysis.unlink() + _write_status("completed") + print("Analysis completed successfully") + else: + # Restore backup when analysis returns None + if backup_analysis.exists(): + backup_analysis.rename(old_analysis) + _write_status("failed", error="Analysis returned None (may already exist)") + sys.exit(0) +except Exception as e: + # Restore backup on any failure + if backup_analysis.exists(): + backup_analysis.rename(old_analysis) + _write_status("failed", error=str(e)) + print(f"Analysis failed: {{e}}", file=sys.stderr) + sys.exit(1) +""" + + +def poll_running_analyses() -> bool: + """ + Poll all running analyses and update session state. + + Reads status from files and marks completed/failed analyses. + Called by the scenario cards fragment every ~2 seconds. + + Returns: + True if any analysis transitioned to completed or failed (caller + should trigger a full page rerun so sections outside the fragment + re-render with updated state). + """ + if "analysis_runner" not in st.session_state: + st.session_state.analysis_runner = AnalysisRunner() + + runner: AnalysisRunner = st.session_state.analysis_runner + + from spkmc.web.state import SessionState + + running = st.session_state.get("running_analyses", {}) + changed = False + + for analysis_id, info in list(running.items()): + exp_name = info.get("experiment_name") + analysis_type = info.get("analysis_type", "experiment") + sc_normalized = info.get("scenario_normalized", "") + run_id = info.get("run_id", analysis_id) + + if not exp_name: + continue + + # Read status file + status = runner.get_status(run_id) + if status: + file_status = status.get("status", "running") + + # Check if status file reports completion. + # Do NOT fall back to check_completion() here — while status is + # "running", a stale .md from the previous run may still exist on + # disk (the subprocess deletes it shortly after starting). + if file_status == "completed": + SessionState.mark_analysis_completed(analysis_id) + label = "experiment" if analysis_type == "experiment" else sc_normalized + st.toast(f"Analysis complete: {label}") + runner.cleanup_status(run_id) + changed = True + continue + + # Check if status file reports failure + if file_status == "failed": + error_msg = status.get("error", "Unknown error") + SessionState.mark_analysis_failed(analysis_id, error_msg) + st.toast(f"Analysis failed: {error_msg}") + runner.cleanup_status(run_id) + changed = True + continue + + # Check if subprocess died without writing terminal status + if file_status == "running": + pid = status.get("pid") + if pid is not None and not psutil.pid_exists(pid): + # Process no longer exists -- check if output was written + if runner.check_completion(exp_name, analysis_type, sc_normalized): + SessionState.mark_analysis_completed(analysis_id) + label = "experiment" if analysis_type == "experiment" else sc_normalized + st.toast(f"Analysis complete: {label}") + else: + SessionState.mark_analysis_failed( + analysis_id, + "Analysis process exited unexpectedly", + ) + st.toast("Analysis failed: process exited unexpectedly") + runner.cleanup_status(run_id) + changed = True + continue + + # Fallback: check result file directly + elif runner.check_completion(exp_name, analysis_type, sc_normalized): + SessionState.mark_analysis_completed(analysis_id) + runner.cleanup_status(run_id) + changed = True + + return changed diff --git a/spkmc/web/app.py b/spkmc/web/app.py new file mode 100644 index 0000000..4bb7735 --- /dev/null +++ b/spkmc/web/app.py @@ -0,0 +1,126 @@ +""" +SPKMC Web Interface - Main Application. + +This is the entry point for the Streamlit web interface. It handles page routing, +sidebar navigation, and applies custom CSS styling. +""" + +from __future__ import annotations + +import streamlit as st + +from spkmc import __version__ +from spkmc.web.config import WebConfig +from spkmc.web.state import SessionState +from spkmc.web.styles import get_global_styles + +# Page configuration must be first Streamlit command +st.set_page_config( + page_title="SPKMC - Epidemic Simulation Manager", + page_icon="S", + layout="wide", + initial_sidebar_state="expanded", +) + + +def render_sidebar() -> None: + """Render the sidebar navigation with brand, nav items, and footer.""" + with st.sidebar: + current_page = SessionState.get_current_page() + + # ── Brand ──────────────────────────────── + st.markdown( + '
' + "
SPKMC
' + "
' + "Epidemic Simulation Manager
" + "
", + unsafe_allow_html=True, + ) + + # ── Navigation ────────────────────────── + if st.button( + "Experiments", + key="nav_experiments", + width="stretch", + type="primary" if current_page == "dashboard" else "secondary", + ): + SessionState.set_selected_experiment(None) + SessionState.set_current_page("dashboard") + st.rerun() + + if st.button( + "Preferences", + key="nav_settings", + width="stretch", + type="primary" if current_page == "settings" else "secondary", + ): + SessionState.set_selected_experiment(None) + SessionState.set_current_page("settings") + st.rerun() + + # ── Version footer (fixed to sidebar bottom) ── + st.markdown( + '", + unsafe_allow_html=True, + ) + + +def main() -> None: + """Main application entry point.""" + # Apply global styles + st.markdown(get_global_styles(), unsafe_allow_html=True) + + # Initialize session state + SessionState.init() + + # Load configuration + if "config" not in st.session_state: + st.session_state.config = WebConfig() + + # Restore running simulations and analyses from disk (survives refresh) + if not st.session_state.get("_sims_restored"): + SessionState.restore_running_simulations() + SessionState.restore_running_analyses() + st.session_state._sims_restored = True + + # Render sidebar + render_sidebar() + + # Page routing + current_page = SessionState.get_current_page() + + if current_page == "dashboard": + from spkmc.web.pages import dashboard + + if SessionState.get_selected_experiment(): + from spkmc.web.pages import experiment_detail + + experiment_detail.render() + else: + dashboard.render() + + elif current_page == "settings": + from spkmc.web.pages import settings + + settings.render() + + else: + from spkmc.web.pages import dashboard + + dashboard.render() + + +if __name__ == "__main__": + main() diff --git a/spkmc/web/components.py b/spkmc/web/components.py new file mode 100644 index 0000000..d82a8e1 --- /dev/null +++ b/spkmc/web/components.py @@ -0,0 +1,347 @@ +""" +Reusable UI components for the web interface. + +Provides form builders, cards, and other reusable widgets used across pages. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import streamlit as st + +from spkmc.web.config import WebConfig + + +def status_badge(status: str) -> str: + """ + Generate HTML for a status badge. + + Args: + status: One of 'pending', 'running', 'completed', 'failed' + + Returns: + HTML string for the status badge + """ + status_map = { + "pending": ("Pending", "status-badge status-pending"), + "created": ("Created", "status-badge status-created"), + "running": ("Running", "status-badge status-running"), + "completed": ("Completed", "status-badge status-completed"), + "failed": ("Failed", "status-badge status-failed"), + } + + text, css_class = status_map.get(status, ("Unknown", "status-badge")) + return f'{text}' + + +def network_config_form(key_prefix: str = "network") -> Dict[str, Any]: + """ + Render dynamic form fields for network configuration. + + Args: + key_prefix: Unique prefix for form element keys + + Returns: + Dictionary of network configuration values + """ + config = st.session_state.config + + col1, col2 = st.columns(2) + + with col1: + network_type = st.selectbox( + "Network Type", + options=["er", "sf", "cg", "rrn"], + format_func=lambda x: { + "er": "Erdős-Rényi (Random)", + "sf": "Scale-Free", + "cg": "Complete Graph", + "rrn": "Random Regular", + }[x], + index=0, + key=f"{key_prefix}_type", + help="Network topology structure", + ) + + with col2: + nodes = st.number_input( + "Number of Nodes", + min_value=10, + max_value=100000, + value=config.get("default_nodes", 1000), + step=100, + key=f"{key_prefix}_nodes", + help="Size of the network (population)", + ) + + result = {"network": network_type, "nodes": nodes} + + # Network-specific parameters + if network_type in ["er", "sf", "rrn"]: + col1, col2 = st.columns(2) + with col1: + k_avg = st.number_input( + "Average Degree (k_avg)", + min_value=1.0, + max_value=float(nodes), + value=float(config.get("default_k_avg", 10.0)), + step=1.0, + key=f"{key_prefix}_k_avg", + help="Average number of connections per node", + ) + result["k_avg"] = k_avg + + if network_type == "sf": + with col2: + exponent = st.number_input( + "Power-law Exponent", + min_value=2.0, + max_value=5.0, + value=float(config.get("default_exponent", 2.5)), + step=0.1, + key=f"{key_prefix}_exponent", + help="Controls hub concentration (lower = more hubs)", + ) + result["exponent"] = exponent + + return result + + +def distribution_config_form(key_prefix: str = "distribution") -> Dict[str, Any]: + """ + Render dynamic form fields for distribution configuration. + + Args: + key_prefix: Unique prefix for form element keys + + Returns: + Dictionary of distribution configuration values + """ + config = st.session_state.config + + col1, col2 = st.columns(2) + + with col1: + dist_type = st.selectbox( + "Distribution Type", + options=["gamma", "exponential"], + format_func=lambda x: x.capitalize(), + index=0, + key=f"{key_prefix}_type", + help="Recovery time distribution", + ) + + with col2: + lambda_param = st.number_input( + "Infection Rate (λ)", + min_value=0.01, + max_value=10.0, + value=float(config.get("default_lambda", 1.0)), + step=0.1, + key=f"{key_prefix}_lambda", + help="Transmission rate along edges", + ) + + result = {"distribution": dist_type, "lambda": lambda_param} + + # Distribution-specific parameters + col1, col2 = st.columns(2) + + if dist_type == "gamma": + with col1: + shape = st.number_input( + "Shape Parameter", + min_value=0.1, + max_value=10.0, + value=float(config.get("default_shape", 2.0)), + step=0.1, + key=f"{key_prefix}_shape", + help="Controls recovery time distribution shape", + ) + with col2: + scale = st.number_input( + "Scale Parameter", + min_value=0.1, + max_value=10.0, + value=float(config.get("default_scale", 1.0)), + step=0.1, + key=f"{key_prefix}_scale", + help="Controls recovery time scale", + ) + result["shape"] = shape + result["scale"] = scale + + elif dist_type == "exponential": + with col1: + mu = st.number_input( + "Recovery Rate (μ)", + min_value=0.01, + max_value=10.0, + value=float(config.get("default_mu", 1.0)), + step=0.1, + key=f"{key_prefix}_mu", + help="Exponential recovery rate", + ) + result["mu"] = mu + + return result + + +def simulation_params_form(key_prefix: str = "simulation") -> Dict[str, Any]: + """ + Render form fields for simulation parameters. + + Args: + key_prefix: Unique prefix for form element keys + + Returns: + Dictionary of simulation configuration values + """ + config = st.session_state.config + + col1, col2, col3 = st.columns(3) + + with col1: + samples = st.number_input( + "Samples", + min_value=1, + max_value=10000, + value=config.get("default_samples", 50), + step=10, + key=f"{key_prefix}_samples", + help="Monte Carlo samples per run", + ) + + with col2: + num_runs = st.number_input( + "Number of Runs", + min_value=1, + max_value=100, + value=config.get("default_num_runs", 2), + step=1, + key=f"{key_prefix}_num_runs", + help="Independent runs for error estimation", + ) + + with col3: + initial_perc = ( + st.number_input( + "Initial Infected (%)", + min_value=0.01, + max_value=100.0, + value=float(config.get("default_initial_perc", 0.01)) * 100, + step=0.1, + key=f"{key_prefix}_initial_perc", + help="Percentage of initially infected nodes", + ) + / 100.0 + ) + + col1, col2 = st.columns(2) + + with col1: + t_max = st.number_input( + "Max Time", + min_value=0.1, + max_value=1000.0, + value=float(config.get("default_t_max", 10.0)), + step=1.0, + key=f"{key_prefix}_t_max", + help="Simulation duration", + ) + + with col2: + steps = st.number_input( + "Time Steps", + min_value=10, + max_value=10000, + value=config.get("default_steps", 100), + step=10, + key=f"{key_prefix}_steps", + help="Number of time points to record", + ) + + return { + "samples": samples, + "num_runs": num_runs, + "initial_perc": initial_perc, + "t_max": t_max, + "steps": steps, + } + + +def result_metric_cards(result_dict: Dict[str, Any]) -> None: + """ + Display key metrics from a simulation result. + + Args: + result_dict: Result dictionary with S_val, I_val, R_val arrays + """ + import numpy as np + + if "I_val" not in result_dict or "time" not in result_dict: + st.warning("No result data available") + return + + I_val = np.array(result_dict["I_val"]) + R_val = np.array(result_dict["R_val"]) + time = np.array(result_dict["time"]) + + peak_infected = float(np.max(I_val)) + peak_time = float(time[np.argmax(I_val)]) + final_recovered = float(R_val[-1]) + + col1, col2, col3 = st.columns(3) + + with col1: + st.metric( + label="Peak Infected", + value=f"{peak_infected:.1%}", + help="Maximum proportion of infected individuals", + ) + + with col2: + st.metric( + label="Peak Time", + value=f"{peak_time:.2f}", + help="Time at which peak infection occurred", + ) + + with col3: + st.metric( + label="Final Epidemic Size", + value=f"{final_recovered:.1%}", + help="Proportion of population that was infected", + ) + + +def experiment_status_badge(experiment: Any) -> str: + """ + Generate status badge for an experiment based on its scenarios. + + Args: + experiment: Experiment object with scenarios + + Returns: + HTML string for the experiment status badge + """ + # Guard against None path and empty scenario list + if experiment.path is None or not experiment.scenarios: + return status_badge("pending") + + # Check scenario statuses + has_results = any( + (experiment.path / f"{s.normalized_label}.json").exists() for s in experiment.scenarios + ) + + all_complete = all( + (experiment.path / f"{s.normalized_label}.json").exists() for s in experiment.scenarios + ) + + if all_complete: + return status_badge("completed") + elif has_results: + return status_badge("running") # Partial completion + else: + return status_badge("pending") diff --git a/spkmc/web/config.py b/spkmc/web/config.py new file mode 100644 index 0000000..a4ab783 --- /dev/null +++ b/spkmc/web/config.py @@ -0,0 +1,182 @@ +""" +Web interface configuration management. + +Handles loading and saving web preferences (stored as JSON) and reading secrets +from Streamlit's secrets.toml file. +""" + +from __future__ import annotations + +import json +import os +import re +from pathlib import Path +from typing import Any, Dict, Optional, cast + +import streamlit as st + +# In-memory override for the API key, used to bypass st.secrets caching. +# st.secrets is process-cached and has no public invalidation API; writing +# to secrets.toml does not update the in-memory singleton. After +# set_openai_api_key(), all subsequent reads in this process see the new +# value via this override. On process restart the override resets to None +# and the freshly-loaded st.secrets provides the correct value from disk. +_api_key_override: Optional[str] = None + + +class WebConfig: + """Manages web interface configuration and secrets.""" + + CONFIG_FILE: Path = Path( + os.environ.get( + "SPKMC_WEB_CONFIG_FILE", + str(Path.home() / ".spkmc" / "web_config.json"), + ) + ) + + # Default configuration values + DEFAULTS = { + "data_directory": "data", + "experiments_directory": "experiments", + "theme": "light", + "chart_height": 500, + "chart_color_s": "#4477AA", + "chart_color_i": "#EE6677", + "chart_color_r": "#228833", + "chart_template": "plotly_white", + "default_network_type": "er", + "default_distribution": "gamma", + "default_nodes": 1000, + "default_k_avg": 10.0, + "default_samples": 50, + "default_num_runs": 2, + "default_initial_perc": 0.01, + "default_t_max": 10.0, + "default_steps": 100, + "default_shape": 2.0, + "default_scale": 1.0, + "default_mu": 1.0, + "default_lambda": 1.0, + "default_exponent": 2.5, + "ai_model": "gpt-4o-mini", + } + + def __init__(self) -> None: + """Initialize configuration manager.""" + self.config: Dict[str, Any] = {} + self.load() + + def load(self) -> None: + """Load configuration from JSON file, creating with defaults if not found.""" + if self.CONFIG_FILE.exists(): + try: + with open(self.CONFIG_FILE, "r") as f: + loaded = json.load(f) + # Merge with defaults to ensure all keys exist + merged = {**self.DEFAULTS, **loaded} + # Coerce types to match defaults (JSON may deserialize + # 10.0 as int 10, which causes StreamlitMixedNumericTypesError) + for key, default_val in self.DEFAULTS.items(): + if key in merged: + if isinstance(default_val, float) and isinstance(merged[key], int): + merged[key] = float(merged[key]) + elif isinstance(default_val, int) and isinstance(merged[key], float): + merged[key] = int(merged[key]) + self.config = merged + except (json.JSONDecodeError, IOError): + # If file is corrupted, start with defaults + self.config = self.DEFAULTS.copy() + else: + # Create config directory if it doesn't exist + self.CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) + self.config = self.DEFAULTS.copy() + self.save() + + def save(self) -> None: + """Save current configuration to JSON file.""" + from spkmc.web import atomic_json_write + + self.CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) + atomic_json_write(self.CONFIG_FILE, self.config) + + def get(self, key: str, default: Any = None) -> Any: + """Get a configuration value.""" + return self.config.get(key, default) + + def set(self, key: str, value: Any) -> None: + """Set a configuration value and save.""" + self.config[key] = value + self.save() + + def update(self, updates: Dict[str, Any]) -> None: + """Update multiple configuration values at once.""" + self.config.update(updates) + self.save() + + @staticmethod + def get_openai_api_key() -> Optional[str]: + """ + Get OpenAI API key. + + Returns the in-memory override (set by ``set_openai_api_key``) if + present, otherwise falls back to ``st.secrets``. + + Returns: + API key if found, None otherwise + """ + if _api_key_override is not None: + return _api_key_override + try: + return cast(Optional[str], st.secrets.get("OPENAI_API_KEY", None)) + except (FileNotFoundError, KeyError): + return None + + @staticmethod + def set_openai_api_key(api_key: str) -> None: + """ + Set OpenAI API key in Streamlit secrets and update in-memory cache. + + Writes to ``.streamlit/secrets.toml`` for persistence across restarts, + and updates ``_api_key_override`` so subsequent reads in this process + see the new value immediately (st.secrets is process-cached with no + public invalidation API). + + Args: + api_key: The OpenAI API key to save + """ + global _api_key_override + _api_key_override = api_key + secrets_file = Path(".streamlit") / "secrets.toml" + secrets_file.parent.mkdir(exist_ok=True) + + # Escape the value for TOML (double-quote string) + escaped_value = api_key.replace("\\", "\\\\").replace('"', '\\"') + new_line = f'OPENAI_API_KEY = "{escaped_value}"' + + # Pattern that matches an existing OPENAI_API_KEY assignment + key_pattern = re.compile(r"^OPENAI_API_KEY\s*=\s*.*$", re.MULTILINE) + + if secrets_file.exists(): + content = secrets_file.read_text() + if key_pattern.search(content): + # Replace existing key in-place, preserving all other content + content = key_pattern.sub(new_line, content) + else: + # Append to end, ensuring a leading newline + if content and not content.endswith("\n"): + content += "\n" + content += new_line + "\n" + secrets_file.write_text(content) + else: + # Create new file with just this key + secrets_file.write_text( + "# Streamlit secrets for SPKMC web interface\n" + new_line + "\n" + ) + + def get_data_path(self) -> Path: + """Get the data directory path.""" + return Path(self.get("data_directory", "data")) + + def get_experiments_path(self) -> Path: + """Get the experiments directory path.""" + return Path(self.get("experiments_directory", "experiments")) diff --git a/spkmc/web/pages/__init__.py b/spkmc/web/pages/__init__.py new file mode 100644 index 0000000..9522f7b --- /dev/null +++ b/spkmc/web/pages/__init__.py @@ -0,0 +1,7 @@ +""" +SPKMC Web Interface Pages. + +This package contains the individual page modules for the Streamlit web interface. +""" + +__all__ = ["dashboard", "experiment_detail", "settings"] diff --git a/spkmc/web/pages/dashboard.py b/spkmc/web/pages/dashboard.py new file mode 100644 index 0000000..c54c05b --- /dev/null +++ b/spkmc/web/pages/dashboard.py @@ -0,0 +1,702 @@ +""" +Dashboard page - main experiments list view. + +Shows all experiments, summary stats, and provides "Create Experiment" functionality. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List + +import streamlit as st + +from spkmc.io.experiments import ExperimentManager +from spkmc.models import Experiment, ExperimentConfig, ScenarioOverride +from spkmc.web.components import ( + distribution_config_form, + experiment_status_badge, + network_config_form, + simulation_params_form, +) +from spkmc.web.runner import poll_running_simulations +from spkmc.web.state import SessionState +from spkmc.web.styles import ( + ICONS, + empty_state, + experiment_card, + page_header, + section_header, + stat_card, +) + + +def render() -> None: + """Render the dashboard page.""" + # Page header + st.markdown( + page_header("Experiments", subtitle="Manage and run SPKMC epidemic simulation experiments"), + unsafe_allow_html=True, + ) + + # Load experiments + config = st.session_state.config + exp_manager = ExperimentManager(str(config.get_experiments_path())) + experiments = exp_manager.list_experiments() + + # Summary stats row with beautiful cards + render_summary_stats(experiments) + + # Spacer between stats and experiments + st.markdown('
', unsafe_allow_html=True) + + # Experiments list or empty state + if experiments: + render_experiments_list(experiments) + else: + render_empty_state_ui() + + +def render_summary_stats(experiments: List[Experiment]) -> None: + """Render beautiful summary statistics cards.""" + total_experiments = len(experiments) + total_scenarios = sum(len(exp.scenarios) for exp in experiments) + + # Count completed scenarios + completed_scenarios = 0 + for exp in experiments: + if exp.path is None: + continue + for scenario in exp.scenarios: + result_file = exp.path / f"{scenario.normalized_label}.json" + if result_file.exists(): + completed_scenarios += 1 + + # Recent activity (last modified experiment) + last_activity = "Never" + if experiments: + most_recent = max( + (exp.path for exp in experiments if exp.path is not None and exp.path.exists()), + key=lambda p: p.stat().st_mtime, + default=None, + ) + if most_recent: + last_modified = datetime.fromtimestamp(most_recent.stat().st_mtime) + last_activity = last_modified.strftime("%Y-%m-%d %H:%M") + + # Use columns for grid layout + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.markdown( + stat_card("Total Experiments", str(total_experiments), ICONS["flask"]), + unsafe_allow_html=True, + ) + + with col2: + st.markdown( + stat_card("Total Scenarios", str(total_scenarios), ICONS["file"]), + unsafe_allow_html=True, + ) + + with col3: + st.markdown( + stat_card("Completed Scenarios", str(completed_scenarios), ICONS["check"]), + unsafe_allow_html=True, + ) + + with col4: + st.markdown( + stat_card("Last Activity", last_activity, ICONS["clock"]), unsafe_allow_html=True + ) + + +def render_experiments_list(experiments: List[Experiment]) -> None: + """Render header + create button, then delegate cards to polling fragment.""" + col_header, col_create = st.columns([8, 2], vertical_alignment="bottom") + with col_header: + st.markdown( + section_header("All Experiments"), + unsafe_allow_html=True, + ) + with col_create: + if st.button( + "Create Experiment", + type="primary", + width="stretch", + key="btn_create_exp", + ): + show_create_experiment_modal() + + _live_experiment_cards(experiments) + + +@st.fragment(run_every=timedelta(seconds=2)) +def _live_experiment_cards(experiments: List[Experiment]) -> None: + """Fragment that polls running simulations and re-renders experiment cards.""" + poll_running_simulations() + + for idx, exp in enumerate(experiments): + if exp.path is None: + continue + + exp_path = exp.path + scenario_count = len(exp.scenarios) + + # Count completed scenarios + completed = sum( + 1 for s in exp.scenarios if (exp_path / f"{s.normalized_label}.json").exists() + ) + + # Get last modified time + if exp_path.exists(): + last_mod = datetime.fromtimestamp(exp_path.stat().st_mtime) + last_modified = last_mod.strftime("%Y-%m-%d %H:%M") + else: + last_modified = "Unknown" + + # Determine status by checking actual running simulations + scenario_statuses = [ + SessionState.get_simulation_status(f"sim--{exp_path.name}--{s.normalized_label}") + for s in exp.scenarios + ] + any_running = "running" in scenario_statuses + any_failed = "failed" in scenario_statuses + + if any_running: + status = "running" + elif scenario_count > 0 and completed == scenario_count: + status = "complete" + elif any_failed: + status = "failed" + else: + status = "pending" + + # Clickable card container: invisible button overlays the card HTML + with st.container(key=f"exp_card_{idx}"): + st.markdown( + experiment_card( + name=exp.name, + description=exp.description or "No description provided", + scenarios_complete=completed, + scenarios_total=scenario_count, + last_run=last_modified, + status=status, + ), + unsafe_allow_html=True, + ) + if st.button("select", key=f"exp_btn_{idx}"): + SessionState.set_selected_experiment(exp_path.name) + st.rerun() + + +def render_empty_state_ui() -> None: + """Render beautiful empty state when no experiments exist.""" + st.markdown( + empty_state( + title="No experiments yet", + message="Create your first experiment to start running epidemic simulations on networks. " + "Each experiment can contain multiple scenarios with different parameters.", + ), + unsafe_allow_html=True, + ) + + # Add some spacing and the CTA button + st.markdown('
', unsafe_allow_html=True) + col1, col2, col3 = st.columns([1, 1, 1]) + with col2: + if st.button("Create Your First Experiment", type="primary", width="stretch"): + show_create_experiment_modal() + st.markdown("
", unsafe_allow_html=True) + + +def _init_scenario_state() -> None: + """Initialize session state for scenario list if not present.""" + if "create_exp_scenarios" not in st.session_state: + st.session_state.create_exp_scenarios = [] + st.session_state.create_exp_sc_counter = 0 + + +def _add_scenario() -> None: + """Append a new scenario to the session state list.""" + counter = st.session_state.create_exp_sc_counter + sc_id = f"sc_{counter}" + st.session_state.create_exp_scenarios.append( + { + "id": sc_id, + "label": "", + } + ) + st.session_state.create_exp_sc_counter = counter + 1 + st.session_state.create_exp_last_added = sc_id + + +def _remove_scenario(sc_id: str) -> None: + """Remove a scenario from the session state list by its ID.""" + st.session_state.create_exp_scenarios = [ + s for s in st.session_state.create_exp_scenarios if s["id"] != sc_id + ] + + +def _render_scenario( + sc_id: str, + default_label: str, + index: int, + can_remove: bool, +) -> None: + """ + Render a single scenario expander with override toggles and forms. + + Args: + sc_id: Unique scenario ID (e.g. "sc_0") + default_label: Default label text for this scenario + index: Display index (1-based) + can_remove: Whether the remove button should be enabled + """ + label_key = f"{sc_id}_label" + current_label = st.session_state.get(label_key, default_label) + display_label = current_label if current_label else "Untitled" + header = f"Scenario {index}: {display_label}" + + last_added = st.session_state.get("create_exp_last_added") + with st.expander(header, expanded=(sc_id == last_added)): + st.text_input( + "Label *", + value=default_label, + key=label_key, + placeholder="e.g., High Infection Rate", + help="Required. Name for this scenario", + ) + + # Override toggle checkboxes + col_net, col_dist, col_sim = st.columns(3) + with col_net: + override_net = st.checkbox( + "Override Network", + key=f"{sc_id}_override_net", + ) + with col_dist: + override_dist = st.checkbox( + "Override Distribution", + key=f"{sc_id}_override_dist", + ) + with col_sim: + override_sim = st.checkbox( + "Override Simulation", + key=f"{sc_id}_override_sim", + ) + + # Render override forms when toggled + if override_net: + st.markdown("---") + st.caption("Network Overrides") + network_config_form(key_prefix=f"{sc_id}_net") + + if override_dist: + st.markdown("---") + st.caption("Distribution Overrides") + distribution_config_form(key_prefix=f"{sc_id}_dist") + + if override_sim: + st.markdown("---") + st.caption("Simulation Overrides") + simulation_params_form(key_prefix=f"{sc_id}_sim") + + if not (override_net or override_dist or override_sim): + st.caption("Using all global defaults") + + # Remove button + if can_remove: + st.button( + "Remove", + key=f"{sc_id}_remove", + on_click=_remove_scenario, + args=(sc_id,), + ) + + +def _collect_scenario_overrides( + sc_id: str, + global_params: Dict[str, Any], +) -> Dict[str, Any]: + """ + Collect override dict for one scenario by reading widget state. + + Only includes keys whose values actually differ from global_params. + + Args: + sc_id: Unique scenario ID + global_params: The global parameter dict to diff against + + Returns: + Dict with "label" and any differing override keys + """ + result: Dict[str, Any] = { + "label": st.session_state.get(f"{sc_id}_label", "Untitled"), + } + + override_net = st.session_state.get(f"{sc_id}_override_net", False) + override_dist = st.session_state.get(f"{sc_id}_override_dist", False) + override_sim = st.session_state.get(f"{sc_id}_override_sim", False) + + if override_net: + net_params = _read_form_values_network(f"{sc_id}_net") + for key, value in net_params.items(): + if global_params.get(key) != value: + result[key] = value + + if override_dist: + dist_params = _read_form_values_distribution(f"{sc_id}_dist") + for key, value in dist_params.items(): + if global_params.get(key) != value: + result[key] = value + + if override_sim: + sim_params = _read_form_values_simulation(f"{sc_id}_sim") + for key, value in sim_params.items(): + if global_params.get(key) != value: + result[key] = value + + return result + + +def _read_form_values_network(key_prefix: str) -> Dict[str, Any]: + """Read network form widget values from session state. + + Only reads conditional parameters (k_avg, exponent) when the + current network type actually uses them, avoiding stale session + state from previously-rendered conditional widgets. + """ + result: Dict[str, Any] = {} + network_type = st.session_state.get(f"{key_prefix}_type") + if network_type is not None: + result["network"] = network_type + nodes = st.session_state.get(f"{key_prefix}_nodes") + if nodes is not None: + result["nodes"] = nodes + # k_avg only exists for er, sf, rrn + if network_type in ("er", "sf", "rrn"): + k_avg = st.session_state.get(f"{key_prefix}_k_avg") + if k_avg is not None: + result["k_avg"] = k_avg + # exponent only exists for sf + if network_type == "sf": + exponent = st.session_state.get(f"{key_prefix}_exponent") + if exponent is not None: + result["exponent"] = exponent + return result + + +def _read_form_values_distribution(key_prefix: str) -> Dict[str, Any]: + """Read distribution form widget values from session state. + + Only reads conditional parameters (shape/scale for gamma, mu for + exponential) when the current distribution type uses them. + """ + result: Dict[str, Any] = {} + dist_type = st.session_state.get(f"{key_prefix}_type") + if dist_type is not None: + result["distribution"] = dist_type + lambda_val = st.session_state.get(f"{key_prefix}_lambda") + if lambda_val is not None: + result["lambda"] = lambda_val + # shape and scale only exist for gamma + if dist_type == "gamma": + shape = st.session_state.get(f"{key_prefix}_shape") + if shape is not None: + result["shape"] = shape + scale = st.session_state.get(f"{key_prefix}_scale") + if scale is not None: + result["scale"] = scale + # mu only exists for exponential + elif dist_type == "exponential": + mu = st.session_state.get(f"{key_prefix}_mu") + if mu is not None: + result["mu"] = mu + return result + + +def _read_form_values_simulation(key_prefix: str) -> Dict[str, Any]: + """Read simulation form widget values from session state.""" + result: Dict[str, Any] = {} + samples = st.session_state.get(f"{key_prefix}_samples") + if samples is not None: + result["samples"] = samples + num_runs = st.session_state.get(f"{key_prefix}_num_runs") + if num_runs is not None: + result["num_runs"] = num_runs + initial_perc = st.session_state.get(f"{key_prefix}_initial_perc") + if initial_perc is not None: + # The widget stores percentage (0-100), convert back to fraction + result["initial_perc"] = initial_perc / 100.0 + t_max = st.session_state.get(f"{key_prefix}_t_max") + if t_max is not None: + result["t_max"] = t_max + steps = st.session_state.get(f"{key_prefix}_steps") + if steps is not None: + result["steps"] = steps + return result + + +def _cleanup_scenario_state() -> None: + """Remove all dialog-related keys from session state after dialog closes.""" + prefixes = ("sc_", "create_network_", "create_dist_", "create_sim_") + explicit_keys = ( + "create_exp_scenarios", + "create_exp_sc_counter", + "create_exp_baseline", + "create_exp_last_added", + ) + keys_to_remove = [k for k in st.session_state if k.startswith(prefixes) or k in explicit_keys] + for key in keys_to_remove: + del st.session_state[key] + + +@st.dialog("Create New Experiment", width="large") +def show_create_experiment_modal() -> None: + """Show the create experiment modal dialog.""" + _init_scenario_state() + + st.markdown("### Experiment Configuration") + + # Basic info + st.subheader("Basic Information") + name = st.text_input( + "Experiment Name", + placeholder="e.g., Network Comparison Study", + help="Descriptive name for your experiment", + ) + + description = st.text_area( + "Description", + placeholder="What are you testing?", + help="Brief description of the experiment's purpose", + ) + + # Global parameters + st.subheader("Global Parameters") + st.caption("These parameters will be inherited by all scenarios (can be overridden)") + + with st.expander("Network Configuration", expanded=True): + network_params = network_config_form(key_prefix="create_network") + + with st.expander("Distribution Configuration", expanded=True): + dist_params = distribution_config_form(key_prefix="create_dist") + + with st.expander("Simulation Parameters", expanded=True): + sim_params = simulation_params_form(key_prefix="create_sim") + + # Scenarios section + st.subheader("Scenarios") + st.caption( + "Each scenario inherits the global parameters above. " + "Override specific values to create different conditions." + ) + + include_baseline = st.checkbox( + "Include Baseline scenario", + value=True, + key="create_exp_baseline", + help="Adds a Baseline scenario using all global defaults", + ) + + # Show baseline preview when checkbox is checked + if include_baseline: + with st.expander("Scenario 1: Baseline", expanded=False): + st.caption("Uses all global defaults (no overrides)") + + # Render each scenario + scenario_list = st.session_state.create_exp_scenarios + offset = 1 if include_baseline else 0 + + for idx, sc in enumerate(scenario_list): + _render_scenario( + sc_id=sc["id"], + default_label=sc["label"], + index=idx + 1 + offset, + can_remove=True, + ) + + # Add Scenario button (below all scenarios) + btn_col1, btn_col2 = st.columns([3, 1]) + with btn_col2: + st.button( + "+ Add Scenario", + on_click=_add_scenario, + width="stretch", + ) + + # Action buttons + st.divider() + spacer, col_cancel, col_create = st.columns([6, 2, 2]) + + with col_cancel: + if st.button("Cancel", width="stretch"): + _cleanup_scenario_state() + st.rerun() + + with col_create: + if st.button("Create Experiment", type="primary", width="stretch"): + if not name: + st.error("Please provide an experiment name") + return + + # Validate all scenario labels are non-empty + for sc in scenario_list: + sc_label = st.session_state.get(f"{sc['id']}_label", "").strip() + if not sc_label: + st.error("All scenarios must have a label.") + return + + # Validate normalized label uniqueness and non-emptiness + from spkmc.models.scenario import Scenario as ScenarioModel + + seen_normalized: Dict[str, str] = {} + if include_baseline: + seen_normalized["baseline"] = "Baseline" + for sc in scenario_list: + sc_label = st.session_state.get(f"{sc['id']}_label", "").strip() + norm = ScenarioModel.normalize_label(sc_label) + if not norm: + st.error( + f"Scenario label '{sc_label}' normalizes to an empty filename. " + "Use a label with at least one alphanumeric character." + ) + return + if norm in seen_normalized: + st.error( + f"Scenario labels '{seen_normalized[norm]}' and '{sc_label}' " + f"conflict (both normalize to '{norm}'). Use distinct names." + ) + return + seen_normalized[norm] = sc_label + + global_params = {**network_params, **dist_params, **sim_params} + + # Collect scenario overrides + scenarios = [ + _collect_scenario_overrides(sc["id"], global_params) for sc in scenario_list + ] + + # Prepend baseline if checked + if include_baseline: + scenarios.insert(0, {"label": "Baseline"}) + + if not scenarios: + st.error("Add at least one scenario or include baseline") + return + + # Create the experiment + try: + exp_path = create_experiment( + name=name, + description=description, + global_params=global_params, + scenarios=scenarios, + ) + + # Auto-run baseline scenario + if include_baseline: + _auto_run_baseline(exp_path) + + _cleanup_scenario_state() + st.success(f"Experiment '{name}' created successfully!") + st.rerun() + except Exception as e: + st.error(f"Failed to create experiment: {str(e)}") + + +def create_experiment( + name: str, + description: str, + global_params: Dict[str, Any], + scenarios: List[Dict[str, Any]], +) -> Path: + """ + Create a new experiment in the experiments directory. + + Args: + name: Experiment name + description: Experiment description + global_params: Global parameters dictionary + scenarios: List of scenario override dictionaries + + Returns: + Path to the created experiment directory + """ + config = st.session_state.config + exp_dir = config.get_experiments_path() + + # Create normalized directory name + from spkmc.models.scenario import Scenario + + dir_name = Scenario.normalize_label(name) + if not dir_name: + raise ValueError( + f"Experiment name '{name}' normalizes to an empty directory name. " + "Use a name with at least one alphanumeric character." + ) + exp_path = Path(exp_dir) / dir_name + + # Check if already exists + if exp_path.exists(): + raise ValueError(f"Experiment '{dir_name}' already exists") + + # Create directory + exp_path.mkdir(parents=True, exist_ok=True) + + # Build experiment config + config_dict = { + "name": name, + "description": description, + "parameters": global_params, + "scenarios": scenarios, + } + + # Write data.json (atomic to prevent corruption on crash) + from spkmc.web import atomic_json_write + + data_file = exp_path / "data.json" + atomic_json_write(data_file, config_dict) + + return exp_path + + +def _auto_run_baseline(exp_path: Path) -> None: + """Start the baseline scenario run for a freshly created experiment. + + Args: + exp_path: Path to the experiment directory + """ + from spkmc.web.runner import SimulationRunner + + config = st.session_state.config + exp_manager = ExperimentManager(str(config.get_experiments_path())) + + try: + experiment = exp_manager.load_experiment(exp_path.name) + except Exception: + return + + # Find the baseline scenario + baseline = None + for sc in experiment.scenarios: + if sc.label == "Baseline": + baseline = sc + break + + if baseline is None: + return + + if "simulation_runner" not in st.session_state: + st.session_state.simulation_runner = SimulationRunner() + runner: SimulationRunner = st.session_state.simulation_runner + + run_id = runner.run_scenario(experiment, baseline, show_progress=True) + if run_id: + status_info = runner.get_status(run_id) + if status_info and experiment.path is not None: + scenario_id = f"sim--{experiment.path.name}--{baseline.normalized_label}" + status_info["run_id"] = run_id + SessionState.add_running_simulation(scenario_id, status_info) diff --git a/spkmc/web/pages/experiment_detail.py b/spkmc/web/pages/experiment_detail.py new file mode 100644 index 0000000..7e61e25 --- /dev/null +++ b/spkmc/web/pages/experiment_detail.py @@ -0,0 +1,1989 @@ +""" +Experiment detail page - view and manage a single experiment. + +Shows experiment overview, scenario cards, scenario detail modals, +and comparison functionality. +""" + +from __future__ import annotations + +import base64 +import html as _html +import json +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +import streamlit as st + +from spkmc.io.data_manager import DataManager +from spkmc.io.experiments import ExperimentManager +from spkmc.models import Experiment, Scenario +from spkmc.web.analysis_runner import AnalysisRunner, poll_running_analyses +from spkmc.web.components import result_metric_cards +from spkmc.web.config import WebConfig +from spkmc.web.plotting import create_comparison_figure, create_sir_figure +from spkmc.web.runner import SimulationRunner, poll_running_simulations +from spkmc.web.state import SessionState +from spkmc.web.styles import ( + COLORS, + FONTS, + _dedent, + circular_progress_html, + params_card, + scenario_card, + section_header, +) + +# SVG icons used on this page +_ICON_NETWORK = ( + '' + '' + '' + '' +) +_ICON_DIST = ( + '' + '' +) +_ICON_SIM = ( + '' +) +_ICON_AI = ( + '' + '' + '' + "" +) + + +def _values_equal(a: Any, b: Any) -> bool: + """Compare two values with numeric type normalization. + + Handles the int/float mismatch that occurs when Pydantic coerces + JSON integers but Python code uses floats (e.g. 10 vs 10.0). + """ + if a is None or b is None: + return a is b + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return float(a) == float(b) + return bool(a == b) + + +def _download_anchor(data: bytes, filename: str, mime: str, label: str = "Download") -> str: + """ + Return an HTML anchor styled as a button using a base64 data URI. + + Avoids Streamlit's media file storage, which causes KeyError when a + st.download_button is re-rendered inside a popover (stale file ID). + """ + b64 = base64.b64encode(data).decode() + safe_filename = _html.escape(filename, quote=True) + safe_label = _html.escape(label) + return ( + f'' + f"⬇ {safe_label}" + ) + + +def render() -> None: + """Render the experiment detail page.""" + exp_name = SessionState.get_selected_experiment() + if not exp_name: + st.error("No experiment selected") + if st.button("Back to Dashboard", key="detail_back_err"): + SessionState.set_selected_experiment(None) + st.rerun() + return + + config = st.session_state.config + exp_manager = ExperimentManager(str(config.get_experiments_path())) + + try: + experiment = exp_manager.load_experiment(exp_name) + except Exception as e: + st.error(f"Failed to load experiment: {str(e)}") + if st.button("Back to Dashboard", key="detail_back_err2"): + SessionState.set_selected_experiment(None) + st.rerun() + return + + exp_path = experiment.path + assert exp_path is not None + + # -- Header -- + with st.container(key="detail_back"): + if st.button( + "Back", + key="detail_back_btn", + icon=":material/arrow_back:", + ): + SessionState.set_selected_experiment(None) + st.rerun() + + api_key = WebConfig.get_openai_api_key() + exp_analysis_id = f"exp_analysis--{exp_path.name}" + analysis_status = SessionState.get_analysis_status(exp_analysis_id) + analysis_running = analysis_status == "running" + analysis_file = exp_path / "analysis.md" + has_analysis = analysis_file.exists() + + if analysis_running: + ai_label = "Analyzing..." + ai_icon = ":material/sync:" + ai_disabled = True + ai_help = "Analysis in progress..." + elif has_analysis: + ai_label = "Re-analyze" + ai_icon = ":material/auto_awesome:" + ai_disabled = not api_key + ai_help = "Re-generate AI analysis" if api_key else "Set OpenAI API key in Preferences" + else: + ai_label = "Analyze experiment" + ai_icon = ":material/auto_awesome:" + ai_disabled = not api_key + ai_help = "Generate AI analysis" if api_key else "Set OpenAI API key in Preferences" + + col_title, col_ai = st.columns([8, 2]) + with col_title: + st.markdown( + _dedent(f""" +
+

+{experiment.name}

+
+"""), + unsafe_allow_html=True, + ) + with col_ai: + with st.container(key="action_ai"): + if st.button( + ai_label, + key="btn_ai", + disabled=ai_disabled, + help=ai_help, + icon=ai_icon, + width="stretch", + ): + if api_key: + run_ai_analysis(experiment) + + if experiment.description: + st.caption(experiment.description) + + # -- Global Parameters -- + render_experiment_overview(experiment) + + # -- Spacer between sections -- + st.markdown('
', unsafe_allow_html=True) + + # -- Action bar + Scenarios -- + render_action_bar(experiment) + _live_scenario_cards(experiment) + + # -- AI Analysis (always visible) -- + st.markdown(section_header("AI Analysis"), unsafe_allow_html=True) + if has_analysis: + with st.expander("View Analysis", expanded=True): + try: + with open(analysis_file, "r") as f: + st.markdown(f.read()) + except Exception as e: + st.error(f"Failed to load analysis: {str(e)}") + elif analysis_running: + st.markdown( + _dedent(f""" +
+
+ +Generating analysis... This may take a moment. +
+"""), + unsafe_allow_html=True, + ) + else: + st.markdown( + _dedent(f""" +
+

+No analysis generated yet. Click "Analyze experiment" above to generate one.

+
+"""), + unsafe_allow_html=True, + ) + + +def render_experiment_overview(experiment: Experiment) -> None: + """Render global parameters as three refined cards.""" + st.markdown(section_header("Global Parameters"), unsafe_allow_html=True) + + params = experiment.parameters + network_names = { + "er": "Erdos-Renyi", + "sf": "Scale-Free", + "cg": "Complete Graph", + "rrn": "Random Regular", + } + + # Build rows for each card + net_rows = [("Type", network_names.get(params.get("network", ""), "N/A"))] + if "nodes" in params: + net_rows.append(("Nodes", str(params["nodes"]))) + if "k_avg" in params: + net_rows.append(("Avg Degree", str(params["k_avg"]))) + if "exponent" in params: + net_rows.append(("Exponent", str(params["exponent"]))) + + dist_rows = [("Type", params.get("distribution", "N/A").capitalize())] + if "lambda" in params: + dist_rows.append(("Infection Rate", str(params["lambda"]))) + if params.get("distribution") == "gamma": + if "shape" in params: + dist_rows.append(("Shape", str(params["shape"]))) + if "scale" in params: + dist_rows.append(("Scale", str(params["scale"]))) + elif params.get("distribution") == "exponential": + if "mu" in params: + dist_rows.append(("Recovery Rate", str(params["mu"]))) + + sim_rows = [] + if "samples" in params: + sim_rows.append(("Samples", str(params["samples"]))) + if "num_runs" in params: + sim_rows.append(("Runs", str(params["num_runs"]))) + if "t_max" in params: + sim_rows.append(("Max Time", str(params["t_max"]))) + if "steps" in params: + sim_rows.append(("Steps", str(params["steps"]))) + + with st.container(key="params_section"): + col1, col2, col3 = st.columns(3) + with col1: + st.markdown( + params_card("Network", _ICON_NETWORK, net_rows), + unsafe_allow_html=True, + ) + with col2: + st.markdown( + params_card("Distribution", _ICON_DIST, dist_rows), + unsafe_allow_html=True, + ) + with col3: + st.markdown( + params_card("Simulation", _ICON_SIM, sim_rows), + unsafe_allow_html=True, + ) + + +def render_action_bar(experiment: Experiment) -> None: + """Render the scenarios section header with Add Scenario button.""" + col_title, col_add = st.columns([8, 2]) + with col_title: + st.markdown( + section_header("Scenarios"), + unsafe_allow_html=True, + ) + with col_add: + with st.container(key="action_add_scenario"): + if st.button( + "Add Scenario", + key="btn_add_scenario_bar", + width="stretch", + icon=":material/add:", + ): + show_add_scenario_modal(experiment) + + +@st.fragment(run_every=timedelta(seconds=2)) +def _live_scenario_cards(experiment: Experiment) -> None: + """Fragment wrapper that polls progress and re-renders scenario cards. + + Runs every 2 seconds to check subprocess status files and update + progress bars without triggering a full page rerun. + """ + poll_running_simulations() + analysis_changed = poll_running_analyses() + if analysis_changed: + # Full page rerun so AI section (outside this fragment) re-renders + st.rerun() + render_scenario_cards(experiment) + + +def _get_scenario_entry(experiment: Experiment, label: str) -> dict[str, Any] | None: + """Read a scenario's raw entry from data.json.""" + exp_path = experiment.path + assert exp_path is not None + data_file = exp_path / "data.json" + try: + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + for s in data.get("scenarios", []): + if s.get("label") == label: + result: dict[str, Any] = s + return result + except Exception: + pass + return None + + +def render_scenario_cards(experiment: Experiment) -> None: + """Render scenarios as clickable cards with run and delete buttons.""" + if not experiment.scenarios: + st.markdown( + _dedent(f""" +
+No scenarios defined yet. Add one above. +
+"""), + unsafe_allow_html=True, + ) + return + + exp_path = experiment.path + assert exp_path is not None + + if "simulation_runner" not in st.session_state: + st.session_state.simulation_runner = SimulationRunner() + runner: SimulationRunner = st.session_state.simulation_runner + + for sc in experiment.scenarios: + result_file = exp_path / f"{sc.normalized_label}.json" + has_result = result_file.exists() + scenario_id = f"sim--{exp_path.name}--{sc.normalized_label}" + sc_status = "completed" if has_result else "created" + if not has_result: + sc_entry = _get_scenario_entry(experiment, sc.label) + if sc_entry and sc_entry.get("status") == "edited": + sc_status = "edited" + sim_state = SessionState.get_simulation_status(scenario_id) + if sim_state == "running": + sc_status = "running" + elif sim_state == "failed": + sc_status = "failed" + + override_text = get_override_summary(sc, experiment.parameters) + + # Calculate progress fraction for running scenarios + progress_frac = -1.0 + if sc_status == "running": + prog_info = SessionState.get_simulation_progress(scenario_id) + if prog_info and prog_info["total"] > 0: + progress_frac = prog_info["progress"] / prog_info["total"] + + # Card container with overlay button + with st.container(key=f"sc_card_{scenario_id}"): + is_baseline = sc.label == "Baseline" + col_body, col_run, col_edit, col_del = st.columns([8.5, 0.5, 0.5, 0.5]) + + with col_body: + st.markdown( + scenario_card( + label=sc.label, + override_text=override_text, + status=sc_status, + progress=progress_frac, + ), + unsafe_allow_html=True, + ) + + with col_run: + with st.container(key=f"sc_run_{scenario_id}"): + is_running = sim_state == "running" + if st.button( + "", + key=f"run_sc_{scenario_id}", + width="stretch", + disabled=is_running, + icon=":material/play_arrow:", + help=( + "Running..." + if is_running + else "Re-run this scenario" if has_result else "Run this scenario" + ), + ): + if has_result: + show_rerun_scenario_dialog(experiment, sc, runner) + else: + _start_scenario_run(experiment, sc, runner) + + with col_edit: + if not is_baseline: + with st.container(key=f"sc_edit_{scenario_id}"): + if st.button( + "", + key=f"edit_sc_{scenario_id}", + width="stretch", + disabled=is_running, + icon=":material/edit:", + help="Running..." if is_running else "Edit this scenario", + ): + show_edit_scenario_modal(experiment, sc) + + with col_del: + if not is_baseline: + with st.container(key=f"sc_del_{scenario_id}"): + if st.button( + "", + key=f"del_sc_{scenario_id}", + width="stretch", + disabled=is_running, + icon=":material/delete:", + help="Running..." if is_running else "Delete this scenario", + ): + show_delete_scenario_dialog(experiment, sc) + + # Invisible overlay button for card click + if st.button("open", key=f"sc_btn_{scenario_id}"): + show_scenario_detail_modal(experiment, sc) + + +def get_override_summary(scenario: Scenario, global_params: Dict[str, Any]) -> str: + """ + Get a summary of parameters that differ from global parameters. + + Args: + scenario: The scenario to check + global_params: Global experiment parameters + + Returns: + String summary of overridden parameters + """ + overrides = [] + skip_keys = {"label", "experiment_name", "output_path"} + + # Check each parameter + scenario_dict = scenario.model_dump(by_alias=True) + for key, value in scenario_dict.items(): + if key in skip_keys: + continue + if value is None: + continue + + # Show if key doesn't exist in global (new param) or value differs + if key not in global_params or not _values_equal(value, global_params[key]): + overrides.append(f"{key}: {value}") + + return " | ".join(overrides) if overrides else "" + + +@st.dialog("Scenario Details", width="large") +def show_scenario_detail_modal(experiment: Experiment, scenario: Scenario) -> None: + """Show detailed modal for a single scenario.""" + _modal_body_fragment(experiment, scenario) + + +@st.fragment(run_every=timedelta(seconds=2)) +def _modal_body_fragment(experiment: Experiment, scenario: Scenario) -> None: + """Fragment handling the full modal body. + + Renders title row (with AI/Export when results exist), parameters, + Run button, and content area. Runs every 2 seconds to poll progress. + Uses st.rerun(scope="fragment") for clean DOM transitions. + """ + exp_path = experiment.path + assert exp_path is not None + + scenario_id = f"sim--{exp_path.name}--{scenario.normalized_label}" + modal_running_key = f"_modal_running_{scenario_id}" + + poll_running_simulations() + analysis_changed = poll_running_analyses() + if analysis_changed: + st.rerun(scope="fragment") + + result_file = exp_path / f"{scenario.normalized_label}.json" + analysis_file = exp_path / f"{scenario.normalized_label}_analysis.md" + + sim_status = SessionState.get_simulation_status(scenario_id) + has_result = result_file.exists() + + # Clear modal running latch when simulation finishes (success or failure). + # The latch bridges the gap between "user clicked Run" and "status file written". + # Once sim_status reflects a terminal state, the latch is no longer needed. + latch_on = st.session_state.get(modal_running_key, False) + if latch_on and sim_status != "running": + # Simulation finished (completed/failed/pending) — clear latch and refresh + if has_result or sim_status in ("failed", "completed"): + st.session_state.pop(modal_running_key, None) + st.rerun(scope="fragment") + + is_running = sim_status == "running" or latch_on + + # Pre-read result data for title row actions + content + result_json = None + result_dict = None + if has_result: + try: + with open(result_file, "r") as f: + result_json = f.read() + result_dict = json.loads(result_json) + except Exception as e: + st.error(f"Failed to load results: {str(e)}") + return + + # -- Title row: title + action buttons (AI, Export, Run) -- + is_baseline = scenario.label == "Baseline" + show_run = True # All scenarios (including Baseline) can be run + sc_key = scenario_id + + if has_result and result_json: + api_key = WebConfig.get_openai_api_key() + has_analysis = analysis_file.exists() + sc_analysis_id = f"sc_analysis--{exp_path.name}--{scenario.normalized_label}" + sc_analysis_status = SessionState.get_analysis_status(sc_analysis_id) + sc_analysis_running = sc_analysis_status == "running" + + if sc_analysis_running: + sc_ai_label = "Analyzing..." + sc_ai_icon = ":material/sync:" + sc_ai_disabled = True + sc_ai_help = "Analysis in progress..." + elif has_analysis: + sc_ai_label = "Re-analyze" + sc_ai_icon = ":material/auto_awesome:" + sc_ai_disabled = not api_key + sc_ai_help = ( + "Re-generate AI analysis" if api_key else "Set OpenAI API key in Preferences" + ) + else: + sc_ai_label = "Analyze scenario" + sc_ai_icon = ":material/auto_awesome:" + sc_ai_disabled = not api_key + sc_ai_help = ( + "Analyze this scenario with AI" if api_key else "Set OpenAI API key in Preferences" + ) + + if show_run: + cols = st.columns([4, 1.5, 1.5, 1.5]) + col_title, col_ai, col_export, col_run = cols + else: + cols = st.columns([6, 2, 2]) + col_title, col_ai, col_export = cols + with col_title: + st.title(scenario.label) + with col_ai: + with st.container(key=f"modal_action_ai_{sc_key}"): + if st.button( + sc_ai_label, + key=f"modal_btn_ai_{sc_key}", + disabled=sc_ai_disabled, + help=sc_ai_help, + icon=sc_ai_icon, + width="stretch", + ): + run_scenario_ai_analysis(experiment, scenario, result_file) + with col_export: + with st.container(key=f"modal_action_export_{sc_key}"): + with st.popover( + "Export", + icon=":material/download:", + use_container_width=True, + ): + _export_fmt = st.radio( + "Format", + options=["json", "csv", "excel", "md", "html"], + horizontal=True, + label_visibility="collapsed", + key=f"export_fmt_{sc_key}", + ) + assert result_dict is not None + _export_data, _export_mime, _export_ext = DataManager.to_bytes( + result_dict, _export_fmt + ) + st.markdown( + _download_anchor( + _export_data, + f"{scenario.normalized_label}{_export_ext}", + _export_mime, + ), + unsafe_allow_html=True, + ) + else: + if show_run: + col_title, col_run = st.columns([8, 2]) + else: + col_title = st.columns([1])[0] + col_run = None + with col_title: + st.title(scenario.label) + + # -- Run button in title row (for non-Baseline) -- + if show_run: + if "simulation_runner" not in st.session_state: + st.session_state.simulation_runner = SimulationRunner() + runner: SimulationRunner = st.session_state.simulation_runner + assert col_run is not None + with col_run: + with st.container(key=f"modal_action_run_{sc_key}"): + run_label = "Re-run scenario" if has_result else "Run scenario" + if st.button( + run_label, + type="primary", + key=f"modal_btn_run_{sc_key}", + width="stretch", + icon=":material/play_arrow:", + disabled=is_running, + ): + # Move stale artifacts to .bak BEFORE spawning so + # fast scenarios don't race. Restore if launch fails. + bak_r, bak_a = _backup_scenario_artifacts( + result_file, analysis_file, has_result + ) + sid = _start_scenario_run_no_rerun(experiment, scenario, runner) + _finalize_artifact_backups( + result_file, analysis_file, bak_r, bak_a, sid is not None + ) + if sid: + st.session_state[modal_running_key] = True + st.rerun(scope="fragment") + + # -- Parameters -- + render_scenario_parameters(scenario, experiment.parameters, experiment_name=exp_path.name) + + # -- Content area -- + if has_result and result_dict: + _render_result_content(result_dict, experiment, scenario, analysis_file) + elif is_running: + prog_info = SessionState.get_simulation_progress(scenario_id) + progress = 0.0 + if prog_info and prog_info["total"] > 0: + progress = prog_info["progress"] / prog_info["total"] + st.markdown( + circular_progress_html(progress, "Running simulation..."), + unsafe_allow_html=True, + ) + else: + st.markdown( + _dedent(f""" +
+

+No results available

+

+Run this scenario to generate simulation results.

+
+"""), + unsafe_allow_html=True, + ) + + +def _render_result_content( + result_dict: dict[str, Any], + experiment: Experiment, + scenario: Scenario, + analysis_file: Path, +) -> None: + """Render result metrics, chart, comparison, and AI analysis.""" + exp_path = experiment.path + assert exp_path is not None + + has_analysis = analysis_file.exists() + + # -- Key Metrics -- + st.subheader("Key Metrics") + result_metric_cards(result_dict) + + st.divider() + + # -- SIR Dynamics -- + st.subheader("SIR Dynamics") + + # -- Chart controls (single row) -- + ( + col_s, + col_i, + col_r, + col_spacer, + col_err, + col_type, + ) = st.columns([0.8, 0.8, 0.8, 1.875, 0.6, 0.8]) + sc_key = f"sim--{exp_path.name}--{scenario.normalized_label}" + with col_s: + show_s = st.checkbox("Susceptible", value=True, key=f"modal_show_s_{sc_key}") + with col_i: + show_i = st.checkbox("Infected", value=True, key=f"modal_show_i_{sc_key}") + with col_r: + show_r = st.checkbox("Recovered", value=True, key=f"modal_show_r_{sc_key}") + with col_err: + show_errors = st.checkbox( + "Error bars", + value=True, + key=f"modal_show_errors_{sc_key}", + ) + with col_type: + chart_type_label = st.selectbox( + "Chart Type", + ["Lines", "Lines + Markers", "Area"], + key=f"modal_chart_mode_{sc_key}", + label_visibility="collapsed", + ) + + # Map selectbox label to chart_mode parameter + chart_mode_map = { + "Lines": "lines", + "Lines + Markers": "lines+markers", + "Area": "area", + } + chart_mode = chart_mode_map.get(chart_type_label, "lines") + + states_to_plot = [] + if show_s: + states_to_plot.append("S") + if show_i: + states_to_plot.append("I") + if show_r: + states_to_plot.append("R") + + # -- Discover other scenarios with results -- + other_scenarios = [] + for other_sc in experiment.scenarios: + if other_sc.label == scenario.label: + continue + other_file = exp_path / f"{other_sc.normalized_label}.json" + if other_file.exists(): + other_scenarios.append(other_sc) + + # -- Reserve visual space for chart (rendered after controls) -- + chart_container = st.container() + + # -- Compare controls (execute first to set state) -- + comparing = False + comp_results: List[Dict] = [] + comp_labels: List[str] = [] + + if other_scenarios: + st.subheader("Compare with Other Scenarios") + compare_options = [sc.label for sc in other_scenarios] + + compare_key = f"modal_compare_{exp_path.name}_{scenario.normalized_label}" + selected_labels = st.multiselect( + "Select scenarios to compare", + options=compare_options, + key=compare_key, + label_visibility="collapsed", + ) + + # Auto-trigger comparison when scenarios are selected + if selected_labels: + comp_results.append(result_dict) + comp_labels.append(scenario.label) + + for sel_label in selected_labels: + for other_sc in other_scenarios: + if other_sc.label == sel_label: + sel_file = exp_path / f"{other_sc.normalized_label}.json" + try: + with open(sel_file, "r") as f: + comp_results.append(json.load(f)) + comp_labels.append(sel_label) + except Exception: + continue + + if len(comp_results) >= 2: + comparing = True + + # -- Render chart into reserved container -- + with chart_container: + if not states_to_plot: + st.warning("Select at least one state to display") + elif comparing: + config = st.session_state.config + fig = create_comparison_figure( + comp_results, + comp_labels, + title=f"Comparison: {experiment.name}", + states=states_to_plot, + height=config.get("chart_height", 500), + template=config.get("chart_template", "plotly_white"), + ) + st.plotly_chart(fig, width="stretch") + else: + config = st.session_state.config + fig = create_sir_figure( + result_dict, + title=scenario.label, + states=states_to_plot, + show_error_bands=show_errors and "S_err" in result_dict, + height=config.get("chart_height", 500), + chart_mode=chart_mode, + state_colors={ + "S": config.get("chart_color_s", "#4477AA"), + "I": config.get("chart_color_i", "#EE6677"), + "R": config.get("chart_color_r", "#228833"), + }, + template=config.get("chart_template", "plotly_white"), + ) + st.plotly_chart(fig, width="stretch") + + # -- Comparison statistics -- + if comparing: + st.subheader("Comparison Statistics") + render_comparison_stats(comp_results, comp_labels) + + # -- AI Analysis (always visible) -- + st.divider() + st.subheader("AI Analysis") + + sc_analysis_id = f"sc_analysis--{exp_path.name}--{scenario.normalized_label}" + sc_analysis_running = SessionState.get_analysis_status(sc_analysis_id) == "running" + + if has_analysis: + try: + with open(analysis_file, "r") as f: + st.markdown(f.read()) + except Exception as e: + st.error(f"Failed to load analysis: {str(e)}") + elif sc_analysis_running: + st.markdown( + _dedent(f""" +
+
+ +Generating analysis... +
+"""), + unsafe_allow_html=True, + ) + else: + st.markdown( + _dedent(f""" +
+

+No analysis generated yet. Click "Analyze scenario" above to generate one.

+
+"""), + unsafe_allow_html=True, + ) + + +def render_scenario_parameters( + scenario: Scenario, + global_params: Dict[str, Any], + experiment_name: str = "", +) -> None: + """Render scenario parameters with visual distinction for overrides.""" + scenario_dict = scenario.model_dump(by_alias=True) + + network_keys = ["network", "nodes", "k_avg", "exponent"] + dist_keys = ["distribution", "shape", "scale", "mu", "lambda"] + sim_keys = ["samples", "num_runs", "t_max", "steps", "initial_perc"] + + def _build_rows(keys: list) -> list: + rows = [] + for key in keys: + if key in scenario_dict and scenario_dict[key] is not None: + val = scenario_dict[key] + is_override = not _values_equal(global_params.get(key), val) + rows.append((key, str(val), is_override)) + return rows + + sc_key = ( + f"{experiment_name}_{scenario.normalized_label}" + if experiment_name + else scenario.normalized_label + ) + with st.container(key=f"modal_params_section_{sc_key}"): + col1, col2, col3 = st.columns(3) + with col1: + st.markdown( + params_card("Network", _ICON_NETWORK, _build_rows(network_keys)), + unsafe_allow_html=True, + ) + with col2: + st.markdown( + params_card("Distribution", _ICON_DIST, _build_rows(dist_keys)), + unsafe_allow_html=True, + ) + with col3: + st.markdown( + params_card("Simulation", _ICON_SIM, _build_rows(sim_keys)), + unsafe_allow_html=True, + ) + + +def render_comparison_stats(results: List[Dict], labels: List[str]) -> None: + """Render a comparison table of key statistics.""" + from datetime import timedelta + + import humanize + import numpy as np + + stats = [] + for result_dict, label in zip(results, labels): + I_val = np.array(result_dict["I_val"]) + R_val = np.array(result_dict["R_val"]) + time = np.array(result_dict["time"]) + + exec_time = result_dict.get("metadata", {}).get("execution_time") + if exec_time is not None: + duration = humanize.precisedelta( + timedelta(seconds=exec_time), minimum_unit="seconds", format="%0.0f" + ) + else: + duration = "N/A" + + stats.append( + { + "Scenario": label, + "Peak Infected": f"{np.max(I_val):.2%}", + "Peak Time": f"{time[np.argmax(I_val)]:.2f}", + "Final Size": f"{R_val[-1]:.2%}", + "Duration": duration, + } + ) + + df = pd.DataFrame(stats) + st.dataframe(df, width="stretch", hide_index=True) + + +@st.dialog("Add Scenario", width="large") +def show_add_scenario_modal(experiment: Experiment) -> None: + """Show modal to add a new scenario to the experiment.""" + exp_path = experiment.path + assert exp_path is not None + exp_key = exp_path.name + st.title("Add New Scenario") + + label = st.text_input( + "Scenario Label", + placeholder="e.g., High Infection Rate", + help="Descriptive name for this scenario", + ) + + st.subheader("Parameter Overrides") + st.caption( + "Values are pre-filled with experiment defaults. " + "Only changed values will be saved as overrides." + ) + + global_params = experiment.parameters + override_params: Dict[str, Any] = {} + + # -- Network Overrides -- + with st.expander("Network Overrides", expanded=False): + network_options = ["er", "sf", "cg", "rrn"] + network_names = { + "er": "Erdos-Renyi", + "sf": "Scale-Free", + "cg": "Complete Graph", + "rrn": "Random Regular", + } + global_network = global_params.get("network", "er") + global_idx = ( + network_options.index(global_network) if global_network in network_options else 0 + ) + override_network = st.selectbox( + "Network Type", + options=network_options, + format_func=lambda x: network_names.get(x, x), + index=global_idx, + key=f"add_sc_network_{exp_key}", + help=f"Experiment default: {network_names.get(global_network, global_network)}", + ) + network_changed = override_network != global_network + if network_changed: + override_params["network"] = override_network + + col_n1, col_n2 = st.columns(2) + with col_n1: + global_nodes = int(global_params.get("nodes", 1000)) + override_nodes = st.number_input( + "Nodes", + min_value=1, + value=global_nodes, + step=100, + key=f"add_sc_nodes_{exp_key}", + help=f"Experiment default: {global_nodes}", + ) + if override_nodes != global_nodes: + override_params["nodes"] = override_nodes + + with col_n2: + global_k_avg = float(global_params.get("k_avg", 10.0)) + override_k_avg = st.number_input( + "Average Degree (k_avg)", + min_value=0.1, + value=global_k_avg, + step=1.0, + key=f"add_sc_k_avg_{exp_key}", + help=f"Experiment default: {global_k_avg}", + ) + # Always include k_avg when network type is overridden (required for er/sf/rrn) + if override_k_avg != global_k_avg or network_changed: + override_params["k_avg"] = override_k_avg + + # Exponent only relevant for scale-free networks + effective_network = override_network or global_network + if effective_network == "sf": + global_exponent = float(global_params.get("exponent", 2.5)) + override_exponent = st.number_input( + "Power-law Exponent", + min_value=0.1, + value=global_exponent, + step=0.1, + key=f"add_sc_exponent_{exp_key}", + help=f"Experiment default: {global_exponent}", + ) + # Always include exponent when network type is overridden to sf + if override_exponent != global_exponent or network_changed: + override_params["exponent"] = override_exponent + + # -- Distribution Overrides -- + with st.expander("Distribution Overrides", expanded=False): + dist_options = ["gamma", "exponential"] + global_dist = global_params.get("distribution", "gamma") + global_dist_idx = dist_options.index(global_dist) if global_dist in dist_options else 0 + override_dist = st.selectbox( + "Distribution Type", + options=dist_options, + format_func=lambda x: x.capitalize(), + index=global_dist_idx, + key=f"add_sc_distribution_{exp_key}", + help=f"Experiment default: {global_dist.capitalize()}", + ) + if override_dist != global_dist: + override_params["distribution"] = override_dist + + global_lambda = float(global_params.get("lambda", 1.0)) + override_lambda = st.number_input( + "Infection Rate (lambda)", + min_value=0.01, + value=global_lambda, + step=0.1, + key=f"add_sc_lambda_{exp_key}", + help=f"Experiment default: {global_lambda}", + ) + if override_lambda != global_lambda: + override_params["lambda"] = override_lambda + + effective_dist = override_dist or global_dist + # When distribution is overridden, always include the + # distribution-specific required params so that + # Scenario.from_merged() won't fail validation. + dist_changed = override_dist != global_dist + if effective_dist == "gamma": + col_d1, col_d2 = st.columns(2) + with col_d1: + global_shape = float(global_params.get("shape", 2.0)) + override_shape = st.number_input( + "Shape", + min_value=0.01, + value=global_shape, + step=0.1, + key=f"add_sc_shape_{exp_key}", + help=f"Experiment default: {global_shape}", + ) + if override_shape != global_shape or dist_changed: + override_params["shape"] = override_shape + with col_d2: + global_scale = float(global_params.get("scale", 1.0)) + override_scale = st.number_input( + "Scale", + min_value=0.01, + value=global_scale, + step=0.1, + key=f"add_sc_scale_{exp_key}", + help=f"Experiment default: {global_scale}", + ) + if override_scale != global_scale or dist_changed: + override_params["scale"] = override_scale + elif effective_dist == "exponential": + global_mu = float(global_params.get("mu", 1.0)) + override_mu = st.number_input( + "Recovery Rate (mu)", + min_value=0.01, + value=global_mu, + step=0.1, + key=f"add_sc_mu_{exp_key}", + help=f"Experiment default: {global_mu}", + ) + if override_mu != global_mu or dist_changed: + override_params["mu"] = override_mu + + # -- Simulation Overrides -- + with st.expander("Simulation Overrides", expanded=False): + col_s1, col_s2 = st.columns(2) + with col_s1: + global_samples = int(global_params.get("samples", 50)) + override_samples = st.number_input( + "Samples", + min_value=1, + value=global_samples, + step=10, + key=f"add_sc_samples_{exp_key}", + help=f"Experiment default: {global_samples}", + ) + if override_samples != global_samples: + override_params["samples"] = override_samples + with col_s2: + global_num_runs = int(global_params.get("num_runs", 2)) + override_num_runs = st.number_input( + "Number of Runs", + min_value=1, + value=global_num_runs, + step=1, + key=f"add_sc_num_runs_{exp_key}", + help=f"Experiment default: {global_num_runs}", + ) + if override_num_runs != global_num_runs: + override_params["num_runs"] = override_num_runs + + col_s3, col_s4 = st.columns(2) + with col_s3: + global_t_max = float(global_params.get("t_max", 10.0)) + override_t_max = st.number_input( + "Max Time (t_max)", + min_value=0.01, + value=global_t_max, + step=1.0, + key=f"add_sc_t_max_{exp_key}", + help=f"Experiment default: {global_t_max}", + ) + if override_t_max != global_t_max: + override_params["t_max"] = override_t_max + with col_s4: + global_steps = int(global_params.get("steps", 100)) + override_steps = st.number_input( + "Steps", + min_value=1, + value=global_steps, + step=10, + key=f"add_sc_steps_{exp_key}", + help=f"Experiment default: {global_steps}", + ) + if override_steps != global_steps: + override_params["steps"] = override_steps + + global_initial_perc = float(global_params.get("initial_perc", 0.01)) + override_initial_perc = st.number_input( + "Initial Infected Fraction", + min_value=0.001, + max_value=1.0, + value=global_initial_perc, + step=0.01, + format="%.3f", + key=f"add_sc_initial_perc_{exp_key}", + help=f"Experiment default: {global_initial_perc}", + ) + if override_initial_perc != global_initial_perc: + override_params["initial_perc"] = override_initial_perc + + # Action buttons (pinned to bottom via CSS on modal_actions container) + with st.container(key=f"modal_actions_add_{exp_key}"): + st.divider() + col1, col2 = st.columns(2) + + with col1: + if st.button("Cancel", width="stretch"): + st.rerun() + + with col2: + if st.button( + "Add Scenario", + type="primary", + width="stretch", + icon=":material/add:", + ): + if not label: + st.error("Please provide a scenario label") + return + + try: + add_scenario_to_experiment(experiment, label, override_params) + st.success(f"Scenario '{label}' added successfully!") + st.rerun() + except Exception as e: + st.error(f"Failed to add scenario: {str(e)}") + + +def add_scenario_to_experiment( + experiment: Experiment, label: str, override_params: Dict[str, Any] +) -> None: + """ + Add a new scenario to an existing experiment. + + Args: + experiment: The experiment to add to + label: Scenario label + override_params: Parameters that override global settings + + Raises: + ValueError: If a scenario with the same normalized label already exists + """ + from spkmc.models.scenario import Scenario as ScenarioModel + + exp_path = experiment.path + assert exp_path is not None + + # Load current data.json + data_file = exp_path / "data.json" + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + + # Check that label normalizes to a non-empty filename + new_norm = ScenarioModel.normalize_label(label) + if not new_norm: + raise ValueError( + f"Scenario label '{label}' normalizes to an empty filename. " + "Use a label with at least one alphanumeric character." + ) + + # Check for normalized label collision + for sc in data.get("scenarios", []): + existing_norm = ScenarioModel.normalize_label(sc.get("label", "")) + if existing_norm == new_norm: + raise ValueError( + f"A scenario with a conflicting name already exists: '{sc.get('label')}' " + f"(both normalize to '{new_norm}')" + ) + + # Add new scenario. + # When a global `parameters` block exists, store only overrides (label + diffs). + # For legacy experiments without globals, include the full effective parameter + # set so the scenario entry remains valid on reload. + global_params = data.get("parameters", {}) + if global_params: + new_scenario: Dict[str, Any] = {"label": label, **override_params} + else: + # Derive defaults from the first existing scenario (minus meta keys) + existing = data.get("scenarios", []) + base_params: Dict[str, Any] = {} + if existing: + meta_keys = {"label", "status"} + base_params = {k: v for k, v in existing[0].items() if k not in meta_keys} + new_scenario = {"label": label, **base_params, **override_params} + data.setdefault("scenarios", []).append(new_scenario) + + # Write back (atomic to prevent corruption on crash) + from spkmc.web import atomic_json_write + + atomic_json_write(data_file, data) + + +def _start_scenario_run( + experiment: Experiment, scenario: Scenario, runner: SimulationRunner +) -> None: + """Start a scenario simulation run.""" + exp_path = experiment.path + assert exp_path is not None + + run_id = runner.run_scenario(experiment, scenario, show_progress=True) + if run_id: + status_info = runner.get_status(run_id) + if status_info: + # Key by scenario_id (stable, matches render lookups). + # Store run_id inside info for status file lookups. + scenario_id = f"sim--{exp_path.name}--{scenario.normalized_label}" + status_info["run_id"] = run_id + SessionState.add_running_simulation(scenario_id, status_info) + st.rerun() + + +def _backup_scenario_artifacts( + result_file: Path, analysis_file: Path, has_result: bool +) -> Tuple[Optional[Path], Optional[Path]]: + """Rename scenario result/analysis to ``.bak`` so the UI sees them as absent. + + Returns ``(backup_result, backup_analysis)`` paths (None when the + original did not exist). Call :func:`_finalize_artifact_backups` + after the launch attempt to discard or restore backups. + """ + backup_result: Optional[Path] = None + backup_analysis: Optional[Path] = None + if has_result and result_file.exists(): + backup_result = result_file.with_suffix(".json.bak") + result_file.rename(backup_result) + if analysis_file.exists(): + backup_analysis = analysis_file.with_suffix(".md.bak") + analysis_file.rename(backup_analysis) + return backup_result, backup_analysis + + +def _finalize_artifact_backups( + result_file: Path, + analysis_file: Path, + backup_result: Optional[Path], + backup_analysis: Optional[Path], + launch_ok: bool, +) -> None: + """Keep backups on successful launch, or restore originals on failure. + + On launch success the `.bak` files are intentionally kept: the subprocess + has only *started* — it may still fail later. The backups are harmless + (the UI only checks canonical filenames, never `.bak`) and act as a + safety net. They are naturally superseded when the scenario is re-run + successfully or when the experiment/scenario is deleted. + """ + if not launch_ok: + if backup_result is not None and backup_result.exists(): + backup_result.rename(result_file) + if backup_analysis is not None and backup_analysis.exists(): + backup_analysis.rename(analysis_file) + + +def _start_scenario_run_no_rerun( + experiment: Experiment, scenario: Scenario, runner: SimulationRunner +) -> str | None: + """Start a scenario run without calling st.rerun(). + + Returns the scenario_id if started successfully, None otherwise. + Used by the modal to keep the dialog open while polling progress. + """ + exp_path = experiment.path + assert exp_path is not None + + run_id = runner.run_scenario(experiment, scenario, show_progress=True) + if run_id: + status_info = runner.get_status(run_id) + if status_info: + scenario_id = f"sim--{exp_path.name}--{scenario.normalized_label}" + status_info["run_id"] = run_id + SessionState.add_running_simulation(scenario_id, status_info) + return scenario_id + return None + + +@st.dialog("Re-run Scenario") +def show_rerun_scenario_dialog( + experiment: Experiment, + scenario: Scenario, + runner: SimulationRunner, +) -> None: + """Show confirmation dialog before re-running a completed scenario.""" + exp_path = experiment.path + assert exp_path is not None + + scope = f"sim--{exp_path.name}--{scenario.normalized_label}" + st.markdown( + f"**{scenario.label}** already has results. " + "Re-running will overwrite the existing data.", + ) + + col_cancel, col_rerun = st.columns(2) + with col_cancel: + if st.button("Cancel", key=f"rerun_cancel_{scope}", width="stretch"): + st.rerun() + with col_rerun: + if st.button( + "Re-run", + key=f"rerun_confirm_{scope}", + type="primary", + width="stretch", + icon=":material/play_arrow:", + ): + # Move stale artifacts to .bak BEFORE spawning so + # fast scenarios don't race. Restore if launch fails. + result_file = exp_path / f"{scenario.normalized_label}.json" + analysis_file = exp_path / f"{scenario.normalized_label}_analysis.md" + bak_r, bak_a = _backup_scenario_artifacts( + result_file, analysis_file, result_file.exists() + ) + sid = _start_scenario_run_no_rerun(experiment, scenario, runner) + _finalize_artifact_backups(result_file, analysis_file, bak_r, bak_a, sid is not None) + st.rerun() + + +@st.dialog("Delete Scenario") +def show_delete_scenario_dialog(experiment: Experiment, scenario: Scenario) -> None: + """Show confirmation dialog before deleting a scenario.""" + exp_path = experiment.path + assert exp_path is not None + + scope = f"sim--{exp_path.name}--{scenario.normalized_label}" + st.markdown( + f"Are you sure you want to delete **{scenario.label}**?", + ) + + # Block deletion of the last scenario (would make experiment unloadable) + if len(experiment.scenarios) <= 1: + st.error("Cannot delete the only scenario in an experiment.") + if st.button("Close", key=f"del_close_{scope}", width="stretch"): + st.rerun() + return + + result_file = exp_path / f"{scenario.normalized_label}.json" + if result_file.exists(): + st.warning("This scenario has results that will also be deleted.") + + col_cancel, col_delete = st.columns(2) + with col_cancel: + if st.button("Cancel", key=f"del_cancel_{scope}", width="stretch"): + st.rerun() + with col_delete: + if st.button( + "Delete", + key=f"del_confirm_{scope}", + type="primary", + width="stretch", + ): + delete_scenario_from_experiment(experiment, scenario) + st.rerun() + + +def delete_scenario_from_experiment(experiment: Experiment, scenario: Scenario) -> None: + """ + Remove a scenario from an experiment. + + Deletes the scenario entry from data.json and removes the result file + if it exists. + + Args: + experiment: The parent experiment + scenario: The scenario to delete + """ + exp_path = experiment.path + assert exp_path is not None + + # Remove from data.json + data_file = exp_path / "data.json" + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + + scenarios_list = data.get("scenarios", []) + new_scenarios = [s for s in scenarios_list if s.get("label") != scenario.label] + if not new_scenarios: + raise ValueError("Cannot delete the last scenario in an experiment") + data["scenarios"] = new_scenarios + + # Write back (atomic to prevent corruption on crash) + from spkmc.web import atomic_json_write + + atomic_json_write(data_file, data) + + # Remove result and analysis files if they exist + result_file = exp_path / f"{scenario.normalized_label}.json" + if result_file.exists(): + result_file.unlink() + analysis_file = exp_path / f"{scenario.normalized_label}_analysis.md" + if analysis_file.exists(): + analysis_file.unlink() + + +@st.dialog("Edit Scenario", width="large") +def show_edit_scenario_modal(experiment: Experiment, scenario: Scenario) -> None: + """Show modal to edit an existing scenario.""" + exp_path = experiment.path + assert exp_path is not None + + edit_scope = f"sim--{exp_path.name}--{scenario.normalized_label}" + st.title(f"Edit: {scenario.label}") + + original_label = scenario.label + label = st.text_input( + "Scenario Label", + value=scenario.label, + help="Descriptive name for this scenario", + key=f"edit_sc_label_{edit_scope}", + ) + + st.subheader("Parameter Overrides") + st.caption( + "Values are pre-filled with this scenario's current settings. " + "Only values that differ from experiment defaults will be saved as overrides." + ) + + global_params = experiment.parameters + scenario_dict = scenario.model_dump(by_alias=True) + override_params: Dict[str, Any] = {} + + # -- Network Overrides -- + with st.expander("Network Overrides", expanded=False): + network_options = ["er", "sf", "cg", "rrn"] + network_names = { + "er": "Erdos-Renyi", + "sf": "Scale-Free", + "cg": "Complete Graph", + "rrn": "Random Regular", + } + current_network = scenario_dict.get("network", global_params.get("network", "er")) + global_network = global_params.get("network", "er") + current_idx = ( + network_options.index(current_network) if current_network in network_options else 0 + ) + override_network = st.selectbox( + "Network Type", + options=network_options, + format_func=lambda x: network_names.get(x, x), + index=current_idx, + key=f"edit_sc_network_{edit_scope}", + help=f"Experiment default: {network_names.get(global_network, global_network)}", + ) + network_changed = not _values_equal(override_network, global_network) + if network_changed: + override_params["network"] = override_network + + col_n1, col_n2 = st.columns(2) + with col_n1: + global_nodes = int(global_params.get("nodes", 1000)) + current_nodes = int(scenario_dict.get("nodes", global_nodes)) + override_nodes = st.number_input( + "Nodes", + min_value=1, + value=current_nodes, + step=100, + key=f"edit_sc_nodes_{edit_scope}", + help=f"Experiment default: {global_nodes}", + ) + if not _values_equal(override_nodes, global_nodes): + override_params["nodes"] = override_nodes + + with col_n2: + global_k_avg = float(global_params.get("k_avg", 10.0)) + current_k_avg = float(scenario_dict.get("k_avg", global_k_avg)) + override_k_avg = st.number_input( + "Average Degree (k_avg)", + min_value=0.1, + value=current_k_avg, + step=1.0, + key=f"edit_sc_k_avg_{edit_scope}", + help=f"Experiment default: {global_k_avg}", + ) + # Always include k_avg when network type is overridden (required for er/sf/rrn) + if not _values_equal(override_k_avg, global_k_avg) or network_changed: + override_params["k_avg"] = override_k_avg + + effective_network = override_network or current_network + if effective_network == "sf": + global_exponent = float(global_params.get("exponent", 2.5)) + current_exponent = float(scenario_dict.get("exponent", global_exponent)) + override_exponent = st.number_input( + "Power-law Exponent", + min_value=0.1, + value=current_exponent, + step=0.1, + key=f"edit_sc_exponent_{edit_scope}", + help=f"Experiment default: {global_exponent}", + ) + # Always include exponent when network type is overridden to sf + if not _values_equal(override_exponent, global_exponent) or network_changed: + override_params["exponent"] = override_exponent + + # -- Distribution Overrides -- + with st.expander("Distribution Overrides", expanded=False): + dist_options = ["gamma", "exponential"] + global_dist = global_params.get("distribution", "gamma") + current_dist = scenario_dict.get("distribution", global_dist) + current_dist_idx = dist_options.index(current_dist) if current_dist in dist_options else 0 + override_dist = st.selectbox( + "Distribution Type", + options=dist_options, + format_func=lambda x: x.capitalize(), + index=current_dist_idx, + key=f"edit_sc_distribution_{edit_scope}", + help=f"Experiment default: {global_dist.capitalize()}", + ) + if not _values_equal(override_dist, global_dist): + override_params["distribution"] = override_dist + + global_lambda = float(global_params.get("lambda", 1.0)) + current_lambda = float(scenario_dict.get("lambda", global_lambda)) + override_lambda = st.number_input( + "Infection Rate (lambda)", + min_value=0.01, + value=current_lambda, + step=0.1, + key=f"edit_sc_lambda_{edit_scope}", + help=f"Experiment default: {global_lambda}", + ) + if not _values_equal(override_lambda, global_lambda): + override_params["lambda"] = override_lambda + + effective_dist = override_dist or current_dist + # When distribution is overridden, always include the + # distribution-specific required params so that + # Scenario.from_merged() won't fail validation. + dist_changed = not _values_equal(override_dist, global_dist) + if effective_dist == "gamma": + col_d1, col_d2 = st.columns(2) + with col_d1: + global_shape = float(global_params.get("shape", 2.0)) + current_shape = float(scenario_dict.get("shape", global_shape)) + override_shape = st.number_input( + "Shape", + min_value=0.01, + value=current_shape, + step=0.1, + key=f"edit_sc_shape_{edit_scope}", + help=f"Experiment default: {global_shape}", + ) + if not _values_equal(override_shape, global_shape) or dist_changed: + override_params["shape"] = override_shape + with col_d2: + global_scale = float(global_params.get("scale", 1.0)) + current_scale = float(scenario_dict.get("scale", global_scale)) + override_scale = st.number_input( + "Scale", + min_value=0.01, + value=current_scale, + step=0.1, + key=f"edit_sc_scale_{edit_scope}", + help=f"Experiment default: {global_scale}", + ) + if not _values_equal(override_scale, global_scale) or dist_changed: + override_params["scale"] = override_scale + elif effective_dist == "exponential": + global_mu = float(global_params.get("mu", 1.0)) + current_mu = float(scenario_dict.get("mu", global_mu)) + override_mu = st.number_input( + "Recovery Rate (mu)", + min_value=0.01, + value=current_mu, + step=0.1, + key=f"edit_sc_mu_{edit_scope}", + help=f"Experiment default: {global_mu}", + ) + if not _values_equal(override_mu, global_mu) or dist_changed: + override_params["mu"] = override_mu + + # -- Simulation Overrides -- + with st.expander("Simulation Overrides", expanded=False): + col_s1, col_s2 = st.columns(2) + with col_s1: + global_samples = int(global_params.get("samples", 50)) + current_samples = int(scenario_dict.get("samples", global_samples)) + override_samples = st.number_input( + "Samples", + min_value=1, + value=current_samples, + step=10, + key=f"edit_sc_samples_{edit_scope}", + help=f"Experiment default: {global_samples}", + ) + if not _values_equal(override_samples, global_samples): + override_params["samples"] = override_samples + with col_s2: + global_num_runs = int(global_params.get("num_runs", 2)) + current_num_runs = int(scenario_dict.get("num_runs", global_num_runs)) + override_num_runs = st.number_input( + "Number of Runs", + min_value=1, + value=current_num_runs, + step=1, + key=f"edit_sc_num_runs_{edit_scope}", + help=f"Experiment default: {global_num_runs}", + ) + if not _values_equal(override_num_runs, global_num_runs): + override_params["num_runs"] = override_num_runs + + col_s3, col_s4 = st.columns(2) + with col_s3: + global_t_max = float(global_params.get("t_max", 10.0)) + current_t_max = float(scenario_dict.get("t_max", global_t_max)) + override_t_max = st.number_input( + "Max Time (t_max)", + min_value=0.01, + value=current_t_max, + step=1.0, + key=f"edit_sc_t_max_{edit_scope}", + help=f"Experiment default: {global_t_max}", + ) + if not _values_equal(override_t_max, global_t_max): + override_params["t_max"] = override_t_max + with col_s4: + global_steps = int(global_params.get("steps", 100)) + current_steps = int(scenario_dict.get("steps", global_steps)) + override_steps = st.number_input( + "Steps", + min_value=1, + value=current_steps, + step=10, + key=f"edit_sc_steps_{edit_scope}", + help=f"Experiment default: {global_steps}", + ) + if not _values_equal(override_steps, global_steps): + override_params["steps"] = override_steps + + global_initial_perc = float(global_params.get("initial_perc", 0.01)) + current_initial_perc = float(scenario_dict.get("initial_perc", global_initial_perc)) + override_initial_perc = st.number_input( + "Initial Infected Fraction", + min_value=0.001, + max_value=1.0, + value=current_initial_perc, + step=0.01, + format="%.3f", + key=f"edit_sc_initial_perc_{edit_scope}", + help=f"Experiment default: {global_initial_perc}", + ) + if not _values_equal(override_initial_perc, global_initial_perc): + override_params["initial_perc"] = override_initial_perc + + # Action buttons (pinned to bottom via CSS on modal_actions container) + with st.container(key=f"modal_actions_edit_{edit_scope}"): + st.divider() + col1, col2 = st.columns(2) + + with col1: + if st.button("Cancel", width="stretch", key=f"edit_sc_cancel_{edit_scope}"): + st.rerun() + + with col2: + if st.button( + "Save Changes", + type="primary", + width="stretch", + icon=":material/save:", + key=f"edit_sc_save_{edit_scope}", + ): + if not label: + st.error("Please provide a scenario label") + return + + try: + update_scenario_in_experiment( + experiment, original_label, label, override_params + ) + st.success(f"Scenario '{label}' updated successfully!") + st.rerun() + except Exception as e: + st.error(f"Failed to update scenario: {str(e)}") + + +def update_scenario_in_experiment( + experiment: Experiment, + original_label: str, + new_label: str, + override_params: Dict[str, Any], +) -> None: + """Update an existing scenario in the experiment's data.json. + + Args: + experiment: The parent experiment + original_label: The scenario's current label (for lookup) + new_label: The new label (may be same as original) + override_params: Parameters that override global settings + + Raises: + ValueError: If new_label normalizes to the same value as another scenario + """ + from spkmc.models.scenario import Scenario as ScenarioModel + + exp_path = experiment.path + assert exp_path is not None + + data_file = exp_path / "data.json" + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + + # Check that label normalizes to a non-empty filename + new_norm = ScenarioModel.normalize_label(new_label) + if not new_norm: + raise ValueError( + f"Scenario label '{new_label}' normalizes to an empty filename. " + "Use a label with at least one alphanumeric character." + ) + + # Check for normalized label collision (excluding the scenario being edited) + for sc in data.get("scenarios", []): + if sc.get("label") == original_label: + continue + existing_norm = ScenarioModel.normalize_label(sc.get("label", "")) + if existing_norm == new_norm: + raise ValueError( + f"A scenario with a conflicting name already exists: '{sc.get('label')}' " + f"(both normalize to '{new_norm}')" + ) + + # Detect whether anything actually changed before writing. + scenarios_list = data.get("scenarios", []) + old_entry: Dict[str, Any] = {} + old_index = -1 + for i, s in enumerate(scenarios_list): + if s.get("label") == original_label: + old_entry = dict(s) + old_index = i + break + + old_norm = ScenarioModel.normalize_label(original_label) + label_changed = new_label != original_label + + # Compare *effective* parameters (globals + overrides) so legacy experiments + # that store full params in scenario entries don't trigger false positives. + global_params = data.get("parameters", {}) + meta_keys = {"label", "status"} + old_overrides = {k: v for k, v in old_entry.items() if k not in meta_keys} + effective_old = {**global_params, **old_overrides} + if global_params: + # Modern format: effective_new is globals + new overrides. + effective_new = {**global_params, **override_params} + else: + # Legacy format: override_params only contains values that differ from + # hardcoded form defaults — a SUBSET of the full param set. Start from + # effective_old so that keys present in old_overrides but matching the + # hardcoded defaults don't produce false-positive diffs. + effective_new = {**effective_old, **override_params} + params_changed = effective_old != effective_new + + if not label_changed and not params_changed: + return # No-op: nothing changed, preserve result files and data.json + + # Replace the scenario entry and mark as edited. + # When a global `parameters` block exists, store only overrides. + # For legacy experiments without globals, preserve the full parameter + # set so the scenario entry remains valid on reload. + if old_index >= 0: + if global_params: + scenarios_list[old_index] = { + "label": new_label, + "status": "edited", + **override_params, + } + else: + # Legacy format: keep all existing params, apply overrides on top + saved_entry = {k: v for k, v in old_entry.items() if k not in meta_keys} + scenarios_list[old_index] = { + "label": new_label, + "status": "edited", + **saved_entry, + **override_params, + } + + # Write back (atomic to prevent corruption on crash) + from spkmc.web import atomic_json_write + + atomic_json_write(data_file, data) + + # Delete stale result/analysis files since parameters or label changed. + old_result = exp_path / f"{old_norm}.json" + if old_result.exists(): + old_result.unlink() + old_analysis = exp_path / f"{old_norm}_analysis.md" + if old_analysis.exists(): + old_analysis.unlink() + + +def run_ai_analysis(experiment: Experiment) -> None: + """ + Launch subprocess-based AI analysis on an experiment. + + Args: + experiment: The experiment to analyze + """ + exp_path = experiment.path + assert exp_path is not None + + api_key = WebConfig.get_openai_api_key() + if not api_key: + st.error("OpenAI API key not found. Please set it in Preferences.") + return + + # Check there are completed scenarios + has_results = any( + (exp_path / f"{sc.normalized_label}.json").exists() for sc in experiment.scenarios + ) + if not has_results: + st.warning("No completed scenarios to analyze. Run some scenarios first.") + return + + config = st.session_state.config + model = config.get("ai_model", "gpt-4o-mini") + + if "analysis_runner" not in st.session_state: + st.session_state.analysis_runner = AnalysisRunner() + runner: AnalysisRunner = st.session_state.analysis_runner + + run_id = runner.run_experiment_analysis( + experiment_path=exp_path, + experiment_name=experiment.name, + experiment_description=experiment.description or "No description provided", + model=model, + api_key=api_key, + ) + + if run_id: + analysis_id = f"exp_analysis--{exp_path.name}" + SessionState.add_running_analysis( + analysis_id, + { + "experiment_name": exp_path.name, + "analysis_type": "experiment", + "scenario_normalized": "", + "run_id": run_id, + "status": "running", + }, + ) + st.toast("Analysis started...") + st.rerun() + + +def run_scenario_ai_analysis( + experiment: Experiment, + scenario: Scenario, + result_file: Path, +) -> None: + """ + Launch subprocess-based AI analysis on a single scenario. + + Args: + experiment: The parent experiment + scenario: The scenario to analyze + result_file: Path to the scenario's result JSON file + """ + exp_path = experiment.path + assert exp_path is not None + + api_key = WebConfig.get_openai_api_key() + if not api_key: + st.error("OpenAI API key not found. Please set it in Preferences.") + return + + config = st.session_state.config + model = config.get("ai_model", "gpt-4o-mini") + + if "analysis_runner" not in st.session_state: + st.session_state.analysis_runner = AnalysisRunner() + runner: AnalysisRunner = st.session_state.analysis_runner + + run_id = runner.run_scenario_analysis( + experiment_path=exp_path, + scenario_label=scenario.label, + scenario_normalized=scenario.normalized_label, + model=model, + api_key=api_key, + ) + + if run_id: + analysis_id = f"sc_analysis--{exp_path.name}--{scenario.normalized_label}" + SessionState.add_running_analysis( + analysis_id, + { + "experiment_name": exp_path.name, + "analysis_type": "scenario", + "scenario_normalized": scenario.normalized_label, + "run_id": run_id, + "status": "running", + }, + ) + st.toast("Scenario analysis started...") + st.rerun(scope="fragment") diff --git a/spkmc/web/pages/settings.py b/spkmc/web/pages/settings.py new file mode 100644 index 0000000..19b581e --- /dev/null +++ b/spkmc/web/pages/settings.py @@ -0,0 +1,476 @@ +""" +Settings page - configure web interface preferences and API keys. + +Manages OpenAI API keys, AI model selection, directory paths, +chart preferences, default simulation parameters, and export format. + +All changes auto-save on widget interaction (no Save button required). +""" + +from __future__ import annotations + +import textwrap + +import streamlit as st + +from spkmc.web.config import WebConfig +from spkmc.web.styles import COLORS, FONTS, page_header + + +def _dedent(html: str) -> str: + """Strip leading whitespace from HTML to prevent Markdown code-block rendering.""" + return textwrap.dedent(html).strip() + + +# ── SVG icons for section headers (Feather/Lucide style, 16x16) ── + +_ICON_AI = ( + '' +) + +_ICON_CHART = ( + '' + '' +) + +_ICON_SLIDERS = ( + '' + '' + '' +) + +_ICON_FOLDER = ( + '' +) + +_ICON_ALERT = ( + '' + '' + '' +) + + +# ── HTML helpers ─────────────────────────────────────────── + + +def _section_icon( + title: str, + subtitle: str, + icon_svg: str, + icon_bg: str = "", + icon_color: str = "", +) -> str: + """Create a section header with icon for the preferences page.""" + bg = icon_bg or COLORS["teal_100"] + color = icon_color or COLORS["teal_500"] + return _dedent(f""" +
+
{icon_svg}
+
+
{title}
+
{subtitle}
+
+
+""") + + +def _sublabel(title: str) -> str: + """Create a small uppercase subsection label inside a card.""" + return _dedent(f""" +
{title}
+""") + + +def _status_badge(configured: bool) -> str: + """Create an API key status badge.""" + if configured: + bg = COLORS["success_bg"] + color = COLORS["success"] + icon = ( + '' + ) + text = "Configured" + else: + bg = COLORS["warning_bg"] + color = COLORS["warning"] + icon = ( + '' + '' + '' + ) + text = "Not configured" + + return _dedent(f""" +
{icon} {text}
+""") + + +# ── Main render ──────────────────────────────────────────── + + +def render() -> None: + """Render the settings page. All values auto-save on change.""" + st.markdown( + page_header( + "Preferences", + subtitle="Configure web interface and simulation defaults", + ), + unsafe_allow_html=True, + ) + + config: WebConfig = st.session_state.config + + # Consume the post-reset flag so auto-save doesn't overwrite defaults. + skip_autosave = st.session_state.pop("pref_skip_autosave", False) + + # ── AI & Intelligence ───────────────────────────────── + st.markdown( + _section_icon( + "AI & Intelligence", + "API key and model for AI-powered analysis", + _ICON_AI, + ), + unsafe_allow_html=True, + ) + + with st.container(key="pref_card_ai"): + current_key = WebConfig.get_openai_api_key() + st.markdown(_status_badge(bool(current_key)), unsafe_allow_html=True) + + col_key, col_model = st.columns([3, 1]) + + with col_key: + new_key = st.text_input( + "API Key", + value=current_key or "", + type="password", + placeholder="sk-...", + help="Your OpenAI API key for AI analysis features", + ) + + with col_model: + model_options = [ + "gpt-4o-mini", + "gpt-4o", + "gpt-4.1-mini", + "gpt-4.1", + "o3-mini", + ] + current_model = config.get("ai_model", "gpt-4o-mini") + model_index = ( + model_options.index(current_model) if current_model in model_options else 0 + ) + selected_model = st.selectbox( + "AI Model", + options=model_options, + index=model_index, + help="OpenAI model used for AI analysis", + ) + + # ── Visualization ───────────────────────────────────── + st.markdown( + _section_icon( + "Visualization", + "Chart appearance and SIR state colors", + _ICON_CHART, + ), + unsafe_allow_html=True, + ) + + with st.container(key="pref_card_viz"): + col_chart, col_sep, col_colors = st.columns([10, 1, 5]) + + with col_chart: + sub_h, sub_t = st.columns(2) + + with sub_h: + chart_height = st.number_input( + "Default Height (px)", + min_value=300, + max_value=1000, + value=config.get("chart_height", 500), + step=50, + ) + + with sub_t: + template_options = [ + "plotly_white", + "plotly_dark", + "simple_white", + "ggplot2", + ] + current_template = config.get("chart_template", "plotly_white") + template_index = ( + template_options.index(current_template) + if current_template in template_options + else 0 + ) + chart_template = st.selectbox( + "Template", + options=template_options, + index=template_index, + ) + + with col_sep: + st.markdown( + _dedent(""" +
+
+
+"""), + unsafe_allow_html=True, + ) + + with col_colors: + sub_s, sub_i, sub_r = st.columns([1, 1, 1], gap="small") + with sub_s: + color_s = st.color_picker( + "Susceptible", + value=config.get("chart_color_s", "#4477AA"), + ) + with sub_i: + color_i = st.color_picker( + "Infected", + value=config.get("chart_color_i", "#EE6677"), + ) + with sub_r: + color_r = st.color_picker( + "Recovered", + value=config.get("chart_color_r", "#228833"), + ) + + # ── Simulation Defaults ─────────────────────────────── + st.markdown( + _section_icon( + "Simulation Defaults", + "Default parameters for new experiments", + _ICON_SLIDERS, + ), + unsafe_allow_html=True, + ) + + with st.container(key="pref_card_sim"): + col_net, col_dist, col_sim = st.columns(3) + + with col_net: + st.markdown(_sublabel("Network"), unsafe_allow_html=True) + + default_nodes = st.number_input( + "Nodes", + min_value=10, + max_value=100000, + value=config.get("default_nodes", 1000), + step=100, + ) + default_k_avg = st.number_input( + "k_avg", + min_value=1.0, + max_value=100.0, + value=float(config.get("default_k_avg", 10.0)), + step=1.0, + ) + default_exponent = st.number_input( + "Exponent (SF)", + min_value=2.0, + max_value=5.0, + value=float(config.get("default_exponent", 2.5)), + step=0.1, + ) + + with col_dist: + st.markdown(_sublabel("Distribution"), unsafe_allow_html=True) + + default_shape = st.number_input( + "Shape (Gamma)", + min_value=0.1, + max_value=10.0, + value=float(config.get("default_shape", 2.0)), + step=0.1, + ) + default_scale = st.number_input( + "Scale (Gamma)", + min_value=0.1, + max_value=10.0, + value=float(config.get("default_scale", 1.0)), + step=0.1, + ) + default_mu = st.number_input( + "\u03bc (Exponential)", + min_value=0.01, + max_value=10.0, + value=float(config.get("default_mu", 1.0)), + step=0.1, + ) + default_lambda = st.number_input( + "\u03bb", + min_value=0.01, + max_value=10.0, + value=float(config.get("default_lambda", 1.0)), + step=0.1, + ) + + with col_sim: + st.markdown(_sublabel("Simulation"), unsafe_allow_html=True) + + default_samples = st.number_input( + "Samples", + min_value=1, + max_value=10000, + value=config.get("default_samples", 50), + step=10, + ) + default_num_runs = st.number_input( + "Runs per scenario", + min_value=1, + max_value=100, + value=config.get("default_num_runs", 2), + step=1, + ) + default_t_max = st.number_input( + "t_max", + min_value=0.1, + max_value=1000.0, + value=float(config.get("default_t_max", 10.0)), + step=1.0, + ) + default_steps = st.number_input( + "Steps", + min_value=10, + max_value=10000, + value=config.get("default_steps", 100), + step=10, + ) + default_initial_perc = ( + st.number_input( + "Initial % infected", + min_value=0.01, + max_value=100.0, + value=float(config.get("default_initial_perc", 0.01)) * 100, + step=0.1, + ) + / 100.0 + ) + + # ── Storage & Export ────────────────────────────────── + st.markdown( + _section_icon( + "Storage & Export", + "Directory paths and default export format", + _ICON_FOLDER, + ), + unsafe_allow_html=True, + ) + + with st.container(key="pref_card_storage"): + st.markdown(_sublabel("Directories"), unsafe_allow_html=True) + col_data, col_exp = st.columns(2) + + with col_data: + data_dir = st.text_input( + "Data Directory", + value=config.get("data_directory", "data"), + help="Where simulation results are stored", + ) + + with col_exp: + experiments_dir = st.text_input( + "Experiments Directory", + value=config.get("experiments_directory", "experiments"), + help="Where experiment configurations are stored", + ) + + # ── Auto-save ───────────────────────────────────────── + + # Save API key when it changes (writes to secrets.toml). + # Rerun immediately to refresh the status badge. + old_key = current_key or "" + new_key_str = new_key or "" + if new_key_str != old_key: + try: + WebConfig.set_openai_api_key(new_key_str) + st.rerun() + except Exception as e: + st.error(f"Failed to save API key: {str(e)}") + + # Auto-save all other config values on every render pass, + # except the render immediately after a reset (to preserve defaults). + if not skip_autosave: + config.update( + { + "ai_model": selected_model, + "chart_height": chart_height, + "chart_template": chart_template, + "chart_color_s": color_s, + "chart_color_i": color_i, + "chart_color_r": color_r, + "default_nodes": default_nodes, + "default_k_avg": default_k_avg, + "default_exponent": default_exponent, + "default_shape": default_shape, + "default_scale": default_scale, + "default_mu": default_mu, + "default_lambda": default_lambda, + "default_samples": default_samples, + "default_num_runs": default_num_runs, + "default_t_max": default_t_max, + "default_steps": default_steps, + "default_initial_perc": default_initial_perc, + "data_directory": data_dir, + "experiments_directory": experiments_dir, + } + ) + + # ── Danger Zone ─────────────────────────────────────── + st.markdown( + _section_icon( + "Danger Zone", + "Irreversible actions", + _ICON_ALERT, + icon_bg=COLORS["error_bg"], + icon_color=COLORS["error"], + ), + unsafe_allow_html=True, + ) + + with st.container(key="pref_card_danger"): + col_text, col_btn = st.columns([3, 1]) + + with col_text: + st.markdown( + _dedent(f""" +
+Reset all preferences to their default values. This action cannot be undone. +
+"""), + unsafe_allow_html=True, + ) + + with col_btn: + with st.container(key="pref_reset"): + if st.button("Reset all", type="secondary"): + config.config = WebConfig.DEFAULTS.copy() + config.save() + # Skip auto-save on the next render so defaults are preserved. + st.session_state["pref_skip_autosave"] = True + st.rerun() diff --git a/spkmc/web/plotting.py b/spkmc/web/plotting.py new file mode 100644 index 0000000..26491ee --- /dev/null +++ b/spkmc/web/plotting.py @@ -0,0 +1,291 @@ +""" +Plotly figure builders for the web interface. + +Provides functions to create interactive Plotly charts for SIR simulation results +and comparisons. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots + +# SIR state colors (matching the plan) +COLOR_S = "#4477AA" +COLOR_I = "#EE6677" +COLOR_R = "#228833" + +STATE_COLORS = { + "S": COLOR_S, + "I": COLOR_I, + "R": COLOR_R, +} + + +def _hex_to_rgba(hex_color: str, alpha: float = 1.0) -> str: + """Convert a hex color string to an rgba() string.""" + hex_color = hex_color.lstrip("#") + r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) + return f"rgba({r}, {g}, {b}, {alpha})" + + +def create_sir_figure( + result_dict: Dict[str, Any], + title: str = "SIR Dynamics", + states: Optional[List[str]] = None, + show_error_bands: bool = True, + height: int = 500, + chart_mode: str = "lines", + state_colors: Optional[Dict[str, str]] = None, + template: str = "plotly_white", +) -> go.Figure: + """ + Create an interactive Plotly figure for a single SIR simulation result. + + Args: + result_dict: Result dictionary containing S_val, I_val, R_val, time, etc. + title: Plot title + states: List of states to plot (default: ['S', 'I', 'R']) + show_error_bands: Whether to show error bands (if S_err, I_err, R_err exist) + height: Figure height in pixels + chart_mode: One of "lines", "lines+markers", or "area" + state_colors: Override colors for S/I/R states (keys: "S", "I", "R") + template: Plotly template name + + Returns: + Plotly Figure object + """ + if states is None: + states = ["S", "I", "R"] + + effective_colors = {**STATE_COLORS, **(state_colors or {})} + + fig = go.Figure() + + time = result_dict.get("time", []) + has_errors = show_error_bands and "S_err" in result_dict + + for state in states: + state_upper = state.upper() + val_key = f"{state_upper}_val" + err_key = f"{state_upper}_err" + + if val_key not in result_dict: + continue + + y_val = result_dict[val_key] + color = effective_colors.get(state_upper, "#666666") + + # Determine trace mode and fill from chart_mode + if chart_mode == "lines+markers": + trace_mode = "lines+markers" + trace_fill = None + trace_fillcolor = None + elif chart_mode == "area": + trace_mode = "lines" + trace_fill = "tozeroy" + trace_fillcolor = _hex_to_rgba(color, 0.15) + else: + trace_mode = "lines" + trace_fill = None + trace_fillcolor = None + + # Build error_y config if applicable + error_y_config = None + if has_errors and err_key in result_dict: + y_err = result_dict[err_key] + error_y_config = dict( + type="data", + array=y_err, + visible=True, + color=color, + thickness=1.5, + width=4, + ) + + # Main line (with optional error bars attached) + fig.add_trace( + go.Scatter( + x=time, + y=y_val, + mode=trace_mode, + name=state_upper, + line=dict(color=color, width=2), + fill=trace_fill, + fillcolor=trace_fillcolor, + error_y=error_y_config, + hovertemplate=f"{state_upper}: %{{y:.4f}}
Time: %{{x:.2f}}", + ) + ) + + # Compute explicit x-axis range to prevent layout shift + # when error bars or markers add visual padding + x_max = float(max(time)) if len(time) > 0 else 1 + x_range = [0, x_max] + + # Layout + fig.update_layout( + title=dict(text=title, x=0.5, xanchor="center"), + xaxis=dict( + title="Time", + showgrid=True, + gridcolor="rgba(0,0,0,0.1)", + range=x_range, + ), + yaxis=dict( + title="Proportion of Population", + showgrid=True, + gridcolor="rgba(0,0,0,0.1)", + range=[0, 1], + ), + template=template, + height=height, + hovermode="x unified", + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1, + ), + ) + + return fig + + +def create_comparison_figure( + results: List[Dict[str, Any]], + labels: List[str], + title: str = "Scenario Comparison", + states: Optional[List[str]] = None, + height: int = 600, + template: str = "plotly_white", +) -> go.Figure: + """ + Create an interactive Plotly figure comparing multiple SIR simulation results. + + Args: + results: List of result dictionaries + labels: List of labels for each result + title: Plot title + states: List of states to plot (default: ['S', 'I', 'R']) + height: Figure height in pixels + template: Plotly template name + + Returns: + Plotly Figure object + """ + if states is None: + states = ["S", "I", "R"] + + fig = go.Figure() + + # Color palette for different scenarios (cycling if needed) + scenario_colors = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", + ] + + for idx, (result_dict, label) in enumerate(zip(results, labels)): + time = result_dict.get("time", []) + base_color = scenario_colors[idx % len(scenario_colors)] + + for state in states: + state_upper = state.upper() + val_key = f"{state_upper}_val" + + if val_key not in result_dict: + continue + + y_val = result_dict[val_key] + + # Use different line styles for different states + line_dash = "solid" if state_upper == "I" else "dot" if state_upper == "S" else "dash" + + fig.add_trace( + go.Scatter( + x=time, + y=y_val, + mode="lines", + name=f"{label} - {state_upper}", + line=dict(color=base_color, width=2, dash=line_dash), + hovertemplate=f"{label} - {state_upper}: %{{y:.4f}}
Time: %{{x:.2f}}", + ) + ) + + # Layout + fig.update_layout( + title=dict(text=title, x=0.5, xanchor="center"), + xaxis=dict( + title="Time", + showgrid=True, + gridcolor="rgba(0,0,0,0.1)", + ), + yaxis=dict( + title="Proportion of Population", + showgrid=True, + gridcolor="rgba(0,0,0,0.1)", + range=[0, 1], + ), + template=template, + height=height, + hovermode="x unified", + legend=dict( + orientation="v", + yanchor="top", + y=1, + xanchor="left", + x=1.02, + ), + ) + + return fig + + +def create_metric_card_figure( + value: float, + title: str, + subtitle: str = "", + color: str = "#4477AA", +) -> go.Figure: + """ + Create a simple metric card figure for displaying key statistics. + + Args: + value: The metric value to display + title: Metric title + subtitle: Optional subtitle/description + color: Color for the metric value + + Returns: + Plotly Figure object (minimal chart used as a card) + """ + fig = go.Figure() + + # Create a minimal figure that just displays the metric + fig.add_trace( + go.Indicator( + mode="number", + value=value, + title={"text": f"{title}
{subtitle}"}, + number={"font": {"size": 48, "color": color}}, + ) + ) + + fig.update_layout( + height=150, + margin=dict(l=20, r=20, t=40, b=20), + ) + + return fig diff --git a/spkmc/web/runner.py b/spkmc/web/runner.py new file mode 100644 index 0000000..87cc346 --- /dev/null +++ b/spkmc/web/runner.py @@ -0,0 +1,442 @@ +""" +Subprocess-based simulation runner for the web interface. + +Runs simulations in background subprocesses so they survive browser refresh +and UI interactions. +""" + +from __future__ import annotations + +import json +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, cast + +import psutil +import streamlit as st + +from spkmc.models import Experiment, Scenario + + +class SimulationRunner: + """Manages subprocess-based simulation execution.""" + + def __init__(self) -> None: + """Initialize the simulation runner.""" + self.status_dir = Path(".spkmc_web") / "status" + self.status_dir.mkdir(parents=True, exist_ok=True) + # Retain Popen handles so we can reap children and avoid zombies + self._processes: Dict[str, subprocess.Popen] = {} # type: ignore[type-arg] + + def run_scenario( + self, experiment: Experiment, scenario: Scenario, show_progress: bool = True + ) -> Optional[str]: + """ + Launch a subprocess to run a single scenario. + + Args: + experiment: The parent experiment + scenario: Scenario to execute + show_progress: Whether to show progress indicators + + Returns: + Subprocess ID if launched successfully, None otherwise + """ + assert experiment.path is not None, "Experiment must have a path to run scenarios" + + # Generate unique ID for this run + run_id = f"sim--{experiment.path.name}--{scenario.normalized_label}--{time.time_ns()}" + + # Create status file + status_file = self.status_dir / f"{run_id}.json" + status_data = { + "run_id": run_id, + "experiment_name": experiment.path.name, + "scenario_label": scenario.label, + "scenario_normalized": scenario.normalized_label, + "status": "starting", + "progress": 0, + "total": scenario.total_samples(), + "start_time": time.time(), + } + + with open(status_file, "w") as f: + json.dump(status_data, f) + + # Build command to execute scenario + # We'll create a simple Python script that calls execute_scenario + script_content = self._build_execution_script(experiment, scenario, run_id) + + # Write temporary script + script_file = self.status_dir / f"{run_id}_script.py" + with open(script_file, "w") as f: + f.write(script_content) + + # Launch subprocess + try: + process = subprocess.Popen( + [sys.executable, str(script_file)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Store process info and retain handle for reaping + status_data["status"] = "running" + status_data["pid"] = process.pid + self._processes[run_id] = process + + with open(status_file, "w") as f: + json.dump(status_data, f) + + if show_progress: + st.toast(f"Started: {scenario.label}") + + return run_id + + except Exception as e: + # Mark as failed + status_data["status"] = "failed" + status_data["error"] = str(e) + + with open(status_file, "w") as f: + json.dump(status_data, f) + + st.error(f"Failed to start simulation: {str(e)}") + return None + + def run_all_scenarios(self, experiment: Experiment, show_progress: bool = True) -> List[str]: + """ + Launch subprocesses to run all scenarios in an experiment. + + Args: + experiment: The experiment to run + show_progress: Whether to show progress indicators + + Returns: + List of run IDs for all launched simulations + """ + assert experiment.path is not None, "Experiment must have a path to run scenarios" + + run_ids = [] + + for scenario in experiment.scenarios: + # Skip if already has results + result_file = experiment.path / f"{scenario.normalized_label}.json" + if result_file.exists(): + continue + + run_id = self.run_scenario(experiment, scenario, show_progress=False) + if run_id: + run_ids.append(run_id) + + if show_progress and run_ids: + st.toast(f"Started {len(run_ids)} simulations") + + return run_ids + + def get_status(self, run_id: str) -> Optional[Dict[str, Any]]: + """ + Get the status of a running or completed simulation. + + Args: + run_id: The run ID to check + + Returns: + Status dictionary or None if not found + """ + status_file = self.status_dir / f"{run_id}.json" + + if not status_file.exists(): + return None + + try: + with open(status_file, "r") as f: + return cast(Dict[str, Any], json.load(f)) + except (json.JSONDecodeError, IOError): + return None + + def is_running(self, run_id: str) -> bool: + """ + Check if a simulation is still running. + + Args: + run_id: The run ID to check + + Returns: + True if running, False otherwise + """ + status = self.get_status(run_id) + return status is not None and status.get("status") == "running" + + def check_completion(self, experiment_name: str, scenario_label: str) -> bool: + """ + Check if a scenario has completed by looking for its result file. + + Args: + experiment_name: Name of the experiment + scenario_label: Label of the scenario + + Returns: + True if result file exists, False otherwise + """ + from spkmc.models.scenario import Scenario + from spkmc.web.config import WebConfig + + normalized = Scenario.normalize_label(scenario_label) + config = WebConfig() + exp_path = config.get_experiments_path() / experiment_name + result_file = exp_path / f"{normalized}.json" + + return result_file.exists() + + def cleanup_status(self, run_id: str) -> None: + """ + Clean up status files and reap the child process for a completed run. + + Args: + run_id: The run ID to clean up + """ + # Reap child process to prevent zombies + proc = self._processes.pop(run_id, None) + if proc is not None: + proc.poll() # Non-blocking reap + + status_file = self.status_dir / f"{run_id}.json" + script_file = self.status_dir / f"{run_id}_script.py" + + if status_file.exists(): + status_file.unlink() + if script_file.exists(): + script_file.unlink() + + def _build_execution_script( + self, experiment: Experiment, scenario: Scenario, run_id: str + ) -> str: + """ + Build a Python script to execute a scenario. + + Args: + experiment: The parent experiment + scenario: Scenario to execute + run_id: Unique run identifier (matches the status file name) + + Returns: + Python script as a string + """ + assert experiment.path is not None, "Experiment must have a path to build script" + + # Pass the exact status file path so the subprocess doesn't need + # to discover it via glob (which is ambiguous for prefix-overlapping labels). + status_file_repr = repr(str(self.status_dir / f"{run_id}.json")) + + script = f""" +import sys +import json +import os +import time +from pathlib import Path + +# Add package to path if needed +sys.path.insert(0, str(Path.cwd())) + +from spkmc.core.engine import ExecutionContext, ExecutionEngine +from spkmc.models import Scenario + +# Exact status file path (set by runner before launching subprocess) +STATUS_FILE = {status_file_repr} + +_progress_count = 0 +_last_write = 0.0 + +def _progress_callback(completed): + global _progress_count, _last_write + _progress_count += completed + now = time.time() + if now - _last_write >= 0.5: + _last_write = now + _write_progress(_progress_count, "running") + +def _write_progress(progress, status, error=None): + if STATUS_FILE is None: + return + try: + with open(STATUS_FILE, "r") as f: + data = json.load(f) + data["progress"] = progress + data["status"] = status + if error: + data["error"] = error + tmp = STATUS_FILE + ".tmp" + with open(tmp, "w") as f: + json.dump(data, f) + os.replace(tmp, STATUS_FILE) + except Exception: + pass + +# Load scenario from experiment +experiment_path = Path({repr(str(experiment.path))}) +scenario_data = {repr(scenario.model_dump_json())} +scenario = Scenario.model_validate_json(scenario_data) + +# Set experiment context +scenario.experiment_name = {repr(experiment.path.name)} +scenario.output_path = str(experiment_path / {repr(scenario.normalized_label + '.json')}) + +# Create execution context with progress callback +context = ExecutionContext( + scenarios=[scenario], + experiment_name={repr(experiment.path.name)}, + results_dir=experiment_path, + no_plot=True, + export_format="json", + on_sample_progress=_progress_callback, +) + +# Execute +engine = ExecutionEngine(verbose=False) +try: + results = engine.execute(context) + _write_progress(scenario.total_samples(), "completed") + print("Execution completed successfully") + sys.exit(0) +except Exception as e: + _write_progress(_progress_count, "failed", error=str(e)) + print(f"Execution failed: {{e}}", file=sys.stderr) + sys.exit(1) +""" + return script + + def get_progress(self, run_id: str) -> Optional[tuple]: + """ + Get progress for a running simulation. + + Args: + run_id: The run ID to check + + Returns: + (progress, total) tuple or None if not available + """ + status = self.get_status(run_id) + if status is None: + return None + progress = status.get("progress", 0) + total = status.get("total", 0) + return (progress, total) + + +def _settle_scenario_backups(exp_name: str, scenario_label: str, succeeded: bool) -> None: + """Clean up or restore ``.bak`` artifacts after a simulation terminates. + + On success the backups are stale and can be removed. On failure the + backups are restored so the user retains the previous successful result. + """ + from spkmc.models.scenario import Scenario as ScenarioModel + from spkmc.web.config import WebConfig + + config = WebConfig() + exp_path = config.get_experiments_path() / exp_name + normalized = ScenarioModel.normalize_label(scenario_label) + + result_bak = exp_path / f"{normalized}.json.bak" + analysis_bak = exp_path / f"{normalized}_analysis.md.bak" + + if succeeded: + result_bak.unlink(missing_ok=True) + analysis_bak.unlink(missing_ok=True) + else: + result_file = exp_path / f"{normalized}.json" + analysis_file = exp_path / f"{normalized}_analysis.md" + if result_bak.exists() and not result_file.exists(): + result_bak.rename(result_file) + else: + result_bak.unlink(missing_ok=True) + if analysis_bak.exists() and not analysis_file.exists(): + analysis_bak.rename(analysis_file) + else: + analysis_bak.unlink(missing_ok=True) + + +def poll_running_simulations() -> None: + """ + Poll all running simulations and update session state. + + Reads progress from status files and marks completed/failed simulations. + Called by the scenario cards fragment every ~2 seconds. + """ + if "simulation_runner" not in st.session_state: + st.session_state.simulation_runner = SimulationRunner() + + runner: SimulationRunner = st.session_state.simulation_runner + + from spkmc.web.state import SessionState + + running_sims = st.session_state.get("running_simulations", {}) + + # Dict is keyed by scenario_id; run_id stored inside info + for scenario_id, info in list(running_sims.items()): + exp_name = info.get("experiment_name") + scenario_label = info.get("scenario_label") + run_id = info.get("run_id", scenario_id) + + if not (exp_name and scenario_label): + continue + + # Read status file for progress + status = runner.get_status(run_id) + if status: + progress = status.get("progress", 0) + total = status.get("total", 0) + file_status = status.get("status", "running") + + # Update progress in session state + if total > 0: + SessionState.set_simulation_progress(scenario_id, progress, total) + + # Check if status file reports completion + if file_status == "completed" or runner.check_completion(exp_name, scenario_label): + SessionState.mark_simulation_completed(scenario_id) + SessionState.clear_simulation_progress(scenario_id) + _settle_scenario_backups(exp_name, scenario_label, succeeded=True) + st.toast(f"Completed: {scenario_label}") + runner.cleanup_status(run_id) + continue + + # Check if status file reports failure + if file_status == "failed": + error_msg = status.get("error", "Unknown error") + SessionState.mark_simulation_failed(scenario_id, error_msg) + SessionState.clear_simulation_progress(scenario_id) + _settle_scenario_backups(exp_name, scenario_label, succeeded=False) + st.toast(f"Failed: {scenario_label}") + runner.cleanup_status(run_id) + continue + + # Check if subprocess died without writing terminal status + if file_status == "running": + pid = status.get("pid") + if pid is not None and not psutil.pid_exists(pid): + # Process no longer exists -- check if output was written + completed = runner.check_completion(exp_name, scenario_label) + if completed: + SessionState.mark_simulation_completed(scenario_id) + st.toast(f"Completed: {scenario_label}") + else: + SessionState.mark_simulation_failed( + scenario_id, "Process exited unexpectedly" + ) + st.toast(f"Failed: {scenario_label}") + SessionState.clear_simulation_progress(scenario_id) + _settle_scenario_backups(exp_name, scenario_label, succeeded=completed) + runner.cleanup_status(run_id) + continue + + # Fallback: check result file directly + elif runner.check_completion(exp_name, scenario_label): + SessionState.mark_simulation_completed(scenario_id) + SessionState.clear_simulation_progress(scenario_id) + _settle_scenario_backups(exp_name, scenario_label, succeeded=True) + st.toast(f"Completed: {scenario_label}") + runner.cleanup_status(run_id) diff --git a/spkmc/web/state.py b/spkmc/web/state.py new file mode 100644 index 0000000..4464ad4 --- /dev/null +++ b/spkmc/web/state.py @@ -0,0 +1,426 @@ +""" +Session state management for the web interface. + +Provides typed access to Streamlit's session state and initialization of +session-level variables. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, Optional, Set, cast + +import psutil +import streamlit as st + + +class SessionState: + """Typed accessor for Streamlit session state.""" + + @staticmethod + def init() -> None: + """Initialize session state with default values.""" + if "initialized" not in st.session_state: + # Navigation state — restore from URL query params if present + st.session_state.current_page = st.query_params.get("page", "dashboard") + st.session_state.selected_experiment = None + + # UI state + st.session_state.selected_scenarios = set() + st.session_state.show_comparison_modal = False + st.session_state.show_scenario_detail_modal = False + st.session_state.selected_scenario_id = None + + # Simulation state + st.session_state.running_simulations = {} # Dict[str, subprocess_info] + st.session_state.completed_simulations = set() # Set[str] + st.session_state.failed_simulations = {} # Dict[str, error_message] + st.session_state.simulation_progress = {} # Dict[str, progress_info] + + # Analysis state (parallel to simulation state) + st.session_state.running_analyses = {} # Dict[str, subprocess_info] + st.session_state.completed_analyses = set() # Set[str] + st.session_state.failed_analyses = {} # Dict[str, error_message] + + # Form state + st.session_state.creating_experiment = False + st.session_state.creating_scenario = False + + # Mark as initialized + st.session_state.initialized = True + + @staticmethod + def get_current_page() -> str: + """Get the current page name.""" + return cast(str, st.session_state.get("current_page", "dashboard")) + + @staticmethod + def set_current_page(page: str) -> None: + """Set the current page and sync to URL query params for refresh persistence.""" + st.session_state.current_page = page + st.query_params["page"] = page + + @staticmethod + def get_selected_experiment() -> Optional[str]: + """Get the currently selected experiment name. + + Falls back to st.query_params to survive page refresh. + """ + name: Optional[str] = cast(Optional[str], st.session_state.get("selected_experiment", None)) + if name is None: + name = cast(Optional[str], st.query_params.get("experiment", None)) + if name: + st.session_state.selected_experiment = name + return name + + @staticmethod + def set_selected_experiment(experiment_name: Optional[str]) -> None: + """Set the currently selected experiment. + + Also syncs to st.query_params so the selection survives refresh. + """ + st.session_state.selected_experiment = experiment_name + # Clear scenario selections and stale UI flags when switching experiments + st.session_state.selected_scenarios = set() + st.session_state.show_comparison_modal = False + st.session_state.show_scenario_detail_modal = False + st.session_state.selected_scenario_id = None + # Sync to query params for refresh persistence + if experiment_name: + st.query_params["experiment"] = experiment_name + else: + st.query_params.pop("experiment", None) + + @staticmethod + def get_selected_scenarios() -> Set[str]: + """Get the set of selected scenario IDs.""" + return cast(Set[str], st.session_state.get("selected_scenarios", set())) + + @staticmethod + def toggle_scenario_selection(scenario_id: str) -> None: + """Toggle a scenario's selection state.""" + selected = st.session_state.get("selected_scenarios", set()) + if scenario_id in selected: + selected.remove(scenario_id) + else: + selected.add(scenario_id) + st.session_state.selected_scenarios = selected + + @staticmethod + def clear_scenario_selections() -> None: + """Clear all scenario selections.""" + st.session_state.selected_scenarios = set() + + @staticmethod + def is_simulation_running(simulation_id: str) -> bool: + """Check if a simulation is currently running.""" + running = st.session_state.get("running_simulations", {}) + return simulation_id in running + + @staticmethod + def add_running_simulation(simulation_id: str, info: Dict[str, Any]) -> None: + """Add a simulation to the running set.""" + if "running_simulations" not in st.session_state: + st.session_state.running_simulations = {} + st.session_state.running_simulations[simulation_id] = info + + @staticmethod + def remove_running_simulation(simulation_id: str) -> None: + """Remove a simulation from the running set.""" + running = st.session_state.get("running_simulations", {}) + if simulation_id in running: + del running[simulation_id] + + @staticmethod + def mark_simulation_completed(simulation_id: str) -> None: + """Mark a simulation as completed.""" + if "completed_simulations" not in st.session_state: + st.session_state.completed_simulations = set() + st.session_state.completed_simulations.add(simulation_id) + # Clear opposite terminal state so reruns don't get stale status + st.session_state.get("failed_simulations", {}).pop(simulation_id, None) + SessionState.remove_running_simulation(simulation_id) + + @staticmethod + def mark_simulation_failed(simulation_id: str, error_message: str) -> None: + """Mark a simulation as failed.""" + if "failed_simulations" not in st.session_state: + st.session_state.failed_simulations = {} + st.session_state.failed_simulations[simulation_id] = error_message + # Clear opposite terminal state so reruns don't get stale status + st.session_state.get("completed_simulations", set()).discard(simulation_id) + SessionState.remove_running_simulation(simulation_id) + + @staticmethod + def get_simulation_status(simulation_id: str) -> str: + """ + Get the status of a simulation. + + Returns: + One of: 'pending', 'running', 'completed', 'failed' + """ + if SessionState.is_simulation_running(simulation_id): + return "running" + # Check failed before completed — a rerun failure must not be masked + # by a stale completion from a previous run. + if simulation_id in st.session_state.get("failed_simulations", {}): + return "failed" + if simulation_id in st.session_state.get("completed_simulations", set()): + return "completed" + return "pending" + + @staticmethod + def set_simulation_progress(sim_id: str, progress: int, total: int) -> None: + """Store progress for a running simulation.""" + if "simulation_progress" not in st.session_state: + st.session_state.simulation_progress = {} + st.session_state.simulation_progress[sim_id] = { + "progress": progress, + "total": total, + } + + @staticmethod + def get_simulation_progress(sim_id: str) -> Optional[Dict[str, int]]: + """Return progress dict or None if not tracked.""" + return cast( + Optional[Dict[str, int]], + st.session_state.get("simulation_progress", {}).get(sim_id), + ) + + @staticmethod + def clear_simulation_progress(sim_id: str) -> None: + """Remove progress tracking for a completed simulation.""" + progress = st.session_state.get("simulation_progress", {}) + progress.pop(sim_id, None) + + # ── Analysis tracking ────────────────────────────────────── + + @staticmethod + def is_analysis_running(analysis_id: str) -> bool: + """Check if an analysis is currently running.""" + running = st.session_state.get("running_analyses", {}) + return analysis_id in running + + @staticmethod + def add_running_analysis(analysis_id: str, info: Dict[str, Any]) -> None: + """Add an analysis to the running set.""" + if "running_analyses" not in st.session_state: + st.session_state.running_analyses = {} + st.session_state.running_analyses[analysis_id] = info + + @staticmethod + def remove_running_analysis(analysis_id: str) -> None: + """Remove an analysis from the running set.""" + running = st.session_state.get("running_analyses", {}) + if analysis_id in running: + del running[analysis_id] + + @staticmethod + def mark_analysis_completed(analysis_id: str) -> None: + """Mark an analysis as completed.""" + if "completed_analyses" not in st.session_state: + st.session_state.completed_analyses = set() + st.session_state.completed_analyses.add(analysis_id) + # Clear opposite terminal state so reruns don't get stale status + st.session_state.get("failed_analyses", {}).pop(analysis_id, None) + SessionState.remove_running_analysis(analysis_id) + + @staticmethod + def mark_analysis_failed(analysis_id: str, error_message: str) -> None: + """Mark an analysis as failed.""" + if "failed_analyses" not in st.session_state: + st.session_state.failed_analyses = {} + st.session_state.failed_analyses[analysis_id] = error_message + # Clear opposite terminal state so reruns don't get stale status + st.session_state.get("completed_analyses", set()).discard(analysis_id) + SessionState.remove_running_analysis(analysis_id) + + @staticmethod + def get_analysis_status(analysis_id: str) -> str: + """ + Get the status of an analysis. + + Returns: + One of: 'pending', 'running', 'completed', 'failed' + """ + if SessionState.is_analysis_running(analysis_id): + return "running" + # Check failed before completed — a rerun failure must not be masked + # by a stale completion from a previous run. + if analysis_id in st.session_state.get("failed_analyses", {}): + return "failed" + if analysis_id in st.session_state.get("completed_analyses", set()): + return "completed" + return "pending" + + @staticmethod + def restore_running_analyses() -> None: + """Restore running analyses from status files on disk. + + Scans .spkmc_web/status/ for analysis status files, verifies the PID + is still alive, and adds them back to session state. Called once on + session init to survive page refresh. + """ + status_dir = Path(".spkmc_web") / "status" + if not status_dir.exists(): + return + + for status_file in sorted( + list(status_dir.glob("exp_analysis--*.json")) + + list(status_dir.glob("sc_analysis--*.json")) + ): + + try: + with open(status_file, "r") as f: + data = json.load(f) + except (json.JSONDecodeError, IOError): + continue + + # Only process analysis-type status files + if data.get("type") != "analysis": + continue + + file_status = data.get("status") + if file_status not in ("running", "starting"): + continue + + pid = data.get("pid") + if not pid: + continue + + # Check if process is still alive + if not _is_pid_alive(pid): + # Process died -- check if it completed by looking for result file + exp_name = data.get("experiment_name", "") + analysis_type = data.get("analysis_type", "") + sc_normalized = data.get("scenario_normalized", "") + + if analysis_type == "experiment": + analysis_id = f"exp_analysis--{exp_name}" + else: + analysis_id = f"sc_analysis--{exp_name}--{sc_normalized}" + + # Check if analysis file was actually written + from spkmc.web.config import WebConfig + + config = WebConfig() + exp_path = config.get_experiments_path() / exp_name + + if analysis_type == "experiment": + result_exists = (exp_path / "analysis.md").exists() + else: + result_exists = (exp_path / f"{sc_normalized}_analysis.md").exists() + + if result_exists: + SessionState.mark_analysis_completed(analysis_id) + else: + SessionState.mark_analysis_failed(analysis_id, "Process exited unexpectedly") + + # Clean up stale status file + run_id = data.get("run_id", "") + status_file.unlink(missing_ok=True) + script_file = status_dir / f"{run_id}_script.py" + if script_file.exists(): + script_file.unlink(missing_ok=True) + continue + + # Process is alive -- restore into session state + exp_name = data.get("experiment_name", "") + analysis_type = data.get("analysis_type", "") + sc_normalized = data.get("scenario_normalized", "") + + if analysis_type == "experiment": + analysis_id = f"exp_analysis--{exp_name}" + else: + analysis_id = f"sc_analysis--{exp_name}--{sc_normalized}" + + run_id = data.get("run_id", analysis_id) + info = { + "experiment_name": exp_name, + "analysis_type": analysis_type, + "scenario_normalized": sc_normalized, + "run_id": run_id, + "status": "running", + "pid": pid, + } + SessionState.add_running_analysis(analysis_id, info) + + @staticmethod + def restore_running_simulations() -> None: + """Restore running simulations from status files on disk. + + Scans .spkmc_web/status/ for status files with running processes, + verifies the PID is still alive, and adds them back to session state. + Called once on session init to survive page refresh. + """ + status_dir = Path(".spkmc_web") / "status" + if not status_dir.exists(): + return + + for status_file in status_dir.glob("sim--*.json"): + try: + with open(status_file, "r") as f: + data = json.load(f) + except (json.JSONDecodeError, IOError): + continue + + file_status = data.get("status") + if file_status not in ("running", "starting"): + continue + + pid = data.get("pid") + if not pid: + continue + + # Check if process is still alive + if not _is_pid_alive(pid): + # Process died -- check if it completed by looking for result + exp_name = data.get("experiment_name", "") + sc_normalized = data.get("scenario_normalized", "") + if exp_name and sc_normalized: + from spkmc.web.config import WebConfig + + config = WebConfig() + result_path = config.get_experiments_path() / exp_name / f"{sc_normalized}.json" + scenario_id = f"sim--{exp_name}--{sc_normalized}" + if result_path.exists(): + SessionState.mark_simulation_completed(scenario_id) + else: + SessionState.mark_simulation_failed( + scenario_id, "Process exited unexpectedly" + ) + # Clean up stale status file + run_id = data.get("run_id", "") + status_file.unlink(missing_ok=True) + script_file = status_dir / f"{run_id}_script.py" + if script_file.exists(): + script_file.unlink(missing_ok=True) + continue + + # Process is alive -- restore into session state + exp_name = data.get("experiment_name", "") + sc_normalized = data.get("scenario_normalized", "") + scenario_id = f"sim--{exp_name}--{sc_normalized}" + run_id = data.get("run_id", scenario_id) + + info = { + "experiment_name": exp_name, + "scenario_label": data.get("scenario_label", ""), + "scenario_normalized": sc_normalized, + "run_id": run_id, + "status": "running", + "pid": pid, + } + SessionState.add_running_simulation(scenario_id, info) + + # Restore progress + progress = data.get("progress", 0) + total = data.get("total", 0) + if total > 0: + SessionState.set_simulation_progress(scenario_id, progress, total) + + +def _is_pid_alive(pid: int) -> bool: + """Check if a process with the given PID is still running.""" + return psutil.pid_exists(pid) diff --git a/spkmc/web/styles.py b/spkmc/web/styles.py new file mode 100644 index 0000000..b541848 --- /dev/null +++ b/spkmc/web/styles.py @@ -0,0 +1,1440 @@ +""" +Design system for SPKMC web interface. + +Clean, professional aesthetic inspired by modern SaaS dashboards. +Soft teal accents, dark sidebar, generous whitespace, refined typography. +""" + +import base64 +import textwrap + +# SVG Icons - Simple, professional stroke icons (Feather/Lucide style) +ICONS = { + "flask": ( + '' + ), + "file": ( + '' + ), + "check": ( + '' + ), + "clock": ( + '' + '' + ), + "settings": ( + '' + '' + ), +} + +# Refined color palette - soft teals and clean neutrals +COLORS = { + # Primary palette - soft teal/sage + "teal_700": "#1E5F55", + "teal_600": "#2D7A6E", + "teal_500": "#4A9E8E", + "teal_400": "#5FB5A6", + "teal_300": "#8ECFC3", + "teal_100": "#E8F5F3", + "teal_50": "#F0F9F7", + # Neutrals - clean grays + "gray_950": "#0B0F19", + "gray_900": "#111827", + "gray_800": "#1F2937", + "gray_700": "#374151", + "gray_600": "#4B5563", + "gray_500": "#6B7280", + "gray_400": "#9CA3AF", + "gray_300": "#D1D5DB", + "gray_200": "#E5E7EB", + "gray_100": "#F3F4F6", + "gray_50": "#F9FAFB", + # Background + "bg_primary": "#F7F8FA", + "bg_secondary": "#FAFBFC", + # White + "white": "#FFFFFF", + # Status colors - muted and professional + "success": "#10B981", + "success_bg": "#D1FAE5", + "warning": "#F59E0B", + "warning_bg": "#FEF3C7", + "error": "#EF4444", + "error_bg": "#FEE2E2", + "info": "#3B82F6", + "info_bg": "#DBEAFE", +} + +FONTS = { + "body": "'Plus Jakarta Sans', 'DM Sans', -apple-system, BlinkMacSystemFont, sans-serif", + "mono": "'JetBrains Mono', 'Fira Code', 'Courier New', monospace", +} + + +def _dedent(html: str) -> str: + """Strip leading whitespace from HTML to prevent Markdown code-block rendering.""" + return textwrap.dedent(html).strip() + + +def _svg_data_uri(svg: str) -> str: + """Convert raw SVG string to a CSS-safe base64 data URI.""" + encoded = base64.b64encode(svg.encode("utf-8")).decode("ascii") + return f'url("data:image/svg+xml;base64,{encoded}")' + + +def get_global_styles() -> str: + """ + Returns comprehensive CSS for the entire application. + Clean, professional aesthetic with soft teal accents and dark sidebar. + """ + # SVG data URIs for sidebar nav icons + experiments_icon_svg = ( + '' + '' + ) + experiments_icon_active_svg = ( + '' + '' + ) + settings_icon_svg = ( + '' + '' + '' + ) + settings_icon_active_svg = settings_icon_svg.replace( + 'stroke="rgba(255,255,255,0.5)"', 'stroke="#5FB5A6"' + ) + + exp_icon = _svg_data_uri(experiments_icon_svg) + exp_icon_active = _svg_data_uri(experiments_icon_active_svg) + set_icon = _svg_data_uri(settings_icon_svg) + set_icon_active = _svg_data_uri(settings_icon_active_svg) + + return _dedent(f""" + +""") + + +def stat_card(label: str, value: str, icon_svg: str = "") -> str: + """Create a clean stat card with minimal teal accent.""" + icon_html = "" + if icon_svg: + icon_html = ( + f'
' + f"{icon_svg}
" + ) + + return _dedent(f""" +
+
+{icon_html} +
{label}
+
+
{value}
+
+""") + + +def experiment_card( + name: str, + description: str, + scenarios_complete: int, + scenarios_total: int, + last_run: str, + status: str = "pending", +) -> str: + """Create a clean experiment card with subtle hover-ready styling.""" + status_colors = { + "pending": COLORS["gray_600"], + "running": COLORS["info"], + "complete": COLORS["success"], + "failed": COLORS["error"], + } + + status_bg = { + "pending": COLORS["gray_100"], + "running": COLORS["info_bg"], + "complete": COLORS["success_bg"], + "failed": COLORS["error_bg"], + } + + progress = (scenarios_complete / scenarios_total * 100) if scenarios_total > 0 else 0 + + return _dedent(f""" +
+
+
{name}
+
{status}
+
+
{description}
+
+
+
+
+
{scenarios_complete}/{scenarios_total} scenarios
+
{last_run}
+
+
+""") + + +def page_header(title: str, subtitle: str = "") -> str: + """Create a clean page header with proper hierarchy.""" + sub = "" + if subtitle: + sub = ( + f'

{subtitle}

' + ) + + return _dedent(f""" +
+

{title}

+{sub} +
+""") + + +def empty_state(title: str, message: str) -> str: + """Create a clean empty state with centered content.""" + return _dedent(f""" +
+
+ +
+

{title}

+

{message}

+
+""") + + +def scenario_card( + label: str, + override_text: str, + status: str = "created", + progress: float = -1.0, +) -> str: + """Create a scenario card with label, override summary, and status badge. + + Args: + label: Scenario display name + override_text: Summary of overridden parameters + status: One of 'created', 'pending', 'running', 'completed', 'failed' + progress: Progress fraction 0.0-1.0 when running, -1.0 for no bar + """ + status_colors = { + "created": (COLORS["teal_500"], COLORS["teal_100"]), + "edited": (COLORS["teal_500"], COLORS["teal_100"]), + "pending": (COLORS["gray_600"], COLORS["gray_100"]), + "running": (COLORS["info"], COLORS["info_bg"]), + "completed": (COLORS["success"], COLORS["success_bg"]), + "failed": (COLORS["error"], COLORS["error_bg"]), + } + + s_color, s_bg = status_colors.get(status, (COLORS["gray_600"], COLORS["gray_100"])) + s_text = status.upper() + + # Animated pulsing dot for running status + badge_prefix = "" + if status == "running": + badge_prefix = ( + '' + ) + + override_html = "" + if override_text: + override_html = ( + f'
{override_text}
' + ) + else: + override_html = ( + f'
Using all global defaults
' + ) + + # Inline progress bar for running simulations + progress_html = "" + if status == "running" and progress >= 0: + pct = max(0.0, min(1.0, progress)) * 100 + pct_text = f"{pct:.0f}%" + progress_html = ( + f'
' + '
' + f'
' + "
" + f'{pct_text}' + "
" + ) + + badge_html = ( + f'
{badge_prefix}{s_text}
' + ) + + title_html = ( + f'
{label}
' + ) + + return _dedent(f""" +
+
+{title_html} +{progress_html} +{badge_html} +
+{override_html} +
+""") + + +def params_card(title: str, icon_svg: str, rows: list) -> str: + """Create a parameter display card with key-value rows. + + Args: + title: Card title (e.g. "Network", "Distribution") + icon_svg: SVG icon HTML string + rows: List of (key, value) or (key, value, is_override) tuples + """ + rows_html = "" + for row in rows: + key, val = row[0], row[1] + is_override = row[2] if len(row) > 2 else False + key_class = "params-card-key-override" if is_override else "params-card-key" + val_class = "params-card-val-override" if is_override else "params-card-val" + rows_html += ( + f'
' + f'{key}' + f'{val}' + f"
" + ) + + icon_html = "" + if icon_svg: + icon_html = ( + f'{icon_svg}' + ) + + return _dedent(f""" +
+
{icon_html}{title}
+{rows_html} +
+""") + + +def circular_progress_html(progress: float, label: str = "Running simulation...") -> str: + """Create a CSS-only circular progress ring. + + Args: + progress: Fraction between 0.0 and 1.0 + label: Text shown below the ring + """ + pct = max(0.0, min(1.0, progress)) + deg = int(pct * 360) + pct_text = f"{int(pct * 100)}%" + + return _dedent(f""" +
+
+
{pct_text}
+
+

{label}

+
+""") + + +def section_header(title: str, subtitle: str = "") -> str: + """Create a section header for content areas.""" + sub = "" + if subtitle: + sub = ( + f'

{subtitle}

' + ) + + return _dedent(f""" +
+

{title}

+{sub} +
+""") diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000..39e97b1 --- /dev/null +++ b/tests/e2e/__init__.py @@ -0,0 +1 @@ +"""End-to-end tests for the SPKMC web interface.""" diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 0000000..f0e2653 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,233 @@ +""" +E2E test configuration and fixtures. + +Manages the Streamlit server lifecycle, pre-seeds experiment fixture data, +and provides reusable page navigation helpers. +""" + +from __future__ import annotations + +import json +import os +import shutil +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, Generator + +import numpy as np +import pytest + +# ── Paths ────────────────────────────────────────────── + +E2E_DIR = Path(__file__).parent +FIXTURES_DIR = E2E_DIR / "fixtures" +PROJECT_ROOT = E2E_DIR.parent.parent + + +# ── SIR result data generator ────────────────────────── + + +def _make_sir_result(n: int = 50) -> Dict: + """Generate synthetic SIR simulation result data. + + Produces mathematically plausible S/I/R curves that look like a + real epidemic simulation result, suitable for chart rendering tests. + """ + t = np.linspace(0, 10, n) + s = np.exp(-t * 0.3) + i = 0.3 * np.exp(-((t - 3) ** 2) / 4) + r = 1.0 - s - i + r = np.clip(r, 0, 1) + err = np.ones(n) * 0.01 + return { + "time": t.tolist(), + "S_val": s.tolist(), + "I_val": i.tolist(), + "R_val": r.tolist(), + "S_err": err.tolist(), + "I_err": err.tolist(), + "R_err": err.tolist(), + } + + +# ── Fixture: temp environment ────────────────────────── + + +@pytest.fixture(scope="session") +def app_env(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Create a temporary environment with fixture experiments and pre-seeded results. + + Returns the temp directory root that contains: + - experiments//data.json (copied from fixtures) + - experiments//baseline.json (generated SIR result) + - experiments//high_lambda.json (generated SIR result) + - web_config.json (pointing experiments_directory to the temp experiments dir) + """ + tmp = tmp_path_factory.mktemp("spkmc_e2e") + + # Copy fixture experiments + src_experiments = FIXTURES_DIR / "experiments" + dst_experiments = tmp / "experiments" + shutil.copytree(src_experiments, dst_experiments) + + # Generate result JSONs for the smoke experiment + smoke_exp_dir = dst_experiments / "e2e_smoke_exp" + for filename in ("baseline.json", "high_lambda.json"): + result_path = smoke_exp_dir / filename + result_path.write_text(json.dumps(_make_sir_result(), indent=2)) + + # Write web config pointing to the temp experiments directory + config_path = tmp / "web_config.json" + config = { + "data_directory": str(tmp / "data"), + "experiments_directory": str(dst_experiments), + "default_nodes": 100, + "default_k_avg": 5.0, + "default_samples": 10, + "default_num_runs": 1, + "default_t_max": 5.0, + "default_steps": 50, + "default_initial_perc": 0.01, + "chart_height": 500, + "chart_template": "plotly_white", + "chart_color_s": "#4477AA", + "chart_color_i": "#EE6677", + "chart_color_r": "#228833", + "ai_model": "gpt-4o-mini", + } + config_path.write_text(json.dumps(config, indent=2)) + + # Create data directory + (tmp / "data").mkdir(exist_ok=True) + + return tmp + + +# ── Fixture: Streamlit server ────────────────────────── + + +def _wait_for_port(host: str, port: int, timeout: float = 60.0) -> None: + """Block until a TCP port accepts connections or timeout is reached.""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + with socket.create_connection((host, port), timeout=2): + return + except OSError: + time.sleep(0.5) + raise TimeoutError(f"Streamlit server did not start on {host}:{port} within {timeout}s") + + +@pytest.fixture(scope="session") +def app_url(app_env: Path) -> Generator[str, None, None]: + """Start the Streamlit server and yield its base URL. + + The server runs with: + - SPKMC_WEB_CONFIG_FILE pointing to the temp config + - Headless mode enabled, file watcher disabled + - Port 8502 to avoid conflicts with a dev server on 8501 + """ + port = 8502 + config_path = app_env / "web_config.json" + + env = {**os.environ} + env["SPKMC_WEB_CONFIG_FILE"] = str(config_path) + + app_path = str(PROJECT_ROOT / "spkmc" / "web" / "app.py") + + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "streamlit", + "run", + app_path, + "--server.port", + str(port), + "--server.headless", + "true", + "--server.fileWatcherType", + "none", + "--browser.gatherUsageStats", + "false", + ], + cwd=str(app_env), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + try: + _wait_for_port("localhost", port, timeout=60) + yield f"http://localhost:{port}" + finally: + proc.terminate() + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout=5) + + +# ── Fixture: Playwright base URL ─────────────────────── + + +@pytest.fixture(scope="session") +def base_url(app_url: str) -> str: + """Provide the base URL for pytest-playwright's page.goto().""" + return app_url + + +# ── Fixture: page with sidebar ready ────────────────── + + +@pytest.fixture +def app_page(page, app_url: str): + """Navigate to the app root and wait for the sidebar to be ready. + + Returns the Playwright page object after Streamlit has fully loaded. + """ + page.goto(app_url) + page.wait_for_selector("[data-testid='stSidebar']", timeout=15000) + # Wait a bit for Streamlit to settle its initial render + page.wait_for_timeout(1000) + return page + + +# ── Navigation helpers ───────────────────────────────── + + +def navigate_to_settings(page) -> None: + """Click the Preferences nav button in the sidebar.""" + page.locator(".st-key-nav_settings button").click() + page.wait_for_timeout(1500) + + +def navigate_to_dashboard(page) -> None: + """Click the Experiments nav button in the sidebar.""" + page.locator(".st-key-nav_experiments button").click() + page.wait_for_timeout(1500) + + +def open_experiment(page, idx: int = 0) -> None: + """Click the experiment card at the given index to open its detail view.""" + btn = page.locator(f".st-key-exp_btn_{idx} button") + btn.wait_for(state="visible", timeout=10000) + btn.click() + page.wait_for_timeout(1500) + + +def open_scenario_detail(page, experiment_dir: str, scenario_label: str) -> None: + """Click a scenario card to open its detail modal. + + Args: + page: Playwright page object + experiment_dir: The experiment directory name (e.g. "e2e_smoke_exp") + scenario_label: The normalized scenario label (e.g. "baseline", "high_lambda") + """ + scenario_id = f"sim--{experiment_dir}--{scenario_label}" + page.locator(f".st-key-sc_btn_{scenario_id} button").click() + page.wait_for_timeout(1500) diff --git a/tests/e2e/fixtures/experiments/e2e_smoke_exp/data.json b/tests/e2e/fixtures/experiments/e2e_smoke_exp/data.json new file mode 100644 index 0000000..d4fd1d7 --- /dev/null +++ b/tests/e2e/fixtures/experiments/e2e_smoke_exp/data.json @@ -0,0 +1,22 @@ +{ + "name": "E2E Smoke Test Experiment", + "description": "Pre-seeded experiment for E2E testing", + "parameters": { + "network": "er", + "nodes": 100, + "k_avg": 5.0, + "distribution": "gamma", + "lambda": 1.0, + "shape": 2.0, + "scale": 1.0, + "samples": 10, + "num_runs": 1, + "initial_perc": 0.01, + "t_max": 5.0, + "steps": 50 + }, + "scenarios": [ + {"label": "Baseline"}, + {"label": "High Lambda", "lambda": 2.0} + ] +} diff --git a/tests/e2e/test_dashboard.py b/tests/e2e/test_dashboard.py new file mode 100644 index 0000000..cb2d363 --- /dev/null +++ b/tests/e2e/test_dashboard.py @@ -0,0 +1,150 @@ +""" +E2E tests for the dashboard page: stats cards, experiment cards, +and the Create Experiment modal. +""" + +from __future__ import annotations + +import pytest +from playwright.sync_api import expect + +from tests.e2e.conftest import open_experiment + +pytestmark = pytest.mark.e2e + + +# ── Stats cards ──────────────────────────────────────── + + +def test_stats_cards_render(app_page): + """The dashboard shows 4 stat card columns.""" + # Stat cards are rendered as raw HTML via st.markdown(unsafe_allow_html=True). + # Use locator() with :text() pseudo-class and .first to handle potential + # multi-element matches from nested DOM structure. + expect(app_page.locator(":text('Total Experiments')").first).to_be_visible(timeout=10000) + expect(app_page.locator(":text('Total Scenarios')").first).to_be_visible(timeout=5000) + expect(app_page.locator(":text('Completed Scenarios')").first).to_be_visible(timeout=5000) + expect(app_page.locator(":text('Last Activity')").first).to_be_visible(timeout=5000) + + +# ── Create Experiment button & modal ─────────────────── + + +def test_create_button_visible(app_page): + """The Create Experiment button is present on the dashboard.""" + btn = app_page.locator(".st-key-btn_create_exp button") + expect(btn).to_be_visible(timeout=8000) + + +def test_create_modal_opens(app_page): + """Clicking Create Experiment opens a modal dialog.""" + app_page.locator(".st-key-btn_create_exp button").click() + dialog = app_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + +def test_create_modal_has_name_input(app_page): + """The create modal contains a text input for the experiment name.""" + app_page.locator(".st-key-btn_create_exp button").click() + dialog = app_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + # The dialog should contain a text input (name field) + name_input = dialog.locator("input[type='text']").first + expect(name_input).to_be_visible() + + +def test_create_modal_has_network_config(app_page): + """The create modal includes network type configuration.""" + app_page.locator(".st-key-btn_create_exp button").click() + dialog = app_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + # Network type is configured via a selectbox with key="create_network_type" + expect(app_page.locator(".st-key-create_network_type")).to_be_visible() + + +def test_create_modal_cancel_closes(app_page): + """Closing/dismissing the create modal makes it disappear.""" + app_page.locator(".st-key-btn_create_exp button").click() + dialog = app_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # Press Escape to close the modal + app_page.keyboard.press("Escape") + expect(dialog).not_to_be_visible(timeout=5000) + + +# ── Experiment cards ─────────────────────────────────── + + +def test_experiment_card_renders(app_page): + """The pre-seeded experiment card is visible on the dashboard.""" + card = app_page.locator(".st-key-exp_card_0") + expect(card).to_be_visible(timeout=8000) + + +def test_experiment_card_shows_name(app_page): + """The experiment card displays the experiment name.""" + # Use text-based lookup — card index can shift when other tests create experiments + expect(app_page.get_by_text("E2E Smoke Test Experiment").first).to_be_visible(timeout=8000) + + +def test_experiment_card_shows_scenario_count(app_page): + """The experiment card shows the correct number of scenarios.""" + # Find the card containing the pre-seeded experiment name, then check scenario count + cards = app_page.locator("[class*='st-key-exp_card_']") + card_count = cards.count() + found = False + for i in range(card_count): + card = cards.nth(i) + if "E2E Smoke Test Experiment" in (card.text_content() or ""): + expect(card).to_contain_text("2") + found = True + break + assert found, "Pre-seeded experiment card not found" + + +def test_experiment_card_clickable(app_page): + """Clicking an experiment card navigates to the detail view.""" + # Find the card container that has the pre-seeded experiment, then click its button. + # Card indices can shift when other tests create experiments, so we search by text. + cards = app_page.locator("[class*='st-key-exp_card_']") + card_count = cards.count() + clicked = False + for i in range(card_count): + card = cards.nth(i) + if "E2E Smoke Test Experiment" in (card.text_content() or ""): + card.locator("button").click() + clicked = True + break + assert clicked, "Pre-seeded experiment card not found" + # Wait for the detail-specific back button to confirm navigation succeeded + back_btn = app_page.locator(".st-key-detail_back_btn button") + expect(back_btn).to_be_visible(timeout=15000) + # The detail view should show the experiment name + expect(app_page.locator(":text('E2E Smoke Test Experiment')").first).to_be_visible() + + +# ── Create experiment flow ───────────────────────────── + + +def test_create_experiment_flow(app_page): + """Creating a new experiment adds a card to the dashboard.""" + # Open the create modal + app_page.locator(".st-key-btn_create_exp button").click() + dialog = app_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # Fill in the experiment name + name_input = dialog.locator("input[type='text']").first + name_input.fill("E2E Created Experiment") + + # Submit — look for a Create/Save button inside the dialog + create_btn = dialog.get_by_role("button", name="Create") + expect(create_btn).to_be_visible(timeout=5000) + create_btn.click() + + # Wait for the dialog to close and the page to rerender + app_page.wait_for_timeout(2000) + + # The new experiment should appear somewhere on the page + expect(app_page.get_by_text("E2E Created Experiment")).to_be_visible(timeout=8000) diff --git a/tests/e2e/test_experiment_detail.py b/tests/e2e/test_experiment_detail.py new file mode 100644 index 0000000..da41067 --- /dev/null +++ b/tests/e2e/test_experiment_detail.py @@ -0,0 +1,287 @@ +""" +E2E tests for the experiment detail page: parameters, scenario cards, +scenario detail modal, chart controls, comparison, and export. +""" + +from __future__ import annotations + +import pytest +from playwright.sync_api import expect + +from tests.e2e.conftest import open_experiment, open_scenario_detail + +pytestmark = pytest.mark.e2e + +# The pre-seeded experiment directory name and scenario labels +EXP_DIR = "e2e_smoke_exp" +SC_BASELINE = "baseline" +SC_HIGH_LAMBDA = "high_lambda" + +# Scenario IDs match the pattern used by experiment_detail.py for widget keys +SC_ID_BASELINE = f"sim--{EXP_DIR}--{SC_BASELINE}" +SC_ID_HIGH_LAMBDA = f"sim--{EXP_DIR}--{SC_HIGH_LAMBDA}" + + +@pytest.fixture +def detail_page(app_page): + """Navigate to the experiment detail page for the pre-seeded experiment.""" + # Find the pre-seeded experiment card by name (not by index, which may shift). + # Use .first to avoid strict mode violations when get_by_text matches multiple + # DOM levels (the inner text div and its ancestor containers). + app_page.locator(":text('E2E Smoke Test Experiment')").first.wait_for( + state="visible", timeout=10000 + ) + # Iterate through experiment card containers to find the right one + for idx in range(10): + container = app_page.locator(f".st-key-exp_card_{idx}") + if container.count() == 0: + break + if container.locator(":text('E2E Smoke Test Experiment')").count() > 0: + app_page.locator(f".st-key-exp_btn_{idx} button").click() + break + # Wait for the detail page to render — use the back button as the definitive signal + back_btn = app_page.locator(".st-key-detail_back_btn button") + expect(back_btn).to_be_visible(timeout=15000) + return app_page + + +# ── Detail page basics ───────────────────────────────── + + +def test_detail_page_renders(detail_page): + """The experiment name is displayed as a heading.""" + expect(detail_page.locator(":text('E2E Smoke Test Experiment')").first).to_be_visible() + + +def test_back_button_returns_to_dashboard(detail_page): + """Clicking the back button returns to the dashboard.""" + detail_page.locator(".st-key-detail_back_btn button").click() + detail_page.wait_for_timeout(1500) + # Dashboard should be visible again (Create Experiment button as indicator) + expect(detail_page.locator(".st-key-btn_create_exp button")).to_be_visible(timeout=8000) + + +# ── Global parameters ────────────────────────────────── + + +def test_global_params_three_cards(detail_page): + """The global parameters section shows 3 param cards (Network, Distribution, Simulation).""" + params_section = detail_page.locator(".st-key-params_section") + expect(params_section).to_be_visible(timeout=8000) + + # The three cards should contain these titles + expect(params_section.get_by_text("Network")).to_be_visible() + expect(params_section.get_by_text("Distribution")).to_be_visible() + expect(params_section.get_by_text("Simulation")).to_be_visible() + + +# ── Scenario cards ───────────────────────────────────── + + +def test_scenario_cards_render(detail_page): + """Two scenario cards are visible for the pre-seeded experiment.""" + baseline_card = detail_page.locator(f".st-key-sc_card_sim--{EXP_DIR}--{SC_BASELINE}") + high_lambda_card = detail_page.locator(f".st-key-sc_card_sim--{EXP_DIR}--{SC_HIGH_LAMBDA}") + expect(baseline_card).to_be_visible(timeout=8000) + expect(high_lambda_card).to_be_visible() + + +def test_scenario_card_labels(detail_page): + """Scenario cards display the correct labels.""" + # Scope text assertions to specific card containers to avoid strict mode violations + baseline_card = detail_page.locator(f".st-key-sc_card_{SC_ID_BASELINE}") + high_lambda_card = detail_page.locator(f".st-key-sc_card_{SC_ID_HIGH_LAMBDA}") + expect(baseline_card).to_contain_text("Baseline") + expect(high_lambda_card).to_contain_text("High Lambda") + + +def test_completed_status_badge(detail_page): + """Pre-seeded scenarios show 'Completed' status badges.""" + # Both scenarios have result files, so they should show completed status + baseline_card = detail_page.locator(f".st-key-sc_card_sim--{EXP_DIR}--{SC_BASELINE}") + expect(baseline_card).to_contain_text("Completed", ignore_case=True) + + +# ── Add Scenario ─────────────────────────────────────── + + +def test_add_scenario_button_visible(detail_page): + """The Add Scenario button is present.""" + btn = detail_page.locator(".st-key-btn_add_scenario_bar button") + expect(btn).to_be_visible(timeout=8000) + + +def test_add_scenario_modal_opens(detail_page): + """Clicking Add Scenario opens a dialog.""" + detail_page.locator(".st-key-btn_add_scenario_bar button").click() + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + +# ── Scenario detail modal ───────────────────────────── + + +def test_scenario_card_opens_detail_modal(detail_page): + """Clicking a scenario card opens the detail modal.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + +def test_modal_shows_scenario_name(detail_page): + """The detail modal header shows the scenario label.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + # st.title renders inside a stHeading container (Streamlit 1.54+) + title = dialog.locator("[data-testid='stHeading']").first + expect(title).to_contain_text("Baseline", timeout=8000) + + +def test_modal_sir_chart_renders(detail_page): + """The detail modal contains a Plotly chart for a completed scenario.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + # Plotly renders with a class containing 'js-plotly-plot' + chart = dialog.locator("[class*='js-plotly-plot']") + expect(chart).to_be_visible(timeout=10000) + + +def test_modal_state_checkboxes_present(detail_page): + """S, I, R checkboxes are present in the detail modal.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + expect(detail_page.locator(f".st-key-modal_show_s_{SC_ID_BASELINE}")).to_be_visible( + timeout=8000 + ) + expect(detail_page.locator(f".st-key-modal_show_i_{SC_ID_BASELINE}")).to_be_visible() + expect(detail_page.locator(f".st-key-modal_show_r_{SC_ID_BASELINE}")).to_be_visible() + + +def test_modal_chart_type_selectbox(detail_page): + """The chart type dropdown is visible in the modal.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + chart_mode = detail_page.locator(f".st-key-modal_chart_mode_{SC_ID_BASELINE}") + expect(chart_mode).to_be_visible(timeout=8000) + + +def test_modal_comparison_multiselect(detail_page): + """The comparison multiselect is visible when other scenarios have results.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # The comparison section should be present since High Lambda also has results + compare_key = f"modal_compare_{EXP_DIR}_{SC_BASELINE}" + compare_widget = detail_page.locator(f".st-key-{compare_key}") + expect(compare_widget).to_be_visible(timeout=8000) + + +def test_modal_select_comparison_scenario(detail_page): + """Selecting a comparison scenario triggers a comparison chart.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # Click on the comparison multiselect to open its dropdown + compare_key = f"modal_compare_{EXP_DIR}_{SC_BASELINE}" + compare_widget = detail_page.locator(f".st-key-{compare_key}") + expect(compare_widget).to_be_visible(timeout=8000) + compare_widget.click() + detail_page.wait_for_timeout(500) + + # Select "High Lambda" from the dropdown options (target the virtual dropdown) + dropdown = detail_page.locator("[data-testid='stSelectboxVirtualDropdown']") + dropdown.get_by_text("High Lambda").click() + detail_page.wait_for_timeout(2000) + + # A comparison chart should now be visible (second Plotly chart) + charts = dialog.locator("[class*='js-plotly-plot']") + expect(charts.first).to_be_visible(timeout=8000) + + +def test_modal_export_popover(detail_page): + """The export button reveals format radio options.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # Click the Export popover trigger + export_container = detail_page.locator(f".st-key-modal_action_export_{SC_ID_BASELINE}") + expect(export_container).to_be_visible(timeout=8000) + export_container.locator("button").click() + detail_page.wait_for_timeout(1000) + + # The export format radio should appear + export_fmt = detail_page.locator(f".st-key-export_fmt_{SC_ID_BASELINE}") + expect(export_fmt).to_be_visible(timeout=5000) + + +def test_uncheck_infected_updates_chart(detail_page): + """Unchecking the Infected checkbox updates the chart (fewer traces).""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # Wait for chart to render + chart = dialog.locator("[class*='js-plotly-plot']") + expect(chart).to_be_visible(timeout=10000) + + # Uncheck the Infected checkbox + infected_cb = detail_page.locator(f".st-key-modal_show_i_{SC_ID_BASELINE}") + infected_cb.click() + detail_page.wait_for_timeout(2000) + + # Chart should still be visible (it re-renders with fewer traces) + expect(chart).to_be_visible() + + +def test_chart_mode_area(detail_page): + """Switching chart mode to Area updates the chart.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + chart = dialog.locator("[class*='js-plotly-plot']") + expect(chart).to_be_visible(timeout=10000) + + # Click the chart mode selectbox to change it + chart_mode = detail_page.locator(f".st-key-modal_chart_mode_{SC_ID_BASELINE}") + chart_mode.click() + detail_page.wait_for_timeout(500) + + # Select "Area" from the dropdown options (target the virtual dropdown) + dropdown = detail_page.locator("[data-testid='stSelectboxVirtualDropdown']") + dropdown.get_by_text("Area", exact=True).click() + detail_page.wait_for_timeout(2000) + + # Chart should still render + expect(chart).to_be_visible() + + +def test_ai_button_disabled_without_key(detail_page): + """The AI analyze button is disabled when no API key is configured.""" + # The experiment-level AI button (use .first to avoid tooltip wrapper duplicates) + ai_btn = detail_page.locator(".st-key-btn_ai button").first + expect(ai_btn).to_be_visible(timeout=8000) + expect(ai_btn).to_be_disabled() + + +def test_modal_close(detail_page): + """Dismissing the modal returns to the detail page.""" + open_scenario_detail(detail_page, EXP_DIR, SC_BASELINE) + dialog = detail_page.locator("[data-testid='stDialog']") + expect(dialog).to_be_visible(timeout=8000) + + # Press Escape to close + detail_page.keyboard.press("Escape") + expect(dialog).not_to_be_visible(timeout=5000) + + # The detail page should still be showing + expect(detail_page.locator(":text('E2E Smoke Test Experiment')").first).to_be_visible() diff --git a/tests/e2e/test_navigation.py b/tests/e2e/test_navigation.py new file mode 100644 index 0000000..2c3d0d9 --- /dev/null +++ b/tests/e2e/test_navigation.py @@ -0,0 +1,67 @@ +""" +E2E tests for sidebar navigation, page routing, and app chrome. +""" + +from __future__ import annotations + +import pytest +from playwright.sync_api import expect + +from tests.e2e.conftest import navigate_to_dashboard, navigate_to_settings + +pytestmark = pytest.mark.e2e + + +def test_app_loads(app_page): + """The app loads without crash and the sidebar is visible.""" + sidebar = app_page.locator("[data-testid='stSidebar']") + expect(sidebar).to_be_visible() + + +def test_page_title(app_page): + """The browser tab title contains SPKMC.""" + assert "SPKMC" in app_page.title() + + +def test_dashboard_is_default_page(app_page): + """The dashboard is shown by default on first load.""" + # Dashboard renders stat cards — use a specific stat card label as indicator. + # Use .first to handle potential multi-element matches from nested HTML containers. + expect(app_page.locator(":text('Total Experiments')").first).to_be_visible(timeout=10000) + + +def test_navigate_to_settings(app_page): + """Clicking the Preferences button navigates to the settings page.""" + navigate_to_settings(app_page) + # Settings page has a unique subtitle + expect(app_page.get_by_text("Configure web interface and simulation defaults")).to_be_visible( + timeout=8000 + ) + + +def test_navigate_back_to_dashboard(app_page): + """Clicking the Experiments button returns to the dashboard.""" + navigate_to_settings(app_page) + navigate_to_dashboard(app_page) + expect(app_page.locator(":text('Total Experiments')").first).to_be_visible(timeout=10000) + + +def test_sidebar_brand_visible(app_page): + """The SPKMC brand text is visible in the sidebar.""" + sidebar = app_page.locator("[data-testid='stSidebar']") + expect(sidebar.get_by_text("SPKMC")).to_be_visible() + + +def test_sidebar_version_visible(app_page): + """The version footer is displayed in the sidebar.""" + sidebar = app_page.locator("[data-testid='stSidebar']") + version = sidebar.locator(".sidebar-version-footer") + expect(version).to_be_visible() + # Version text should start with 'v' + expect(version).to_contain_text("v") + + +def test_query_params_reflect_page(app_page): + """URL query params reflect the current page after navigation.""" + navigate_to_settings(app_page) + assert "page=settings" in app_page.url diff --git a/tests/e2e/test_settings.py b/tests/e2e/test_settings.py new file mode 100644 index 0000000..9221f66 --- /dev/null +++ b/tests/e2e/test_settings.py @@ -0,0 +1,116 @@ +""" +E2E tests for the settings (Preferences) page: section cards, +inputs, defaults, and the reset button. +""" + +from __future__ import annotations + +import pytest +from playwright.sync_api import expect + +from tests.e2e.conftest import navigate_to_settings + +pytestmark = pytest.mark.e2e + + +@pytest.fixture +def settings_page(app_page): + """Navigate to the settings page.""" + navigate_to_settings(app_page) + # Use the unique subtitle as indicator that settings loaded + expect(app_page.get_by_text("Configure web interface and simulation defaults")).to_be_visible( + timeout=8000 + ) + return app_page + + +# ── Page structure ───────────────────────────────────── + + +def test_settings_page_renders(settings_page): + """All section cards are visible on the settings page.""" + # 4 main section cards + danger zone = 5 containers + expect(settings_page.locator(".st-key-pref_card_ai")).to_be_visible(timeout=8000) + expect(settings_page.locator(".st-key-pref_card_viz")).to_be_visible() + expect(settings_page.locator(".st-key-pref_card_sim")).to_be_visible() + expect(settings_page.locator(".st-key-pref_card_storage")).to_be_visible() + + +# ── AI & Intelligence section ────────────────────────── + + +def test_ai_section_shows_not_configured(settings_page): + """The AI section shows 'Not configured' when no API key is set.""" + ai_card = settings_page.locator(".st-key-pref_card_ai") + expect(ai_card).to_contain_text("Not configured", ignore_case=True) + + +def test_model_selectbox_options(settings_page): + """The AI model dropdown is present with model options.""" + ai_card = settings_page.locator(".st-key-pref_card_ai") + # The selectbox should show the default model + expect(ai_card).to_contain_text("gpt-4o-mini") + + +# ── Visualization section ────────────────────────────── + + +def test_chart_height_input_visible(settings_page): + """The chart height number input is present.""" + viz_card = settings_page.locator(".st-key-pref_card_viz") + expect(viz_card).to_contain_text("Default Height") + + +def test_template_selectbox_visible(settings_page): + """The chart template dropdown is present.""" + viz_card = settings_page.locator(".st-key-pref_card_viz") + expect(viz_card).to_contain_text("plotly_white") + + +def test_color_pickers_visible(settings_page): + """Three color picker inputs are visible for S, I, R.""" + viz_card = settings_page.locator(".st-key-pref_card_viz") + expect(viz_card.get_by_text("Susceptible")).to_be_visible() + expect(viz_card.get_by_text("Infected")).to_be_visible() + expect(viz_card.get_by_text("Recovered")).to_be_visible() + + +# ── Simulation Defaults section ──────────────────────── + + +def test_simulation_defaults_section(settings_page): + """The simulation defaults section shows Network, Distribution, Simulation subsections.""" + sim_card = settings_page.locator(".st-key-pref_card_sim") + expect(sim_card.get_by_text("Network")).to_be_visible() + expect(sim_card.get_by_text("Distribution")).to_be_visible() + expect(sim_card.get_by_text("Simulation")).to_be_visible() + + +def test_default_nodes_reflects_config(settings_page): + """The default nodes input reflects the value from the test config (100).""" + sim_card = settings_page.locator(".st-key-pref_card_sim") + # Find the Nodes input and check its value + nodes_input = sim_card.locator("input[type='number']").first + expect(nodes_input).to_have_value("100") + + +# ── Storage & Export section ─────────────────────────── + + +def test_storage_inputs_visible(settings_page): + """Data and Experiments directory inputs are present.""" + storage_card = settings_page.locator(".st-key-pref_card_storage") + expect(storage_card.get_by_text("Data Directory")).to_be_visible() + expect(storage_card.get_by_text("Experiments Directory")).to_be_visible() + + +# ── Danger Zone ──────────────────────────────────────── + + +def test_danger_zone_reset_button(settings_page): + """The Reset all button is present in the danger zone.""" + reset_container = settings_page.locator(".st-key-pref_reset") + expect(reset_container).to_be_visible(timeout=8000) + reset_btn = reset_container.locator("button") + expect(reset_btn).to_be_visible() + expect(reset_btn).to_contain_text("Reset all") diff --git a/tests/test_plot_improvements.py b/tests/test_plot_improvements.py index aa43304..9764e7c 100644 --- a/tests/test_plot_improvements.py +++ b/tests/test_plot_improvements.py @@ -282,8 +282,6 @@ def mock_export(*args, **kwargs): def test_visualizer_states_filter(): """Directly test Visualizer functions with state filters.""" - import matplotlib.pyplot as plt - # Test data s_vals = np.array([0.99, 0.95, 0.90, 0.85, 0.80]) i_vals = np.array([0.01, 0.04, 0.05, 0.05, 0.04]) @@ -300,25 +298,22 @@ def test_visualizer_states_filter(): r_vals, time, "Test", - save_path="test_all.png", + save_path="test_all.html", states_to_plot={"S", "I", "R"}, ) - plt.close() # Only infected Visualizer.plot_result( - s_vals, i_vals, r_vals, time, "Test", save_path="test_i.png", states_to_plot={"I"} + s_vals, i_vals, r_vals, time, "Test", save_path="test_i.html", states_to_plot={"I"} ) - plt.close() # Infected and recovered Visualizer.plot_result( - s_vals, i_vals, r_vals, time, "Test", save_path="test_ir.png", states_to_plot={"I", "R"} + s_vals, i_vals, r_vals, time, "Test", save_path="test_ir.html", states_to_plot={"I", "R"} ) - plt.close() # Remove test files if created - for file in ["test_all.png", "test_i.png", "test_ir.png"]: + for file in ["test_all.html", "test_i.html", "test_ir.html"]: if os.path.exists(file): os.remove(file) diff --git a/tests/test_web/__init__.py b/tests/test_web/__init__.py new file mode 100644 index 0000000..a30ae59 --- /dev/null +++ b/tests/test_web/__init__.py @@ -0,0 +1 @@ +"""Tests for SPKMC web interface.""" diff --git a/tests/test_web/test_analysis_runner.py b/tests/test_web/test_analysis_runner.py new file mode 100644 index 0000000..aa47e5a --- /dev/null +++ b/tests/test_web/test_analysis_runner.py @@ -0,0 +1,336 @@ +""" +Tests for AnalysisRunner. + +Tests cover file-based status management, completion detection (experiment +vs. scenario analysis types), script generation with correct content, and +cleanup. API keys are passed via subprocess env, not embedded in scripts. +Subprocess execution is not tested. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +# ── Fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def runner(tmp_path): + """AnalysisRunner with status_dir isolated to tmp_path.""" + from spkmc.web.analysis_runner import AnalysisRunner + + r = AnalysisRunner.__new__(AnalysisRunner) + r.status_dir = tmp_path / "status" + r.status_dir.mkdir() + r._processes = {} + return r + + +# ── get_status ──────────────────────────────────────────────────────────────── + + +class TestGetStatus: + def test_returns_none_for_missing_run_id(self, runner): + assert runner.get_status("nonexistent") is None + + def test_returns_parsed_dict_for_valid_file(self, runner): + data = { + "run_id": "exp_analysis--exp1--1", + "type": "analysis", + "status": "running", + } + (runner.status_dir / "exp_analysis--exp1--1.json").write_text(json.dumps(data)) + assert runner.get_status("exp_analysis--exp1--1") == data + + def test_returns_none_for_corrupted_json(self, runner): + (runner.status_dir / "bad.json").write_text("{not: valid}") + assert runner.get_status("bad") is None + + +# ── cleanup_status ──────────────────────────────────────────────────────────── + + +class TestCleanupStatus: + def test_removes_both_status_and_script_files(self, runner): + run_id = "exp_analysis--exp1--99" + (runner.status_dir / f"{run_id}.json").write_text("{}") + (runner.status_dir / f"{run_id}_script.py").write_text("pass") + + runner.cleanup_status(run_id) + + assert not (runner.status_dir / f"{run_id}.json").exists() + assert not (runner.status_dir / f"{run_id}_script.py").exists() + + def test_cleanup_is_idempotent_when_files_absent(self, runner): + runner.cleanup_status("never_existed") + runner.cleanup_status("never_existed") + + +# ── check_completion ────────────────────────────────────────────────────────── + + +class TestCheckCompletion: + def test_experiment_analysis_checks_analysis_md(self, runner, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + (exp_dir / "analysis.md").touch() + + assert runner.check_completion("exp1", "experiment") is True + + def test_experiment_analysis_returns_false_when_file_missing( + self, runner, tmp_path, monkeypatch + ): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + + assert runner.check_completion("exp1", "experiment") is False + + def test_scenario_analysis_checks_scenario_analysis_md(self, runner, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + (exp_dir / "baseline_analysis.md").touch() + + assert runner.check_completion("exp1", "scenario", "baseline") is True + + def test_scenario_analysis_returns_false_when_file_missing(self, runner, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + + assert runner.check_completion("exp1", "scenario", "baseline") is False + + def test_experiment_and_scenario_paths_do_not_collide(self, runner, tmp_path, monkeypatch): + """analysis.md and baseline_analysis.md are distinct files.""" + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + (exp_dir / "analysis.md").touch() + # Only experiment analysis exists; scenario analysis must report False + + assert runner.check_completion("exp1", "scenario", "baseline") is False + + +# ── _build_experiment_script ────────────────────────────────────────────────── + + +class TestBuildExperimentScript: + @pytest.fixture() + def script(self, runner, tmp_path): + exp_path = tmp_path / "experiments" / "my_exp" + exp_path.mkdir(parents=True) + return runner._build_experiment_script( + experiment_path=exp_path, + experiment_name="My Experiment", + experiment_description="Does X spread faster on SF networks?", + model="gpt-4o-mini", + run_id="test_exp_run_id", + ) + + def test_script_does_not_embed_api_key(self, script): + """API key is passed via subprocess env, never written to script file.""" + assert "OPENAI_API_KEY" in script # env var reference exists + assert "sk-" not in script # but no actual key value + + def test_script_references_experiment_path(self, script, tmp_path): + assert "my_exp" in script + + def test_script_references_model_name(self, script): + assert "gpt-4o-mini" in script + + def test_script_imports_ai_analyzer(self, script): + assert "AIAnalyzer" in script + + def test_script_calls_analyze_experiment(self, script): + assert "analyze_experiment" in script + + def test_script_is_valid_python_syntax(self, script): + import ast + + ast.parse(script) + + def test_multiline_description_is_safely_embedded(self, runner, tmp_path): + """P1 bugfix: multiline descriptions must not break the generated script.""" + exp_path = tmp_path / "experiments" / "exp1" + exp_path.mkdir(parents=True) + multiline_desc = "Line one\nLine two\nLine three with 'quotes'" + script = runner._build_experiment_script( + experiment_path=exp_path, + experiment_name="Test\nNewline", + experiment_description=multiline_desc, + model="gpt-4o-mini", + run_id="test_multiline_run_id", + ) + import ast + + ast.parse(script) + # repr() must be used for all user-provided strings + assert repr(multiline_desc) in script + assert repr("Test\nNewline") in script + + def test_script_writes_running_status(self, script): + assert "_write_status" in script + assert '"running"' in script or "'running'" in script + + def test_script_uses_exact_status_file_path(self, runner, tmp_path): + """P1 bugfix: script must use exact status file path, not prefix-glob discovery.""" + exp_path = tmp_path / "experiments" / "my_exp" + exp_path.mkdir(parents=True) + run_id = "exp_analysis--my_exp--1700000000" + script = runner._build_experiment_script( + experiment_path=exp_path, + experiment_name="My Experiment", + experiment_description="Test", + model="gpt-4o-mini", + run_id=run_id, + ) + assert f"{run_id}.json" in script + assert "glob(" not in script + + +# ── _build_scenario_script ──────────────────────────────────────────────────── + + +class TestBuildScenarioScript: + @pytest.fixture() + def script(self, runner, tmp_path): + exp_path = tmp_path / "experiments" / "my_exp" + exp_path.mkdir(parents=True) + return runner._build_scenario_script( + experiment_path=exp_path, + scenario_label="High Transmission", + scenario_normalized="high_transmission", + model="gpt-4o", + run_id="test_scenario_run_id", + ) + + def test_script_does_not_embed_api_key(self, script): + """API key is passed via subprocess env, never written to script file.""" + assert "OPENAI_API_KEY" in script # env var reference exists + assert "sk-" not in script # but no actual key value + + def test_script_references_scenario_normalized_label(self, script): + assert "high_transmission" in script + + def test_script_calls_analyze_scenario(self, script): + assert "analyze_scenario" in script + + def test_script_is_valid_python_syntax(self, script): + import ast + + ast.parse(script) + + def test_script_loads_result_file(self, script): + assert "result_file" in script + + def test_script_uses_exact_status_file_path(self, runner, tmp_path): + """P1 bugfix: script must use exact status file path, not prefix-glob discovery.""" + exp_path = tmp_path / "experiments" / "my_exp" + exp_path.mkdir(parents=True) + run_id = "sc_analysis--my_exp--high_transmission--1700000000" + script = runner._build_scenario_script( + experiment_path=exp_path, + scenario_label="High Transmission", + scenario_normalized="high_transmission", + model="gpt-4o", + run_id=run_id, + ) + assert f"{run_id}.json" in script + assert "glob(" not in script + + +# ── poll_running_analyses: dead process with output file ───────────────────── + + +class _DictAttr(dict): + """Dict that also supports attribute access (like Streamlit session_state).""" + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + +class TestPollDeadAnalysisProcessWithOutputFile: + """Regression: a dead analysis PID must be marked completed when .md exists.""" + + @staticmethod + def _make_mock_st(runner, run_id): + """Build a mock ``st`` module whose ``session_state`` behaves like a dict.""" + from unittest.mock import MagicMock + + mock_st = MagicMock() + mock_st.session_state = _DictAttr( + analysis_runner=runner, + running_analyses={ + "analysis--exp1": { + "experiment_name": "exp1", + "analysis_type": "experiment", + "scenario_normalized": "", + "run_id": run_id, + } + }, + ) + return mock_st + + def test_dead_process_with_output_file_marks_completed(self, runner, tmp_path, monkeypatch): + from unittest.mock import MagicMock + + import spkmc.web.analysis_runner as ar_mod + + run_id = "exp_analysis--exp1--1700000000" + status_data = { + "run_id": run_id, + "status": "running", + "pid": 99999999, + } + (runner.status_dir / f"{run_id}.json").write_text(json.dumps(status_data)) + + # Simulate output file existing (check_completion returns True) + runner.check_completion = lambda *a, **kw: True + + mock_st = self._make_mock_st(runner, run_id) + monkeypatch.setattr(ar_mod, "st", mock_st) + + # SessionState is imported locally inside poll_running_analyses + mock_session = MagicMock() + monkeypatch.setattr("spkmc.web.state.SessionState", mock_session) + + ar_mod.poll_running_analyses() + + mock_session.mark_analysis_completed.assert_called_once_with("analysis--exp1") + mock_session.mark_analysis_failed.assert_not_called() + + def test_dead_process_without_output_file_marks_failed(self, runner, tmp_path, monkeypatch): + from unittest.mock import MagicMock + + import spkmc.web.analysis_runner as ar_mod + + run_id = "exp_analysis--exp1--1700000000" + status_data = { + "run_id": run_id, + "status": "running", + "pid": 99999999, + } + (runner.status_dir / f"{run_id}.json").write_text(json.dumps(status_data)) + + # Simulate output file NOT existing (check_completion returns False) + runner.check_completion = lambda *a, **kw: False + + mock_st = self._make_mock_st(runner, run_id) + monkeypatch.setattr(ar_mod, "st", mock_st) + + mock_session = MagicMock() + monkeypatch.setattr("spkmc.web.state.SessionState", mock_session) + + ar_mod.poll_running_analyses() + + mock_session.mark_analysis_failed.assert_called_once() + mock_session.mark_analysis_completed.assert_not_called() diff --git a/tests/test_web/test_config.py b/tests/test_web/test_config.py new file mode 100644 index 0000000..38f716a --- /dev/null +++ b/tests/test_web/test_config.py @@ -0,0 +1,345 @@ +"""Tests for web configuration management.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _make_config(tmp_path): + """Return a WebConfig instance whose CONFIG_FILE lives in tmp_path.""" + from spkmc.web.config import WebConfig + + cfg = WebConfig() + cfg.CONFIG_FILE = tmp_path / "web_config.json" + cfg.config = WebConfig.DEFAULTS.copy() + return cfg + + +# ── Defaults ────────────────────────────────────────────────────────────────── + + +def test_web_config_defaults(): + """WebConfig initializes with expected default values.""" + from spkmc.web.config import WebConfig + + config = WebConfig() + + assert config.get("data_directory") == "data" + assert config.get("experiments_directory") == "experiments" + assert config.get("default_nodes") == 1000 + + +def test_all_defaults_are_present(): + """Every key in DEFAULTS must be accessible via get().""" + from spkmc.web.config import WebConfig + + config = WebConfig() + for key in WebConfig.DEFAULTS: + assert config.get(key) is not None or WebConfig.DEFAULTS[key] is None + + +# ── Save / load round-trip ──────────────────────────────────────────────────── + + +def test_save_and_load_round_trip(tmp_path): + """A value written via set() must survive a reload from disk.""" + cfg = _make_config(tmp_path) + cfg.set("my_key", "my_value") + + cfg2 = _make_config(tmp_path) + cfg2.load() + + assert cfg2.get("my_key") == "my_value" + + +def test_update_persists_multiple_keys(tmp_path): + """update() must persist all supplied keys to disk.""" + cfg = _make_config(tmp_path) + cfg.update({"key1": "v1", "key2": 42}) + + cfg2 = _make_config(tmp_path) + cfg2.load() + + assert cfg2.get("key1") == "v1" + assert cfg2.get("key2") == 42 + + +# ── Type coercion ───────────────────────────────────────────────────────────── + + +def test_json_integer_coerced_to_float_when_default_is_float(tmp_path): + """ + JSON may deserialize 10.0 as int 10. + WebConfig must coerce it back to float to avoid Streamlit type errors. + """ + cfg = _make_config(tmp_path) + # Write a file where a float default is stored as JSON integer + data = {**cfg.config, "default_k_avg": 10} # int instead of float + tmp_path.joinpath("web_config.json").write_text(json.dumps(data)) + + cfg.load() + + value = cfg.get("default_k_avg") + assert isinstance(value, float), f"Expected float, got {type(value)}" + assert value == 10.0 + + +def test_json_float_coerced_to_int_when_default_is_int(tmp_path): + """ + If a default is int and the file stores a float, coerce back to int. + """ + cfg = _make_config(tmp_path) + data = {**cfg.config, "default_nodes": 1000.0} # float instead of int + tmp_path.joinpath("web_config.json").write_text(json.dumps(data)) + + cfg.load() + + value = cfg.get("default_nodes") + assert isinstance(value, int), f"Expected int, got {type(value)}" + assert value == 1000 + + +# ── Resilience ──────────────────────────────────────────────────────────────── + + +def test_corrupted_config_file_falls_back_to_defaults(tmp_path): + """A malformed JSON config must not crash — fall back to DEFAULTS.""" + from spkmc.web.config import WebConfig + + config_file = tmp_path / "web_config.json" + config_file.write_text("{this is not valid json}") + + cfg = _make_config(tmp_path) + cfg.load() + + assert cfg.get("data_directory") == WebConfig.DEFAULTS["data_directory"] + + +def test_missing_key_returns_provided_default(tmp_path): + """get() must return the caller-supplied default for absent keys.""" + cfg = _make_config(tmp_path) + assert cfg.get("nonexistent_key", "fallback") == "fallback" + + +def test_missing_key_returns_none_by_default(tmp_path): + """get() must return None (not raise) for an absent key.""" + cfg = _make_config(tmp_path) + assert cfg.get("nonexistent_key") is None + + +def test_load_merges_file_with_defaults(tmp_path): + """ + A config file that only contains some keys must be merged with DEFAULTS + so all expected keys remain present. + """ + from spkmc.web.config import WebConfig + + partial = {"data_directory": "custom_data"} + tmp_path.joinpath("web_config.json").write_text(json.dumps(partial)) + + cfg = _make_config(tmp_path) + cfg.load() + + # Custom value was kept + assert cfg.get("data_directory") == "custom_data" + # Default for a key absent from the file is still present + assert cfg.get("default_nodes") == WebConfig.DEFAULTS["default_nodes"] + + +# ── Path helpers ────────────────────────────────────────────────────────────── + + +def test_get_data_path_returns_path_instance(tmp_path): + cfg = _make_config(tmp_path) + result = cfg.get_data_path() + assert isinstance(result, Path) + + +def test_get_data_path_reflects_configured_value(tmp_path): + cfg = _make_config(tmp_path) + cfg.set("data_directory", "my_data") + assert cfg.get_data_path() == Path("my_data") + + +def test_get_experiments_path_returns_path_instance(tmp_path): + cfg = _make_config(tmp_path) + result = cfg.get_experiments_path() + assert isinstance(result, Path) + + +def test_get_experiments_path_reflects_configured_value(tmp_path): + cfg = _make_config(tmp_path) + cfg.set("experiments_directory", "my_experiments") + assert cfg.get_experiments_path() == Path("my_experiments") + + +# ── OpenAI secrets ──────────────────────────────────────────────────────────── + + +def test_set_and_read_openai_api_key_round_trip(tmp_path, monkeypatch): + """ + set_openai_api_key writes to .streamlit/secrets.toml inside tmp_path. + The subsequent read must return the same key. + """ + from unittest.mock import MagicMock, patch + + from spkmc.web.config import WebConfig + + monkeypatch.chdir(tmp_path) + + WebConfig.set_openai_api_key("sk-test-abc123") + + secrets_file = tmp_path / ".streamlit" / "secrets.toml" + assert secrets_file.exists() + + content = secrets_file.read_text() + assert "sk-test-abc123" in content + + +def test_set_openai_api_key_preserves_existing_secrets(tmp_path, monkeypatch): + """ + Writing a new API key must not clobber other secrets already in the file. + """ + from spkmc.web.config import WebConfig + + monkeypatch.chdir(tmp_path) + secrets_dir = tmp_path / ".streamlit" + secrets_dir.mkdir() + (secrets_dir / "secrets.toml").write_text('OTHER_SECRET = "keep_me"\n') + + WebConfig.set_openai_api_key("sk-new-key") + + content = (secrets_dir / "secrets.toml").read_text() + assert "keep_me" in content + assert "sk-new-key" in content + + +def test_set_openai_api_key_preserves_toml_structure(tmp_path, monkeypatch): + """Structured TOML (sections, typed values, comments) must survive API key update.""" + from spkmc.web.config import WebConfig + + monkeypatch.chdir(tmp_path) + secrets_dir = tmp_path / ".streamlit" + secrets_dir.mkdir() + + original = ( + "# Top comment\n" + 'OPENAI_API_KEY = "sk-old"\n' + "\n" + "[database]\n" + 'host = "localhost"\n' + "port = 5432\n" + ) + (secrets_dir / "secrets.toml").write_text(original) + + WebConfig.set_openai_api_key("sk-new") + + content = (secrets_dir / "secrets.toml").read_text() + # New key must be present, old value gone + assert "sk-new" in content + assert "sk-old" not in content + # TOML structure must be preserved verbatim + assert "[database]" in content + assert 'host = "localhost"' in content + assert "port = 5432" in content + assert "# Top comment" in content + + +def test_set_openai_api_key_appends_when_key_absent(tmp_path, monkeypatch): + """When OPENAI_API_KEY is not yet in the file, append without disturbing content.""" + from spkmc.web.config import WebConfig + + monkeypatch.chdir(tmp_path) + secrets_dir = tmp_path / ".streamlit" + secrets_dir.mkdir() + + original = "[other]\nfoo = true\n" + (secrets_dir / "secrets.toml").write_text(original) + + WebConfig.set_openai_api_key("sk-appended") + + content = (secrets_dir / "secrets.toml").read_text() + assert "sk-appended" in content + assert "[other]" in content + assert "foo = true" in content + + +def test_get_openai_api_key_returns_override_after_set(tmp_path, monkeypatch): + """After set_openai_api_key(), get_openai_api_key() returns the new value.""" + import spkmc.web.config as config_mod + from spkmc.web.config import WebConfig + + monkeypatch.chdir(tmp_path) + # Reset module-level override to isolate from other tests + monkeypatch.setattr(config_mod, "_api_key_override", None) + + WebConfig.set_openai_api_key("sk-first") + assert WebConfig.get_openai_api_key() == "sk-first" + + WebConfig.set_openai_api_key("sk-second") + assert WebConfig.get_openai_api_key() == "sk-second" + + +# ── atomic_json_write tests ────────────────────────────────────────────────── + + +class TestAtomicJsonWrite: + """Tests for the atomic_json_write helper.""" + + def test_writes_valid_json(self, tmp_path): + from spkmc.web import atomic_json_write + + path = tmp_path / "test.json" + data = {"key": "value", "num": 42} + atomic_json_write(path, data) + + with open(path) as f: + loaded = json.load(f) + assert loaded == data + + def test_overwrites_existing_file(self, tmp_path): + from spkmc.web import atomic_json_write + + path = tmp_path / "test.json" + atomic_json_write(path, {"v": 1}) + atomic_json_write(path, {"v": 2}) + + with open(path) as f: + loaded = json.load(f) + assert loaded == {"v": 2} + + def test_no_temp_file_left_on_success(self, tmp_path): + from spkmc.web import atomic_json_write + + path = tmp_path / "test.json" + atomic_json_write(path, {"a": 1}) + + tmp_file = path.with_suffix(".json.tmp") + assert not tmp_file.exists() + + def test_preserves_original_on_failure(self, tmp_path): + from spkmc.web import atomic_json_write + + path = tmp_path / "test.json" + original = {"original": True} + atomic_json_write(path, original) + + # Attempt to write non-serializable data + class BadObj: + pass + + with pytest.raises(TypeError): + atomic_json_write(path, {"bad": BadObj()}) + + # Original should still be intact + with open(path) as f: + loaded = json.load(f) + assert loaded == original + + # Temp file should be cleaned up + assert not path.with_suffix(".json.tmp").exists() diff --git a/tests/test_web/test_experiment_detail.py b/tests/test_web/test_experiment_detail.py new file mode 100644 index 0000000..91b81ed --- /dev/null +++ b/tests/test_web/test_experiment_detail.py @@ -0,0 +1,352 @@ +""" +Tests for experiment_detail page logic. + +Covers update_scenario_in_experiment and related functions that manage +scenario editing, label collision detection, and result file lifecycle. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict + +import pytest + +from spkmc.models.experiment import Experiment +from spkmc.models.scenario import Scenario + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _write_data_json(exp_path: Path, data: Dict[str, Any]) -> None: + """Write a data.json file for an experiment.""" + (exp_path / "data.json").write_text(json.dumps(data, indent=2)) + + +def _read_data_json(exp_path: Path) -> Dict[str, Any]: + """Read and parse an experiment's data.json.""" + return json.loads((exp_path / "data.json").read_text()) + + +def _make_legacy_experiment(tmp_path: Path) -> Experiment: + """Create a legacy experiment (no global ``parameters`` block). + + Returns an Experiment whose data.json stores full params in each scenario + entry — the format used before the web interface introduced global params. + """ + exp_path = tmp_path / "experiments" / "legacy_exp" + exp_path.mkdir(parents=True) + + data = { + "name": "Legacy Experiment", + "description": "Pre-web-interface experiment", + "scenarios": [ + { + "label": "Baseline", + "network": "er", + "distribution": "gamma", + "nodes": 500, + "k_avg": 5.0, + "lambda": 0.5, + "shape": 2.0, + "scale": 1.0, + "samples": 10, + "num_runs": 1, + "initial_perc": 0.01, + "t_max": 5.0, + "steps": 50, + }, + { + "label": "High Lambda", + "network": "er", + "distribution": "gamma", + "nodes": 500, + "k_avg": 5.0, + "lambda": 2.0, + "shape": 2.0, + "scale": 1.0, + "samples": 10, + "num_runs": 1, + "initial_perc": 0.01, + "t_max": 5.0, + "steps": 50, + }, + ], + } + _write_data_json(exp_path, data) + + scenarios = [ + Scenario( + label="Baseline", + network="er", + distribution="gamma", + nodes=500, + k_avg=5.0, + shape=2.0, + scale=1.0, + samples=10, + initial_perc=0.01, + t_max=5.0, + steps=50, + **{"lambda": 0.5}, + ), + Scenario( + label="High Lambda", + network="er", + distribution="gamma", + nodes=500, + k_avg=5.0, + shape=2.0, + scale=1.0, + samples=10, + initial_perc=0.01, + t_max=5.0, + steps=50, + **{"lambda": 2.0}, + ), + ] + + return Experiment(name="Legacy Experiment", scenarios=scenarios, path=exp_path) + + +def _make_modern_experiment(tmp_path: Path) -> Experiment: + """Create a modern experiment with a global ``parameters`` block.""" + exp_path = tmp_path / "experiments" / "modern_exp" + exp_path.mkdir(parents=True) + + data = { + "name": "Modern Experiment", + "description": "Experiment with global params", + "parameters": { + "network": "er", + "distribution": "gamma", + "nodes": 1000, + "k_avg": 10.0, + "lambda": 0.5, + "shape": 2.0, + "scale": 1.0, + "samples": 50, + "num_runs": 1, + "initial_perc": 0.01, + "t_max": 10.0, + "steps": 100, + }, + "scenarios": [ + {"label": "Baseline"}, + {"label": "High Lambda", "lambda": 2.0}, + ], + } + _write_data_json(exp_path, data) + + scenarios = [ + Scenario( + label="Baseline", + network="er", + distribution="gamma", + nodes=1000, + k_avg=10.0, + shape=2.0, + scale=1.0, + samples=50, + initial_perc=0.01, + t_max=10.0, + steps=100, + **{"lambda": 0.5}, + ), + Scenario( + label="High Lambda", + network="er", + distribution="gamma", + nodes=1000, + k_avg=10.0, + shape=2.0, + scale=1.0, + samples=50, + initial_perc=0.01, + t_max=10.0, + steps=100, + **{"lambda": 2.0}, + ), + ] + + return Experiment( + name="Modern Experiment", + scenarios=scenarios, + path=exp_path, + parameters=data["parameters"], + ) + + +# ── update_scenario_in_experiment ──────────────────────────────────────────── + + +class TestUpdateScenarioInExperiment: + """Tests for update_scenario_in_experiment().""" + + def test_noop_edit_on_legacy_experiment_preserves_results(self, tmp_path): + """P1 regression: a no-op edit on a legacy experiment must NOT delete result files.""" + exp = _make_legacy_experiment(tmp_path) + exp_path = exp.path + assert exp_path is not None + + # Create result and analysis files that should be preserved + result_file = exp_path / "baseline.json" + analysis_file = exp_path / "baseline_analysis.md" + result_file.write_text('{"S_val": [1]}') + analysis_file.write_text("# Analysis") + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + # Simulate a no-op edit: same label, empty overrides (matching hardcoded defaults). + # For legacy experiments the form produces override_params containing only + # values that differ from hardcoded defaults — NOT all stored params. + # A no-op edit where some stored params happen to match defaults yields + # a sparse override_params dict. + update_scenario_in_experiment( + experiment=exp, + original_label="Baseline", + new_label="Baseline", + override_params={ + # Only include params that differ from hardcoded defaults. + # For legacy scenarios these are the values the form would emit. + "nodes": 500, # differs from hardcoded default of 1000 + "k_avg": 5.0, # differs from hardcoded default of 10.0 + "t_max": 5.0, # differs from hardcoded default of 10.0 + "steps": 50, # differs from hardcoded default of 100 + "samples": 10, # differs from hardcoded default of 50 + }, + ) + + # Result files must still exist (no-op edit should not delete them) + assert result_file.exists(), "Result file was deleted by a no-op edit!" + assert analysis_file.exists(), "Analysis file was deleted by a no-op edit!" + + def test_noop_edit_on_modern_experiment_preserves_results(self, tmp_path): + """Modern experiments: no-op edit must NOT delete result files.""" + exp = _make_modern_experiment(tmp_path) + exp_path = exp.path + assert exp_path is not None + + result_file = exp_path / "high_lambda.json" + analysis_file = exp_path / "high_lambda_analysis.md" + result_file.write_text('{"S_val": [1]}') + analysis_file.write_text("# Analysis") + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + # The override is the same as the existing one (lambda: 2.0) + update_scenario_in_experiment( + experiment=exp, + original_label="High Lambda", + new_label="High Lambda", + override_params={"lambda": 2.0}, + ) + + assert result_file.exists(), "Result file was deleted by a no-op edit!" + assert analysis_file.exists(), "Analysis file was deleted by a no-op edit!" + + def test_actual_edit_deletes_stale_results(self, tmp_path): + """When params actually change, stale result files must be deleted.""" + exp = _make_modern_experiment(tmp_path) + exp_path = exp.path + assert exp_path is not None + + result_file = exp_path / "high_lambda.json" + analysis_file = exp_path / "high_lambda_analysis.md" + result_file.write_text('{"S_val": [1]}') + analysis_file.write_text("# Analysis") + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + # Change lambda from 2.0 to 3.0 + update_scenario_in_experiment( + experiment=exp, + original_label="High Lambda", + new_label="High Lambda", + override_params={"lambda": 3.0}, + ) + + assert not result_file.exists(), "Result file was NOT deleted after param change!" + assert not analysis_file.exists(), "Analysis file was NOT deleted after param change!" + + def test_label_rename_deletes_old_results(self, tmp_path): + """Renaming a scenario must delete the old result files.""" + exp = _make_modern_experiment(tmp_path) + exp_path = exp.path + assert exp_path is not None + + old_result = exp_path / "high_lambda.json" + old_analysis = exp_path / "high_lambda_analysis.md" + old_result.write_text('{"S_val": [1]}') + old_analysis.write_text("# Analysis") + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + update_scenario_in_experiment( + experiment=exp, + original_label="High Lambda", + new_label="Very High Lambda", + override_params={"lambda": 2.0}, + ) + + assert not old_result.exists(), "Old result file was NOT deleted after rename!" + assert not old_analysis.exists(), "Old analysis file was NOT deleted after rename!" + + def test_label_collision_raises_error(self, tmp_path): + """Renaming to an existing scenario's normalized label must raise ValueError.""" + exp = _make_modern_experiment(tmp_path) + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + with pytest.raises(ValueError, match="conflicting name"): + update_scenario_in_experiment( + experiment=exp, + original_label="High Lambda", + new_label="Baseline", + override_params={}, + ) + + def test_empty_label_raises_error(self, tmp_path): + """A label that normalizes to empty string must raise ValueError.""" + exp = _make_modern_experiment(tmp_path) + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + with pytest.raises(ValueError, match="normalizes to an empty"): + update_scenario_in_experiment( + experiment=exp, + original_label="High Lambda", + new_label="!!!", + override_params={}, + ) + + def test_legacy_actual_edit_deletes_stale_results(self, tmp_path): + """Legacy experiment: actual param change must delete stale results.""" + exp = _make_legacy_experiment(tmp_path) + exp_path = exp.path + assert exp_path is not None + + result_file = exp_path / "baseline.json" + analysis_file = exp_path / "baseline_analysis.md" + result_file.write_text('{"S_val": [1]}') + analysis_file.write_text("# Analysis") + + from spkmc.web.pages.experiment_detail import update_scenario_in_experiment + + # Change nodes from 500 to 600 (an actual parameter change) + update_scenario_in_experiment( + experiment=exp, + original_label="Baseline", + new_label="Baseline", + override_params={ + "nodes": 600, # CHANGED from 500 + "k_avg": 5.0, + "t_max": 5.0, + "steps": 50, + "samples": 10, + }, + ) + + assert not result_file.exists(), "Result file was NOT deleted after param change!" + assert not analysis_file.exists(), "Analysis file was NOT deleted after param change!" diff --git a/tests/test_web/test_plotting.py b/tests/test_web/test_plotting.py new file mode 100644 index 0000000..72d9ae4 --- /dev/null +++ b/tests/test_web/test_plotting.py @@ -0,0 +1,354 @@ +"""Tests for web plotting functions.""" + +from __future__ import annotations + +import numpy as np +import pytest + +pytestmark = pytest.mark.skipif( + pytest.importorskip("plotly", reason="plotly not installed") is None, + reason="plotly not installed", +) + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def sir_result(): + """Minimal SIR result dict with S, I, R values.""" + t = np.linspace(0, 10, 100).tolist() + return { + "time": t, + "S_val": np.linspace(1.0, 0.5, 100).tolist(), + "I_val": np.linspace(0.0, 0.3, 100).tolist(), + "R_val": np.linspace(0.0, 0.2, 100).tolist(), + } + + +@pytest.fixture() +def sir_result_with_errors(sir_result): + """SIR result dict including error band arrays.""" + sir_result["S_err"] = (np.ones(100) * 0.01).tolist() + sir_result["I_err"] = (np.ones(100) * 0.01).tolist() + sir_result["R_err"] = (np.ones(100) * 0.01).tolist() + return sir_result + + +# ── _hex_to_rgba ────────────────────────────────────────────────────────────── + + +class TestHexToRgba: + def test_converts_black(self): + from spkmc.web.plotting import _hex_to_rgba + + assert _hex_to_rgba("#000000") == "rgba(0, 0, 0, 1.0)" + + def test_converts_white(self): + from spkmc.web.plotting import _hex_to_rgba + + assert _hex_to_rgba("#ffffff") == "rgba(255, 255, 255, 1.0)" + + def test_converts_known_color(self): + from spkmc.web.plotting import _hex_to_rgba + + # #4477AA → r=68, g=119, b=170 + result = _hex_to_rgba("#4477AA") + assert result == "rgba(68, 119, 170, 1.0)" + + def test_applies_alpha(self): + from spkmc.web.plotting import _hex_to_rgba + + result = _hex_to_rgba("#ffffff", alpha=0.15) + assert "0.15" in result + + def test_strips_hash_prefix(self): + from spkmc.web.plotting import _hex_to_rgba + + # With and without # must produce the same result + assert _hex_to_rgba("#4477AA") == _hex_to_rgba("#4477AA") + + +# ── create_sir_figure ───────────────────────────────────────────────────────── + + +class TestCreateSirFigure: + def test_produces_three_traces_by_default(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result) + assert len(fig.data) == 3 + + def test_state_subset_returns_only_requested_traces(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result, states=["I"]) + assert len(fig.data) == 1 + assert fig.data[0].name == "I" + + def test_two_state_subset(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result, states=["S", "R"]) + names = [t.name for t in fig.data] + assert "S" in names + assert "R" in names + assert "I" not in names + + def test_error_bands_are_added_when_present(self, sir_result_with_errors): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result_with_errors, show_error_bands=True) + for trace in fig.data: + assert trace.error_y is not None + assert trace.error_y.visible is True + + def test_error_bands_absent_when_disabled(self, sir_result_with_errors): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result_with_errors, show_error_bands=False) + for trace in fig.data: + # Plotly represents "no error bars" as ErrorY with visible=None/False, + # not as Python None. Check that visible is not True. + assert trace.error_y.visible is not True + + def test_custom_state_colors_override_defaults(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + custom_colors = {"I": "#FF0000"} + fig = create_sir_figure(sir_result, states=["I"], state_colors=custom_colors) + i_trace = next(t for t in fig.data if t.name == "I") + assert i_trace.line.color == "#FF0000" + + def test_default_colors_are_unchanged_when_not_overridden(self, sir_result): + from spkmc.web.plotting import COLOR_S, create_sir_figure + + fig = create_sir_figure(sir_result, states=["S"]) + s_trace = fig.data[0] + assert s_trace.line.color == COLOR_S + + def test_custom_template_propagates_to_layout(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result, template="plotly_dark") + assert fig.layout.template.layout.colorway is not None or True + # The template name is resolved by Plotly; verify the call didn't raise + assert fig is not None + + def test_height_is_applied_to_layout(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result, height=800) + assert fig.layout.height == 800 + + def test_area_chart_mode_sets_fill_on_traces(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result, chart_mode="area") + for trace in fig.data: + assert trace.fill == "tozeroy" + + def test_lines_plus_markers_mode_is_set(self, sir_result): + from spkmc.web.plotting import create_sir_figure + + fig = create_sir_figure(sir_result, chart_mode="lines+markers") + for trace in fig.data: + assert trace.mode == "lines+markers" + + def test_missing_state_key_is_silently_skipped(self, sir_result): + """Requesting a state not present in result_dict must not raise.""" + from spkmc.web.plotting import create_sir_figure + + del sir_result["I_val"] + fig = create_sir_figure(sir_result, states=["S", "I", "R"]) + names = [t.name for t in fig.data] + assert "I" not in names + assert "S" in names + + def test_numpy_array_inputs_do_not_raise(self): + """NumPy arrays in result_dict must not cause a truthiness ValueError.""" + from spkmc.web.plotting import create_sir_figure + + result = { + "time": np.array([0, 1, 2, 3, 4]), + "S_val": np.array([1.0, 0.9, 0.8, 0.7, 0.6]), + "I_val": np.array([0.0, 0.05, 0.1, 0.15, 0.2]), + "R_val": np.array([0.0, 0.05, 0.1, 0.15, 0.2]), + } + fig = create_sir_figure(result) + assert len(fig.data) == 3 + assert fig.layout.xaxis.range == (0, 4.0) + + +# ── create_comparison_figure ────────────────────────────────────────────────── + + +class TestCreateComparisonFigure: + @pytest.fixture() + def two_results(self): + t = np.linspace(0, 10, 100).tolist() + return [ + { + "time": t, + "S_val": np.linspace(1.0, 0.5, 100).tolist(), + "I_val": np.linspace(0.0, 0.3, 100).tolist(), + "R_val": np.linspace(0.0, 0.2, 100).tolist(), + }, + { + "time": t, + "S_val": np.linspace(1.0, 0.4, 100).tolist(), + "I_val": np.linspace(0.0, 0.4, 100).tolist(), + "R_val": np.linspace(0.0, 0.2, 100).tolist(), + }, + ] + + def test_two_scenarios_three_states_produces_six_traces(self, two_results): + from spkmc.web.plotting import create_comparison_figure + + fig = create_comparison_figure( + two_results, ["Scenario A", "Scenario B"], states=["S", "I", "R"] + ) + assert len(fig.data) == 6 + + def test_single_scenario_produces_correct_trace_count(self): + from spkmc.web.plotting import create_comparison_figure + + t = np.linspace(0, 10, 50).tolist() + result = { + "time": t, + "I_val": np.linspace(0.0, 0.3, 50).tolist(), + } + fig = create_comparison_figure([result], ["Solo"], states=["I"]) + assert len(fig.data) == 1 + + def test_trace_names_include_scenario_label_and_state(self, two_results): + from spkmc.web.plotting import create_comparison_figure + + fig = create_comparison_figure(two_results, ["Alpha", "Beta"], states=["I"]) + names = [t.name for t in fig.data] + assert any("Alpha" in n for n in names) + assert any("Beta" in n for n in names) + + def test_custom_template_applied(self, two_results): + from spkmc.web.plotting import create_comparison_figure + + fig = create_comparison_figure(two_results, ["A", "B"], template="plotly_dark") + assert fig is not None + + def test_state_subset_limits_traces(self, two_results): + from spkmc.web.plotting import create_comparison_figure + + fig = create_comparison_figure(two_results, ["A", "B"], states=["I"]) + assert len(fig.data) == 2 # 2 scenarios × 1 state + + +# ── create_metric_card_figure ───────────────────────────────────────────────── + + +class TestCreateMetricCardFigure: + def test_returns_figure(self): + from spkmc.web.plotting import create_metric_card_figure + + fig = create_metric_card_figure(0.42, "Peak Infected") + assert fig is not None + + def test_contains_indicator_trace(self): + import plotly.graph_objects as go + + from spkmc.web.plotting import create_metric_card_figure + + fig = create_metric_card_figure(0.75, "Final Recovered") + assert len(fig.data) == 1 + assert isinstance(fig.data[0], go.Indicator) + + def test_value_is_stored_in_indicator(self): + from spkmc.web.plotting import create_metric_card_figure + + fig = create_metric_card_figure(0.33, "Some Metric") + assert fig.data[0].value == pytest.approx(0.33) + + def test_title_appears_in_figure(self): + from spkmc.web.plotting import create_metric_card_figure + + fig = create_metric_card_figure(0.5, "My Title", subtitle="Details here") + title_text = fig.data[0].title.text + assert "My Title" in title_text + + def test_custom_color_is_applied(self): + from spkmc.web.plotting import create_metric_card_figure + + fig = create_metric_card_figure(0.5, "Metric", color="#FF0000") + assert fig.data[0].number.font.color == "#FF0000" + + def test_height_is_compact(self): + from spkmc.web.plotting import create_metric_card_figure + + fig = create_metric_card_figure(0.5, "Compact") + assert fig.layout.height == 150 + + +# ── Visualizer.compare_results_with_config ─────────────────────────────────── + + +class TestCompareResultsWithConfig: + """Verify that PlotConfig settings propagate through the Plotly refactor.""" + + @pytest.fixture() + def two_results(self): + t = np.linspace(0, 10, 50).tolist() + return [ + { + "time": t, + "S_val": np.linspace(1, 0.5, 50).tolist(), + "I_val": np.linspace(0, 0.3, 50).tolist(), + "R_val": np.linspace(0, 0.2, 50).tolist(), + }, + { + "time": t, + "S_val": np.linspace(1, 0.4, 50).tolist(), + "I_val": np.linspace(0, 0.4, 50).tolist(), + "R_val": np.linspace(0, 0.2, 50).tolist(), + }, + ] + + def _capture_figure(self, results, labels, plot_config): + """Run compare_results_with_config and capture the figure.""" + import spkmc.visualization.plots as vp + from spkmc.models.config import PlotConfig + from spkmc.visualization.plots import Visualizer + + captured = {} + orig = vp._save_or_show + + def fake_save(fig, *a, **kw): + captured["fig"] = fig + + vp._save_or_show = fake_save + try: + Visualizer.compare_results_with_config(results, labels, plot_config) + finally: + vp._save_or_show = orig + return captured["fig"] + + def test_grid_disabled_propagates(self, two_results): + from spkmc.models.config import PlotConfig + + pc = PlotConfig(grid=False) + fig = self._capture_figure(two_results, ["A", "B"], pc) + assert fig.layout.xaxis.showgrid is False + assert fig.layout.yaxis.showgrid is False + + def test_grid_alpha_propagates(self, two_results): + from spkmc.models.config import PlotConfig + + pc = PlotConfig(grid=True, grid_alpha=0.7) + fig = self._capture_figure(two_results, ["A", "B"], pc) + assert "0.7" in fig.layout.xaxis.gridcolor + + def test_legend_position_center(self, two_results): + from spkmc.models.config import PlotConfig + + pc = PlotConfig(legend_position="center") + fig = self._capture_figure(two_results, ["A", "B"], pc) + assert fig.layout.legend.x == 0.5 + assert fig.layout.legend.y == 0.5 diff --git a/tests/test_web/test_runner.py b/tests/test_web/test_runner.py new file mode 100644 index 0000000..8930b4c --- /dev/null +++ b/tests/test_web/test_runner.py @@ -0,0 +1,382 @@ +""" +Tests for SimulationRunner. + +Tests cover file-based status management, completion detection, progress +reading, and cleanup. Subprocess execution is NOT tested — that belongs to +integration tests. The runner fixture bypasses __init__ so no real +.spkmc_web/ directory is created during the test suite. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +# ── Fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def runner(tmp_path): + """SimulationRunner with status_dir isolated to tmp_path.""" + from spkmc.web.runner import SimulationRunner + + r = SimulationRunner.__new__(SimulationRunner) + r.status_dir = tmp_path / "status" + r.status_dir.mkdir() + r._processes = {} + return r + + +@pytest.fixture() +def minimal_scenario(): + from spkmc.models.scenario import Scenario + + return Scenario( + label="Baseline", + network="er", + distribution="gamma", + nodes=1000, + samples=50, + k_avg=10.0, + **{"lambda": 0.5}, + shape=2.0, + scale=1.0, + t_max=10.0, + steps=100, + initial_perc=0.01, + ) + + +@pytest.fixture() +def minimal_experiment(tmp_path, minimal_scenario): + from spkmc.models.experiment import Experiment + + exp_path = tmp_path / "experiments" / "test_experiment" + exp_path.mkdir(parents=True) + return Experiment( + name="Test Experiment", + scenarios=[minimal_scenario], + path=exp_path, + ) + + +# ── get_status ──────────────────────────────────────────────────────────────── + + +class TestGetStatus: + def test_returns_none_for_missing_run_id(self, runner): + assert runner.get_status("nonexistent_run") is None + + def test_returns_parsed_dict_for_valid_status_file(self, runner): + data = {"run_id": "run_1", "status": "running", "progress": 5, "total": 100} + (runner.status_dir / "run_1.json").write_text(json.dumps(data)) + + result = runner.get_status("run_1") + assert result == data + + def test_returns_none_for_corrupted_json(self, runner): + (runner.status_dir / "bad.json").write_text("{not: valid}") + assert runner.get_status("bad") is None + + def test_returns_none_for_empty_file(self, runner): + (runner.status_dir / "empty.json").write_text("") + assert runner.get_status("empty") is None + + +# ── is_running ──────────────────────────────────────────────────────────────── + + +class TestIsRunning: + def test_returns_true_when_status_is_running(self, runner): + (runner.status_dir / "r1.json").write_text( + json.dumps({"run_id": "r1", "status": "running"}) + ) + assert runner.is_running("r1") is True + + def test_returns_false_when_status_is_completed(self, runner): + (runner.status_dir / "r1.json").write_text( + json.dumps({"run_id": "r1", "status": "completed"}) + ) + assert runner.is_running("r1") is False + + def test_returns_false_when_status_is_failed(self, runner): + (runner.status_dir / "r1.json").write_text(json.dumps({"run_id": "r1", "status": "failed"})) + assert runner.is_running("r1") is False + + def test_returns_false_for_nonexistent_run(self, runner): + assert runner.is_running("ghost") is False + + +# ── check_completion ────────────────────────────────────────────────────────── + + +class TestCheckCompletion: + def test_returns_true_when_result_file_exists(self, runner, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "my_exp" + exp_dir.mkdir(parents=True) + (exp_dir / "baseline.json").touch() + + assert runner.check_completion("my_exp", "Baseline") is True + + def test_returns_false_when_result_file_is_missing(self, runner, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "my_exp" + exp_dir.mkdir(parents=True) + + assert runner.check_completion("my_exp", "Baseline") is False + + def test_label_is_normalized_before_checking(self, runner, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + # "High Risk Scenario" normalizes to "high_risk_scenario" + (exp_dir / "high_risk_scenario.json").touch() + + assert runner.check_completion("exp1", "High Risk Scenario") is True + + +# ── cleanup_status ──────────────────────────────────────────────────────────── + + +class TestCleanupStatus: + def test_removes_status_json_and_script_files(self, runner): + run_id = "run_cleanup" + (runner.status_dir / f"{run_id}.json").write_text("{}") + (runner.status_dir / f"{run_id}_script.py").write_text("pass") + + runner.cleanup_status(run_id) + + assert not (runner.status_dir / f"{run_id}.json").exists() + assert not (runner.status_dir / f"{run_id}_script.py").exists() + + def test_cleanup_is_idempotent_when_files_already_absent(self, runner): + # Must not raise when called with a run_id that has no files + runner.cleanup_status("nonexistent_run") + runner.cleanup_status("nonexistent_run") + + def test_cleanup_handles_missing_script_file_gracefully(self, runner): + run_id = "partial" + (runner.status_dir / f"{run_id}.json").write_text("{}") + # No script file + + runner.cleanup_status(run_id) + assert not (runner.status_dir / f"{run_id}.json").exists() + + +# ── get_progress ────────────────────────────────────────────────────────────── + + +class TestGetProgress: + def test_returns_progress_tuple_from_status_file(self, runner): + data = {"run_id": "r1", "status": "running", "progress": 30, "total": 100} + (runner.status_dir / "r1.json").write_text(json.dumps(data)) + + assert runner.get_progress("r1") == (30, 100) + + def test_returns_none_for_missing_status_file(self, runner): + assert runner.get_progress("ghost") is None + + def test_defaults_to_zero_when_progress_keys_absent(self, runner): + (runner.status_dir / "r1.json").write_text( + json.dumps({"run_id": "r1", "status": "running"}) + ) + assert runner.get_progress("r1") == (0, 0) + + def test_returns_full_progress_at_completion(self, runner): + data = {"run_id": "r1", "status": "completed", "progress": 100, "total": 100} + (runner.status_dir / "r1.json").write_text(json.dumps(data)) + + progress, total = runner.get_progress("r1") + assert progress == total == 100 + + +# ── _build_execution_script ─────────────────────────────────────────────────── + + +class TestBuildExecutionScript: + def test_script_references_experiment_path(self, runner, minimal_experiment, minimal_scenario): + script = runner._build_execution_script(minimal_experiment, minimal_scenario, "test_run_id") + assert repr(str(minimal_experiment.path)) in script + + def test_script_contains_scenario_normalized_label( + self, runner, minimal_experiment, minimal_scenario + ): + script = runner._build_execution_script(minimal_experiment, minimal_scenario, "test_run_id") + assert minimal_scenario.normalized_label in script + + def test_script_contains_execution_engine_import( + self, runner, minimal_experiment, minimal_scenario + ): + script = runner._build_execution_script(minimal_experiment, minimal_scenario, "test_run_id") + assert "ExecutionEngine" in script + + def test_script_is_valid_python_syntax(self, runner, minimal_experiment, minimal_scenario): + import ast + + script = runner._build_execution_script(minimal_experiment, minimal_scenario, "test_run_id") + # Must not raise SyntaxError + ast.parse(script) + + def test_script_contains_progress_callback(self, runner, minimal_experiment, minimal_scenario): + script = runner._build_execution_script(minimal_experiment, minimal_scenario, "test_run_id") + assert "_progress_callback" in script + + def test_script_uses_exact_status_file_path(self, runner, minimal_experiment, minimal_scenario): + """P1 bugfix: script must use exact status file path, not prefix-glob discovery.""" + run_id = "sim--myexp--baseline--1700000000" + script = runner._build_execution_script(minimal_experiment, minimal_scenario, run_id) + assert f"{run_id}.json" in script + assert "glob(" not in script + + def test_scenario_with_apostrophe_in_label_is_safe(self, runner, tmp_path): + """P1 bugfix: scenario JSON with quotes must not break the generated script.""" + from spkmc.models.experiment import Experiment + from spkmc.models.scenario import Scenario + + scenario = Scenario( + label="O'Brien's Test", + network="er", + distribution="gamma", + nodes=100, + samples=10, + k_avg=5.0, + **{"lambda": 1.0}, + shape=2.0, + scale=1.0, + t_max=5.0, + steps=50, + initial_perc=0.01, + ) + exp_path = tmp_path / "experiments" / "apos_exp" + exp_path.mkdir(parents=True) + experiment = Experiment(name="Apostrophe Exp", scenarios=[scenario], path=exp_path) + + import ast + + script = runner._build_execution_script(experiment, scenario, "test_apos_run_id") + ast.parse(script) + + +# ── run_all_scenarios skips existing results ────────────────────────────────── + + +class TestRunAllScenariosSkipsExistingResults: + def test_skips_scenario_with_existing_result( + self, runner, minimal_experiment, tmp_path, monkeypatch + ): + """run_all_scenarios must not re-launch scenarios that already have results.""" + # Pre-create the result file for the baseline scenario + result_file = minimal_experiment.path / "baseline.json" + result_file.touch() + + launched = [] + + def mock_run_scenario(exp, sc, show_progress=False): + launched.append(sc.normalized_label) + return "fake_run_id" + + runner.run_scenario = mock_run_scenario + + # Patch st.toast so it doesn't fail outside Streamlit + with patch("spkmc.web.runner.st") as mock_st: + run_ids = runner.run_all_scenarios(minimal_experiment, show_progress=False) + + assert "baseline" not in launched + assert run_ids == [] + + +# ── poll_running_simulations: dead process with output file ────────────────── + + +class _DictAttr(dict): + """Dict that also supports attribute access (like Streamlit session_state).""" + + def __getattr__(self, name): + try: + return self[name] + except KeyError: + raise AttributeError(name) + + +class TestPollDeadProcessWithOutputFile: + """Regression: a dead PID must be marked completed when the result file exists.""" + + @staticmethod + def _make_mock_st(runner, run_id): + """Build a mock ``st`` module whose ``session_state`` behaves like a dict.""" + from unittest.mock import MagicMock + + mock_st = MagicMock() + mock_st.session_state = _DictAttr( + simulation_runner=runner, + running_simulations={ + "sim--test_exp--baseline": { + "experiment_name": "test_exp", + "scenario_label": "Baseline", + "run_id": run_id, + } + }, + ) + return mock_st + + def test_dead_process_with_result_file_marks_completed(self, runner, tmp_path, monkeypatch): + from unittest.mock import MagicMock + + import spkmc.web.runner as runner_mod + + # Write a status file claiming "running" with a dead PID + run_id = "sim--test_exp--baseline--1700000000" + status_data = { + "run_id": run_id, + "status": "running", + "pid": 99999999, # PID that does not exist + } + (runner.status_dir / f"{run_id}.json").write_text(json.dumps(status_data)) + + # Simulate result file existing (check_completion returns True) + runner.check_completion = lambda *a, **kw: True + + mock_st = self._make_mock_st(runner, run_id) + monkeypatch.setattr(runner_mod, "st", mock_st) + + # SessionState is imported locally inside poll_running_simulations + mock_session = MagicMock() + monkeypatch.setattr("spkmc.web.state.SessionState", mock_session) + + runner_mod.poll_running_simulations() + + # Must call mark_simulation_completed (not mark_simulation_failed) + mock_session.mark_simulation_completed.assert_called_once_with("sim--test_exp--baseline") + mock_session.mark_simulation_failed.assert_not_called() + + def test_dead_process_without_result_file_marks_failed(self, runner, tmp_path, monkeypatch): + from unittest.mock import MagicMock + + import spkmc.web.runner as runner_mod + + run_id = "sim--test_exp--baseline--1700000000" + status_data = { + "run_id": run_id, + "status": "running", + "pid": 99999999, + } + (runner.status_dir / f"{run_id}.json").write_text(json.dumps(status_data)) + + # Simulate result file NOT existing (check_completion returns False) + runner.check_completion = lambda *a, **kw: False + + mock_st = self._make_mock_st(runner, run_id) + monkeypatch.setattr(runner_mod, "st", mock_st) + + mock_session = MagicMock() + monkeypatch.setattr("spkmc.web.state.SessionState", mock_session) + + runner_mod.poll_running_simulations() + + # Must call mark_simulation_failed (no result file to rescue) + mock_session.mark_simulation_failed.assert_called_once() + mock_session.mark_simulation_completed.assert_not_called() diff --git a/tests/test_web/test_state.py b/tests/test_web/test_state.py new file mode 100644 index 0000000..7cd5b39 --- /dev/null +++ b/tests/test_web/test_state.py @@ -0,0 +1,578 @@ +""" +Tests for SessionState business logic. + +Tests cover state machine transitions, scenario selection, progress tracking, +and disk-based restoration. Streamlit is patched at the module level so no +Streamlit runtime is required. +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +class _FakeState(dict): + """Minimal st.session_state substitute: dict with attribute access.""" + + def __getattr__(self, key: str): + try: + return self[key] + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key: str, value): + self[key] = value + + def __delattr__(self, key: str): + try: + del self[key] + except KeyError: + raise AttributeError(key) + + +class _FakeQueryParams(dict): + """Minimal st.query_params substitute.""" + + def pop(self, key, default=None): # type: ignore[override] + return super().pop(key, default) + + +@pytest.fixture() +def session(monkeypatch): + """ + Patch st.session_state and st.query_params with plain dict substitutes. + + Returns the fake session_state dict so tests can pre-populate it. + """ + state = _FakeState() + params = _FakeQueryParams() + + import spkmc.web.state as state_module + + monkeypatch.setattr(state_module.st, "session_state", state) + monkeypatch.setattr(state_module.st, "query_params", params) + return state + + +# ── PID detection ───────────────────────────────────────────────────────────── + + +class TestIsPidAlive: + def test_current_process_is_alive(self): + from spkmc.web.state import _is_pid_alive + + assert _is_pid_alive(os.getpid()) is True + + def test_unreachable_pid_is_not_alive(self): + from spkmc.web.state import _is_pid_alive + + # PID space on modern OSes is typically limited to ~4 million + assert _is_pid_alive(999_999_999) is False + + def test_zero_pid_does_not_raise(self): + from spkmc.web.state import _is_pid_alive + + # PID 0 means "same process group" on POSIX — we just care it doesn't raise + result = _is_pid_alive(0) + assert isinstance(result, bool) + + +# ── Scenario selection ──────────────────────────────────────────────────────── + + +class TestScenarioSelection: + @pytest.fixture(autouse=True) + def _init(self, session): + session["selected_scenarios"] = set() + + def test_toggle_adds_unselected_scenario(self, session): + from spkmc.web.state import SessionState + + SessionState.toggle_scenario_selection("sc_1") + assert "sc_1" in session["selected_scenarios"] + + def test_toggle_removes_already_selected_scenario(self, session): + from spkmc.web.state import SessionState + + session["selected_scenarios"] = {"sc_1"} + SessionState.toggle_scenario_selection("sc_1") + assert "sc_1" not in session["selected_scenarios"] + + def test_double_toggle_restores_original_state(self, session): + from spkmc.web.state import SessionState + + SessionState.toggle_scenario_selection("sc_1") + SessionState.toggle_scenario_selection("sc_1") + assert "sc_1" not in session["selected_scenarios"] + + def test_clear_empties_all_selections(self, session): + from spkmc.web.state import SessionState + + session["selected_scenarios"] = {"sc_1", "sc_2", "sc_3"} + SessionState.clear_scenario_selections() + assert session["selected_scenarios"] == set() + + def test_get_returns_empty_set_when_key_absent(self, session): + from spkmc.web.state import SessionState + + session.pop("selected_scenarios", None) + assert SessionState.get_selected_scenarios() == set() + + def test_independent_scenarios_do_not_interfere(self, session): + from spkmc.web.state import SessionState + + SessionState.toggle_scenario_selection("sc_a") + SessionState.toggle_scenario_selection("sc_b") + SessionState.toggle_scenario_selection("sc_a") # remove sc_a + assert "sc_a" not in session["selected_scenarios"] + assert "sc_b" in session["selected_scenarios"] + + +# ── Simulation state machine ────────────────────────────────────────────────── + + +class TestSimulationStateMachine: + @pytest.fixture(autouse=True) + def _init(self, session): + session["running_simulations"] = {} + session["completed_simulations"] = set() + session["failed_simulations"] = {} + session["simulation_progress"] = {} + + def test_unknown_simulation_status_is_pending(self, session): + from spkmc.web.state import SessionState + + assert SessionState.get_simulation_status("sim_1") == "pending" + + def test_added_simulation_is_running(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_simulation("sim_1", {"run_id": "r1"}) + assert SessionState.is_simulation_running("sim_1") is True + assert SessionState.get_simulation_status("sim_1") == "running" + + def test_completed_simulation_is_removed_from_running(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_simulation("sim_1", {"run_id": "r1"}) + SessionState.mark_simulation_completed("sim_1") + assert SessionState.is_simulation_running("sim_1") is False + assert SessionState.get_simulation_status("sim_1") == "completed" + + def test_failed_simulation_is_removed_from_running(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_simulation("sim_1", {"run_id": "r1"}) + SessionState.mark_simulation_failed("sim_1", "Out of memory") + assert SessionState.is_simulation_running("sim_1") is False + assert SessionState.get_simulation_status("sim_1") == "failed" + + def test_completed_simulation_is_not_running(self, session): + from spkmc.web.state import SessionState + + session["completed_simulations"].add("sim_1") + assert SessionState.is_simulation_running("sim_1") is False + + def test_two_simulations_transition_independently(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_simulation("sim_a", {}) + SessionState.add_running_simulation("sim_b", {}) + SessionState.mark_simulation_completed("sim_a") + assert SessionState.get_simulation_status("sim_a") == "completed" + assert SessionState.get_simulation_status("sim_b") == "running" + + def test_remove_running_is_idempotent_when_absent(self, session): + from spkmc.web.state import SessionState + + # Must not raise even if the simulation was never added + SessionState.remove_running_simulation("sim_never_added") + SessionState.remove_running_simulation("sim_never_added") + + +# ── Simulation progress ─────────────────────────────────────────────────────── + + +class TestSimulationProgress: + @pytest.fixture(autouse=True) + def _init(self, session): + session["simulation_progress"] = {} + + def test_set_and_get_progress(self, session): + from spkmc.web.state import SessionState + + SessionState.set_simulation_progress("sim_1", 25, 100) + result = SessionState.get_simulation_progress("sim_1") + assert result == {"progress": 25, "total": 100} + + def test_get_progress_returns_none_for_unknown(self, session): + from spkmc.web.state import SessionState + + assert SessionState.get_simulation_progress("unknown") is None + + def test_clear_progress_removes_entry(self, session): + from spkmc.web.state import SessionState + + SessionState.set_simulation_progress("sim_1", 50, 100) + SessionState.clear_simulation_progress("sim_1") + assert SessionState.get_simulation_progress("sim_1") is None + + def test_clear_progress_is_idempotent_for_absent_key(self, session): + from spkmc.web.state import SessionState + + # Must not raise + SessionState.clear_simulation_progress("never_tracked") + + def test_updated_progress_overwrites_previous(self, session): + from spkmc.web.state import SessionState + + SessionState.set_simulation_progress("sim_1", 10, 100) + SessionState.set_simulation_progress("sim_1", 80, 100) + result = SessionState.get_simulation_progress("sim_1") + assert result["progress"] == 80 + + +# ── Analysis state machine ──────────────────────────────────────────────────── + + +class TestAnalysisStateMachine: + @pytest.fixture(autouse=True) + def _init(self, session): + session["running_analyses"] = {} + session["completed_analyses"] = set() + session["failed_analyses"] = {} + + def test_unknown_analysis_status_is_pending(self, session): + from spkmc.web.state import SessionState + + assert SessionState.get_analysis_status("analysis_1") == "pending" + + def test_added_analysis_is_running(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_analysis("analysis_1", {"run_id": "r1"}) + assert SessionState.is_analysis_running("analysis_1") is True + assert SessionState.get_analysis_status("analysis_1") == "running" + + def test_completed_analysis_is_removed_from_running(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_analysis("analysis_1", {"run_id": "r1"}) + SessionState.mark_analysis_completed("analysis_1") + assert SessionState.is_analysis_running("analysis_1") is False + assert SessionState.get_analysis_status("analysis_1") == "completed" + + def test_failed_analysis_is_removed_from_running(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_analysis("analysis_1", {"run_id": "r1"}) + SessionState.mark_analysis_failed("analysis_1", "API key invalid") + assert SessionState.is_analysis_running("analysis_1") is False + assert SessionState.get_analysis_status("analysis_1") == "failed" + + def test_two_analyses_transition_independently(self, session): + from spkmc.web.state import SessionState + + SessionState.add_running_analysis("analysis_a", {}) + SessionState.add_running_analysis("analysis_b", {}) + SessionState.mark_analysis_completed("analysis_a") + assert SessionState.get_analysis_status("analysis_a") == "completed" + assert SessionState.get_analysis_status("analysis_b") == "running" + + +# ── Disk restoration: simulations ───────────────────────────────────────────── + + +class TestRestoreRunningSimulations: + """ + restore_running_simulations reads .spkmc_web/status/*.json files. + monkeypatch.chdir ensures Path(".spkmc_web") resolves inside tmp_path. + _is_pid_alive is patched to control alive/dead process scenarios. + """ + + @pytest.fixture(autouse=True) + def _init(self, session): + session["running_simulations"] = {} + session["completed_simulations"] = set() + session["failed_simulations"] = {} + session["simulation_progress"] = {} + + def _write_status(self, status_dir: Path, data: dict) -> Path: + f = status_dir / f"{data['run_id']}.json" + f.write_text(json.dumps(data)) + return f + + def test_alive_process_is_restored_as_running(self, session, tmp_path, monkeypatch): + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + self._write_status( + status_dir, + { + "run_id": "sim--exp1--baseline--111", + "experiment_name": "exp1", + "scenario_label": "Baseline", + "scenario_normalized": "baseline", + "status": "running", + "pid": 12345, + "progress": 10, + "total": 100, + }, + ) + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: True) + + from spkmc.web.state import SessionState + + SessionState.restore_running_simulations() + + assert SessionState.is_simulation_running("sim--exp1--baseline") is True + + def test_dead_process_with_result_file_is_marked_completed( + self, session, tmp_path, monkeypatch + ): + from unittest.mock import patch + + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + self._write_status( + status_dir, + { + "run_id": "sim--exp1--baseline--222", + "experiment_name": "exp1", + "scenario_label": "Baseline", + "scenario_normalized": "baseline", + "status": "running", + "pid": 99999, + "progress": 0, + "total": 100, + }, + ) + + # Create the expected result file + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + (exp_dir / "baseline.json").touch() + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: False) + + # WebConfig is imported inside the function body: patch at its source module + mock_config = MagicMock() + mock_config.get_experiments_path.return_value = tmp_path / "experiments" + with patch("spkmc.web.config.WebConfig", return_value=mock_config): + from spkmc.web.state import SessionState + + SessionState.restore_running_simulations() + + assert SessionState.get_simulation_status("sim--exp1--baseline") == "completed" + + def test_dead_process_without_result_file_is_marked_failed( + self, session, tmp_path, monkeypatch + ): + from unittest.mock import patch + + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + self._write_status( + status_dir, + { + "run_id": "sim--exp1--scenario_a--333", + "experiment_name": "exp1", + "scenario_label": "Scenario A", + "scenario_normalized": "scenario_a", + "status": "running", + "pid": 99999, + "progress": 0, + "total": 100, + }, + ) + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: False) + + mock_config = MagicMock() + mock_config.get_experiments_path.return_value = tmp_path / "experiments" + with patch("spkmc.web.config.WebConfig", return_value=mock_config): + from spkmc.web.state import SessionState + + SessionState.restore_running_simulations() + + assert SessionState.get_simulation_status("sim--exp1--scenario_a") == "failed" + + def test_corrupted_status_file_is_silently_skipped(self, session, tmp_path, monkeypatch): + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + (status_dir / "bad.json").write_text("{not: valid json}") + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: True) + + from spkmc.web.state import SessionState + + # Must not raise + SessionState.restore_running_simulations() + assert session["running_simulations"] == {} + + def test_completed_status_files_are_ignored(self, session, tmp_path, monkeypatch): + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + self._write_status( + status_dir, + { + "run_id": "sim--exp1--baseline--444", + "experiment_name": "exp1", + "scenario_normalized": "baseline", + "status": "completed", # already done + "pid": 12345, + "progress": 100, + "total": 100, + }, + ) + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: True) + + from spkmc.web.state import SessionState + + SessionState.restore_running_simulations() + assert session["running_simulations"] == {} + + def test_missing_status_dir_does_not_raise(self, session, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + # .spkmc_web/status does NOT exist + + from spkmc.web.state import SessionState + + # Must return silently + SessionState.restore_running_simulations() + + +# ── Disk restoration: analyses ──────────────────────────────────────────────── + + +class TestRestoreRunningAnalyses: + @pytest.fixture(autouse=True) + def _init(self, session): + session["running_analyses"] = {} + session["completed_analyses"] = set() + session["failed_analyses"] = {} + + def _write_status(self, status_dir: Path, data: dict) -> Path: + f = status_dir / f"{data['run_id']}.json" + f.write_text(json.dumps(data)) + return f + + def test_alive_experiment_analysis_is_restored(self, session, tmp_path, monkeypatch): + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + self._write_status( + status_dir, + { + "run_id": "exp_analysis--exp1--555", + "type": "analysis", + "analysis_type": "experiment", + "experiment_name": "exp1", + "scenario_normalized": "", + "status": "running", + "pid": 12345, + }, + ) + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: True) + + from spkmc.web.state import SessionState + + SessionState.restore_running_analyses() + assert SessionState.is_analysis_running("exp_analysis--exp1") is True + + def test_non_analysis_type_status_files_are_ignored(self, session, tmp_path, monkeypatch): + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + # Write a simulation status file (type is absent / not "analysis") + self._write_status( + status_dir, + { + "run_id": "sim--exp1--baseline--666", + "experiment_name": "exp1", + "scenario_normalized": "baseline", + "status": "running", + "pid": 12345, + }, + ) + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: True) + + from spkmc.web.state import SessionState + + SessionState.restore_running_analyses() + assert session["running_analyses"] == {} + + def test_dead_analysis_with_result_file_is_marked_completed( + self, session, tmp_path, monkeypatch + ): + from unittest.mock import patch + + import spkmc.web.state as state_module + + monkeypatch.chdir(tmp_path) + status_dir = tmp_path / ".spkmc_web" / "status" + status_dir.mkdir(parents=True) + + self._write_status( + status_dir, + { + "run_id": "exp_analysis--exp1--777", + "type": "analysis", + "analysis_type": "experiment", + "experiment_name": "exp1", + "scenario_normalized": "", + "status": "running", + "pid": 99999, + }, + ) + + # Create the expected analysis output file + exp_dir = tmp_path / "experiments" / "exp1" + exp_dir.mkdir(parents=True) + (exp_dir / "analysis.md").touch() + + monkeypatch.setattr(state_module, "_is_pid_alive", lambda pid: False) + + mock_config = MagicMock() + mock_config.get_experiments_path.return_value = tmp_path / "experiments" + with patch("spkmc.web.config.WebConfig", return_value=mock_config): + from spkmc.web.state import SessionState + + SessionState.restore_running_analyses() + + assert SessionState.get_analysis_status("exp_analysis--exp1") == "completed"