From 3150311b7cf8abcdb4f162d1ae45c135a1230e6d Mon Sep 17 00:00:00 2001 From: renardeinside Date: Sun, 1 Mar 2026 16:15:01 +0100 Subject: [PATCH 1/5] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20rename=20c?= =?UTF-8?q?rates/agent=20to=20crates/tracing,=20flux=20to=20collector=20(#?= =?UTF-8?q?130)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../memories/healthcheck-redirect-fix.md | 2 +- .../memory/memories/logging-refactor.md | 18 ++-- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 4 +- Cargo.lock | 40 ++++---- Cargo.toml | 4 +- crates/cli/src/dev/logs.rs | 4 +- crates/cli/src/lib.rs | 22 ++--- crates/cli/src/{flux => tracing_cmd}/mod.rs | 0 crates/cli/src/{flux => tracing_cmd}/start.rs | 18 ++-- crates/cli/src/{flux => tracing_cmd}/stop.rs | 17 ++-- crates/common/src/lib.rs | 44 ++++----- crates/common/src/storage.rs | 12 +-- crates/core/Cargo.toml | 2 +- crates/core/src/{flux => collector}/mod.rs | 95 ++++++++++--------- crates/core/src/dev/backend.rs | 8 +- crates/core/src/dev/embedded_db.rs | 8 +- crates/core/src/dev/frontend.rs | 10 +- crates/core/src/dev/logging.rs | 2 +- crates/core/src/dev/mod.rs | 2 +- crates/core/src/dev/otel.rs | 23 +++-- crates/core/src/dev/server.rs | 18 ++-- crates/core/src/lib.rs | 8 +- crates/core/src/ops/dev.rs | 8 +- .../core/src/{agent.rs => tracing_binary.rs} | 2 +- crates/core/src/tracing_init.rs | 2 +- crates/{agent => tracing}/Cargo.toml | 2 +- crates/{agent => tracing}/src/lib.rs | 0 crates/{agent => tracing}/src/main.rs | 4 +- crates/{agent => tracing}/src/server.rs | 14 +-- docs/content/docs/reference/cli.mdx | 20 ++-- scripts/build_agent.py | 2 +- src/apx/assets/entrypoint.ts | 6 +- 33 files changed, 218 insertions(+), 205 deletions(-) rename crates/cli/src/{flux => tracing_cmd}/mod.rs (100%) rename crates/cli/src/{flux => tracing_cmd}/start.rs (57%) rename crates/cli/src/{flux => tracing_cmd}/stop.rs (54%) rename crates/core/src/{flux => collector}/mod.rs (59%) rename crates/core/src/{agent.rs => tracing_binary.rs} (95%) rename crates/{agent => tracing}/Cargo.toml (96%) rename crates/{agent => tracing}/src/lib.rs (100%) rename crates/{agent => tracing}/src/main.rs (91%) rename crates/{agent => tracing}/src/server.rs (97%) diff --git a/.claude/skills/memory/memories/healthcheck-redirect-fix.md b/.claude/skills/memory/memories/healthcheck-redirect-fix.md index e79c3944..b51d7607 100644 --- a/.claude/skills/memory/memories/healthcheck-redirect-fix.md +++ b/.claude/skills/memory/memories/healthcheck-redirect-fix.md @@ -20,7 +20,7 @@ Fixed a critical bug where newly generated projects failed to start because the 1. **Redirect policy** (`crates/core/src/dev/process.rs`): Added `.redirect(reqwest::redirect::Policy::none())` to `HEALTH_CLIENT`. A 302 is a valid HTTP response proving the server is listening -- no need to follow it. -2. **Log path canonicalization** (`crates/core/src/dev/process.rs`): `ProcessManager::new()` now canonicalizes `app_dir` so that `forward_log_to_flux()` uses the same path as `StartupLogStreamer`. On macOS, `/tmp/foo` canonicalizes to `/private/tmp/foo`, and the mismatch caused startup logs to appear empty. +2. **Log path canonicalization** (`crates/core/src/dev/process.rs`): `ProcessManager::new()` now canonicalizes `app_dir` so that `forward_log_to_collector()` uses the same path as `StartupLogStreamer`. On macOS, `/tmp/foo` canonicalizes to `/private/tmp/foo`, and the mismatch caused startup logs to appear empty. 3. **Module extraction**: Moved `wait_for_healthy_with_logs()` from `ops/dev.rs` to new `ops/healthcheck.rs` for better separation of concerns. diff --git a/.claude/skills/memory/memories/logging-refactor.md b/.claude/skills/memory/memories/logging-refactor.md index 49222346..e4f35546 100644 --- a/.claude/skills/memory/memories/logging-refactor.md +++ b/.claude/skills/memory/memories/logging-refactor.md @@ -1,14 +1,14 @@ --- name: logging-refactor created: 2026-02-21 -tags: [logging, refactor, apx_common, format, severity, flux, agent] +tags: [logging, refactor, apx_common, format, severity, collector, agent] --- # Logging System Refactoring ## Summary -Centralized all log formatting, timestamp handling, severity parsing, and skip-filtering into `apx_common::format`. Fixed stderr severity mapping (uvicorn INFO was stored as ERROR), added agent version checking in flux daemon, and removed dead code. +Centralized all log formatting, timestamp handling, severity parsing, and skip-filtering into `apx_common::format`. Fixed stderr severity mapping (uvicorn INFO was stored as ERROR), added agent version checking in collector daemon, and removed dead code. ## Context @@ -16,7 +16,7 @@ The APX logging system had grown organically across 6+ crates with inconsistent - 5 different timestamp formats (UTC vs Local, with/without date) - All uvicorn stderr was tagged as ERROR (severity 17) — even normal `INFO: Uvicorn running...` - Duplicate `should_skip_log()` implementations with different filter sets -- No agent version check — old flux daemon kept running after apx upgrade +- No agent version check — old collector daemon kept running after apx upgrade - Dead code: `APX_COLLECT_LOGS` env var, `installed_version()` function ### Key decisions @@ -24,7 +24,7 @@ The APX logging system had grown organically across 6+ crates with inconsistent 1. **All user-facing timestamps use Local timezone** — `format_timestamp()` always converts to local. The old `format_log_line` in process.rs used UTC; now uses Local. 2. **Single `format_startup_log` uses full timestamp** (`YYYY-MM-DD HH:MM:SS.mmm`), not short (`HH:MM:SS.mmm`). User explicitly requested this. 3. **Stderr severity parsing** via `parse_python_severity()` — matches patterns like `INFO /`, `WARNING:`, `ERROR /` etc. Defaults to `"INFO"` since most uvicorn stderr is informational. -4. **Version check via lock file** — `FluxLock` now has `version: Option` (with `#[serde(default)]` for backward compat). `ensure_running()` compares lock version to `apx_common::VERSION`. +4. **Version check via lock file** — `CollectorLock` now has `version: Option` (with `#[serde(default)]` for backward compat). `ensure_running()` compares lock version to `apx_common::VERSION`. 5. **Unified skip filtering** — `should_skip_log_message(message: &str)` is the raw-string entry point; `should_skip_log(&LogRecord)` wraps it with severity-based _core filtering. Password patterns (`WITH PASSWORD`, `PASSWORD '`) now filtered in both paths. ## Diagram @@ -65,16 +65,16 @@ graph TD - `crates/common/src/format.rs` — **NEW** centralized formatting module - `crates/common/src/storage.rs` — unified `should_skip_log` + `should_skip_log_message` -- `crates/common/src/lib.rs` — `VERSION` const, `FluxLock.version` field, `pub mod format` +- `crates/common/src/lib.rs` — `VERSION` const, `CollectorLock.version` field, `pub mod format` - `crates/core/src/dev/process.rs` — uses `parse_python_severity` for stderr, `format_process_log_line` for timestamps - `crates/core/src/dev/otel.rs` — removed duplicate `severity_to_number` and `should_skip_log` - `crates/core/src/ops/logs.rs` — removed local format fns, imports from `apx_common::format` - `crates/core/src/ops/startup_logs.rs` — removed local format fns, imports from `apx_common::format` -- `crates/core/src/flux/mod.rs` — `ensure_running()` checks version, restarts on mismatch +- `crates/core/src/collector/mod.rs` — `ensure_running()` checks version, restarts on mismatch - `crates/cli/src/dev/logs.rs` — updated imports to `apx_common::format` -- `crates/agent/src/main.rs` — tracing uses `APX_LOG` env, stderr writer, file/line info -- `crates/agent/src/server.rs` — replaced `eprintln!` with `info!()` -- `crates/core/src/agent.rs` — removed dead `installed_version()` +- `crates/tracing/src/main.rs` — tracing uses `APX_LOG` env, stderr writer, file/line info +- `crates/tracing/src/server.rs` — replaced `eprintln!` with `info!()` +- `crates/core/src/tracing_binary.rs` — removed dead `installed_version()` - `crates/core/src/ops/dev.rs` — removed dead `APX_COLLECT_LOGS` env var ## Notes diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 338e31c7..2e145f5a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,7 +78,7 @@ jobs: uses: actions/cache@v5 with: path: .bins/agent - key: apx-agent-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('crates/agent/**', 'crates/common/**', 'Cargo.lock') }} + key: apx-agent-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('crates/tracing/**', 'crates/common/**', 'Cargo.lock') }} - name: Build apx-agent if: steps.cache-agent.outputs.cache-hit != 'true' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ab7d4df6..c001bd32 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -85,7 +85,7 @@ jobs: uses: actions/cache@v5 with: path: .bins/agent - key: apx-agent-${{ runner.os }}-${{ matrix.target }}-${{ hashFiles('crates/agent/**', 'crates/common/**', 'Cargo.lock') }} + key: apx-agent-${{ runner.os }}-${{ matrix.target }}-${{ hashFiles('crates/tracing/**', 'crates/common/**', 'Cargo.lock') }} # Map matrix target to Rust target for agent build - name: Build apx-agent (Linux x86_64) @@ -223,7 +223,7 @@ jobs: uses: actions/cache@v5 with: path: .bins/agent - key: apx-agent-${{ runner.os }}-${{ matrix.rust_target }}-${{ hashFiles('crates/agent/**', 'crates/common/**', 'Cargo.lock') }} + key: apx-agent-${{ runner.os }}-${{ matrix.rust_target }}-${{ hashFiles('crates/tracing/**', 'crates/common/**', 'Cargo.lock') }} - name: Build apx-agent (native) if: steps.cache-agent.outputs.cache-hit != 'true' && !matrix.cross diff --git a/Cargo.lock b/Cargo.lock index ed781e66..21d6c9a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,25 +135,6 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" -[[package]] -name = "apx-agent" -version = "0.3.8" -dependencies = [ - "apx-common", - "apx-db", - "axum", - "chrono", - "clap", - "hex", - "opentelemetry-proto 0.31.0", - "prost 0.14.3", - "serde", - "serde_json", - "tokio", - "tracing", - "tracing-subscriber", -] - [[package]] name = "apx-bin" version = "0.3.8" @@ -206,10 +187,10 @@ dependencies = [ name = "apx-core" version = "0.3.8" dependencies = [ - "apx-agent", "apx-common", "apx-databricks-sdk", "apx-db", + "apx-tracing", "axum", "biome_css_parser", "biome_css_syntax", @@ -317,6 +298,25 @@ dependencies = [ "toml 0.8.23", ] +[[package]] +name = "apx-tracing" +version = "0.3.8" +dependencies = [ + "apx-common", + "apx-db", + "axum", + "chrono", + "clap", + "hex", + "opentelemetry-proto 0.31.0", + "prost 0.14.3", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "arbitrary" version = "1.4.2" diff --git a/Cargo.toml b/Cargo.toml index 1bd49829..fe00b06f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/common", "crates/agent", "crates/studio", "crates/core", "crates/mcp", "crates/cli", "crates/databricks_sdk", "crates/apx", "crates/db"] +members = ["crates/common", "crates/tracing", "crates/studio", "crates/core", "crates/mcp", "crates/cli", "crates/databricks_sdk", "crates/apx", "crates/db"] resolver = "2" [workspace.package] @@ -16,7 +16,7 @@ categories = ["development-tools"] [workspace.dependencies] # Internal crates apx-common = { path = "crates/common" } -apx-agent = { path = "crates/agent" } +apx-tracing = { path = "crates/tracing" } apx-core = { path = "crates/core" } apx-mcp = { path = "crates/mcp" } apx-cli = { path = "crates/cli" } diff --git a/crates/cli/src/dev/logs.rs b/crates/cli/src/dev/logs.rs index 4547ca45..a9c136c5 100644 --- a/crates/cli/src/dev/logs.rs +++ b/crates/cli/src/dev/logs.rs @@ -1,6 +1,6 @@ -//! Log viewer for APX dev server using flux SQLite storage. +//! Log viewer for APX dev server using the OTEL collector SQLite storage. //! -//! Reads logs from ~/.apx/logs/db which is maintained by flux. +//! Reads logs from ~/.apx/logs/db which is maintained by the collector. use clap::Args; use std::path::PathBuf; diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index 14e73b3d..a273aeb6 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -11,12 +11,12 @@ pub(crate) mod common; pub(crate) mod components; pub(crate) mod dev; pub(crate) mod feedback; -pub(crate) mod flux; pub(crate) mod frontend; pub(crate) mod info; /// Project initialization wizard and template rendering. pub mod init; pub(crate) mod skill; +pub(crate) mod tracing_cmd; pub(crate) mod upgrade; use clap::{CommandFactory, Parser, Subcommand}; @@ -52,9 +52,9 @@ enum Commands { /// 🚀 Development server commands #[command(subcommand)] Dev(DevCommands), - /// 📊 Flux OTEL collector commands + /// 📊 Tracing OTEL collector commands #[command(subcommand)] - Flux(FluxCommands), + Tracing(TracingCommands), /// 🧠 Skill commands (Claude Code integration) #[command(subcommand)] Skill(SkillCommands), @@ -111,11 +111,11 @@ enum SkillCommands { } #[derive(Subcommand)] -enum FluxCommands { - /// Start the flux OTEL collector daemon - Start(flux::start::StartArgs), - /// Stop the flux OTEL collector daemon - Stop(flux::stop::StopArgs), +enum TracingCommands { + /// Start the tracing OTEL collector daemon + Start(tracing_cmd::start::StartArgs), + /// Stop the tracing OTEL collector daemon + Stop(tracing_cmd::stop::StopArgs), } /// Standard Unix exit code for processes terminated by SIGINT (128 + signal number 2). @@ -202,9 +202,9 @@ async fn run_command(args: Vec) -> i32 { DevCommands::Apply(args) => dev::apply::run(args).await, DevCommands::InternalRunServer(args) => dev::__internal_run_server::run(args).await, }, - Some(Commands::Flux(flux_cmd)) => match flux_cmd { - FluxCommands::Start(args) => flux::start::run(args).await, - FluxCommands::Stop(args) => flux::stop::run(args).await, + Some(Commands::Tracing(cmd)) => match cmd { + TracingCommands::Start(args) => tracing_cmd::start::run(args).await, + TracingCommands::Stop(args) => tracing_cmd::stop::run(args).await, }, Some(Commands::Skill(skill_cmd)) => match skill_cmd { SkillCommands::Install(args) => skill::install::run(args).await, diff --git a/crates/cli/src/flux/mod.rs b/crates/cli/src/tracing_cmd/mod.rs similarity index 100% rename from crates/cli/src/flux/mod.rs rename to crates/cli/src/tracing_cmd/mod.rs diff --git a/crates/cli/src/flux/start.rs b/crates/cli/src/tracing_cmd/start.rs similarity index 57% rename from crates/cli/src/flux/start.rs rename to crates/cli/src/tracing_cmd/start.rs index 4fb8ca61..35160aee 100644 --- a/crates/cli/src/flux/start.rs +++ b/crates/cli/src/tracing_cmd/start.rs @@ -1,11 +1,11 @@ -//! Start the flux OTEL collector daemon. +//! Start the tracing OTEL collector daemon. use clap::Args; use std::time::Instant; use crate::run_cli_async_helper; +use apx_core::collector; use apx_core::common::{format_elapsed_ms, spinner}; -use apx_core::flux; #[derive(Args, Debug, Clone)] pub struct StartArgs {} @@ -16,23 +16,23 @@ pub async fn run(_args: StartArgs) -> i32 { async fn run_inner() -> Result<(), String> { // Check if already running - if flux::is_running() { + if collector::is_running() { println!( - "✅ Flux already running at http://127.0.0.1:{}\n", - flux::FLUX_PORT + "✅ Tracing collector already running at http://127.0.0.1:{}\n", + collector::COLLECTOR_PORT ); return Ok(()); } let start_time = Instant::now(); - let start_spinner = spinner("Starting flux daemon..."); + let start_spinner = spinner("Starting tracing collector..."); - flux::start()?; + collector::start()?; start_spinner.finish_and_clear(); println!( - "✅ Flux started at http://127.0.0.1:{} in {}\n", - flux::FLUX_PORT, + "✅ Tracing collector started at http://127.0.0.1:{} in {}\n", + collector::COLLECTOR_PORT, format_elapsed_ms(start_time) ); Ok(()) diff --git a/crates/cli/src/flux/stop.rs b/crates/cli/src/tracing_cmd/stop.rs similarity index 54% rename from crates/cli/src/flux/stop.rs rename to crates/cli/src/tracing_cmd/stop.rs index 1f8ab690..9179d649 100644 --- a/crates/cli/src/flux/stop.rs +++ b/crates/cli/src/tracing_cmd/stop.rs @@ -1,11 +1,11 @@ -//! Stop the flux OTEL collector daemon. +//! Stop the tracing OTEL collector daemon. use clap::Args; use std::time::Instant; use crate::run_cli_async_helper; +use apx_core::collector; use apx_core::common::{format_elapsed_ms, spinner}; -use apx_core::flux; #[derive(Args, Debug, Clone)] pub struct StopArgs {} @@ -15,17 +15,20 @@ pub async fn run(_args: StopArgs) -> i32 { } async fn run_inner() -> Result<(), String> { - if !flux::is_running() { - println!("⚠️ Flux is not running\n"); + if !collector::is_running() { + println!("⚠️ Tracing collector is not running\n"); return Ok(()); } let start_time = Instant::now(); - let stop_spinner = spinner("Stopping flux daemon..."); + let stop_spinner = spinner("Stopping tracing collector..."); - flux::stop()?; + collector::stop()?; stop_spinner.finish_and_clear(); - println!("✅ Flux stopped in {}\n", format_elapsed_ms(start_time)); + println!( + "✅ Tracing collector stopped in {}\n", + format_elapsed_ms(start_time) + ); Ok(()) } diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 4b1f12bc..53f106f2 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -1,4 +1,4 @@ -//! Shared types and utilities for APX flux system +//! Shared types and utilities for APX tracing collector //! //! This crate contains shared functionality used by both the main `apx` CLI //! and the standalone `apx-agent` binary. @@ -9,7 +9,7 @@ pub mod bundles; pub mod format; /// Network host constants for binding, client connections, and browser URLs. pub mod hosts; -/// Pure types and logic for flux OTEL log records, filtering, and aggregation. +/// Pure types and logic for OTEL log records, filtering, and aggregation. pub mod storage; use serde::{Deserialize, Serialize}; @@ -20,15 +20,15 @@ use std::time::Duration; // Re-export commonly used types pub use storage::{ - AggregatedRecord, LogAggregator, LogRecord, ServiceKind, flux_dir, get_aggregation_key, + AggregatedRecord, LogAggregator, LogRecord, ServiceKind, collector_dir, get_aggregation_key, should_skip_log, should_skip_log_message, source_label, }; /// Version of the apx-common crate, used for agent version matching. pub const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Flux port for OTLP HTTP receiver -pub const FLUX_PORT: u16 = 11111; +/// Collector port for OTLP HTTP receiver +pub const COLLECTOR_PORT: u16 = 11111; /// Lock filename const LOCK_FILENAME: &str = "agent.lock"; @@ -38,7 +38,7 @@ const LOG_FILENAME: &str = "agent.log"; /// Lock file contents. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FluxLock { +pub struct CollectorLock { /// OS process ID of the running agent. pub pid: u32, /// TCP port the agent listens on. @@ -50,7 +50,7 @@ pub struct FluxLock { pub version: Option, } -impl FluxLock { +impl CollectorLock { /// Create a new lock for the current process. #[must_use] pub fn new(pid: u32) -> Self { @@ -61,7 +61,7 @@ impl FluxLock { Self { pid, - port: FLUX_PORT, + port: COLLECTOR_PORT, started_at, version: Some(VERSION.to_string()), } @@ -74,7 +74,7 @@ impl FluxLock { /// /// Returns an error if the home directory cannot be determined. pub fn lock_path() -> Result { - Ok(flux_dir()?.join(LOCK_FILENAME)) + Ok(collector_dir()?.join(LOCK_FILENAME)) } /// Get the daemon log file path (`~/.apx/logs/agent.log`). @@ -83,7 +83,7 @@ pub fn lock_path() -> Result { /// /// Returns an error if the home directory cannot be determined. pub fn log_path() -> Result { - Ok(flux_dir()?.join(LOG_FILENAME)) + Ok(collector_dir()?.join(LOG_FILENAME)) } /// Read the lock file if it exists. @@ -91,17 +91,17 @@ pub fn log_path() -> Result { /// # Errors /// /// Returns an error if the lock file exists but cannot be read or parsed. -pub fn read_lock() -> Result, String> { +pub fn read_lock() -> Result, String> { let path = lock_path()?; if !path.exists() { return Ok(None); } - let contents = - fs::read_to_string(&path).map_err(|e| format!("Failed to read flux lock file: {e}"))?; + let contents = fs::read_to_string(&path) + .map_err(|e| format!("Failed to read collector lock file: {e}"))?; - let lock: FluxLock = serde_json::from_str(&contents) - .map_err(|e| format!("Failed to parse flux lock file: {e}"))?; + let lock: CollectorLock = serde_json::from_str(&contents) + .map_err(|e| format!("Failed to parse collector lock file: {e}"))?; Ok(Some(lock)) } @@ -111,7 +111,7 @@ pub fn read_lock() -> Result, String> { /// # Errors /// /// Returns an error if the lock file cannot be written. -pub fn write_lock(lock: &FluxLock) -> Result<(), String> { +pub fn write_lock(lock: &CollectorLock) -> Result<(), String> { let path = lock_path()?; // Ensure parent directory exists @@ -122,7 +122,7 @@ pub fn write_lock(lock: &FluxLock) -> Result<(), String> { let contents = serde_json::to_string_pretty(lock).map_err(|e| format!("Failed to serialize lock: {e}"))?; - fs::write(&path, contents).map_err(|e| format!("Failed to write flux lock file: {e}")) + fs::write(&path, contents).map_err(|e| format!("Failed to write collector lock file: {e}")) } /// Remove the lock file. @@ -133,20 +133,20 @@ pub fn write_lock(lock: &FluxLock) -> Result<(), String> { pub fn remove_lock() -> Result<(), String> { let path = lock_path()?; if path.exists() { - fs::remove_file(&path).map_err(|e| format!("Failed to remove flux lock file: {e}"))?; + fs::remove_file(&path).map_err(|e| format!("Failed to remove collector lock file: {e}"))?; } Ok(()) } -/// Check if flux is accepting connections at the given port. +/// Check if the collector is accepting connections at the given port. #[must_use] -pub fn is_flux_listening(port: u16) -> bool { +pub fn is_collector_listening(port: u16) -> bool { let addr = std::net::SocketAddr::from((hosts::CLIENT_HOST_OCTETS, port)); TcpStream::connect_timeout(&addr, Duration::from_millis(500)).is_ok() } -/// Check if flux is currently running by testing TCP connectivity. +/// Check if the collector is currently running by testing TCP connectivity. #[must_use] pub fn is_running() -> bool { - is_flux_listening(FLUX_PORT) + is_collector_listening(COLLECTOR_PORT) } diff --git a/crates/common/src/storage.rs b/crates/common/src/storage.rs index 98274815..65bd1df4 100644 --- a/crates/common/src/storage.rs +++ b/crates/common/src/storage.rs @@ -1,4 +1,4 @@ -//! Pure types and logic for flux OTEL logs. +//! Pure types and logic for OTEL log collector records. //! //! This module contains log record types, filtering, and aggregation logic. //! Database operations have been moved to the `apx-db` crate. @@ -6,8 +6,8 @@ use std::collections::HashMap; use std::path::PathBuf; -/// Directory for flux data (~/.apx/logs) -const FLUX_DIR: &str = ".apx/logs"; +/// Directory for collector data (~/.apx/logs) +const COLLECTOR_DIR: &str = ".apx/logs"; /// A log record to be inserted into the database. #[derive(Debug, Clone)] @@ -313,12 +313,12 @@ impl LogAggregator { } } -/// Get the flux directory path (`~/.apx/logs`). +/// Get the collector directory path (`~/.apx/logs`). /// /// # Errors /// /// Returns an error if the home directory cannot be determined. -pub fn flux_dir() -> Result { +pub fn collector_dir() -> Result { let home = dirs::home_dir().ok_or("Could not determine home directory")?; - Ok(home.join(FLUX_DIR)) + Ok(home.join(COLLECTOR_DIR)) } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 8aa4d7a2..d8e293e4 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -15,7 +15,7 @@ workspace = true [dependencies] apx-common.workspace = true -apx-agent.workspace = true +apx-tracing.workspace = true apx-databricks-sdk.workspace = true dirs.workspace = true indicatif.workspace = true diff --git a/crates/core/src/flux/mod.rs b/crates/core/src/collector/mod.rs similarity index 59% rename from crates/core/src/flux/mod.rs rename to crates/core/src/collector/mod.rs index 72904af5..c0432ae1 100644 --- a/crates/core/src/flux/mod.rs +++ b/crates/core/src/collector/mod.rs @@ -1,4 +1,4 @@ -//! Flux: Native Rust OTEL Collector +//! Native Rust OTEL Log Collector //! //! This module provides a native OpenTelemetry log collector that replaces //! the external otelcol binary. It stores logs in SQLite, runs as a detached @@ -7,18 +7,18 @@ //! ## Usage //! //! ```ignore -//! use apx::flux; +//! use apx::collector; //! -//! // Ensure flux is running (starts if not) -//! flux::ensure_running()?; +//! // Ensure the collector is running (starts if not) +//! collector::ensure_running()?; //! -//! // Check if flux is running -//! if flux::is_running() { -//! println!("Flux is running"); +//! // Check if the collector is running +//! if collector::is_running() { +//! println!("Collector is running"); //! } //! -//! // Stop flux -//! flux::stop()?; +//! // Stop the collector +//! collector::stop()?; //! ``` use std::fs; @@ -28,15 +28,15 @@ use tracing::{debug, info, warn}; // Re-export from apx-common crate pub use apx_common::{ - FLUX_PORT, FluxLock, flux_dir, is_flux_listening, is_running, log_path, read_lock, remove_lock, - write_lock, + COLLECTOR_PORT, CollectorLock, collector_dir, is_collector_listening, is_running, log_path, + read_lock, remove_lock, write_lock, }; // ============================================================================ // Daemon management // ============================================================================ -/// Spawn flux as a detached daemon process using the apx-agent binary. +/// Spawn the collector as a detached daemon process using the apx-agent binary. fn spawn_daemon() -> Result { let log_file = log_path()?; @@ -57,9 +57,9 @@ fn spawn_daemon() -> Result { .map_err(|e| format!("Failed to clone log file handle: {e}"))?; // Get the agent binary path (installs if needed) - let agent_path = crate::agent::ensure_installed()?; + let agent_path = crate::tracing_binary::ensure_installed()?; - debug!("Spawning flux daemon: {}", agent_path.display()); + debug!("Spawning collector daemon: {}", agent_path.display()); let child = std::process::Command::new(&agent_path) .stdin(Stdio::null()) @@ -69,130 +69,131 @@ fn spawn_daemon() -> Result { .map_err(|e| format!("Failed to spawn agent: {e}"))?; let pid = child.id(); - info!("Spawned flux daemon with pid={}", pid); + info!("Spawned collector daemon with pid={}", pid); Ok(pid) } -/// Wait for flux to start accepting connections. +/// Wait for the collector to start accepting connections. fn wait_for_ready(timeout_ms: u64) -> Result<(), String> { let start = Instant::now(); let timeout = Duration::from_millis(timeout_ms); while start.elapsed() < timeout { - let addr = std::net::SocketAddr::from((apx_common::hosts::CLIENT_HOST_OCTETS, FLUX_PORT)); + let addr = + std::net::SocketAddr::from((apx_common::hosts::CLIENT_HOST_OCTETS, COLLECTOR_PORT)); if std::net::TcpStream::connect_timeout(&addr, Duration::from_millis(100)).is_ok() { return Ok(()); } std::thread::sleep(Duration::from_millis(100)); } - Err(format!("Flux did not start within {timeout_ms}ms")) + Err(format!("Collector did not start within {timeout_ms}ms")) } -/// Start flux daemon. +/// Start the collector daemon. /// -/// Spawns a new flux daemon process if one is not already running. -/// Returns an error if flux cannot be started. +/// Spawns a new collector daemon process if one is not already running. +/// Returns an error if the collector cannot be started. pub fn start() -> Result<(), String> { - // Create the flux directory if it doesn't exist - let dir = flux_dir()?; - fs::create_dir_all(&dir).map_err(|e| format!("Failed to create flux directory: {e}"))?; + // Create the collector directory if it doesn't exist + let dir = collector_dir()?; + fs::create_dir_all(&dir).map_err(|e| format!("Failed to create collector directory: {e}"))?; // Check if already running via lock file if let Some(lock) = read_lock()? { - if is_flux_listening(lock.port) { + if is_collector_listening(lock.port) { debug!( - "Flux already running (pid={}, port={})", + "Collector already running (pid={}, port={})", lock.pid, lock.port ); return Ok(()); } // Stale lock - clean up - debug!("Stale flux lock found, cleaning up"); + debug!("Stale collector lock found, cleaning up"); remove_lock()?; } // Check if something else is using the port - if is_flux_listening(FLUX_PORT) { + if is_collector_listening(COLLECTOR_PORT) { warn!( - "Port {} is in use but no valid lock file found. Assuming flux is running.", - FLUX_PORT + "Port {} is in use but no valid lock file found. Assuming collector is running.", + COLLECTOR_PORT ); return Ok(()); } // Start the daemon - info!("Starting flux daemon on port {}", FLUX_PORT); + info!("Starting collector daemon on port {}", COLLECTOR_PORT); let pid = spawn_daemon()?; // Wait for it to be ready wait_for_ready(5000)?; // Write lock file - let lock = FluxLock::new(pid); + let lock = CollectorLock::new(pid); write_lock(&lock)?; - info!("Flux daemon started successfully (pid={})", pid); + info!("Collector daemon started successfully (pid={})", pid); Ok(()) } -/// Ensure flux is running, starting it if necessary. +/// Ensure the collector is running, starting it if necessary. /// /// This is the main API for callers like `apx dev start` that need to ensure -/// flux is running before proceeding. Also checks that the running daemon +/// the collector is running before proceeding. Also checks that the running daemon /// matches the current apx version — restarts on mismatch. pub fn ensure_running() -> Result<(), String> { if is_running() { // Check version from lock file if let Some(lock) = read_lock()? { if lock.version.as_deref() == Some(apx_common::VERSION) { - debug!("Flux is already running (version matches)"); + debug!("Collector is already running (version matches)"); return Ok(()); } // Version mismatch or old lock without version — restart info!( - "Flux version mismatch (running: {:?}, expected: {}), restarting", + "Collector version mismatch (running: {:?}, expected: {}), restarting", lock.version, apx_common::VERSION ); stop()?; // Fall through to start() } else { - debug!("Flux is already running (no lock file to check version)"); + debug!("Collector is already running (no lock file to check version)"); return Ok(()); } } start() } -/// Stop flux daemon. +/// Stop the collector daemon. /// -/// Stops the running flux daemon and removes the lock file. +/// Stops the running collector daemon and removes the lock file. pub fn stop() -> Result<(), String> { let Some(lock) = read_lock()? else { - debug!("Flux is not running (no lock file)"); + debug!("Collector is not running (no lock file)"); return Ok(()); }; - if !is_flux_listening(lock.port) { - debug!("Flux is not listening, cleaning up stale lock"); + if !is_collector_listening(lock.port) { + debug!("Collector is not listening, cleaning up stale lock"); remove_lock()?; return Ok(()); } - info!("Stopping flux daemon (pid={})", lock.pid); + info!("Stopping collector daemon (pid={})", lock.pid); // Kill the process tree - if let Err(e) = crate::dev::common::kill_process_tree(lock.pid, "flux-daemon") { - warn!("Failed to kill flux process tree: {}", e); + if let Err(e) = crate::dev::common::kill_process_tree(lock.pid, "collector-daemon") { + warn!("Failed to kill collector process tree: {}", e); } // Wait a bit for the process to exit std::thread::sleep(Duration::from_millis(500)); remove_lock()?; - info!("Flux daemon stopped"); + info!("Collector daemon stopped"); Ok(()) } diff --git a/crates/core/src/dev/backend.rs b/crates/core/src/dev/backend.rs index bbdb514f..8c41157d 100644 --- a/crates/core/src/dev/backend.rs +++ b/crates/core/src/dev/backend.rs @@ -21,7 +21,7 @@ use tracing::{info, warn}; use crate::dev::common::{DevProcess, ProbeResult, http_health_probe, stop_child_tree}; use crate::dev::embedded_db::EmbeddedDb; -use crate::dev::otel::forward_log_to_flux; +use crate::dev::otel::forward_log_to_collector; use crate::dev::token; use crate::dotenv::DotenvFile; use crate::external::uv::UvTool; @@ -264,7 +264,7 @@ impl Backend { // -- private: log forwarding -- - /// Spawn tasks to read stdout/stderr, prefix with source, and forward to flux. + /// Spawn tasks to read stdout/stderr, prefix with source, and forward to the collector. fn attach_log_forwarders(&self, child: &mut Child) { let service_name = format!("{}_app", self.cfg.app_slug); let app_path = self.cfg.app_dir.display().to_string(); @@ -279,7 +279,7 @@ impl Backend { "{}", apx_common::format::format_process_log_line("app", &line) ); - forward_log_to_flux(&line, "INFO", &svc, &path).await; + forward_log_to_collector(&line, "INFO", &svc, &path).await; } }); } @@ -293,7 +293,7 @@ impl Backend { apx_common::format::format_process_log_line("app", &line) ); let severity = apx_common::format::parse_python_severity(&line); - forward_log_to_flux(&line, severity, &service_name, &app_path).await; + forward_log_to_collector(&line, severity, &service_name, &app_path).await; } }); } diff --git a/crates/core/src/dev/embedded_db.rs b/crates/core/src/dev/embedded_db.rs index 94b06e75..22b01ac3 100644 --- a/crates/core/src/dev/embedded_db.rs +++ b/crates/core/src/dev/embedded_db.rs @@ -18,7 +18,7 @@ use tokio::time::{Duration, timeout}; use tracing::{debug, warn}; use crate::dev::common::DevProcess; -use crate::dev::otel::forward_log_to_flux; +use crate::dev::otel::forward_log_to_collector; use crate::dev::token; use crate::external::ExternalTool; use crate::external::bun::Bun; @@ -139,7 +139,7 @@ impl EmbeddedDb { .spawn() .map_err(|err| format!("Failed to start embedded database: {err}"))?; - // Forward stdout/stderr to flux with "db" source prefix + // Forward stdout/stderr to the collector with "db" source prefix let service_name = format!("{app_slug}_db"); let app_path = app_dir.display().to_string(); @@ -154,7 +154,7 @@ impl EmbeddedDb { "{}", apx_common::format::format_process_log_line("db", &line) ); - forward_log_to_flux(&line, "INFO", &svc, &path).await; + forward_log_to_collector(&line, "INFO", &svc, &path).await; } }); } @@ -169,7 +169,7 @@ impl EmbeddedDb { apx_common::format::format_process_log_line("db", &line) ); let severity = apx_common::format::parse_python_severity(&line); - forward_log_to_flux(&line, severity, &service_name, &app_path).await; + forward_log_to_collector(&line, severity, &service_name, &app_path).await; } }); } diff --git a/crates/core/src/dev/frontend.rs b/crates/core/src/dev/frontend.rs index a95ea16a..49be28b3 100644 --- a/crates/core/src/dev/frontend.rs +++ b/crates/core/src/dev/frontend.rs @@ -63,7 +63,7 @@ impl Frontend { /// Spawn the frontend dev server (`apx frontend dev` via uv). /// /// Frontend logs are NOT piped through apx stdout/stderr — the frontend - /// process sends logs directly to flux via OTEL SDK. See entrypoint.ts. + /// process sends logs directly to the collector via OTEL SDK. See entrypoint.ts. pub async fn spawn(&self) -> Result<(), String> { let cmd = self.build_command().await?; let child = cmd.spawn().map_err(String::from)?; @@ -97,10 +97,14 @@ impl Frontend { .env("APX_APP_NAME", &cfg.app_slug) .env("APX_APP_PATH", cfg.app_dir.display().to_string()) .env("APX_FRONTEND_PORT", cfg.frontend_port.to_string()) - // OpenTelemetry configuration — frontend sends logs directly to flux + // OpenTelemetry configuration — frontend sends logs directly to the collector .env( "OTEL_EXPORTER_OTLP_ENDPOINT", - format!("http://{}:{}", CLIENT_HOST, crate::flux::FLUX_PORT), + format!( + "http://{}:{}", + CLIENT_HOST, + crate::collector::COLLECTOR_PORT + ), ) .env(apx_common::hosts::ENV_FRONTEND_HOST, CLIENT_HOST) .env("OTEL_SERVICE_NAME", format!("{}_ui", cfg.app_slug)); diff --git a/crates/core/src/dev/logging.rs b/crates/core/src/dev/logging.rs index c312eb93..1df894f8 100644 --- a/crates/core/src/dev/logging.rs +++ b/crates/core/src/dev/logging.rs @@ -1,4 +1,4 @@ -//! Logging types for browser log forwarding to flux. +//! Logging types for browser log forwarding to the OTEL collector. use serde::Deserialize; diff --git a/crates/core/src/dev/mod.rs b/crates/core/src/dev/mod.rs index 177c0c88..c5f53f55 100644 --- a/crates/core/src/dev/mod.rs +++ b/crates/core/src/dev/mod.rs @@ -6,7 +6,7 @@ pub mod common; pub(crate) mod embedded_db; pub(crate) mod frontend; pub mod logging; -/// OpenTelemetry log forwarding to the flux collector. +/// OpenTelemetry log forwarding to the OTEL collector. pub mod otel; /// Subprocess management for backend and frontend processes. pub mod process; diff --git a/crates/core/src/dev/otel.rs b/crates/core/src/dev/otel.rs index 68a7f6f9..8931ae7e 100644 --- a/crates/core/src/dev/otel.rs +++ b/crates/core/src/dev/otel.rs @@ -1,7 +1,7 @@ -//! OTEL utilities for sending logs to flux. +//! OTEL utilities for sending logs to the collector. //! //! This module provides shared functionality for building and sending OTLP log payloads -//! to the flux collector. Used by both subprocess log forwarding and browser log forwarding. +//! to the OTEL collector. Used by both subprocess log forwarding and browser log forwarding. use std::path::Path; use std::sync::LazyLock; @@ -10,11 +10,11 @@ use std::time::Duration; use apx_common::format::severity_to_number; use apx_common::hosts::CLIENT_HOST; -use crate::flux::FLUX_PORT; +use crate::collector::COLLECTOR_PORT; -/// Shared HTTP client for forwarding logs to flux. +/// Shared HTTP client for forwarding logs to the collector. /// Reused across all calls to avoid creating a new client per log line. -static FLUX_CLIENT: LazyLock = LazyLock::new(|| { +static COLLECTOR_CLIENT: LazyLock = LazyLock::new(|| { reqwest::Client::builder() .timeout(Duration::from_secs(1)) .pool_max_idle_per_host(2) @@ -78,9 +78,14 @@ pub fn build_otlp_log_payload_from_ms( ) } -/// Forward a log line to flux via OTLP HTTP. +/// Forward a log line to the collector via OTLP HTTP. /// This is fire-and-forget; errors are silently ignored to avoid log loops. -pub async fn forward_log_to_flux(message: &str, level: &str, service_name: &str, app_path: &str) { +pub async fn forward_log_to_collector( + message: &str, + level: &str, + service_name: &str, + app_path: &str, +) { // Skip noisy internal logs if apx_common::should_skip_log_message(message) { return; @@ -88,9 +93,9 @@ pub async fn forward_log_to_flux(message: &str, level: &str, service_name: &str, let timestamp_ns = chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0); let payload = build_otlp_log_payload(message, level, timestamp_ns, service_name, app_path); - let endpoint = format!("http://{CLIENT_HOST}:{FLUX_PORT}/v1/logs"); + let endpoint = format!("http://{CLIENT_HOST}:{COLLECTOR_PORT}/v1/logs"); - let _ = FLUX_CLIENT + let _ = COLLECTOR_CLIENT .post(&endpoint) .header("Content-Type", "application/json") .json(&payload) diff --git a/crates/core/src/dev/server.rs b/crates/core/src/dev/server.rs index 18a21161..74c8917d 100644 --- a/crates/core/src/dev/server.rs +++ b/crates/core/src/dev/server.rs @@ -1,4 +1,4 @@ -//! APX dev server with flux-based logging. +//! APX dev server with OTEL collector-based logging. use axum::Json; use axum::Router; @@ -16,6 +16,7 @@ use tracing::{debug, info, warn}; use apx_databricks_sdk::DatabricksClient; use crate::api_generator::start_openapi_watcher; +use crate::collector; use crate::dev::common::{Shutdown, lock_path, remove_lock}; use crate::dev::logging::BrowserLogPayload; use crate::dev::otel::build_otlp_log_payload_from_ms; @@ -23,7 +24,6 @@ use crate::dev::process::ProcessManager; use crate::dev::proxy; use crate::dev::watcher::{PollingWatcher, spawn_polling_watcher}; use crate::dotenv::DotenvFile; -use crate::flux; /// Shared application state for the dev server. #[derive(Clone)] @@ -31,7 +31,7 @@ struct AppState { /// Broadcast sender for shutdown signals - the single authority for shutdown coordination. shutdown_tx: broadcast::Sender, process_manager: Arc, - /// HTTP client for forwarding browser logs to flux + /// HTTP client for forwarding browser logs to the collector http_client: reqwest::Client, /// App directory path for resource attributes app_dir: PathBuf, @@ -75,10 +75,10 @@ pub async fn run_server(config: ServerConfig) -> Result<(), String> { db_port, dev_token, } = config; - // Ensure flux is running for log collection - if let Err(e) = flux::ensure_running() { + // Ensure the collector is running for log collection + if let Err(e) = collector::ensure_running() { warn!( - "Failed to start flux: {}. Logging may not work correctly.", + "Failed to start collector: {}. Logging may not work correctly.", e ); } @@ -418,7 +418,7 @@ async fn browser_logs( message.push_str(&stack); } - // Forward to flux via OTLP HTTP using shared otel module + // Forward to the collector via OTLP HTTP using shared otel module let otlp_payload = build_otlp_log_payload_from_ms( &message, &payload.level, @@ -430,7 +430,7 @@ async fn browser_logs( let endpoint = format!( "http://{}:{}/v1/logs", apx_common::hosts::CLIENT_HOST, - flux::FLUX_PORT + collector::COLLECTOR_PORT ); let result = state .http_client @@ -441,7 +441,7 @@ async fn browser_logs( .await; if let Err(e) = result { - debug!("Failed to forward browser log to flux: {}", e); + debug!("Failed to forward browser log to collector: {}", e); } StatusCode::OK diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 4703b5f3..fa812d13 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -5,12 +5,12 @@ #![deny(clippy::print_stdout)] -/// Agent integration utilities. -pub mod agent; /// OpenAPI spec generation and TypeScript client codegen. pub mod api_generator; /// Global application directory state. pub mod app_state; +/// OTEL log collector integration. +pub mod collector; /// Common types, project metadata, and CLI utilities. pub mod common; /// UI component registry operations (search, add, CSS updates). @@ -27,8 +27,6 @@ pub mod download; pub mod external; /// User feedback issue creation (GitHub). pub mod feedback; -/// Flux log collector integration. -pub mod flux; /// Frontend build and scaffolding utilities. pub mod frontend; /// Python interop (OpenAPI generation, SDK version detection). @@ -41,6 +39,8 @@ pub mod py_edit; pub mod resources; /// Full-text search indexes (component search, SDK docs). pub mod search; +/// Tracing collector binary management. +pub mod tracing_binary; /// Tracing / logging initialization. pub mod tracing_init; diff --git a/crates/core/src/ops/dev.rs b/crates/core/src/ops/dev.rs index f1f14d27..786149ea 100644 --- a/crates/core/src/ops/dev.rs +++ b/crates/core/src/ops/dev.rs @@ -4,6 +4,7 @@ use std::process::Stdio; use std::time::{Duration, Instant}; use crate::app_state::set_app_dir; +use crate::collector; use crate::common::{ OutputMode, emit, ensure_dir, format_elapsed_ms, read_project_metadata, run_preflight_checks, spinner_for_mode, @@ -17,7 +18,6 @@ use crate::dev::common::{ use crate::dev::server::{ServerConfig, run_server}; use crate::dev::token; use crate::external::uv::ApxTool; -use crate::flux; use crate::ops::healthcheck::wait_for_healthy_with_logs; use crate::registry::Registry; use apx_common::hosts::{BIND_HOST, BROWSER_HOST}; @@ -163,7 +163,7 @@ pub struct PreparedServer { pub command_display: String, } -/// Run preflight checks, start flux, allocate a stable port. +/// Run preflight checks, start the collector, allocate a stable port. /// Returns a `PreparedServer` ready for any launch mode. pub async fn prepare_server_launch( app_dir: &Path, @@ -175,8 +175,8 @@ pub async fn prepare_server_launch( emit(mode, "🚀 Starting dev server..."); - if let Err(e) = flux::ensure_running() { - debug!("Failed to start flux: {e}. Logs may not be collected."); + if let Err(e) = collector::ensure_running() { + debug!("Failed to start collector: {e}. Logs may not be collected."); } let mut registry = Registry::load()?; diff --git a/crates/core/src/agent.rs b/crates/core/src/tracing_binary.rs similarity index 95% rename from crates/core/src/agent.rs rename to crates/core/src/tracing_binary.rs index e2d85ecc..1f562466 100644 --- a/crates/core/src/agent.rs +++ b/crates/core/src/tracing_binary.rs @@ -1,4 +1,4 @@ -//! Agent binary management module. +//! Tracing collector binary management module. //! //! The agent binary is embedded in the apx binary via `include_bytes!` and //! extracted to `~/.apx/apx-agent` on first use. Version management is handled diff --git a/crates/core/src/tracing_init.rs b/crates/core/src/tracing_init.rs index 46f8bae9..a87e765c 100644 --- a/crates/core/src/tracing_init.rs +++ b/crates/core/src/tracing_init.rs @@ -156,7 +156,7 @@ fn init_tracing_with_otel( let endpoint = format!( "http://{}:{}/v1/logs", apx_common::hosts::CLIENT_HOST, - crate::flux::FLUX_PORT + crate::collector::COLLECTOR_PORT ); let exporter = opentelemetry_otlp::LogExporter::builder() diff --git a/crates/agent/Cargo.toml b/crates/tracing/Cargo.toml similarity index 96% rename from crates/agent/Cargo.toml rename to crates/tracing/Cargo.toml index 97c1ccc2..0edc3ef2 100644 --- a/crates/agent/Cargo.toml +++ b/crates/tracing/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "apx-agent" +name = "apx-tracing" version = "0.3.8" edition.workspace = true rust-version.workspace = true diff --git a/crates/agent/src/lib.rs b/crates/tracing/src/lib.rs similarity index 100% rename from crates/agent/src/lib.rs rename to crates/tracing/src/lib.rs diff --git a/crates/agent/src/main.rs b/crates/tracing/src/main.rs similarity index 91% rename from crates/agent/src/main.rs rename to crates/tracing/src/main.rs index 7729c70d..03eeb150 100644 --- a/crates/agent/src/main.rs +++ b/crates/tracing/src/main.rs @@ -25,7 +25,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_env("APX_LOG") - .unwrap_or_else(|_| "apx_agent=info".into()), + .unwrap_or_else(|_| "apx_tracing=info".into()), ) .with( tracing_subscriber::fmt::layer() @@ -39,7 +39,7 @@ async fn main() { let _args = Args::parse(); // Run server (default behavior regardless of subcommand) - if let Err(e) = apx_agent::run_server().await { + if let Err(e) = apx_tracing::run_server().await { eprintln!("Error: {e}"); std::process::exit(1); } diff --git a/crates/agent/src/server.rs b/crates/tracing/src/server.rs similarity index 97% rename from crates/agent/src/server.rs rename to crates/tracing/src/server.rs index 118749f6..3f43cfd2 100644 --- a/crates/agent/src/server.rs +++ b/crates/tracing/src/server.rs @@ -1,9 +1,9 @@ -//! OTLP HTTP receiver server for flux. +//! OTLP HTTP receiver server for the tracing collector. //! //! This module implements an Axum HTTP server that receives OpenTelemetry logs //! via OTLP HTTP protocol, supporting both JSON and Protobuf content types. -use apx_common::{FLUX_PORT, LogRecord}; +use apx_common::{COLLECTOR_PORT, LogRecord}; use apx_db::LogsDb; use axum::{ Router, @@ -25,7 +25,7 @@ struct AppState { storage: LogsDb, } -/// Run the flux server (entry point for `apx-agent`). +/// Run the collector server (entry point for `apx-agent`). /// /// This function initializes storage, starts the cleanup scheduler, /// and runs the HTTP server. It blocks forever (or until error). @@ -35,7 +35,7 @@ struct AppState { /// Returns an error if storage initialization fails or the HTTP server /// cannot bind to the configured address. pub async fn run_server() -> Result<(), String> { - info!("Flux daemon starting..."); + info!("Collector daemon starting..."); // Open storage let storage = LogsDb::open() @@ -76,7 +76,7 @@ async fn run_cleanup_loop(storage: LogsDb) { } } -/// Start the flux HTTP server with the given storage. +/// Start the collector HTTP server with the given storage. async fn run_http_server(storage: LogsDb) -> Result<(), String> { let state = AppState { storage }; @@ -85,8 +85,8 @@ async fn run_http_server(storage: LogsDb) -> Result<(), String> { .route("/health", get(health_check)) .with_state(state); - let addr = format!("{}:{FLUX_PORT}", apx_common::hosts::BIND_HOST); - info!("Starting flux OTLP receiver on {}", addr); + let addr = format!("{}:{COLLECTOR_PORT}", apx_common::hosts::BIND_HOST); + info!("Starting OTLP collector on {}", addr); let listener = TcpListener::bind(&addr) .await diff --git a/docs/content/docs/reference/cli.mdx b/docs/content/docs/reference/cli.mdx index 388fd1fe..8dab6427 100644 --- a/docs/content/docs/reference/cli.mdx +++ b/docs/content/docs/reference/cli.mdx @@ -40,7 +40,7 @@ apx | `skill` | Install apx skill files | | `mcp` | Start the MCP server | | `bun` | Run a command using bun | -| `flux` | Flux OTEL collector commands | +| `tracing` | Tracing OTEL collector commands | | `upgrade` | Self-update to the latest version | --- @@ -421,31 +421,31 @@ apx bun --- -## flux +## tracing OpenTelemetry collector commands for observability. - **Note:** You typically don't need to use flux commands directly. When you run + **Note:** You typically don't need to use tracing commands directly. When you run `apx dev start`, apx automatically starts the log collector and keeps it - running alongside your development servers. The flux commands are provided for + running alongside your development servers. The tracing commands are provided for advanced use cases where you need manual control over the collector. -### flux start +### tracing start -Start the Flux OTEL collector daemon. +Start the tracing OTEL collector daemon. ```bash -apx flux start +apx tracing start ``` -### flux stop +### tracing stop -Stop the Flux OTEL collector daemon. +Stop the tracing OTEL collector daemon. ```bash -apx flux stop +apx tracing stop ``` --- diff --git a/scripts/build_agent.py b/scripts/build_agent.py index 46c68128..0027bf39 100644 --- a/scripts/build_agent.py +++ b/scripts/build_agent.py @@ -153,7 +153,7 @@ def build_target(target: Target, output_dir: Path, release: bool = True) -> None build_cmd, "build", "-p", - "apx-agent", + "apx-tracing", "--target", target.rust_target, ] diff --git a/src/apx/assets/entrypoint.ts b/src/apx/assets/entrypoint.ts index 2abce0c4..bf367aff 100644 --- a/src/apx/assets/entrypoint.ts +++ b/src/apx/assets/entrypoint.ts @@ -26,7 +26,7 @@ const appName = process.env.APX_APP_NAME!; // ============================================================================ // OpenTelemetry Logging Setup // ============================================================================ -// Logs are sent directly to flux via OTLP HTTP, NOT piped through apx stdout. +// Logs are sent directly to the collector via OTLP HTTP, NOT piped through apx stdout. // This ensures proper service attribution and avoids log interleaving issues. // ============================================================================ @@ -98,14 +98,14 @@ function emitLog( function log(message: string) { // Write to stdout for local visibility process.stdout.write(message + "\n"); - // Send to flux + // Send to collector emitLog("INFO", message); } function logError(message: string) { // Write to stderr for local visibility process.stderr.write(message + "\n"); - // Send to flux + // Send to collector emitLog("ERROR", message); } From 564368156bed7e928bff9de1e1d0093ebd91df93 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Sun, 1 Mar 2026 17:18:03 +0100 Subject: [PATCH 2/5] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20serving=20endpoin?= =?UTF-8?q?ts=20SDK=20+=20agent=20crate=20with=20rig=20completions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 155 +++++++++ Cargo.toml | 6 +- crates/agent/Cargo.toml | 28 ++ crates/agent/src/client.rs | 132 ++++++++ crates/agent/src/error.rs | 21 ++ crates/agent/src/lib.rs | 15 + crates/agent/src/model.rs | 109 ++++++ crates/databricks_sdk/Cargo.toml | 4 + crates/databricks_sdk/src/api/mod.rs | 2 + .../src/api/serving_endpoints.rs | 318 ++++++++++++++++++ crates/databricks_sdk/src/client.rs | 90 ++++- crates/databricks_sdk/src/lib.rs | 3 + 12 files changed, 881 insertions(+), 2 deletions(-) create mode 100644 crates/agent/Cargo.toml create mode 100644 crates/agent/src/client.rs create mode 100644 crates/agent/src/error.rs create mode 100644 crates/agent/src/lib.rs create mode 100644 crates/agent/src/model.rs create mode 100644 crates/databricks_sdk/src/api/serving_endpoints.rs diff --git a/Cargo.lock b/Cargo.lock index 21d6c9a8..f947afd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -135,6 +135,21 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "apx-agent" +version = "0.3.8" +dependencies = [ + "apx-databricks-sdk", + "reqwest 0.13.1", + "rig-core", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tracing", + "wiremock", +] + [[package]] name = "apx-bin" version = "0.3.8" @@ -326,6 +341,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "as-any" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063" + [[package]] name = "ascii" version = "1.1.0" @@ -353,6 +374,28 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -1814,6 +1857,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -2048,6 +2102,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -3344,6 +3404,22 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -3387,6 +3463,15 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "nanoid" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8" +dependencies = [ + "rand 0.8.5", +] + [[package]] name = "ndk" version = "0.9.0" @@ -3429,6 +3514,16 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "notify" version = "8.2.0" @@ -3932,6 +4027,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +dependencies = [ + "num-traits", +] + [[package]] name = "oxc_resolver" version = "1.12.0" @@ -4892,6 +4996,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "percent-encoding", "pin-project-lite", "quinn", @@ -4920,6 +5025,38 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "194d8e591e405d1eecf28819740abed6d719d1a2db87fc0bcdedee9a26d55560" +[[package]] +name = "rig-core" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "437fa2a15825caf2505411bbe55b05c8eb122e03934938b38f9ecaa1d6ded7c8" +dependencies = [ + "as-any", + "async-stream", + "base64 0.22.1", + "bytes", + "eventsource-stream", + "fastrand", + "futures", + "futures-timer", + "glob", + "http", + "mime", + "mime_guess", + "nanoid", + "ordered-float", + "pin-project-lite", + "reqwest 0.13.1", + "schemars 1.2.0", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tracing", + "tracing-futures", + "url", +] + [[package]] name = "ring" version = "0.17.14" @@ -6944,6 +7081,18 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-futures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" +dependencies = [ + "futures", + "futures-task", + "pin-project", + "tracing", +] + [[package]] name = "tracing-log" version = "0.2.0" @@ -7095,6 +7244,12 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-bidi" version = "0.3.18" diff --git a/Cargo.toml b/Cargo.toml index fe00b06f..a1373752 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/common", "crates/tracing", "crates/studio", "crates/core", "crates/mcp", "crates/cli", "crates/databricks_sdk", "crates/apx", "crates/db"] +members = ["crates/common", "crates/tracing", "crates/studio", "crates/core", "crates/mcp", "crates/cli", "crates/databricks_sdk", "crates/apx", "crates/db", "crates/agent"] resolver = "2" [workspace.package] @@ -21,6 +21,7 @@ apx-core = { path = "crates/core" } apx-mcp = { path = "crates/mcp" } apx-cli = { path = "crates/cli" } apx-databricks-sdk = { path = "crates/databricks_sdk" } +apx-agent = { path = "crates/agent" } apx-db = { path = "crates/db" } # Serialization @@ -39,6 +40,9 @@ tokio-stream = "0.1.17" futures-util = "0.3.31" rmcp = { version = "0.15", features = ["server", "transport-io", "schemars"] } +# AI / LLM +rig-core = { version = "0.31", default-features = false, features = ["reqwest-rustls"] } + # Web server axum = { version = "0.8.8", features = ["ws"] } reqwest = { version = "0.13.1", default-features = false, features = ["json", "stream", "rustls", "http2", "charset", "system-proxy"] } diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml new file mode 100644 index 00000000..98b0fc53 --- /dev/null +++ b/crates/agent/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "apx-agent" +version = "0.3.8" +edition.workspace = true +rust-version.workspace = true +description.workspace = true +repository.workspace = true +license.workspace = true +readme.workspace = true +keywords.workspace = true +categories.workspace = true + +[lints] +workspace = true + +[dependencies] +apx-databricks-sdk.workspace = true +rig-core.workspace = true +reqwest.workspace = true +serde.workspace = true +thiserror.workspace = true +tracing.workspace = true + +[dev-dependencies] +apx-databricks-sdk = { workspace = true, features = ["testing"] } +tokio.workspace = true +wiremock.workspace = true +serde_json.workspace = true diff --git a/crates/agent/src/client.rs b/crates/agent/src/client.rs new file mode 100644 index 00000000..11eef7f1 --- /dev/null +++ b/crates/agent/src/client.rs @@ -0,0 +1,132 @@ +use apx_databricks_sdk::DatabricksClient; +use rig::providers::openai; +use tracing::debug; + +use crate::error::Result; +use crate::model::{ModelRef, chat_models}; + +/// Core agent client wrapping a [`DatabricksClient`] for model discovery +/// and rig-based completions. +#[derive(Debug, Clone)] +pub struct AgentClient { + databricks: DatabricksClient, +} + +impl AgentClient { + /// Create an `AgentClient` from an existing [`DatabricksClient`]. + #[must_use] + pub const fn new(databricks: DatabricksClient) -> Self { + Self { databricks } + } + + /// Create an `AgentClient` by resolving a Databricks CLI profile. + /// + /// # Errors + /// + /// Returns an error if the profile cannot be resolved. + pub async fn from_profile(profile: &str) -> Result { + let databricks = DatabricksClient::new(profile).await?; + Ok(Self::new(databricks)) + } + + /// List chat-capable models in the workspace. + /// + /// Fetches all serving endpoints and filters to those that are ready + /// and serve `llm/v1/chat` (or have no task specified). + /// + /// # Errors + /// + /// Returns an error if the serving endpoints cannot be listed. + pub async fn list_models(&self) -> Result> { + let endpoints = self.databricks.serving_endpoints().list().await?; + debug!(count = endpoints.len(), "Fetched serving endpoints"); + Ok(chat_models(&endpoints)) + } + + /// Build a rig [`CompletionsClient`](openai::CompletionsClient) with a fresh + /// token and the proper base URL for this workspace. + /// + /// The base URL is `{host}/serving-endpoints` and the endpoint name is used + /// as the model ID. The returned client is short-lived — rebuild it when + /// the token might have expired. + /// + /// # Errors + /// + /// Returns an error if the token cannot be acquired or the client fails to build. + pub async fn completions_client(&self) -> Result { + let token = self.databricks.access_token().await?; + let base_url = format!("{}/serving-endpoints", self.databricks.host()); + debug!(%base_url, "Building rig CompletionsClient"); + + let client = openai::CompletionsClient::builder() + .api_key(&token) + .base_url(&base_url) + .build() + .map_err(|e| crate::error::AgentError::Completion(e.to_string()))?; + + Ok(client) + } + + /// Returns a reference to the underlying [`DatabricksClient`]. + #[must_use] + pub const fn databricks(&self) -> &DatabricksClient { + &self.databricks + } +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use super::*; + + fn mock_list_response() -> serde_json::Value { + serde_json::json!({ + "endpoints": [ + { + "name": "chat-model", + "state": { "ready": "READY" }, + "task": "llm/v1/chat" + }, + { + "name": "embed-model", + "state": { "ready": "READY" }, + "task": "llm/v1/embeddings" + }, + { + "name": "not-ready-model", + "state": { "ready": "NOT_READY" }, + "task": "llm/v1/chat" + } + ] + }) + } + + #[tokio::test] + async fn list_models_returns_ready_chat_endpoints() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api/2.0/serving-endpoints")) + .respond_with(ResponseTemplate::new(200).set_body_json(mock_list_response())) + .mount(&server) + .await; + + let sdk = DatabricksClient::with_static_token(&server.uri(), "test-token"); + let client = AgentClient::new(sdk); + let models = client.list_models().await.unwrap(); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "chat-model"); + } + + #[tokio::test] + async fn completions_client_builds_successfully() { + let sdk = DatabricksClient::with_static_token("https://test.databricks.com", "test-token"); + let client = AgentClient::new(sdk); + let result = client.completions_client().await; + assert!(result.is_ok()); + } +} diff --git a/crates/agent/src/error.rs b/crates/agent/src/error.rs new file mode 100644 index 00000000..925f9eef --- /dev/null +++ b/crates/agent/src/error.rs @@ -0,0 +1,21 @@ +use apx_databricks_sdk::DatabricksError; + +/// Errors returned by the agent crate. +#[derive(Debug, thiserror::Error)] +pub enum AgentError { + /// An error from the Databricks SDK layer. + #[error(transparent)] + Sdk(#[from] DatabricksError), + /// An HTTP transport error. + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + /// The requested model was not found in the workspace. + #[error("model not found: {0}")] + ModelNotFound(String), + /// A completion request failed. + #[error("completion error: {0}")] + Completion(String), +} + +/// Convenience alias for `Result`. +pub type Result = std::result::Result; diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs new file mode 100644 index 00000000..de7c5b24 --- /dev/null +++ b/crates/agent/src/lib.rs @@ -0,0 +1,15 @@ +//! APX Agent — local agent powered by Databricks-hosted Foundation Models. +//! +//! This crate provides model discovery and a rig-based completions client +//! for Databricks serving endpoints. + +/// Agent client for model discovery and completions. +pub mod client; +/// Error types for the agent crate. +pub mod error; +/// Model reference types and filtering utilities. +pub mod model; + +pub use client::AgentClient; +pub use error::{AgentError, Result}; +pub use model::{ModelRef, chat_models}; diff --git a/crates/agent/src/model.rs b/crates/agent/src/model.rs new file mode 100644 index 00000000..6cd8d3c2 --- /dev/null +++ b/crates/agent/src/model.rs @@ -0,0 +1,109 @@ +use apx_databricks_sdk::ServingEndpoint; + +/// A reference to a model available in the workspace. +#[derive(Debug, Clone)] +pub struct ModelRef { + /// Serving endpoint name (used as the model ID in completions requests). + pub name: String, + /// Task type reported by the endpoint (e.g. `"llm/v1/chat"`). + pub task: Option, +} + +impl ModelRef { + /// Create a `ModelRef` from a [`ServingEndpoint`]. + #[must_use] + pub fn from_endpoint(endpoint: &ServingEndpoint) -> Self { + Self { + name: endpoint.name.clone(), + task: endpoint.task.clone(), + } + } + + /// Returns `true` if the endpoint serves an `llm/v1/chat` task + /// (or has no task specified, which is common for custom endpoints). + #[must_use] + pub fn is_chat_capable(&self) -> bool { + matches!(self.task.as_deref(), Some("llm/v1/chat") | None) + } +} + +/// Filter a slice of serving endpoints to those that are READY and chat-capable. +#[must_use] +pub fn chat_models(endpoints: &[ServingEndpoint]) -> Vec { + endpoints + .iter() + .filter(|ep| ep.is_ready()) + .map(ModelRef::from_endpoint) + .filter(|m| m.is_chat_capable()) + .collect() +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use apx_databricks_sdk::{EndpointReadyState, EndpointState, ServingEndpoint}; + + use super::*; + + fn make_endpoint(name: &str, ready: EndpointReadyState, task: Option<&str>) -> ServingEndpoint { + ServingEndpoint { + name: name.to_string(), + creator: None, + state: Some(EndpointState { + ready: Some(ready), + config_update: None, + }), + task: task.map(ToString::to_string), + config: None, + } + } + + #[test] + fn is_chat_capable_with_chat_task() { + let m = ModelRef { + name: "ep".to_string(), + task: Some("llm/v1/chat".to_string()), + }; + assert!(m.is_chat_capable()); + } + + #[test] + fn is_chat_capable_with_no_task() { + let m = ModelRef { + name: "ep".to_string(), + task: None, + }; + assert!(m.is_chat_capable()); + } + + #[test] + fn is_chat_capable_with_embeddings_task() { + let m = ModelRef { + name: "ep".to_string(), + task: Some("llm/v1/embeddings".to_string()), + }; + assert!(!m.is_chat_capable()); + } + + #[test] + fn chat_models_filters_correctly() { + let endpoints = vec![ + make_endpoint("chat-ready", EndpointReadyState::Ready, Some("llm/v1/chat")), + make_endpoint( + "chat-not-ready", + EndpointReadyState::NotReady, + Some("llm/v1/chat"), + ), + make_endpoint( + "embed-ready", + EndpointReadyState::Ready, + Some("llm/v1/embeddings"), + ), + make_endpoint("no-task-ready", EndpointReadyState::Ready, None), + ]; + + let models = chat_models(&endpoints); + let names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + assert_eq!(names, vec!["chat-ready", "no-task-ready"]); + } +} diff --git a/crates/databricks_sdk/Cargo.toml b/crates/databricks_sdk/Cargo.toml index 0504801a..56f3f88c 100644 --- a/crates/databricks_sdk/Cargo.toml +++ b/crates/databricks_sdk/Cargo.toml @@ -10,6 +10,10 @@ readme.workspace = true keywords.workspace = true categories.workspace = true +[features] +default = [] +testing = [] + [lints] workspace = true diff --git a/crates/databricks_sdk/src/api/mod.rs b/crates/databricks_sdk/src/api/mod.rs index 1d3b559a..fce7ffb3 100644 --- a/crates/databricks_sdk/src/api/mod.rs +++ b/crates/databricks_sdk/src/api/mod.rs @@ -2,3 +2,5 @@ pub mod apps; /// SCIM current-user (`/Me`) endpoint. pub mod current_user; +/// Serving Endpoints REST API. +pub mod serving_endpoints; diff --git a/crates/databricks_sdk/src/api/serving_endpoints.rs b/crates/databricks_sdk/src/api/serving_endpoints.rs new file mode 100644 index 00000000..55100947 --- /dev/null +++ b/crates/databricks_sdk/src/api/serving_endpoints.rs @@ -0,0 +1,318 @@ +use serde::Deserialize; + +use crate::client::DatabricksClient; +use crate::error::Result; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +/// Private envelope for deserializing the list response. +#[derive(Deserialize)] +struct ListResponse { + #[serde(default)] + endpoints: Vec, +} + +/// Readiness state of a serving endpoint. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum EndpointReadyState { + /// The endpoint is ready to serve requests. + Ready, + /// The endpoint is not yet ready. + NotReady, + /// An unrecognized state (forward-compatible). + #[serde(other)] + Unknown, +} + +/// Config-update state of a serving endpoint. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum EndpointConfigUpdateState { + /// No config update in progress. + NotUpdating, + /// A config update is in progress. + InProgress, + /// A config update was canceled. + UpdateCanceled, + /// A config update failed. + UpdateFailed, + /// An unrecognized state (forward-compatible). + #[serde(other)] + Unknown, +} + +/// Nested state object of a serving endpoint. +#[derive(Debug, Clone, Copy, Deserialize)] +pub struct EndpointState { + /// Whether the endpoint is ready to serve. + #[serde(default)] + pub ready: Option, + /// Current config-update status. + #[serde(default)] + pub config_update: Option, +} + +/// An external model reference within a served entity. +#[derive(Debug, Clone, Deserialize)] +pub struct ExternalModel { + /// Model name (e.g. `"gpt-4"`). + pub name: String, + /// Provider name (e.g. `"openai"`). + #[serde(default)] + pub provider: Option, +} + +/// A single served entity (model) within a serving endpoint config. +#[derive(Debug, Clone, Deserialize)] +pub struct ServedEntity { + /// Name of the served entity. + #[serde(default)] + pub name: Option, + /// External model reference, if this entity wraps an external model. + #[serde(default)] + pub external_model: Option, +} + +/// Configuration of a serving endpoint, containing the served entities. +#[derive(Debug, Clone, Deserialize)] +pub struct EndpointConfig { + /// Current served entities (preferred). + #[serde(default)] + pub served_entities: Vec, + /// Deprecated alias for `served_entities`. + #[serde(default)] + pub served_models: Vec, +} + +/// A Databricks serving endpoint. +#[derive(Debug, Clone, Deserialize)] +pub struct ServingEndpoint { + /// Endpoint name (unique within the workspace). + pub name: String, + /// User who created the endpoint. + #[serde(default)] + pub creator: Option, + /// Current endpoint state (readiness, config update status). + #[serde(default)] + pub state: Option, + /// Task type (e.g. `"llm/v1/chat"`, `"llm/v1/completions"`). + #[serde(default)] + pub task: Option, + /// Endpoint configuration with served entities. + #[serde(default)] + pub config: Option, +} + +impl ServingEndpoint { + /// Returns `true` if the endpoint is in the `Ready` state. + #[must_use] + pub fn is_ready(&self) -> bool { + self.state + .as_ref() + .and_then(|s| s.ready) + .is_some_and(|r| r == EndpointReadyState::Ready) + } + + /// Returns the served entities, preferring `served_entities` over the + /// deprecated `served_models` field. + #[must_use] + pub fn entities(&self) -> &[ServedEntity] { + match self.config { + Some(ref cfg) if !cfg.served_entities.is_empty() => &cfg.served_entities, + Some(ref cfg) => &cfg.served_models, + None => &[], + } + } +} + +// --------------------------------------------------------------------------- +// API handle +// --------------------------------------------------------------------------- + +/// API handle for Databricks Serving Endpoints operations. +#[derive(Debug)] +pub struct ServingEndpointsApi<'a> { + client: &'a DatabricksClient, +} + +impl<'a> ServingEndpointsApi<'a> { + pub(crate) const fn new(client: &'a DatabricksClient) -> Self { + Self { client } + } + + /// List all serving endpoints in the workspace. + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be deserialized. + pub async fn list(&self) -> Result> { + let resp: ListResponse = self.client.get("/api/2.0/serving-endpoints").await?; + Ok(resp.endpoints) + } + + /// Get a single serving endpoint by name. + /// + /// # Errors + /// + /// Returns an error if the HTTP request fails or the response cannot be deserialized. + pub async fn get(&self, name: &str) -> Result { + self.client + .get(&format!("/api/2.0/serving-endpoints/{name}")) + .await + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn deserialize_endpoint_ready() { + let json = r#"{ + "name": "my-endpoint", + "creator": "user@example.com", + "state": { "ready": "READY", "config_update": "NOT_UPDATING" }, + "task": "llm/v1/chat", + "config": { + "served_entities": [{"name": "entity-1"}], + "served_models": [] + } + }"#; + let ep: ServingEndpoint = serde_json::from_str(json).unwrap(); + assert_eq!(ep.name, "my-endpoint"); + assert_eq!(ep.creator.as_deref(), Some("user@example.com")); + assert!(ep.is_ready()); + assert_eq!(ep.task.as_deref(), Some("llm/v1/chat")); + assert_eq!(ep.entities().len(), 1); + } + + #[test] + fn deserialize_endpoint_not_ready() { + let json = r#"{ + "name": "ep-2", + "state": { "ready": "NOT_READY" } + }"#; + let ep: ServingEndpoint = serde_json::from_str(json).unwrap(); + assert!(!ep.is_ready()); + } + + #[test] + fn deserialize_endpoint_unknown_state() { + let json = r#"{ + "name": "ep-3", + "state": { "ready": "SOME_FUTURE_STATE" } + }"#; + let ep: ServingEndpoint = serde_json::from_str(json).unwrap(); + assert!(!ep.is_ready()); + assert_eq!( + ep.state.as_ref().unwrap().ready, + Some(EndpointReadyState::Unknown) + ); + } + + #[test] + fn deserialize_endpoint_minimal() { + let json = r#"{ "name": "bare-ep" }"#; + let ep: ServingEndpoint = serde_json::from_str(json).unwrap(); + assert_eq!(ep.name, "bare-ep"); + assert!(ep.creator.is_none()); + assert!(ep.state.is_none()); + assert!(ep.task.is_none()); + assert!(ep.config.is_none()); + assert!(!ep.is_ready()); + assert!(ep.entities().is_empty()); + } + + #[test] + fn deserialize_list_response_empty() { + let json = r#"{ "endpoints": [] }"#; + let resp: ListResponse = serde_json::from_str(json).unwrap(); + assert!(resp.endpoints.is_empty()); + } + + #[test] + fn deserialize_list_response_missing_endpoints_key() { + let json = r#"{}"#; + let resp: ListResponse = serde_json::from_str(json).unwrap(); + assert!(resp.endpoints.is_empty()); + } + + #[test] + fn is_ready_true_when_ready() { + let ep = ServingEndpoint { + name: "test".to_string(), + creator: None, + state: Some(EndpointState { + ready: Some(EndpointReadyState::Ready), + config_update: None, + }), + task: None, + config: None, + }; + assert!(ep.is_ready()); + } + + #[test] + fn is_ready_false_when_not_ready() { + let ep = ServingEndpoint { + name: "test".to_string(), + creator: None, + state: Some(EndpointState { + ready: Some(EndpointReadyState::NotReady), + config_update: None, + }), + task: None, + config: None, + }; + assert!(!ep.is_ready()); + } + + #[test] + fn entities_prefers_served_entities() { + let ep = ServingEndpoint { + name: "test".to_string(), + creator: None, + state: None, + task: None, + config: Some(EndpointConfig { + served_entities: vec![ServedEntity { + name: Some("primary".to_string()), + external_model: None, + }], + served_models: vec![ServedEntity { + name: Some("fallback".to_string()), + external_model: None, + }], + }), + }; + assert_eq!(ep.entities().len(), 1); + assert_eq!(ep.entities()[0].name.as_deref(), Some("primary")); + } + + #[test] + fn entities_falls_back_to_served_models() { + let ep = ServingEndpoint { + name: "test".to_string(), + creator: None, + state: None, + task: None, + config: Some(EndpointConfig { + served_entities: vec![], + served_models: vec![ServedEntity { + name: Some("legacy".to_string()), + external_model: None, + }], + }), + }; + assert_eq!(ep.entities().len(), 1); + assert_eq!(ep.entities()[0].name.as_deref(), Some("legacy")); + } +} diff --git a/crates/databricks_sdk/src/client.rs b/crates/databricks_sdk/src/client.rs index 896db9a5..b5356c62 100644 --- a/crates/databricks_sdk/src/client.rs +++ b/crates/databricks_sdk/src/client.rs @@ -6,6 +6,7 @@ use tracing::debug; use crate::api::apps::AppsApi; use crate::api::current_user::CurrentUserApi; +use crate::api::serving_endpoints::ServingEndpointsApi; use crate::auth::{CachedToken, acquire_token}; use crate::config::{DatabricksConfig, resolve_config}; use crate::error::{DatabricksError, Result}; @@ -14,6 +15,7 @@ use crate::useragent::UserAgent; struct Inner { config: DatabricksConfig, http: reqwest::Client, + user_agent: String, cached_token: RwLock>, } @@ -72,9 +74,10 @@ impl DatabricksClient { let product_version = config.product_version.as_deref().unwrap_or("0.0.0"); let ua = UserAgent::new(product, product_version).with_auth("databricks-cli"); + let ua_string = ua.to_string(); let http = reqwest::Client::builder() - .user_agent(ua.to_string()) + .user_agent(&ua_string) .build() .unwrap_or_else(|_| reqwest::Client::new()); @@ -82,6 +85,7 @@ impl DatabricksClient { inner: Arc::new(Inner { config, http, + user_agent: ua_string, cached_token: RwLock::new(None), }), } @@ -155,6 +159,26 @@ impl DatabricksClient { &self.inner.config.profile } + /// The User-Agent header value used by this client. + #[must_use] + pub fn user_agent(&self) -> &str { + &self.inner.user_agent + } + + /// Build a `reqwest::Client` configured with the same User-Agent as this + /// SDK client. No auth headers are included — use [`access_token()`](Self::access_token) + /// separately. + /// + /// # Errors + /// + /// Returns an error if the underlying `reqwest::ClientBuilder` fails to build. + pub fn http_client(&self) -> Result { + reqwest::Client::builder() + .user_agent(&self.inner.user_agent) + .build() + .map_err(Into::into) + } + /// Access the Databricks Apps API. #[must_use] pub const fn apps(&self) -> AppsApi<'_> { @@ -166,6 +190,51 @@ impl DatabricksClient { pub const fn current_user(&self) -> CurrentUserApi<'_> { CurrentUserApi::new(self) } + + /// Access the Serving Endpoints API. + #[must_use] + pub const fn serving_endpoints(&self) -> ServingEndpointsApi<'_> { + ServingEndpointsApi::new(self) + } +} + +#[cfg(any(test, feature = "testing"))] +impl DatabricksClient { + /// Build a client with a pre-populated static token (bypasses CLI auth). + /// + /// Intended for wiremock-based tests in downstream crates. + #[must_use] + pub fn with_static_token(host: &str, token: &str) -> Self { + let config = DatabricksConfig { + profile: "test".to_string(), + host: host.to_string(), + product: Some("apx-test".to_string()), + product_version: Some("0.0.0".to_string()), + }; + + let ua = UserAgent::new("apx-test", "0.0.0").with_auth("static-token"); + let ua_string = ua.to_string(); + + let http = reqwest::Client::builder() + .user_agent(&ua_string) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + + let far_future = chrono::Utc::now() + chrono::Duration::hours(24); + let cached = CachedToken { + access_token: token.to_string(), + expires_at: far_future, + }; + + Self { + inner: Arc::new(Inner { + config, + http, + user_agent: ua_string, + cached_token: RwLock::new(Some(cached)), + }), + } + } } async fn handle_response(response: reqwest::Response) -> Result { @@ -183,3 +252,22 @@ async fn handle_response(response: reqwest::Response) -> Re let body = response.text().await?; serde_json::from_str(&body).map_err(Into::into) } + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn http_client_builds_successfully() { + let client = DatabricksClient::with_static_token("https://test.databricks.com", "tok"); + assert!(client.http_client().is_ok()); + } + + #[test] + fn user_agent_contains_product() { + let client = DatabricksClient::with_static_token("https://test.databricks.com", "tok"); + assert!(client.user_agent().contains("apx-test")); + assert!(client.user_agent().contains("apx-databricks-sdk-rust")); + } +} diff --git a/crates/databricks_sdk/src/lib.rs b/crates/databricks_sdk/src/lib.rs index 109675de..2c79b262 100644 --- a/crates/databricks_sdk/src/lib.rs +++ b/crates/databricks_sdk/src/lib.rs @@ -20,6 +20,9 @@ pub mod useragent; pub use api::apps::{App, AppLogsArgs, ComputeState, LogEntry}; pub use api::current_user::{User, UserEmail, UserName}; +pub use api::serving_endpoints::{ + EndpointConfig, EndpointReadyState, EndpointState, ServedEntity, ServingEndpoint, +}; pub use client::DatabricksClient; pub use config::{DatabricksConfig, list_profile_names, resolve_config}; pub use config_parser::ConfigParser; From f2b6c1d09b0a5fec8ef2be4b1a476f2696ead1c4 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Mon, 2 Mar 2026 18:09:29 +0100 Subject: [PATCH 3/5] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20extract=20?= =?UTF-8?q?EnvProfile=20into=20apx-common,=20deduplicate=20resolve=5Fprofi?= =?UTF-8?q?le?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 --- crates/common/src/lib.rs | 3 + crates/common/src/profile.rs | 120 +++++++++++++++++++++++++++++ crates/core/src/dev/server.rs | 16 ++-- crates/mcp/src/tools/databricks.rs | 17 +--- 4 files changed, 136 insertions(+), 20 deletions(-) create mode 100644 crates/common/src/profile.rs diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 53f106f2..74f480f0 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -9,6 +9,8 @@ pub mod bundles; pub mod format; /// Network host constants for binding, client connections, and browser URLs. pub mod hosts; +/// Databricks CLI profile resolution. +pub mod profile; /// Pure types and logic for OTEL log records, filtering, and aggregation. pub mod storage; @@ -19,6 +21,7 @@ use std::path::PathBuf; use std::time::Duration; // Re-export commonly used types +pub use profile::EnvProfile; pub use storage::{ AggregatedRecord, LogAggregator, LogRecord, ServiceKind, collector_dir, get_aggregation_key, should_skip_log, should_skip_log_message, source_label, diff --git a/crates/common/src/profile.rs b/crates/common/src/profile.rs new file mode 100644 index 00000000..5d1b058b --- /dev/null +++ b/crates/common/src/profile.rs @@ -0,0 +1,120 @@ +//! Databricks CLI profile resolution. +//! +//! Consolidates the logic for resolving which Databricks CLI profile to +//! use from multiple sources (explicit flag, environment variable, dotenv). + +use std::collections::HashMap; + +/// Environment variable name for Databricks CLI profile. +const PROFILE_ENV_VAR: &str = "DATABRICKS_CONFIG_PROFILE"; + +/// Where a resolved profile came from. +enum ProfileSource<'a> { + /// `--profile` CLI flag. + Explicit(&'a str), + /// `DATABRICKS_CONFIG_PROFILE` environment variable. + EnvVar(&'a str), + /// `.env` file entry. + Dotenv(&'a str), + /// No profile specified — SDK uses DEFAULT. + Default, +} + +/// Classify which source provides the profile, in priority order. +fn classify<'a>( + cli_flag: Option<&'a str>, + env_var: Option<&'a str>, + dotenv: Option<&'a str>, +) -> ProfileSource<'a> { + if let Some(v) = non_blank(cli_flag) { + return ProfileSource::Explicit(v); + } + if let Some(v) = non_blank(env_var) { + return ProfileSource::EnvVar(v); + } + if let Some(v) = non_blank(dotenv) { + return ProfileSource::Dotenv(v); + } + ProfileSource::Default +} + +/// Resolve a Databricks profile name: explicit flag → env var → dotenv → empty. +/// +/// Pure function — all inputs are passed as arguments, no I/O. +fn resolve(cli_flag: Option<&str>, env_var: Option<&str>, dotenv: Option<&str>) -> String { + match classify(cli_flag, env_var, dotenv) { + ProfileSource::Explicit(p) | ProfileSource::EnvVar(p) | ProfileSource::Dotenv(p) => { + p.to_string() + } + ProfileSource::Default => String::new(), + } +} + +/// Return the trimmed value if non-empty, `None` otherwise. +fn non_blank(val: Option<&str>) -> Option<&str> { + val.map(str::trim).filter(|s| !s.is_empty()) +} + +/// Resolve a Databricks CLI profile from multiple sources. +/// +/// Priority: explicit CLI flag → `DATABRICKS_CONFIG_PROFILE` env var → dotenv vars → empty. +#[derive(Debug)] +pub struct EnvProfile<'a> { + dotenv_vars: &'a HashMap, +} + +impl<'a> EnvProfile<'a> { + /// Create a resolver backed by dotenv vars. + #[must_use] + pub const fn new(dotenv_vars: &'a HashMap) -> Self { + Self { dotenv_vars } + } + + /// Resolve the profile name. + /// + /// `explicit` is the value from a CLI flag (e.g. `--profile`). + #[must_use] + pub fn retrieve(&self, explicit: Option<&str>) -> String { + let env_val = std::env::var(PROFILE_ENV_VAR).ok(); + let dotenv_val = self.dotenv_vars.get(PROFILE_ENV_VAR).map(String::as_str); + resolve(explicit, env_val.as_deref(), dotenv_val) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn explicit_arg_takes_priority() { + assert_eq!( + resolve(Some("explicit"), Some("env"), Some("dotenv")), + "explicit" + ); + } + + #[test] + fn env_var_takes_priority_over_dotenv() { + assert_eq!(resolve(None, Some("env"), Some("dotenv")), "env"); + } + + #[test] + fn dotenv_used_when_no_env_var() { + assert_eq!(resolve(None, None, Some("dotenv")), "dotenv"); + } + + #[test] + fn empty_when_no_sources() { + assert_eq!(resolve(None, None, None), ""); + } + + #[test] + fn whitespace_explicit_arg_is_ignored() { + assert_eq!(resolve(Some(" "), Some("env"), None), "env"); + } + + #[test] + fn whitespace_only_falls_through_all() { + assert_eq!(resolve(Some(" "), Some(" "), Some(" ")), ""); + } +} diff --git a/crates/core/src/dev/server.rs b/crates/core/src/dev/server.rs index 74c8917d..85555359 100644 --- a/crates/core/src/dev/server.rs +++ b/crates/core/src/dev/server.rs @@ -13,6 +13,7 @@ use tokio::sync::broadcast; use tokio::time::Duration; use tracing::{debug, info, warn}; +use apx_common::EnvProfile; use apx_databricks_sdk::DatabricksClient; use crate::api_generator::start_openapi_watcher; @@ -464,9 +465,14 @@ async fn stop(headers: HeaderMap, State(state): State) -> StatusCode { /// Resolve the Databricks profile name from env var or `.env` file. pub fn resolve_databricks_profile(app_dir: &std::path::Path) -> Option { - std::env::var("DATABRICKS_CONFIG_PROFILE").ok().or_else(|| { - DotenvFile::read(&app_dir.join(".env")) - .ok() - .and_then(|d| d.get_vars().get("DATABRICKS_CONFIG_PROFILE").cloned()) - }) + let dotenv_vars = DotenvFile::read(&app_dir.join(".env")) + .ok() + .map(|d| d.get_vars()) + .unwrap_or_default(); + let profile = EnvProfile::new(&dotenv_vars).retrieve(None); + if profile.is_empty() { + None + } else { + Some(profile) + } } diff --git a/crates/mcp/src/tools/databricks.rs b/crates/mcp/src/tools/databricks.rs index 1dd499ee..60de6f7c 100644 --- a/crates/mcp/src/tools/databricks.rs +++ b/crates/mcp/src/tools/databricks.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::path::Path; use std::time::Duration; +use apx_common::EnvProfile; use apx_core::dotenv::DotenvFile; use apx_databricks_sdk::{AppLogsArgs, DatabricksClient, LogEntry}; use rmcp::model::{CallToolResult, ErrorData}; @@ -172,21 +173,7 @@ async fn get_or_create_client( } fn resolve_profile(args: &DatabricksAppsLogsArgs, dotenv_vars: &HashMap) -> String { - if let Some(ref p) = args.profile { - let trimmed = p.trim(); - if !trimmed.is_empty() { - return trimmed.to_string(); - } - } - - if let Some(p) = dotenv_vars.get("DATABRICKS_CONFIG_PROFILE") { - let trimmed = p.trim(); - if !trimmed.is_empty() { - return trimmed.to_string(); - } - } - - String::new() + EnvProfile::new(dotenv_vars).retrieve(args.profile.as_deref()) } #[cfg(test)] From cbec145db6c82e2f2401556acd7a5e5231790123 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Mon, 2 Mar 2026 18:11:41 +0100 Subject: [PATCH 4/5] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20agent=20chat,=20s?= =?UTF-8?q?ession=20persistence,=20TUI,=20and=20wiremock=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 627 ++++++++++++++++++++++++++++- Cargo.toml | 4 + crates/agent/Cargo.toml | 5 + crates/agent/src/chat.rs | 136 +++++++ crates/agent/src/client.rs | 144 ++++++- crates/agent/src/error.rs | 3 + crates/agent/src/lib.rs | 11 +- crates/agent/src/session.rs | 42 ++ crates/agent/src/session_sqlite.rs | 357 ++++++++++++++++ crates/cli/Cargo.toml | 4 + crates/cli/src/agent/chat.rs | 98 +++++ crates/cli/src/agent/mod.rs | 13 + crates/cli/src/agent/tui.rs | 359 +++++++++++++++++ crates/cli/src/lib.rs | 7 + 14 files changed, 1799 insertions(+), 11 deletions(-) create mode 100644 crates/agent/src/chat.rs create mode 100644 crates/agent/src/session.rs create mode 100644 crates/agent/src/session_sqlite.rs create mode 100644 crates/cli/src/agent/chat.rs create mode 100644 crates/cli/src/agent/mod.rs create mode 100644 crates/cli/src/agent/tui.rs diff --git a/Cargo.lock b/Cargo.lock index f947afd2..2598f6a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,10 +140,15 @@ name = "apx-agent" version = "0.3.8" dependencies = [ "apx-databricks-sdk", + "async-stream", + "dirs 5.0.1", + "futures-util", + "nanoid", "reqwest 0.13.1", "rig-core", "serde", "serde_json", + "sqlx", "thiserror 2.0.18", "tokio", "tracing", @@ -162,6 +167,7 @@ dependencies = [ name = "apx-cli" version = "0.3.8" dependencies = [ + "apx-agent", "apx-common", "apx-core", "apx-databricks-sdk", @@ -170,9 +176,12 @@ dependencies = [ "chrono", "clap", "console", + "crossterm 0.28.1", "dialoguer", + "futures-util", "indicatif", "rand 0.8.5", + "ratatui", "reqwest 0.13.1", "serde", "serde_json", @@ -439,6 +448,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89cbf775b137e9b968e67227ef7f775587cde3fd31b0d8599dbd0f598a48340" +dependencies = [ + "bytemuck", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -765,6 +783,21 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb9696fda489e25051248bad5a73bdd53f8d063dc3a7f4a71d4c6aadf6fbcb18" +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -1226,6 +1259,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cookie" version = "0.18.1" @@ -1368,6 +1410,50 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crossterm" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" +dependencies = [ + "bitflags 2.10.0", + "crossterm_winapi", + "futures-core", + "mio", + "parking_lot", + "rustix 0.38.44", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" +dependencies = [ + "bitflags 2.10.0", + "crossterm_winapi", + "derive_more 2.1.1", + "document-features", + "mio", + "parking_lot", + "rustix 1.1.3", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -1378,6 +1464,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "csscolorparser" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2a7d3066da2de787b7f032c736763eb7ae5d355f81a68bab2675a96008b0bf" +dependencies = [ + "lab", + "phf 0.11.3", +] + [[package]] name = "cssparser" version = "0.29.6" @@ -1528,6 +1624,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26bf8fc351c5ed29b5c2f0cbbac1b209b74f60ecd62e675a998df72c49af5204" +[[package]] +name = "deltae" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5729f5117e208430e437df2f4843f5e5952997175992d1414f94c57d61e270b4" + [[package]] name = "deranged" version = "0.5.5" @@ -1566,7 +1668,29 @@ version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ - "convert_case", + "convert_case 0.4.0", + "proc-macro2", + "quote", + "rustc_version", + "syn 2.0.114", +] + +[[package]] +name = "derive_more" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" +dependencies = [ + "convert_case 0.10.0", "proc-macro2", "quote", "rustc_version", @@ -1695,6 +1819,15 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -1846,6 +1979,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "euclid" +version = "0.22.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df61bf483e837f88d5c2291dcf55c67be7e676b3a51acc48db3a7b163b91ed63" +dependencies = [ + "num-traits", +] + [[package]] name = "event-listener" version = "5.4.1" @@ -1874,6 +2016,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1899,6 +2051,17 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "filedescriptor" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d" +dependencies = [ + "libc", + "thiserror 1.0.69", + "winapi", +] + [[package]] name = "filetime" version = "0.2.27" @@ -1916,6 +2079,18 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db" +[[package]] +name = "finl_unicode" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9844ddc3a6e533d62bba727eb6c28b5d360921d5175e9ff0f1e621a5c590a4d5" + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.1.8" @@ -1949,6 +2124,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.5.0" @@ -2535,7 +2716,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -2543,6 +2724,11 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", +] [[package]] name = "hashlink" @@ -2938,6 +3124,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "indoc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] + [[package]] name = "infer" version = "0.19.0" @@ -2976,6 +3171,19 @@ dependencies = [ "generic-array", ] +[[package]] +name = "instability" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357b7205c6cd18dd2c86ed312d1e70add149aea98e7ef72b9fdf0270e555c11d" +dependencies = [ + "darling 0.23.0", + "indoc", + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "interpolator" version = "0.5.0" @@ -3127,6 +3335,17 @@ dependencies = [ "serde_json", ] +[[package]] +name = "kasuari" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fe90c1150662e858c7d5f945089b7517b0a80d8bf7ba4b1b5ffc984e7230a5b" +dependencies = [ + "hashbrown 0.16.1", + "portable-atomic", + "thiserror 2.0.18", +] + [[package]] name = "keyboard-types" version = "0.7.0" @@ -3170,6 +3389,12 @@ dependencies = [ "selectors", ] +[[package]] +name = "lab" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf36173d4167ed999940f804952e6b08197cae5ad5d572eb4db150ce8ad5d58f" + [[package]] name = "lazy_static" version = "1.5.0" @@ -3244,6 +3469,21 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "line-clipping" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f4de44e98ddbf09375cbf4d17714d18f39195f4f4894e8524501726fd9a8a4a" +dependencies = [ + "bitflags 2.10.0", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -3256,6 +3496,12 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + [[package]] name = "lock_api" version = "0.4.14" @@ -3271,6 +3517,15 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru" +version = "0.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dc47f592c06f33f8e3aea9591776ec7c9f9e4124778ff8a3c3b87159f7e593" +dependencies = [ + "hashbrown 0.16.1", +] + [[package]] name = "lru-slab" version = "0.1.2" @@ -3304,6 +3559,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" +[[package]] +name = "mac_address" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0aeb26bf5e836cc1c341c8106051b573f1766dfa05aa87f0b98be5e51b02303" +dependencies = [ + "nix", + "winapi", +] + [[package]] name = "manyhow" version = "0.11.4" @@ -3389,6 +3654,12 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "memmem" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" + [[package]] name = "memoffset" version = "0.9.1" @@ -3508,6 +3779,19 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nodrop" version = "0.1.14" @@ -3586,6 +3870,17 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -3636,6 +3931,15 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", +] + [[package]] name = "number_prefix" version = "0.4.0" @@ -4027,6 +4331,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "5.1.0" @@ -4826,6 +5139,91 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "ratatui" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1ce67fb8ba4446454d1c8dbaeda0557ff5e94d39d5e5ed7f10a65eb4c8266bc" +dependencies = [ + "instability", + "ratatui-core", + "ratatui-crossterm", + "ratatui-macros", + "ratatui-termwiz", + "ratatui-widgets", +] + +[[package]] +name = "ratatui-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ef8dea09a92caaf73bff7adb70b76162e5937524058a7e5bff37869cbbec293" +dependencies = [ + "bitflags 2.10.0", + "compact_str 0.9.0", + "hashbrown 0.16.1", + "indoc", + "itertools", + "kasuari", + "lru", + "strum", + "thiserror 2.0.18", + "unicode-segmentation", + "unicode-truncate", + "unicode-width 0.2.2", +] + +[[package]] +name = "ratatui-crossterm" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "577c9b9f652b4c121fb25c6a391dd06406d3b092ba68827e6d2f09550edc54b3" +dependencies = [ + "cfg-if", + "crossterm 0.29.0", + "instability", + "ratatui-core", +] + +[[package]] +name = "ratatui-macros" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7f1342a13e83e4bb9d0b793d0ea762be633f9582048c892ae9041ef39c936f4" +dependencies = [ + "ratatui-core", + "ratatui-widgets", +] + +[[package]] +name = "ratatui-termwiz" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f76fe0bd0ed4295f0321b1676732e2454024c15a35d01904ddb315afd3d545c" +dependencies = [ + "ratatui-core", + "termwiz", +] + +[[package]] +name = "ratatui-widgets" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7dbfa023cd4e604c2553483820c5fe8aa9d71a42eea5aa77c6e7f35756612db" +dependencies = [ + "bitflags 2.10.0", + "hashbrown 0.16.1", + "indoc", + "instability", + "itertools", + "line-clipping", + "ratatui-core", + "strum", + "time", + "unicode-segmentation", + "unicode-width 0.2.2", +] + [[package]] name = "raw-window-handle" version = "0.6.2" @@ -5044,7 +5442,7 @@ dependencies = [ "mime", "mime_guess", "nanoid", - "ordered-float", + "ordered-float 5.1.0", "pin-project-lite", "reqwest 0.13.1", "schemars 1.2.0", @@ -5229,6 +5627,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.1.3" @@ -5238,7 +5649,7 @@ dependencies = [ "bitflags 2.10.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.11.0", "windows-sys 0.61.2", ] @@ -5457,7 +5868,7 @@ checksum = "0c37578180969d00692904465fb7f6b3d50b9a2b952b87c23d0e2e5cb5013416" dependencies = [ "bitflags 1.3.2", "cssparser", - "derive_more", + "derive_more 0.99.20", "fxhash", "log", "phf 0.8.0", @@ -5726,6 +6137,27 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -6028,6 +6460,27 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "subtle" version = "2.6.1" @@ -6560,7 +7013,7 @@ dependencies = [ "fastrand", "getrandom 0.3.4", "once_cell", - "rustix", + "rustix 1.1.3", "windows-sys 0.61.2", ] @@ -6606,6 +7059,69 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminfo" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4ea810f0692f9f51b382fff5893887bb4580f5fa246fde546e0b13e7fcee662" +dependencies = [ + "fnv", + "nom", + "phf 0.11.3", + "phf_codegen 0.11.3", +] + +[[package]] +name = "termios" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "411c5bf740737c7918b8b1fe232dca4dc9f8e754b8ad5e20966814001ed0ac6b" +dependencies = [ + "libc", +] + +[[package]] +name = "termwiz" +version = "0.23.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4676b37242ccbd1aabf56edb093a4827dc49086c0ffd764a5705899e0f35f8f7" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bitflags 2.10.0", + "fancy-regex", + "filedescriptor", + "finl_unicode", + "fixedbitset", + "hex", + "lazy_static", + "libc", + "log", + "memmem", + "nix", + "num-derive", + "num-traits", + "ordered-float 4.6.0", + "pest", + "pest_derive", + "phf 0.11.3", + "sha2", + "signal-hook", + "siphasher 1.0.2", + "terminfo", + "termios", + "thiserror 1.0.69", + "ucd-trie", + "unicode-segmentation", + "vtparse", + "wezterm-bidi", + "wezterm-blob-leases", + "wezterm-color-types", + "wezterm-dynamic", + "wezterm-input-types", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -6663,7 +7179,9 @@ checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", + "libc", "num-conv", + "num_threads", "powerfmt", "serde_core", "time-core", @@ -7295,6 +7813,17 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-truncate" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b380a1238663e5f8a691f9039c73e1cdae598a30e9855f541d29b08b53e9a5" +dependencies = [ + "itertools", + "unicode-segmentation", + "unicode-width 0.2.2", +] + [[package]] name = "unicode-width" version = "0.1.14" @@ -7390,6 +7919,7 @@ version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" dependencies = [ + "atomic", "getrandom 0.3.4", "js-sys", "serde_core", @@ -7446,6 +7976,15 @@ dependencies = [ "libc", ] +[[package]] +name = "vtparse" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9b2acfb050df409c972a37d3b8e08cdea3bddb0c09db9d53137e504cfabed0" +dependencies = [ + "utf8parse", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -7685,6 +8224,78 @@ dependencies = [ "windows-core 0.61.2", ] +[[package]] +name = "wezterm-bidi" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0a6e355560527dd2d1cf7890652f4f09bb3433b6aadade4c9b5ed76de5f3ec" +dependencies = [ + "log", + "wezterm-dynamic", +] + +[[package]] +name = "wezterm-blob-leases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692daff6d93d94e29e4114544ef6d5c942a7ed998b37abdc19b17136ea428eb7" +dependencies = [ + "getrandom 0.3.4", + "mac_address", + "sha2", + "thiserror 1.0.69", + "uuid", +] + +[[package]] +name = "wezterm-color-types" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7de81ef35c9010270d63772bebef2f2d6d1f2d20a983d27505ac850b8c4b4296" +dependencies = [ + "csscolorparser", + "deltae", + "lazy_static", + "wezterm-dynamic", +] + +[[package]] +name = "wezterm-dynamic" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f2ab60e120fd6eaa68d9567f3226e876684639d22a4219b313ff69ec0ccd5ac" +dependencies = [ + "log", + "ordered-float 4.6.0", + "strsim", + "thiserror 1.0.69", + "wezterm-dynamic-derive", +] + +[[package]] +name = "wezterm-dynamic-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c0cf2d539c645b448eaffec9ec494b8b19bd5077d9e58cb1ae7efece8d575b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "wezterm-input-types" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7012add459f951456ec9d6c7e6fc340b1ce15d6fc9629f8c42853412c029e57e" +dependencies = [ + "bitflags 1.3.2", + "euclid", + "lazy_static", + "serde", + "wezterm-dynamic", +] + [[package]] name = "which" version = "7.0.3" @@ -7693,7 +8304,7 @@ checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" dependencies = [ "either", "env_home", - "rustix", + "rustix 1.1.3", "winsafe", ] @@ -8414,7 +9025,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" dependencies = [ "libc", - "rustix", + "rustix 1.1.3", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a1373752..c911b433 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio", tokio = { version = "1.49", features = ["rt-multi-thread", "macros", "sync", "process", "io-util", "signal", "io-std", "net"] } tokio-stream = "0.1.17" futures-util = "0.3.31" +async-stream = "0.3" rmcp = { version = "0.15", features = ["server", "transport-io", "schemars"] } # AI / LLM @@ -81,6 +82,9 @@ sha2 = "0.10" tar = "0.4" rayon = "1.11.0" similar = "2.7" +nanoid = "0.4" +ratatui = "0.30" +crossterm = { version = "0.28", features = ["event-stream"] } # System sysinfo = "0.33.1" diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 98b0fc53..7a049eb0 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -20,6 +20,11 @@ reqwest.workspace = true serde.workspace = true thiserror.workspace = true tracing.workspace = true +sqlx.workspace = true +futures-util.workspace = true +async-stream.workspace = true +dirs.workspace = true +nanoid.workspace = true [dev-dependencies] apx-databricks-sdk = { workspace = true, features = ["testing"] } diff --git a/crates/agent/src/chat.rs b/crates/agent/src/chat.rs new file mode 100644 index 00000000..9928114f --- /dev/null +++ b/crates/agent/src/chat.rs @@ -0,0 +1,136 @@ +//! Domain types for chat conversations. + +use rig::message::Message; + +/// Role in a conversation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum Role { + /// A message from the user. + User, + /// A message from the assistant. + Assistant, +} + +impl std::fmt::Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::User => write!(f, "user"), + Self::Assistant => write!(f, "assistant"), + } + } +} + +impl std::str::FromStr for Role { + type Err = String; + fn from_str(s: &str) -> Result { + match s { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + other => Err(format!("unknown role: {other}")), + } + } +} + +/// A single message in a conversation. +#[derive(Debug, Clone)] +pub struct ChatMessage { + /// Who sent the message. + pub role: Role, + /// The text content of the message. + pub content: String, + /// Unix timestamp in seconds. + pub timestamp: i64, +} + +/// Events emitted during streaming chat completion. +#[derive(Debug)] +pub enum ChatEvent { + /// A text token from the model. + Token(String), + /// Stream finished with the full assembled response. + Done(String), +} + +/// Current unix timestamp in seconds. +#[must_use] +pub fn now_secs() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs().cast_signed()) + .unwrap_or(0) +} + +/// Convert chat history to rig `Message` objects for the completions API. +pub(crate) fn to_rig_messages(history: &[ChatMessage]) -> Vec { + history + .iter() + .map(|msg| match msg.role { + Role::User => Message::user(&msg.content), + Role::Assistant => Message::assistant(&msg.content), + }) + .collect() +} + +#[cfg(test)] +// Reason: panicking on failure is idiomatic in tests +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn role_roundtrips_through_string() { + assert_eq!("user".parse::().unwrap(), Role::User); + assert_eq!("assistant".parse::().unwrap(), Role::Assistant); + assert!("bogus".parse::().is_err()); + } + + #[test] + fn to_rig_messages_converts_user() { + let msgs = vec![ChatMessage { + role: Role::User, + content: "hello".into(), + timestamp: 0, + }]; + let rig_msgs = to_rig_messages(&msgs); + assert_eq!(rig_msgs.len(), 1); + } + + #[test] + fn to_rig_messages_converts_assistant() { + let msgs = vec![ChatMessage { + role: Role::Assistant, + content: "hi there".into(), + timestamp: 0, + }]; + let rig_msgs = to_rig_messages(&msgs); + assert_eq!(rig_msgs.len(), 1); + } + + #[test] + fn to_rig_messages_mixed_conversation() { + let msgs = vec![ + ChatMessage { + role: Role::User, + content: "q1".into(), + timestamp: 1, + }, + ChatMessage { + role: Role::Assistant, + content: "a1".into(), + timestamp: 2, + }, + ChatMessage { + role: Role::User, + content: "q2".into(), + timestamp: 3, + }, + ChatMessage { + role: Role::Assistant, + content: "a2".into(), + timestamp: 4, + }, + ]; + let rig_msgs = to_rig_messages(&msgs); + assert_eq!(rig_msgs.len(), 4); + } +} diff --git a/crates/agent/src/client.rs b/crates/agent/src/client.rs index 11eef7f1..0472786d 100644 --- a/crates/agent/src/client.rs +++ b/crates/agent/src/client.rs @@ -1,10 +1,54 @@ +use std::pin::Pin; + use apx_databricks_sdk::DatabricksClient; +use futures_util::StreamExt; +use rig::agent::MultiTurnStreamItem; +use rig::agent::StreamingError; +use rig::message::Text; +use rig::prelude::CompletionClient; use rig::providers::openai; +use rig::streaming::{StreamedAssistantContent, StreamingChat}; use tracing::debug; -use crate::error::Result; +use crate::chat::{ChatEvent, ChatMessage, to_rig_messages}; +use crate::error::{AgentError, Result}; use crate::model::{ModelRef, chat_models}; +/// Default system prompt for the chat agent. +const SYSTEM_PROMPT: &str = "\ +You are an AI assistant powered by Databricks Foundation Models. \ +You help users with questions about their Databricks workspace, \ +data engineering, and general programming tasks."; + +/// Map a rig stream item to a [`ChatEvent`], accumulating text in `full_text`. +/// +/// Returns `None` for non-text items (tool calls, reasoning, etc.) which are skipped. +fn map_stream_item( + item: std::result::Result, StreamingError>, + full_text: &mut String, +) -> Option> { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(Text { + text, + .. + }))) => { + full_text.push_str(&text); + Some(Ok(ChatEvent::Token(text))) + } + Ok(MultiTurnStreamItem::FinalResponse(resp)) => { + let response = resp.response(); + let text = if full_text.is_empty() && !response.is_empty() { + response.to_string() + } else { + full_text.clone() + }; + Some(Ok(ChatEvent::Done(text))) + } + Ok(_) => None, + Err(e) => Some(Err(AgentError::Completion(e.to_string()))), + } +} + /// Core agent client wrapping a [`DatabricksClient`] for model discovery /// and rig-based completions. #[derive(Debug, Clone)] @@ -62,11 +106,44 @@ impl AgentClient { .api_key(&token) .base_url(&base_url) .build() - .map_err(|e| crate::error::AgentError::Completion(e.to_string()))?; + .map_err(|e| AgentError::Completion(e.to_string()))?; Ok(client) } + /// Stream a chat completion. + /// + /// Converts `history` to rig messages, builds a streaming agent for the + /// given model, and returns a stream of [`ChatEvent`]s. + /// + /// # Errors + /// + /// Returns an error if the completions client cannot be built or the + /// streaming request fails. + pub async fn stream_chat( + &self, + model: &str, + message: &str, + history: &[ChatMessage], + ) -> Result> + Send>>> { + let client = self.completions_client().await?; + let agent = client.agent(model).preamble(SYSTEM_PROMPT).build(); + + let rig_history = to_rig_messages(history); + let mut stream = agent.stream_chat(message, rig_history).await; + + let mut full_text = String::new(); + let mapped = async_stream::stream! { + while let Some(item) = stream.next().await { + if let Some(event) = map_stream_item(item, &mut full_text) { + yield event; + } + } + }; + + Ok(Box::pin(mapped)) + } + /// Returns a reference to the underlying [`DatabricksClient`]. #[must_use] pub const fn databricks(&self) -> &DatabricksClient { @@ -129,4 +206,67 @@ mod tests { let result = client.completions_client().await; assert!(result.is_ok()); } + + #[tokio::test] + async fn list_models_empty_workspace() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api/2.0/serving-endpoints")) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"endpoints": []})), + ) + .mount(&server) + .await; + + let sdk = DatabricksClient::with_static_token(&server.uri(), "test-token"); + let client = AgentClient::new(sdk); + let models = client.list_models().await.unwrap(); + assert!(models.is_empty()); + } + + #[tokio::test] + async fn list_models_api_error_returns_sdk_error() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api/2.0/serving-endpoints")) + .respond_with(ResponseTemplate::new(403).set_body_string("Forbidden")) + .mount(&server) + .await; + + let sdk = DatabricksClient::with_static_token(&server.uri(), "test-token"); + let client = AgentClient::new(sdk); + let result = client.list_models().await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn list_models_excludes_embedding_and_not_ready() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api/2.0/serving-endpoints")) + .respond_with(ResponseTemplate::new(200).set_body_json(mock_list_response())) + .mount(&server) + .await; + + let sdk = DatabricksClient::with_static_token(&server.uri(), "test-token"); + let client = AgentClient::new(sdk); + let models = client.list_models().await.unwrap(); + + // Only "chat-model" passes: embed-model is excluded, not-ready-model is excluded + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "chat-model"); + assert!(models.iter().all(|m| m.name != "embed-model")); + assert!(models.iter().all(|m| m.name != "not-ready-model")); + } + + #[tokio::test] + async fn completions_client_base_url_contains_serving_endpoints() { + // Verify client builds without error when host is a valid URL + let sdk = DatabricksClient::with_static_token("https://my-workspace.databricks.com", "tok"); + let client = AgentClient::new(sdk); + assert!(client.completions_client().await.is_ok()); + } } diff --git a/crates/agent/src/error.rs b/crates/agent/src/error.rs index 925f9eef..5414cd4b 100644 --- a/crates/agent/src/error.rs +++ b/crates/agent/src/error.rs @@ -15,6 +15,9 @@ pub enum AgentError { /// A completion request failed. #[error("completion error: {0}")] Completion(String), + /// Session storage error. + #[error("session error: {0}")] + Session(String), } /// Convenience alias for `Result`. diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs index de7c5b24..3654a797 100644 --- a/crates/agent/src/lib.rs +++ b/crates/agent/src/lib.rs @@ -1,15 +1,24 @@ //! APX Agent — local agent powered by Databricks-hosted Foundation Models. //! -//! This crate provides model discovery and a rig-based completions client +//! This crate provides model discovery, streaming chat, and session persistence //! for Databricks serving endpoints. +/// Chat domain types — roles, messages, and streaming events. +pub mod chat; /// Agent client for model discovery and completions. pub mod client; /// Error types for the agent crate. pub mod error; /// Model reference types and filtering utilities. pub mod model; +/// Session store trait and domain types. +pub mod session; +/// SQLite-backed session store. +pub mod session_sqlite; +pub use chat::{ChatEvent, ChatMessage, Role, now_secs}; pub use client::AgentClient; pub use error::{AgentError, Result}; pub use model::{ModelRef, chat_models}; +pub use session::{Session, SessionStore}; +pub use session_sqlite::SqliteSessionStore; diff --git a/crates/agent/src/session.rs b/crates/agent/src/session.rs new file mode 100644 index 00000000..8108928e --- /dev/null +++ b/crates/agent/src/session.rs @@ -0,0 +1,42 @@ +//! Persistence layer for chat sessions and message history. + +use std::future::Future; + +use crate::chat::ChatMessage; +use crate::error::Result; + +/// A chat session. +#[derive(Debug, Clone)] +pub struct Session { + /// Unique session identifier. + pub id: String, + /// Name of the model used in this session. + pub model_name: String, + /// Unix timestamp (seconds) when the session was created. + pub created_at: i64, + /// Unix timestamp (seconds) of the last activity. + pub updated_at: i64, +} + +/// Persistence layer for chat sessions and message history. +/// +/// Implementations must be `Send + Sync` for use across async tasks. +pub trait SessionStore: Send + Sync { + /// Create a new session and return it. + fn create_session(&self, model_name: &str) -> impl Future> + Send; + /// Load a session by ID. Returns `None` if not found. + fn get_session(&self, id: &str) -> impl Future>> + Send; + /// List all sessions, most recent first. + fn list_sessions(&self) -> impl Future>> + Send; + /// Append a message to a session. + fn append_message( + &self, + session_id: &str, + msg: &ChatMessage, + ) -> impl Future> + Send; + /// Load all messages for a session in chronological order. + fn load_messages( + &self, + session_id: &str, + ) -> impl Future>> + Send; +} diff --git a/crates/agent/src/session_sqlite.rs b/crates/agent/src/session_sqlite.rs new file mode 100644 index 00000000..8ae8f803 --- /dev/null +++ b/crates/agent/src/session_sqlite.rs @@ -0,0 +1,357 @@ +//! SQLite-backed session store. +//! +//! Stores chat sessions and messages at `~/.apx/agent/db`, +//! following the same pattern as [`apx_db::LogsDb`]. + +use sqlx::Row; +use sqlx::sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteSynchronous, +}; +use std::path::Path; +use tracing::debug; + +use crate::chat::{ChatMessage, Role}; +use crate::error::{AgentError, Result}; +use crate::session::{Session, SessionStore}; + +/// SQLite-backed [`SessionStore`]. +#[derive(Debug, Clone)] +pub struct SqliteSessionStore { + pool: SqlitePool, +} + +impl SqliteSessionStore { + /// Open or create the session database at the default location (`~/.apx/agent/db`). + /// + /// # Errors + /// + /// Returns an error if the home directory cannot be determined, the + /// database directory cannot be created, or the database cannot be opened. + pub async fn open() -> Result { + let home = dirs::home_dir() + .ok_or_else(|| AgentError::Session("cannot determine home directory".into()))?; + let path = home.join(".apx").join("agent").join("db"); + Self::open_at(&path).await + } + + /// Open or create the session database at a specific path. + /// + /// # Errors + /// + /// Returns an error if the directory cannot be created, the database + /// cannot be opened, or schema initialization fails. + pub async fn open_at(path: &Path) -> Result { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .map_err(|e| AgentError::Session(format!("create directory: {e}")))?; + } + + let opts = SqliteConnectOptions::new() + .filename(path) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal) + .synchronous(SqliteSynchronous::Normal); + + let pool = SqlitePoolOptions::new() + .max_connections(5) + .connect_with(opts) + .await + .map_err(|e| AgentError::Session(format!("open database: {e}")))?; + + let store = Self { pool }; + store.init_schema().await?; + Ok(store) + } + + /// Initialize the database schema. + async fn init_schema(&self) -> Result<()> { + for sql in [ + "CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + model_name TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + )", + "CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL REFERENCES sessions(id), + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at INTEGER NOT NULL + )", + "CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id)", + ] { + sqlx::query(sql) + .execute(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("schema init: {e}")))?; + } + debug!("Agent session schema initialized"); + Ok(()) + } +} + +use crate::chat::now_secs; + +impl SessionStore for SqliteSessionStore { + async fn create_session(&self, model_name: &str) -> Result { + let id = nanoid::nanoid!(); + let now = now_secs(); + sqlx::query( + "INSERT INTO sessions (id, model_name, created_at, updated_at) VALUES (?, ?, ?, ?)", + ) + .bind(&id) + .bind(model_name) + .bind(now) + .bind(now) + .execute(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("create session: {e}")))?; + + Ok(Session { + id, + model_name: model_name.into(), + created_at: now, + updated_at: now, + }) + } + + async fn get_session(&self, id: &str) -> Result> { + let row = + sqlx::query("SELECT id, model_name, created_at, updated_at FROM sessions WHERE id = ?") + .bind(id) + .fetch_optional(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("get session: {e}")))?; + + Ok(row.map(|r| Session { + id: r.get("id"), + model_name: r.get("model_name"), + created_at: r.get("created_at"), + updated_at: r.get("updated_at"), + })) + } + + async fn list_sessions(&self) -> Result> { + let rows = sqlx::query( + "SELECT id, model_name, created_at, updated_at FROM sessions ORDER BY updated_at DESC", + ) + .fetch_all(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("list sessions: {e}")))?; + + Ok(rows + .iter() + .map(|r| Session { + id: r.get("id"), + model_name: r.get("model_name"), + created_at: r.get("created_at"), + updated_at: r.get("updated_at"), + }) + .collect()) + } + + async fn append_message(&self, session_id: &str, msg: &ChatMessage) -> Result<()> { + let role_str = msg.role.to_string(); + sqlx::query( + "INSERT INTO messages (session_id, role, content, created_at) VALUES (?, ?, ?, ?)", + ) + .bind(session_id) + .bind(&role_str) + .bind(&msg.content) + .bind(msg.timestamp) + .execute(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("append message: {e}")))?; + + sqlx::query("UPDATE sessions SET updated_at = ? WHERE id = ?") + .bind(msg.timestamp) + .bind(session_id) + .execute(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("update session: {e}")))?; + + Ok(()) + } + + async fn load_messages(&self, session_id: &str) -> Result> { + let rows = sqlx::query( + "SELECT role, content, created_at FROM messages WHERE session_id = ? ORDER BY id ASC", + ) + .bind(session_id) + .fetch_all(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("load messages: {e}")))?; + + rows.iter() + .map(|r| { + let role_str: String = r.get("role"); + let role: Role = role_str + .parse() + .map_err(|e: String| AgentError::Session(e))?; + Ok(ChatMessage { + role, + content: r.get("content"), + timestamp: r.get("created_at"), + }) + }) + .collect() + } +} + +#[cfg(test)] +// Reason: panicking on failure is idiomatic in tests +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + async fn temp_store() -> SqliteSessionStore { + let dir = std::env::temp_dir().join(format!( + "apx-agent-test-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + std::fs::create_dir_all(&dir).unwrap(); + SqliteSessionStore::open_at(&dir.join("test.db")) + .await + .unwrap() + } + + #[tokio::test] + async fn create_session_generates_unique_ids() { + let store = temp_store().await; + let s1 = store.create_session("model-a").await.unwrap(); + let s2 = store.create_session("model-a").await.unwrap(); + assert_ne!(s1.id, s2.id); + assert_eq!(s1.model_name, "model-a"); + } + + #[tokio::test] + async fn get_session_returns_none_for_missing() { + let store = temp_store().await; + let result = store.get_session("nonexistent").await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn append_and_load_messages_roundtrip() { + let store = temp_store().await; + let session = store.create_session("model-b").await.unwrap(); + + let user_msg = ChatMessage { + role: Role::User, + content: "hello".into(), + timestamp: 1000, + }; + store.append_message(&session.id, &user_msg).await.unwrap(); + + let asst_msg = ChatMessage { + role: Role::Assistant, + content: "hi there".into(), + timestamp: 1001, + }; + store.append_message(&session.id, &asst_msg).await.unwrap(); + + let messages = store.load_messages(&session.id).await.unwrap(); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, Role::User); + assert_eq!(messages[0].content, "hello"); + assert_eq!(messages[1].role, Role::Assistant); + assert_eq!(messages[1].content, "hi there"); + } + + #[tokio::test] + async fn list_sessions_ordered_by_recent() { + let store = temp_store().await; + let s1 = store.create_session("model-a").await.unwrap(); + let s2 = store.create_session("model-b").await.unwrap(); + + // Bump s1's updated_at to a far-future timestamp so it sorts first + let msg = ChatMessage { + role: Role::User, + content: "bump".into(), + timestamp: i64::MAX / 2, + }; + store.append_message(&s1.id, &msg).await.unwrap(); + + let sessions = store.list_sessions().await.unwrap(); + assert_eq!(sessions.len(), 2); + // s1 was bumped to far-future, so it comes first + assert_eq!(sessions[0].id, s1.id); + assert_eq!(sessions[1].id, s2.id); + } + + #[tokio::test] + async fn get_session_returns_some_for_existing() { + let store = temp_store().await; + let created = store.create_session("model-x").await.unwrap(); + let fetched = store.get_session(&created.id).await.unwrap(); + assert!(fetched.is_some()); + let fetched = fetched.unwrap(); + assert_eq!(fetched.id, created.id); + assert_eq!(fetched.model_name, "model-x"); + } + + #[tokio::test] + async fn append_message_updates_session_timestamp() { + let store = temp_store().await; + let session = store.create_session("model-y").await.unwrap(); + let original_updated = session.updated_at; + + let msg = ChatMessage { + role: Role::User, + content: "later".into(), + timestamp: original_updated + 100, + }; + store.append_message(&session.id, &msg).await.unwrap(); + + let refreshed = store.get_session(&session.id).await.unwrap().unwrap(); + assert_eq!(refreshed.updated_at, original_updated + 100); + } + + #[tokio::test] + async fn load_messages_empty_for_new_session() { + let store = temp_store().await; + let session = store.create_session("model-z").await.unwrap(); + let messages = store.load_messages(&session.id).await.unwrap(); + assert!(messages.is_empty()); + } + + #[tokio::test] + async fn load_messages_preserves_insertion_order() { + let store = temp_store().await; + let session = store.create_session("model-order").await.unwrap(); + + // Insert messages with out-of-order timestamps + let msgs = [ + ChatMessage { + role: Role::User, + content: "third-ts".into(), + timestamp: 3000, + }, + ChatMessage { + role: Role::Assistant, + content: "first-ts".into(), + timestamp: 1000, + }, + ChatMessage { + role: Role::User, + content: "second-ts".into(), + timestamp: 2000, + }, + ]; + for msg in &msgs { + store.append_message(&session.id, msg).await.unwrap(); + } + + let loaded = store.load_messages(&session.id).await.unwrap(); + assert_eq!(loaded.len(), 3); + // Order is by insertion (id ASC), not by timestamp + assert_eq!(loaded[0].content, "third-ts"); + assert_eq!(loaded[1].content, "first-ts"); + assert_eq!(loaded[2].content, "second-ts"); + } +} diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 12a06263..272aaac2 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -14,6 +14,7 @@ categories.workspace = true workspace = true [dependencies] +apx-agent.workspace = true apx-common.workspace = true apx-core = { path = "../core" , version = "0.3.8" } apx-db.workspace = true @@ -35,6 +36,9 @@ chrono.workspace = true reqwest.workspace = true toml.workspace = true toml_edit.workspace = true +ratatui.workspace = true +crossterm.workspace = true +futures-util.workspace = true [dev-dependencies] tempfile.workspace = true diff --git a/crates/cli/src/agent/chat.rs b/crates/cli/src/agent/chat.rs new file mode 100644 index 00000000..a50577e3 --- /dev/null +++ b/crates/cli/src/agent/chat.rs @@ -0,0 +1,98 @@ +//! `apx agent chat` — interactive chat with a Databricks-hosted model. + +use std::path::PathBuf; + +use apx_agent::{AgentClient, SessionStore, SqliteSessionStore}; +use apx_common::EnvProfile; +use clap::Args; +use dialoguer::{Select, theme::ColorfulTheme}; + +use crate::run_cli_async_helper; + +/// Arguments for the `agent chat` command. +#[derive(Args)] +pub struct ChatArgs { + /// Path to the app directory (used for .env profile resolution). + #[arg(value_name = "APP_PATH")] + pub app_path: Option, + + /// Databricks CLI profile name. + #[arg(short = 'p', long = "profile")] + pub profile: Option, + + /// Serving endpoint name to use as the model. + #[arg(short = 'm', long = "model")] + pub model: Option, +} + +/// Entry point for `apx agent chat`. +pub async fn run(args: ChatArgs) -> i32 { + run_cli_async_helper(|| run_inner(args)).await +} + +/// Resolve the Databricks profile from flag, env, .env, or default. +fn resolve_profile(args: &ChatArgs) -> String { + let dotenv_vars = args + .app_path + .as_ref() + .and_then(|dir| apx_core::dotenv::DotenvFile::read(&dir.join(".env")).ok()) + .map(|d| d.get_vars()) + .unwrap_or_default(); + EnvProfile::new(&dotenv_vars).retrieve(args.profile.as_deref()) +} + +async fn run_inner(args: ChatArgs) -> Result<(), String> { + let profile = resolve_profile(&args); + let client = AgentClient::from_profile(&profile) + .await + .map_err(|e| format!("Failed to create agent client: {e}"))?; + + let models = client + .list_models() + .await + .map_err(|e| format!("Failed to list models: {e}"))?; + + if models.is_empty() { + return Err("No chat-capable models found in the workspace".into()); + } + + let model_name = select_model(&args, &models)?; + + let store = SqliteSessionStore::open() + .await + .map_err(|e| format!("Failed to open session store: {e}"))?; + + let session = store + .create_session(&model_name) + .await + .map_err(|e| format!("Failed to create session: {e}"))?; + + super::tui::run(client, store, session, &model_name).await +} + +/// Select a model from the list or validate the explicit flag. +fn select_model(args: &ChatArgs, models: &[apx_agent::ModelRef]) -> Result { + if let Some(name) = &args.model { + if models.iter().any(|m| m.name == *name) { + return Ok(name.clone()); + } + return Err(format!( + "Model '{name}' not found. Available: {}", + models + .iter() + .map(|m| m.name.as_str()) + .collect::>() + .join(", ") + )); + } + + let items: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + let selection = Select::with_theme(&ColorfulTheme::default()) + .with_prompt("Select a model") + .items(&items) + .default(0) + .interact() + .map_err(|e| format!("Model selection failed: {e}"))?; + + Ok(models[selection].name.clone()) +} diff --git a/crates/cli/src/agent/mod.rs b/crates/cli/src/agent/mod.rs new file mode 100644 index 00000000..5b726611 --- /dev/null +++ b/crates/cli/src/agent/mod.rs @@ -0,0 +1,13 @@ +//! Agent commands — interactive chat with Databricks-hosted models. + +pub mod chat; +pub mod tui; + +use clap::Subcommand; + +/// Agent subcommands. +#[derive(Subcommand)] +pub enum AgentCommands { + /// Start an interactive chat session + Chat(chat::ChatArgs), +} diff --git a/crates/cli/src/agent/tui.rs b/crates/cli/src/agent/tui.rs new file mode 100644 index 00000000..c348ef1f --- /dev/null +++ b/crates/cli/src/agent/tui.rs @@ -0,0 +1,359 @@ +//! Ratatui-based TUI for interactive chat. + +use std::time::Duration; + +use apx_agent::{ + AgentClient, ChatEvent, ChatMessage, Role, Session, SessionStore, SqliteSessionStore, now_secs, +}; +use crossterm::event::{Event, EventStream, KeyCode, KeyEvent, KeyModifiers}; +use futures_util::StreamExt; +use ratatui::Frame; +use ratatui::layout::{Constraint, Layout}; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{Block, Borders, Paragraph, Wrap}; +use tokio::sync::mpsc; + +/// A message displayed in the chat area. +struct DisplayMessage { + role: Role, + content: String, +} + +/// Application state for the TUI. +struct App { + messages: Vec, + input: String, + scroll_offset: u16, + streaming: bool, + should_quit: bool, + model_name: String, + session: Session, +} + +impl App { + fn new(session: Session, model_name: &str) -> Self { + Self { + messages: Vec::new(), + input: String::new(), + scroll_offset: 0, + streaming: false, + should_quit: false, + model_name: model_name.into(), + session, + } + } +} + +/// Launch the TUI event loop. +/// +/// # Errors +/// +/// Returns an error if terminal setup fails or a fatal I/O error occurs. +pub async fn run( + client: AgentClient, + store: SqliteSessionStore, + session: Session, + model_name: &str, +) -> Result<(), String> { + let mut terminal = ratatui::init(); + + // Install a panic hook that restores the terminal before printing the panic. + let original_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |info| { + ratatui::restore(); + original_hook(info); + })); + + let result = run_event_loop(&mut terminal, client, store, session, model_name).await; + + ratatui::restore(); + result +} + +/// Core event loop: reads keys + chat events + tick timer. +async fn run_event_loop( + terminal: &mut ratatui::DefaultTerminal, + client: AgentClient, + store: SqliteSessionStore, + session: Session, + model_name: &str, +) -> Result<(), String> { + let mut app = App::new(session, model_name); + let (chat_tx, mut chat_rx) = mpsc::channel::(64); + let mut event_stream = EventStream::new(); + let mut tick = tokio::time::interval(Duration::from_millis(200)); + + loop { + terminal + .draw(|f| draw(f, &app)) + .map_err(|e| format!("draw error: {e}"))?; + + tokio::select! { + // Terminal events (keyboard input) + maybe_event = event_stream.next() => { + if let Some(Ok(Event::Key(key))) = maybe_event { + handle_key(&mut app, key, &client, &store, &chat_tx).await?; + } + } + // Chat stream events + Some(event) = chat_rx.recv() => { + handle_chat_event(&mut app, &store, event).await?; + } + // Tick for cursor blink / redraw + _ = tick.tick() => {} + } + + if app.should_quit { + break; + } + } + + Ok(()) +} + +/// Handle a keyboard event. +async fn handle_key( + app: &mut App, + key: KeyEvent, + client: &AgentClient, + store: &SqliteSessionStore, + chat_tx: &mpsc::Sender, +) -> Result<(), String> { + match (key.code, key.modifiers) { + (KeyCode::Char('c'), KeyModifiers::CONTROL) => { + app.should_quit = true; + } + (KeyCode::Enter, _) if !app.streaming && !app.input.is_empty() => { + send_message(app, client, store, chat_tx).await?; + } + (KeyCode::Char(c), _) if !app.streaming => { + app.input.push(c); + } + (KeyCode::Backspace, _) if !app.streaming => { + app.input.pop(); + } + (KeyCode::Up, _) => { + app.scroll_offset = app.scroll_offset.saturating_add(1); + } + (KeyCode::Down, _) => { + app.scroll_offset = app.scroll_offset.saturating_sub(1); + } + _ => {} + } + Ok(()) +} + +/// Send the current input as a user message and start streaming the response. +async fn send_message( + app: &mut App, + client: &AgentClient, + store: &SqliteSessionStore, + chat_tx: &mpsc::Sender, +) -> Result<(), String> { + let content = std::mem::take(&mut app.input); + let now = now_secs(); + + // Save and display user message + let user_msg = ChatMessage { + role: Role::User, + content: content.clone(), + timestamp: now, + }; + store + .append_message(&app.session.id, &user_msg) + .await + .map_err(|e| format!("save message: {e}"))?; + app.messages.push(DisplayMessage { + role: Role::User, + content, + }); + + // Add empty assistant placeholder + app.messages.push(DisplayMessage { + role: Role::Assistant, + content: String::new(), + }); + app.streaming = true; + app.scroll_offset = 0; + + // Build history from stored messages for context + let history = store + .load_messages(&app.session.id) + .await + .map_err(|e| format!("load history: {e}"))?; + + // Spawn background streaming task + let client = client.clone(); + let model = app.model_name.clone(); + let message = user_msg.content.clone(); + let tx = chat_tx.clone(); + tokio::spawn(async move { + match client.stream_chat(&model, &message, &history).await { + Ok(mut stream) => { + while let Some(item) = stream.next().await { + match item { + Ok(event) => { + if tx.send(event).await.is_err() { + break; + } + } + Err(e) => { + // Send the error as a Done with error text + let _ = tx.send(ChatEvent::Done(format!("[Error: {e}]"))).await; + break; + } + } + } + } + Err(e) => { + let _ = tx.send(ChatEvent::Done(format!("[Error: {e}]"))).await; + } + } + }); + + Ok(()) +} + +/// Handle a `ChatEvent` from the streaming task. +async fn handle_chat_event( + app: &mut App, + store: &SqliteSessionStore, + event: ChatEvent, +) -> Result<(), String> { + match event { + ChatEvent::Token(text) => { + if let Some(last) = app.messages.last_mut() { + last.content.push_str(&text); + } + } + ChatEvent::Done(full_text) => { + // Ensure the display message has the full text + if let Some(last) = app.messages.last_mut() { + last.content.clone_from(&full_text); + } + app.streaming = false; + + // Persist assistant message + let asst_msg = ChatMessage { + role: Role::Assistant, + content: full_text, + timestamp: now_secs(), + }; + store + .append_message(&app.session.id, &asst_msg) + .await + .map_err(|e| format!("save assistant message: {e}"))?; + } + } + Ok(()) +} + +/// Render the TUI. +fn draw(f: &mut Frame<'_>, app: &App) { + let chunks = Layout::vertical([ + Constraint::Length(1), // status bar + Constraint::Min(5), // messages area + Constraint::Length(3), // input area + ]) + .split(f.area()); + + draw_status_bar(f, app, chunks[0]); + draw_messages(f, app, chunks[1]); + draw_input(f, app, chunks[2]); +} + +/// Draw the top status bar. +fn draw_status_bar(f: &mut Frame<'_>, app: &App, area: ratatui::layout::Rect) { + let status = Line::from(vec![ + Span::styled( + " apx agent ", + Style::default().fg(Color::Black).bg(Color::Yellow), + ), + Span::raw(" "), + Span::styled(&app.model_name, Style::default().fg(Color::Cyan)), + Span::raw(" "), + Span::styled( + format!("Session: {}", truncate_id(&app.session.id)), + Style::default().fg(Color::DarkGray), + ), + if app.streaming { + Span::styled(" streaming...", Style::default().fg(Color::Green)) + } else { + Span::raw("") + }, + ]); + f.render_widget(Paragraph::new(status), area); +} + +/// Draw the messages area. +fn draw_messages(f: &mut Frame<'_>, app: &App, area: ratatui::layout::Rect) { + let mut lines: Vec> = Vec::new(); + for msg in &app.messages { + let (prefix, style) = match msg.role { + Role::User => ( + "You: ", + Style::default() + .fg(Color::Blue) + .add_modifier(Modifier::BOLD), + ), + Role::Assistant => ("AI: ", Style::default().fg(Color::Green)), + }; + + let prefix_span = Span::styled(prefix, style); + let content_lines: Vec<&str> = msg.content.split('\n').collect(); + for (i, line) in content_lines.iter().enumerate() { + if i == 0 { + lines.push(Line::from(vec![prefix_span.clone(), Span::raw(*line)])); + } else { + lines.push(Line::from(vec![Span::raw(" "), Span::raw(*line)])); + } + } + lines.push(Line::raw("")); + } + + let content_height = lines.len().saturating_sub(area.height as usize); + let scroll = if app.scroll_offset == 0 { + content_height as u16 + } else { + content_height.saturating_sub(app.scroll_offset as usize) as u16 + }; + + let block = Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_style(Style::default().fg(Color::DarkGray)); + let paragraph = Paragraph::new(lines) + .block(block) + .wrap(Wrap { trim: false }) + .scroll((scroll, 0)); + f.render_widget(paragraph, area); +} + +/// Draw the input area. +fn draw_input(f: &mut Frame<'_>, app: &App, area: ratatui::layout::Rect) { + let prompt = if app.streaming { + Span::styled(" ... ", Style::default().fg(Color::DarkGray)) + } else { + Span::styled(" > ", Style::default().fg(Color::Yellow)) + }; + let text = Line::from(vec![prompt, Span::raw(&app.input)]); + let block = Block::default() + .borders(Borders::ALL) + .border_style(Style::default().fg(Color::DarkGray)) + .title("Send (Enter) | Quit (Ctrl+C)"); + let paragraph = Paragraph::new(text).block(block); + f.render_widget(paragraph, area); + + // Place cursor at end of input + if !app.streaming { + #[allow(clippy::cast_possible_truncation)] + let cursor_x = area.x + 4 + app.input.len() as u16; + let cursor_y = area.y + 1; + f.set_cursor_position((cursor_x, cursor_y)); + } +} + +/// Truncate a session ID for display. +fn truncate_id(id: &str) -> &str { + if id.len() > 8 { &id[..8] } else { id } +} diff --git a/crates/cli/src/lib.rs b/crates/cli/src/lib.rs index a273aeb6..dce9232e 100644 --- a/crates/cli/src/lib.rs +++ b/crates/cli/src/lib.rs @@ -5,6 +5,7 @@ //! and more. pub(crate) mod __generate_openapi; +pub(crate) mod agent; pub(crate) mod build; pub(crate) mod bun; pub(crate) mod common; @@ -62,6 +63,9 @@ enum Commands { Feedback(feedback::FeedbackArgs), /// ℹ️ Show environment and version info Info(info::InfoArgs), + /// 🤖 Agent commands + #[command(subcommand)] + Agent(agent::AgentCommands), /// ⬆️ Upgrade apx to the latest version Upgrade, /// Internal: generate OpenAPI schema and client @@ -209,6 +213,9 @@ async fn run_command(args: Vec) -> i32 { Some(Commands::Skill(skill_cmd)) => match skill_cmd { SkillCommands::Install(args) => skill::install::run(args).await, }, + Some(Commands::Agent(agent_cmd)) => match agent_cmd { + agent::AgentCommands::Chat(args) => agent::chat::run(args).await, + }, Some(Commands::Feedback(args)) => feedback::run(args).await, Some(Commands::Info(args)) => info::run(args).await, Some(Commands::Upgrade) => upgrade::run().await, From 3e3bd17038fb8368248a61621a5cd742f3c25450 Mon Sep 17 00:00:00 2001 From: renardeinside Date: Tue, 3 Mar 2026 11:45:25 +0100 Subject: [PATCH 5/5] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20slash-command=20f?= =?UTF-8?q?ramework=20for=20agent=20TUI=20(/help,=20/model,=20/exit)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/agent/src/command.rs | 204 +++++++++++++++++++++++++ crates/agent/src/lib.rs | 3 + crates/agent/src/session.rs | 6 + crates/agent/src/session_sqlite.rs | 25 +++ crates/cli/src/agent/commands/mod.rs | 50 ++++++ crates/cli/src/agent/commands/model.rs | 38 +++++ crates/cli/src/agent/mod.rs | 1 + crates/cli/src/agent/tui.rs | 110 +++++++++++-- 8 files changed, 428 insertions(+), 9 deletions(-) create mode 100644 crates/agent/src/command.rs create mode 100644 crates/cli/src/agent/commands/mod.rs create mode 100644 crates/cli/src/agent/commands/model.rs diff --git a/crates/agent/src/command.rs b/crates/agent/src/command.rs new file mode 100644 index 00000000..c9707873 --- /dev/null +++ b/crates/agent/src/command.rs @@ -0,0 +1,204 @@ +//! Slash-command parsing for agent TUI input. +//! +//! This module provides pure parsing types with no I/O or terminal +//! dependencies. Execution lives in the CLI crate. + +use std::fmt; + +/// Normalized command name (lowercase, without the leading `/`). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandName(String); + +impl CommandName { + /// Parse a command name from a raw token (e.g. `"model"` or `"EXIT"`). + /// + /// Returns `None` if the input is empty. + #[must_use] + pub fn new(raw: &str) -> Option { + let trimmed = raw.trim(); + if trimmed.is_empty() { + return None; + } + Some(Self(trimmed.to_lowercase())) + } + + /// The normalized name as a string slice. + #[must_use] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for CommandName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +/// Positional arguments following a command name. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandArgs(Vec); + +impl CommandArgs { + /// Get the argument at position `n` (zero-based). + #[must_use] + pub fn get(&self, n: usize) -> Option<&str> { + self.0.get(n).map(String::as_str) + } + + /// Number of arguments. + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if there are no arguments. + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +/// Result of parsing raw user input. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ParsedInput { + /// Plain text to send to the model. + Message(String), + /// A slash command with optional arguments. + Command { + /// The normalized command name. + name: CommandName, + /// Positional arguments following the command. + args: CommandArgs, + }, +} + +/// Parse raw user input into a [`ParsedInput`]. +/// +/// Input starting with `/` followed by a non-whitespace character is +/// treated as a command; everything else is a plain message. +#[must_use] +pub fn parse_input(raw: &str) -> ParsedInput { + let trimmed = raw.trim(); + + // Must start with `/` and have at least one non-whitespace char after it + if let Some(rest) = trimmed.strip_prefix('/') + && !rest.is_empty() + && !rest.starts_with(char::is_whitespace) + { + let mut parts = rest.split_whitespace(); + if let Some(cmd) = parts.next() + && let Some(name) = CommandName::new(cmd) + { + let args: Vec = parts.map(String::from).collect(); + return ParsedInput::Command { + name, + args: CommandArgs(args), + }; + } + } + + ParsedInput::Message(trimmed.to_string()) +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn parse_command_no_args() { + let result = parse_input("/model"); + assert_eq!( + result, + ParsedInput::Command { + name: CommandName("model".into()), + args: CommandArgs(vec![]), + } + ); + } + + #[test] + fn parse_command_with_args() { + let result = parse_input("/model foo"); + assert_eq!( + result, + ParsedInput::Command { + name: CommandName("model".into()), + args: CommandArgs(vec!["foo".into()]), + } + ); + } + + #[test] + fn parse_exit_command() { + let result = parse_input("/exit"); + assert_eq!( + result, + ParsedInput::Command { + name: CommandName("exit".into()), + args: CommandArgs(vec![]), + } + ); + } + + #[test] + fn parse_case_insensitive() { + let result = parse_input("/EXIT"); + assert_eq!( + result, + ParsedInput::Command { + name: CommandName("exit".into()), + args: CommandArgs(vec![]), + } + ); + } + + #[test] + fn parse_slash_alone_is_message() { + let result = parse_input("/"); + assert_eq!(result, ParsedInput::Message("/".into())); + } + + #[test] + fn parse_plain_text() { + let result = parse_input("hello world"); + assert_eq!(result, ParsedInput::Message("hello world".into())); + } + + #[test] + fn parse_leading_trailing_whitespace() { + let result = parse_input(" /model "); + assert_eq!( + result, + ParsedInput::Command { + name: CommandName("model".into()), + args: CommandArgs(vec![]), + } + ); + } + + #[test] + fn command_name_display() { + let name = CommandName::new("Model").unwrap(); + assert_eq!(name.to_string(), "model"); + assert_eq!(name.as_str(), "model"); + } + + #[test] + fn command_args_accessors() { + let args = CommandArgs(vec!["a".into(), "b".into()]); + assert_eq!(args.len(), 2); + assert!(!args.is_empty()); + assert_eq!(args.get(0), Some("a")); + assert_eq!(args.get(1), Some("b")); + assert_eq!(args.get(2), None); + } + + #[test] + fn empty_command_args() { + let args = CommandArgs(vec![]); + assert!(args.is_empty()); + assert_eq!(args.len(), 0); + } +} diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs index 3654a797..5ed10348 100644 --- a/crates/agent/src/lib.rs +++ b/crates/agent/src/lib.rs @@ -7,6 +7,8 @@ pub mod chat; /// Agent client for model discovery and completions. pub mod client; +/// Slash-command parsing for TUI input. +pub mod command; /// Error types for the agent crate. pub mod error; /// Model reference types and filtering utilities. @@ -18,6 +20,7 @@ pub mod session_sqlite; pub use chat::{ChatEvent, ChatMessage, Role, now_secs}; pub use client::AgentClient; +pub use command::{CommandArgs, CommandName, ParsedInput, parse_input}; pub use error::{AgentError, Result}; pub use model::{ModelRef, chat_models}; pub use session::{Session, SessionStore}; diff --git a/crates/agent/src/session.rs b/crates/agent/src/session.rs index 8108928e..7a99d79f 100644 --- a/crates/agent/src/session.rs +++ b/crates/agent/src/session.rs @@ -39,4 +39,10 @@ pub trait SessionStore: Send + Sync { &self, session_id: &str, ) -> impl Future>> + Send; + /// Update the model name for an existing session. + fn update_model( + &self, + session_id: &str, + model_name: &str, + ) -> impl Future> + Send; } diff --git a/crates/agent/src/session_sqlite.rs b/crates/agent/src/session_sqlite.rs index 8ae8f803..5ff9bdc1 100644 --- a/crates/agent/src/session_sqlite.rs +++ b/crates/agent/src/session_sqlite.rs @@ -197,6 +197,18 @@ impl SessionStore for SqliteSessionStore { }) .collect() } + + async fn update_model(&self, session_id: &str, model_name: &str) -> Result<()> { + let now = now_secs(); + sqlx::query("UPDATE sessions SET model_name = ?, updated_at = ? WHERE id = ?") + .bind(model_name) + .bind(now) + .bind(session_id) + .execute(&self.pool) + .await + .map_err(|e| AgentError::Session(format!("update model: {e}")))?; + Ok(()) + } } #[cfg(test)] @@ -320,6 +332,19 @@ mod tests { assert!(messages.is_empty()); } + #[tokio::test] + async fn update_model_changes_model_name() { + let store = temp_store().await; + let session = store.create_session("old-model").await.unwrap(); + assert_eq!(session.model_name, "old-model"); + + store.update_model(&session.id, "new-model").await.unwrap(); + + let updated = store.get_session(&session.id).await.unwrap().unwrap(); + assert_eq!(updated.model_name, "new-model"); + assert!(updated.updated_at >= session.updated_at); + } + #[tokio::test] async fn load_messages_preserves_insertion_order() { let store = temp_store().await; diff --git a/crates/cli/src/agent/commands/mod.rs b/crates/cli/src/agent/commands/mod.rs new file mode 100644 index 00000000..3218d05c --- /dev/null +++ b/crates/cli/src/agent/commands/mod.rs @@ -0,0 +1,50 @@ +//! Slash-command dispatch and execution for the agent TUI. + +mod model; + +use apx_agent::{AgentClient, CommandArgs, CommandName}; + +/// Outcome of executing a slash command. +pub enum CommandOutcome { + /// User wants to quit the session. + Quit, + /// Model was changed to the given name. + ModelChanged(String), + /// Informational text to display. + Info(&'static str), + /// An error occurred while executing the command. + CommandError(String), +} + +/// Read-only references needed by command handlers. +pub struct CommandContext<'a> { + pub client: &'a AgentClient, +} + +/// Whether the given command requires suspending TUI raw mode before +/// execution (e.g. because it uses `dialoguer`). +pub fn needs_terminal_suspend(name: &CommandName) -> bool { + matches!(name.as_str(), "model") +} + +/// Dispatch a parsed command to the appropriate handler. +pub async fn dispatch( + name: &CommandName, + args: &CommandArgs, + ctx: CommandContext<'_>, +) -> CommandOutcome { + match name.as_str() { + "exit" | "quit" => CommandOutcome::Quit, + "model" => model::run(args, ctx).await, + "help" => CommandOutcome::Info(HELP_TEXT), + _ => CommandOutcome::CommandError(format!( + "Unknown command: /{name}. Type /help for available commands." + )), + } +} + +const HELP_TEXT: &str = "\ +Available commands: + /model — Switch to a different model + /help — Show this help message + /exit — Quit the chat session"; diff --git a/crates/cli/src/agent/commands/model.rs b/crates/cli/src/agent/commands/model.rs new file mode 100644 index 00000000..578430c7 --- /dev/null +++ b/crates/cli/src/agent/commands/model.rs @@ -0,0 +1,38 @@ +//! `/model` command — switch the active model mid-session. + +use apx_agent::CommandArgs; +use dialoguer::{Select, theme::ColorfulTheme}; + +use super::{CommandContext, CommandOutcome}; + +/// Run the `/model` command: list models and let the user pick one. +/// +/// Returns [`CommandOutcome::ModelChanged`] with the chosen name; the caller +/// is responsible for persisting the change and updating app state. +pub async fn run(args: &CommandArgs, ctx: CommandContext<'_>) -> CommandOutcome { + if let Some(explicit) = args.get(0) { + return CommandOutcome::ModelChanged(explicit.to_string()); + } + + let models = match ctx.client.list_models().await { + Ok(m) => m, + Err(e) => return CommandOutcome::CommandError(format!("Failed to list models: {e}")), + }; + + if models.is_empty() { + return CommandOutcome::CommandError("No chat-capable models found".into()); + } + + let items: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + let selection = match Select::with_theme(&ColorfulTheme::default()) + .with_prompt("Select a model") + .items(&items) + .default(0) + .interact() + { + Ok(i) => i, + Err(e) => return CommandOutcome::CommandError(format!("Selection cancelled: {e}")), + }; + + CommandOutcome::ModelChanged(models[selection].name.clone()) +} diff --git a/crates/cli/src/agent/mod.rs b/crates/cli/src/agent/mod.rs index 5b726611..a3c96604 100644 --- a/crates/cli/src/agent/mod.rs +++ b/crates/cli/src/agent/mod.rs @@ -1,6 +1,7 @@ //! Agent commands — interactive chat with Databricks-hosted models. pub mod chat; +mod commands; pub mod tui; use clap::Subcommand; diff --git a/crates/cli/src/agent/tui.rs b/crates/cli/src/agent/tui.rs index c348ef1f..a4a181db 100644 --- a/crates/cli/src/agent/tui.rs +++ b/crates/cli/src/agent/tui.rs @@ -3,8 +3,11 @@ use std::time::Duration; use apx_agent::{ - AgentClient, ChatEvent, ChatMessage, Role, Session, SessionStore, SqliteSessionStore, now_secs, + AgentClient, ChatEvent, ChatMessage, ParsedInput, Role, Session, SessionStore, + SqliteSessionStore, now_secs, parse_input, }; + +use super::commands::{self, CommandContext, CommandOutcome}; use crossterm::event::{Event, EventStream, KeyCode, KeyEvent, KeyModifiers}; use futures_util::StreamExt; use ratatui::Frame; @@ -14,9 +17,15 @@ use ratatui::text::{Line, Span}; use ratatui::widgets::{Block, Borders, Paragraph, Wrap}; use tokio::sync::mpsc; +/// The kind of message shown in the chat area. +enum MessageKind { + Chat(Role), + Info, +} + /// A message displayed in the chat area. struct DisplayMessage { - role: Role, + kind: MessageKind, content: String, } @@ -27,6 +36,7 @@ struct App { scroll_offset: u16, streaming: bool, should_quit: bool, + needs_reinit: bool, model_name: String, session: Session, } @@ -39,6 +49,7 @@ impl App { scroll_offset: 0, streaming: false, should_quit: false, + needs_reinit: false, model_name: model_name.into(), session, } @@ -104,6 +115,12 @@ async fn run_event_loop( _ = tick.tick() => {} } + // Re-enter TUI after a command that suspended raw mode (e.g. /model). + if app.needs_reinit { + *terminal = ratatui::init(); + app.needs_reinit = false; + } + if app.should_quit { break; } @@ -125,7 +142,7 @@ async fn handle_key( app.should_quit = true; } (KeyCode::Enter, _) if !app.streaming && !app.input.is_empty() => { - send_message(app, client, store, chat_tx).await?; + handle_enter(app, client, store, chat_tx).await?; } (KeyCode::Char(c), _) if !app.streaming => { app.input.push(c); @@ -144,6 +161,23 @@ async fn handle_key( Ok(()) } +/// Route user input to either the chat stream or the command dispatcher. +async fn handle_enter( + app: &mut App, + client: &AgentClient, + store: &SqliteSessionStore, + chat_tx: &mpsc::Sender, +) -> Result<(), String> { + match parse_input(&app.input) { + ParsedInput::Message(_) => send_message(app, client, store, chat_tx).await, + ParsedInput::Command { name, args } => { + app.input.clear(); + handle_command(app, &name, &args, client, store).await; + Ok(()) + } + } +} + /// Send the current input as a user message and start streaming the response. async fn send_message( app: &mut App, @@ -165,13 +199,13 @@ async fn send_message( .await .map_err(|e| format!("save message: {e}"))?; app.messages.push(DisplayMessage { - role: Role::User, + kind: MessageKind::Chat(Role::User), content, }); // Add empty assistant placeholder app.messages.push(DisplayMessage { - role: Role::Assistant, + kind: MessageKind::Chat(Role::Assistant), content: String::new(), }); app.streaming = true; @@ -249,6 +283,63 @@ async fn handle_chat_event( Ok(()) } +/// Handle a parsed slash command. +/// +/// Commands that need interactive prompts (e.g. `/model`) suspend raw mode +/// so that `dialoguer` can function, then signal the event loop to re-init +/// the terminal. +async fn handle_command( + app: &mut App, + name: &apx_agent::CommandName, + args: &apx_agent::CommandArgs, + client: &AgentClient, + store: &SqliteSessionStore, +) { + let suspended = commands::needs_terminal_suspend(name); + if suspended { + ratatui::restore(); + } + + let ctx = CommandContext { client }; + let outcome = commands::dispatch(name, args, ctx).await; + apply_outcome(app, outcome, store).await; + + app.needs_reinit = suspended; +} + +/// Apply a [`CommandOutcome`] to application state. +/// +/// All app mutations from command results happen here — handlers stay pure. +async fn apply_outcome(app: &mut App, outcome: CommandOutcome, store: &SqliteSessionStore) { + match outcome { + CommandOutcome::Quit => { + app.should_quit = true; + } + CommandOutcome::ModelChanged(name) => { + app.model_name.clone_from(&name); + // Best-effort persist; display the change even if the DB write fails. + let _ = store.update_model(&app.session.id, &name).await; + app.messages.push(DisplayMessage { + kind: MessageKind::Info, + content: format!("Model changed to {name}"), + }); + } + CommandOutcome::Info(text) => { + app.messages.push(DisplayMessage { + kind: MessageKind::Info, + content: text.to_string(), + }); + } + CommandOutcome::CommandError(text) => { + app.messages.push(DisplayMessage { + kind: MessageKind::Info, + content: format!("Error: {text}"), + }); + } + } + app.scroll_offset = 0; +} + /// Render the TUI. fn draw(f: &mut Frame<'_>, app: &App) { let chunks = Layout::vertical([ @@ -290,14 +381,15 @@ fn draw_status_bar(f: &mut Frame<'_>, app: &App, area: ratatui::layout::Rect) { fn draw_messages(f: &mut Frame<'_>, app: &App, area: ratatui::layout::Rect) { let mut lines: Vec> = Vec::new(); for msg in &app.messages { - let (prefix, style) = match msg.role { - Role::User => ( + let (prefix, style) = match &msg.kind { + MessageKind::Chat(Role::User) => ( "You: ", Style::default() .fg(Color::Blue) .add_modifier(Modifier::BOLD), ), - Role::Assistant => ("AI: ", Style::default().fg(Color::Green)), + MessageKind::Chat(Role::Assistant) => ("AI: ", Style::default().fg(Color::Green)), + MessageKind::Info => (" > ", Style::default().fg(Color::Yellow)), }; let prefix_span = Span::styled(prefix, style); @@ -340,7 +432,7 @@ fn draw_input(f: &mut Frame<'_>, app: &App, area: ratatui::layout::Rect) { let block = Block::default() .borders(Borders::ALL) .border_style(Style::default().fg(Color::DarkGray)) - .title("Send (Enter) | Quit (Ctrl+C)"); + .title("Send (Enter) | /help | Quit (Ctrl+C)"); let paragraph = Paragraph::new(text).block(block); f.render_widget(paragraph, area);