diff --git a/Cargo.toml b/Cargo.toml index e2e32fd..666cbf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,21 +14,12 @@ name = "harbr-router" path = "src/main.rs" [dependencies] -tokio = { version = "1.28", features = ["full"] } -hyper = { version = "0.14", features = ["full"] } -tower = "0.4" -tower-http = { version = "0.4", features = ["trace"] } -serde = { version = "1.0", features = ["derive"] } -serde_yaml = "0.9" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -metrics = "0.21" -metrics-exporter-prometheus = "0.12" -anyhow = "1.0" -futures-util = "0.3" -dashmap = "5.4" -bytes = "1.4" -warp = "0.3.7" -uuid = { version = "1.2", features = ["v4"] } -clap = { version = "4.0", features = ["derive"] } -reqwest = { version = "0.12.15", features = ["json"] } \ No newline at end of file +reqwest = { version = "0.12.18", features = ["json", "stream"] } +tokio = { version = "1.45.1", features = ["full"] } +serde = { version = "1.0.219", features = ["derive"] } +async-trait = "0.1.80" +serde_json = "1.0.109" +thiserror = "2.0.12" +anyhow = "1.0.98" +tracing = { version = "0.1.40", features = ["std"] } +warp = { version = "0.3.4" } \ No newline at end of file diff --git a/src/client.rs b/src copy/client.rs similarity index 100% rename from src/client.rs rename to src copy/client.rs diff --git a/src copy/config.rs b/src copy/config.rs new file mode 100644 index 0000000..03fe8ee --- /dev/null +++ b/src copy/config.rs @@ -0,0 +1,305 @@ +// src/config.rs +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ProxyConfig { + pub listen_addr: String, + pub routes: HashMap, + pub global_timeout_ms: u64, + pub max_connections: usize, + + // New TCP proxy specific configuration + #[serde(default)] + pub tcp_proxy: TcpProxyConfig, +} + +impl ProxyConfig { + pub fn new(listen_addr: &str, global_timeout_ms: u64, max_connections: usize) -> Self { + Self { + listen_addr: listen_addr.to_string(), + routes: HashMap::new(), + global_timeout_ms, + max_connections, + tcp_proxy: TcpProxyConfig::default(), + } + } + + pub fn with_route(mut self, name: &str, route: RouteConfig) -> Self { + self.routes.insert(name.to_string(), route); + self + } + + pub fn with_tcp_proxy(mut self, tcp_proxy: TcpProxyConfig) -> Self { + self.tcp_proxy = tcp_proxy; + self + } + + pub fn enable_tcp_proxy(mut self, enabled: bool) -> Self { + self.tcp_proxy.enabled = enabled; + self + } + + pub fn tcp_listen_addr(mut self, addr: &str) -> Self { + self.tcp_proxy.listen_addr = addr.to_string(); + self + } + + pub fn enable_udp_proxy(mut self, enabled: bool) -> Self { + self.tcp_proxy.udp_enabled = enabled; + self + } + + pub fn udp_listen_addr(mut self, addr: &str) -> Self { + self.tcp_proxy.udp_listen_addr = addr.to_string(); + self + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct TcpProxyConfig { + #[serde(default = "default_tcp_enabled")] + pub enabled: bool, + #[serde(default = "default_tcp_listen_addr")] + pub listen_addr: String, + #[serde(default = "default_tcp_connection_pooling")] + pub connection_pooling: bool, + #[serde(default = "default_tcp_max_idle_time_secs")] + pub max_idle_time_secs: u64, + #[serde(default = "default_udp_enabled")] + pub udp_enabled: bool, + #[serde(default = "default_udp_listen_addr")] + pub udp_listen_addr: String, +} + +impl TcpProxyConfig { + pub fn new() -> Self { + Self::default() + } + + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + pub fn with_listen_addr(mut self, addr: &str) -> Self { + self.listen_addr = addr.to_string(); + self + } + + pub fn with_connection_pooling(mut self, enabled: bool) -> Self { + self.connection_pooling = enabled; + self + } + + pub fn with_max_idle_time(mut self, secs: u64) -> Self { + self.max_idle_time_secs = secs; + self + } + + pub fn with_udp_enabled(mut self, enabled: bool) -> Self { + self.udp_enabled = enabled; + self + } + + pub fn with_udp_listen_addr(mut self, addr: &str) -> Self { + self.udp_listen_addr = addr.to_string(); + self + } +} + +fn default_tcp_enabled() -> bool { + false +} + +fn default_tcp_listen_addr() -> String { + "0.0.0.0:9090".to_string() +} + +fn default_tcp_connection_pooling() -> bool { + true +} + +fn default_tcp_max_idle_time_secs() -> u64 { + 60 +} + +fn default_udp_enabled() -> bool { + false +} + +fn default_udp_listen_addr() -> String { + "0.0.0.0:9090".to_string() // Same port as TCP by default +} + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct RouteConfig { + pub upstream: String, + pub timeout_ms: Option, + pub retry_count: Option, + #[serde(default)] + pub priority: Option, + pub preserve_host_header: Option, + + // TCP-specific configuration + #[serde(default = "default_is_tcp")] + pub is_tcp: bool, + #[serde(default = "default_tcp_port")] + pub tcp_listen_port: Option, + + // UDP-specific configuration + #[serde(default = "default_is_udp")] + pub is_udp: Option, + #[serde(default = "default_udp_port")] + pub udp_listen_port: Option, + + // Database-specific configuration + #[serde(default = "default_db_type")] + pub db_type: Option, +} + +impl RouteConfig { + pub fn new(upstream: &str) -> Self { + Self { + upstream: upstream.to_string(), + timeout_ms: None, + retry_count: None, + priority: None, + preserve_host_header: None, + is_tcp: false, + tcp_listen_port: None, + is_udp: None, + udp_listen_port: None, + db_type: None, + } + } + + pub fn with_timeout(mut self, timeout_ms: u64) -> Self { + self.timeout_ms = Some(timeout_ms); + self + } + + pub fn with_retry_count(mut self, count: u32) -> Self { + self.retry_count = Some(count); + self + } + + pub fn with_priority(mut self, priority: i32) -> Self { + self.priority = Some(priority); + self + } + + pub fn preserve_host_header(mut self, preserve: bool) -> Self { + self.preserve_host_header = Some(preserve); + self + } + + pub fn as_tcp(mut self, is_tcp: bool) -> Self { + self.is_tcp = is_tcp; + self + } + + pub fn with_tcp_listen_port(mut self, port: u16) -> Self { + self.tcp_listen_port = Some(port); + self + } + + pub fn as_udp(mut self, is_udp: bool) -> Self { + self.is_udp = Some(is_udp); + self + } + + pub fn with_udp_listen_port(mut self, port: u16) -> Self { + self.udp_listen_port = Some(port); + self + } + + pub fn with_db_type(mut self, db_type: &str) -> Self { + self.db_type = Some(db_type.to_string()); + self + } +} + +fn default_is_tcp() -> bool { + false +} + +fn default_tcp_port() -> Option { + None +} + +fn default_is_udp() -> Option { + Some(false) +} + +fn default_udp_port() -> Option { + None +} + +fn default_db_type() -> Option { + None +} + +pub fn load_config(path: &str) -> Result { + let content = fs::read_to_string(path)?; + let config: ProxyConfig = serde_yaml::from_str(&content)?; + Ok(config) +} + +// Helper function to detect if a route is likely a database +pub fn is_likely_database(route: &RouteConfig) -> bool { + // Check if explicitly marked as TCP + if route.is_tcp { + return true; + } + + // Check if db_type is specified + if route.db_type.is_some() { + return true; + } + + // Basic heuristics for common database port detection + if let Some(port) = extract_port(&route.upstream) { + match port { + 3306 | 33060 => true, // MySQL + 5432 => true, // PostgreSQL + 27017 | 27018 | 27019 => true, // MongoDB + 6379 => true, // Redis + 1521 => true, // Oracle + 1433 => true, // SQL Server + 9042 => true, // Cassandra + 5984 => true, // CouchDB + 8086 => true, // InfluxDB + 9200 | 9300 => true, // Elasticsearch + _ => false, + } + } else { + // Check for database prefixes in the upstream URL + let upstream = route.upstream.to_lowercase(); + upstream.starts_with("mysql://") + || upstream.starts_with("postgresql://") + || upstream.starts_with("mongodb://") + || upstream.starts_with("redis://") + || upstream.starts_with("oracle://") + || upstream.starts_with("sqlserver://") + || upstream.starts_with("cassandra://") + || upstream.starts_with("couchdb://") + || upstream.starts_with("influxdb://") + || upstream.starts_with("elasticsearch://") + } +} + +// Helper function to extract port from a URL +fn extract_port(url: &str) -> Option { + // Parse out protocol + let url_without_protocol = url.split("://").nth(1).unwrap_or(url); + + // Extract host:port part + let host_port = url_without_protocol.split('/').next()?; + + // Extract port + let port_str = host_port.split(':').nth(1)?; + port_str.parse::().ok() +} \ No newline at end of file diff --git a/src/config_api.rs b/src copy/config_api.rs similarity index 100% rename from src/config_api.rs rename to src copy/config_api.rs diff --git a/src/dynamic_config.rs b/src copy/dynamic_config.rs similarity index 100% rename from src/dynamic_config.rs rename to src copy/dynamic_config.rs diff --git a/src/http_proxy.rs b/src copy/http_proxy.rs similarity index 100% rename from src/http_proxy.rs rename to src copy/http_proxy.rs diff --git a/src copy/lib.rs b/src copy/lib.rs new file mode 100644 index 0000000..76402a1 --- /dev/null +++ b/src copy/lib.rs @@ -0,0 +1,261 @@ +// src/lib.rs +use anyhow::Result; +use std::sync::Arc; +use tokio::sync::{broadcast, mpsc, RwLock}; + +pub mod client; +pub mod config; +pub mod metrics; +pub mod http_proxy; +pub mod tcp_proxy; +pub mod udp_proxy; +pub mod dynamic_config; +pub mod config_api; + +/// The main Router struct that manages all proxy services +pub struct Router { + config_manager: Arc, + shutdown_tx: Option>, +} + +impl Router { + /// Create a new Router with the provided configuration manager + pub fn new_with_manager(config_manager: Arc) -> Self { + Router { + config_manager, + shutdown_tx: None, + } + } + + /// Create a new Router with the provided configuration + pub fn new(config: config::ProxyConfig) -> Self { + let config_manager = Arc::new(DynamicConfigManager::new(config)); + Self::new_with_manager(config_manager) + } + + /// Create a new Router by loading configuration from a file + pub async fn from_file(config_path: &str) -> Result { + let config_manager = DynamicConfigManager::from_file(config_path).await?; + Ok(Router::new_with_manager(Arc::new(config_manager))) + } + + /// Get the configuration manager + pub fn config_manager(&self) -> Arc { + self.config_manager.clone() + } + + /// Start the router service with all enabled proxies + pub async fn start(&mut self) -> Result<()> { + // Initialize metrics + metrics::init_metrics()?; + + // Create a shutdown channel + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + self.shutdown_tx = Some(shutdown_tx); + + // Set up config change listener + let mut config_rx = self.config_manager.subscribe(); + let config_manager = self.config_manager.clone(); + + // Initial configuration + let initial_config = self.config_manager.get_config().read().await.clone(); + + // Start the HTTP proxy server + let http_config = self.config_manager.get_config().clone(); + let http_handle = tokio::spawn(async move { + if let Err(e) = http_proxy::run_server(http_config).await { + tracing::error!("HTTP Server error: {}", e); + } + }); + + // Check for database routes that should be handled as TCP + let has_db_routes = initial_config.routes.iter().any(|(_, route)| { + config::is_likely_database(route) + }); + + // Check for UDP routes + let has_udp_routes = initial_config.routes.iter().any(|(_, route)| { + route.is_udp.unwrap_or(false) + }); + + // Start TCP proxy if enabled or if database routes are detected + let tcp_handle = if initial_config.tcp_proxy.enabled || has_db_routes { + tracing::info!("TCP proxy support enabled"); + let tcp_config = self.config_manager.get_config().clone(); + + let handle = tokio::spawn(async move { + let tcp_proxy = tcp_proxy::TcpProxyServer::new(tcp_config).await; + if let Err(e) = tcp_proxy.run(&initial_config.tcp_proxy.listen_addr).await { + tracing::error!("TCP proxy server error: {}", e); + } + }); + Some(handle) + } else { + None + }; + + // Start UDP proxy if enabled or if UDP routes are detected + let udp_handle = if initial_config.tcp_proxy.udp_enabled || has_udp_routes { + tracing::info!("UDP proxy support enabled"); + let udp_config = self.config_manager.get_config().clone(); + + // Use the same address as TCP proxy by default + let udp_listen_addr = initial_config.tcp_proxy.udp_listen_addr.clone(); + + let handle = tokio::spawn(async move { + let udp_proxy = udp_proxy::UdpProxyServer::new(udp_config); + if let Err(e) = udp_proxy.run(&udp_listen_addr).await { + tracing::error!("UDP proxy server error: {}", e); + } + }); + Some(handle) + } else { + None + }; + + // Set up the config API + let api_handle = { + let api_routes = config_api::config_api_routes(self.config_manager.clone()); + + // Create a separate server for the API on a different port + let api_port = 8082; // Could be configurable + let api_addr = format!("0.0.0.0:{}", api_port); + + tracing::info!("Starting configuration API server on {}", api_addr); + + tokio::spawn(async move { + warp::serve(api_routes) + .run(api_addr.parse::().unwrap()) + .await; + }) + }; + + // Listen for configuration changes and handle them + let config_change_handle = tokio::spawn(async move { + tracing::info!("Starting configuration change listener"); + + loop { + tokio::select! { + // Handle shutdown signal + _ = shutdown_rx.recv() => { + tracing::info!("Received shutdown signal, stopping config listener"); + break; + } + + // Handle configuration changes + result = config_rx.recv() => { + match result { + Ok(event) => { + tracing::info!("Received configuration change event: {:?}", event); + + match event { + ConfigEvent::RouteAdded(name, config) => { + tracing::info!("Route added: {}", name); + // No need to restart services, they'll pick up the change + } + ConfigEvent::RouteUpdated(name, config) => { + tracing::info!("Route updated: {}", name); + // No need to restart services, they'll pick up the change + } + ConfigEvent::RouteRemoved(name) => { + tracing::info!("Route removed: {}", name); + // No need to restart services, they'll pick up the change + } + ConfigEvent::TcpConfigUpdated(tcp_config) => { + tracing::warn!("TCP configuration updated - some changes may require restart"); + // Here we could potentially restart TCP services if needed + } + ConfigEvent::GlobalSettingsUpdated { .. } => { + tracing::warn!("Global settings updated - some changes may require restart"); + // Here we could potentially restart services if needed + } + ConfigEvent::FullUpdate(_) => { + tracing::warn!("Full configuration replaced - some changes may require restart"); + // Here we could potentially restart all services if needed + } + } + } + Err(e) => { + if matches!(e, tokio::sync::broadcast::error::RecvError::Lagged(_)) { + tracing::warn!("Config listener lagged and missed messages"); + } else { + tracing::error!("Error receiving config change: {}", e); + break; + } + } + } + } + } + } + }); + + // TODO: Add this + // // Optional: start file watcher if config was loaded from file + // if let Some(path) = config_manager.file_path() { + // if let Err(e) = config_manager.start_file_watcher(30).await { + // tracing::error!("Failed to start file watcher: {}", e); + // } + // } + + // Wait for Ctrl+C or other shutdown signal + tokio::signal::ctrl_c().await?; + tracing::info!("Received shutdown signal"); + + // Attempt graceful shutdown + if let Some(tx) = &self.shutdown_tx { + let _ = tx.send(()).await; + } + + Ok(()) + } + + /// Manually trigger a configuration reload from file + pub async fn reload_config(&self) -> Result<()> { + self.config_manager.reload_from_file().await + } + + /// Get the current configuration + pub async fn get_config(&self) -> config::ProxyConfig { + self.config_manager.get_config().read().await.clone() + } + + /// Update a specific route + pub async fn update_route(&self, route_name: &str, route_config: config::RouteConfig) -> Result<()> { + self.config_manager.update_route(route_name, route_config).await + } + + /// Add a new route + pub async fn add_route(&self, route_name: &str, route_config: config::RouteConfig) -> Result<()> { + self.config_manager.add_route(route_name, route_config).await + } + + /// Remove a route + pub async fn remove_route(&self, route_name: &str) -> Result<()> { + self.config_manager.remove_route(route_name).await + } + + /// Update TCP proxy configuration + pub async fn update_tcp_config(&self, tcp_config: config::TcpProxyConfig) -> Result<()> { + self.config_manager.update_tcp_config(tcp_config).await + } + + /// Update global settings + pub async fn update_global_settings( + &self, + listen_addr: Option, + global_timeout_ms: Option, + max_connections: Option, + ) -> Result<()> { + self.config_manager.update_global_settings(listen_addr, global_timeout_ms, max_connections).await + } + + /// Replace the entire configuration + pub async fn replace_config(&self, new_config: config::ProxyConfig) -> Result<()> { + self.config_manager.replace_config(new_config).await + } +} + +// Re-export types for easier usage +pub use config::{ProxyConfig, RouteConfig, TcpProxyConfig}; +pub use dynamic_config::{DynamicConfigManager, ConfigEvent}; +pub use client::ConfigClient; \ No newline at end of file diff --git a/src/main.rs b/src copy/main.rs similarity index 100% rename from src/main.rs rename to src copy/main.rs diff --git a/src copy/metrics.rs b/src copy/metrics.rs new file mode 100644 index 0000000..b25c75d --- /dev/null +++ b/src copy/metrics.rs @@ -0,0 +1,8 @@ +use anyhow::Result; +use metrics_exporter_prometheus::PrometheusBuilder; + +pub fn init_metrics() -> Result<()> { + let builder = PrometheusBuilder::new(); + builder.install()?; + Ok(()) +} diff --git a/src/tcp_proxy.rs b/src copy/tcp_proxy.rs similarity index 100% rename from src/tcp_proxy.rs rename to src copy/tcp_proxy.rs diff --git a/src/udp_proxy.rs b/src copy/udp_proxy.rs similarity index 100% rename from src/udp_proxy.rs rename to src copy/udp_proxy.rs diff --git a/src/builtin_plugins/mod.rs b/src/builtin_plugins/mod.rs new file mode 100644 index 0000000..799766b --- /dev/null +++ b/src/builtin_plugins/mod.rs @@ -0,0 +1,516 @@ +// src/builtin_plugins/mod.rs - Built-in plugin implementations +use crate::plugin::*; +use crate::error::{RouterError, Result}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use reqwest::Client; +use std::time::{Duration, Instant}; +use tokio::net::{TcpStream, UdpSocket}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +pub mod tcp_proxy; +// pub mod http_proxy; +// pub mod udp_proxy; +// pub mod load_balancer; +// pub mod static_files; + +pub use tcp_proxy::TcpProxyPlugin; +// pub use http_proxy::HttpProxyPlugin; +// pub use udp_proxy::UdpProxyPlugin; +// pub use load_balancer::LoadBalancerPlugin; +// pub use static_files::StaticFilesPlugin; + +/// HTTP Proxy Plugin - handles HTTP/HTTPS requests +#[derive(Debug)] +pub struct HttpProxyPlugin { + name: String, + config: Option, + client: Option, + metrics: Arc>, + health: Arc>, + backends: Arc>>, + circuit_breaker: Arc>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HttpProxyConfig { + pub default_backend: String, + pub timeout_ms: u64, + pub retry_count: u32, + pub preserve_host_header: bool, + pub follow_redirects: bool, + pub max_redirects: u32, + pub buffer_size: usize, + pub connection_pool_size: usize, + pub connection_timeout_ms: u64, + pub read_timeout_ms: u64, + pub write_timeout_ms: u64, + pub enable_compression: bool, + pub custom_headers: HashMap, + pub remove_headers: Vec, + pub circuit_breaker: CircuitBreakerConfig, + pub retry_strategy: RetryStrategy, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CircuitBreakerConfig { + pub enabled: bool, + pub failure_threshold: u32, + pub recovery_timeout_ms: u64, + pub half_open_max_calls: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RetryStrategy { + Fixed { delay_ms: u64 }, + Exponential { base_delay_ms: u64, max_delay_ms: u64 }, + Linear { initial_delay_ms: u64, increment_ms: u64 }, +} + +#[derive(Debug, Clone)] +pub struct BackendInfo { + pub url: String, + pub healthy: bool, + pub last_check: Instant, + pub response_time_ms: u64, + pub error_count: u32, +} + +#[derive(Debug, Clone)] +pub enum CircuitBreakerState { + Closed, + Open { opened_at: Instant }, + HalfOpen { calls_made: u32 }, +} + +impl HttpProxyPlugin { + pub fn new() -> Self { + Self { + name: "http_proxy".to_string(), + config: None, + client: None, + metrics: Arc::new(RwLock::new(PluginMetrics { + connections_total: 0, + connections_active: 0, + bytes_sent: 0, + bytes_received: 0, + errors_total: 0, + custom_metrics: HashMap::new(), + last_updated: std::time::SystemTime::now(), + })), + health: Arc::new(RwLock::new(PluginHealth { + healthy: false, + message: "Not initialized".to_string(), + last_check: std::time::SystemTime::now(), + response_time_ms: None, + custom_health_data: HashMap::new(), + })), + backends: Arc::new(RwLock::new(Vec::new())), + circuit_breaker: Arc::new(RwLock::new(CircuitBreakerState::Closed)), + } + } + + async fn proxy_request(&self, metadata: RequestMetadata, body: Vec) -> Result { + let config = self.config.as_ref().ok_or_else(|| { + RouterError::PluginError { + plugin: self.name.clone(), + message: "Plugin not initialized".to_string(), + } + })?; + + let client = self.client.as_ref().ok_or_else(|| { + RouterError::PluginError { + plugin: self.name.clone(), + message: "HTTP client not initialized".to_string(), + } + })?; + + // Check circuit breaker + if !self.is_circuit_breaker_closed().await { + return Err(RouterError::CircuitBreakerOpen { + service: self.name.clone(), + }); + } + + let start_time = Instant::now(); + + // Build target URL + let target_url = if let Some(upstream) = &metadata.upstream_url { + upstream.clone() + } else { + format!("{}{}", config.default_backend, metadata.path) + }; + + // Parse URL and add query parameters + let mut url = reqwest::Url::parse(&target_url) + .map_err(|e| RouterError::UpstreamError { + message: format!("Invalid upstream URL: {}", e), + })?; + + // Add query parameters + for (key, value) in &metadata.query_params { + url.query_pairs_mut().append_pair(key, value); + } + + // Create request + let method = reqwest::Method::from_bytes(metadata.method.as_bytes()) + .map_err(|e| RouterError::RequestError { + message: format!("Invalid HTTP method: {}", e), + })?; + + let mut request_builder = client.request(method, url) + .timeout(Duration::from_millis(config.timeout_ms)) + .body(body); + + // Add headers + for (key, value) in &metadata.headers { + if !config.remove_headers.contains(key) { + request_builder = request_builder.header(key, value); + } + } + + // Add custom headers + for (key, value) in &config.custom_headers { + request_builder = request_builder.header(key, value); + } + + // Handle host header + if !config.preserve_host_header { + if let Some(host) = target_url.split("://").nth(1).and_then(|s| s.split('/').next()) { + request_builder = request_builder.header("Host", host); + } + } + + // Execute request with retries + let mut last_error = None; + for attempt in 0..=config.retry_count { + match request_builder.try_clone().unwrap().send().await { + Ok(response) => { + let status_code = response.status().as_u16(); + let response_headers: HashMap = response.headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let response_body = response.bytes().await + .map_err(|e| RouterError::UpstreamError { + message: format!("Failed to read response body: {}", e), + })? + .to_vec(); + + // Update metrics + let duration = start_time.elapsed(); + self.update_metrics(response_body.len(), duration).await; + + // Update circuit breaker on success + self.record_success().await; + + return Ok(ProxyResponse { + status_code, + headers: response_headers, + body: response_body, + metadata: HashMap::new(), + upstream_response_time_ms: Some(duration.as_millis() as u64), + cache_headers: None, + }); + } + Err(e) => { + last_error = Some(e); + + if attempt < config.retry_count { + // Apply retry strategy + let delay = self.calculate_retry_delay(&config.retry_strategy, attempt).await; + tokio::time::sleep(delay).await; + } + } + } + } + + // Record failure + self.record_failure().await; + + Err(RouterError::UpstreamError { + message: format!("All retry attempts failed: {}", + last_error.unwrap_or_else(|| reqwest::Error::from(reqwest::ErrorKind::Request))), + }) + } + + async fn is_circuit_breaker_closed(&self) -> bool { + let state = self.circuit_breaker.read().await; + match *state { + CircuitBreakerState::Closed => true, + CircuitBreakerState::Open { opened_at } => { + let config = self.config.as_ref().unwrap(); + if opened_at.elapsed().as_millis() as u64 > config.circuit_breaker.recovery_timeout_ms { + // Transition to half-open + drop(state); + *self.circuit_breaker.write().await = CircuitBreakerState::HalfOpen { calls_made: 0 }; + true + } else { + false + } + } + CircuitBreakerState::HalfOpen { calls_made } => { + let config = self.config.as_ref().unwrap(); + calls_made < config.circuit_breaker.half_open_max_calls + } + } + } + + async fn record_success(&self) { + let mut state = self.circuit_breaker.write().await; + match *state { + CircuitBreakerState::HalfOpen { .. } => { + *state = CircuitBreakerState::Closed; + } + _ => {} + } + } + + async fn record_failure(&self) { + let config = self.config.as_ref().unwrap(); + if !config.circuit_breaker.enabled { + return; + } + + let mut state = self.circuit_breaker.write().await; + match *state { + CircuitBreakerState::Closed => { + // Track failures and potentially open circuit + // For simplicity, open immediately - in practice you'd track failure count + *state = CircuitBreakerState::Open { opened_at: Instant::now() }; + } + CircuitBreakerState::HalfOpen { .. } => { + *state = CircuitBreakerState::Open { opened_at: Instant::now() }; + } + _ => {} + } + } + + async fn calculate_retry_delay(&self, strategy: &RetryStrategy, attempt: u32) -> Duration { + match strategy { + RetryStrategy::Fixed { delay_ms } => Duration::from_millis(*delay_ms), + RetryStrategy::Exponential { base_delay_ms, max_delay_ms } => { + let delay = base_delay_ms * 2_u64.pow(attempt); + Duration::from_millis(delay.min(*max_delay_ms)) + } + RetryStrategy::Linear { initial_delay_ms, increment_ms } => { + Duration::from_millis(initial_delay_ms + (increment_ms * attempt as u64)) + } + } + } + + async fn update_metrics(&self, bytes_sent: usize, duration: Duration) { + let mut metrics = self.metrics.write().await; + metrics.connections_total += 1; + metrics.bytes_sent += bytes_sent as u64; + metrics.last_updated = std::time::SystemTime::now(); + + // Update custom metrics + metrics.custom_metrics.insert( + "avg_response_time_ms".to_string(), + duration.as_millis() as f64, + ); + } +} + +#[async_trait] +impl Plugin for HttpProxyPlugin { + fn info(&self) -> PluginInfo { + PluginInfo { + name: self.name.clone(), + version: "1.0.0".to_string(), + description: "HTTP reverse proxy with load balancing and circuit breaker".to_string(), + author: "Router Team".to_string(), + license: "MIT".to_string(), + repository: Some("https://github.com/router/plugins".to_string()), + min_router_version: "2.0.0".to_string(), + config_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "default_backend": {"type": "string"}, + "timeout_ms": {"type": "number"}, + "retry_count": {"type": "number"} + }, + "required": ["default_backend"] + })), + dependencies: vec![], + tags: vec!["http", "proxy", "load-balancer"].iter().map(|s| s.to_string()).collect(), + } + } + + fn capabilities(&self) -> PluginCapabilities { + PluginCapabilities { + can_handle_tcp: true, + can_handle_udp: false, + can_handle_unix_socket: false, + requires_dedicated_port: false, + supports_hot_reload: true, + supports_load_balancing: true, + custom_protocols: vec!["HTTP".to_string(), "HTTPS".to_string()], + port_requirements: vec![], + } + } + + async fn initialize(&mut self, config: PluginConfig, _event_sender: PluginSender) -> Result<()> { + let http_config: HttpProxyConfig = serde_json::from_value(config.config) + .map_err(|e| RouterError::ConfigError { + message: format!("Invalid HTTP proxy configuration: {}", e), + })?; + + // Create HTTP client + let client = Client::builder() + .timeout(Duration::from_millis(http_config.timeout_ms)) + .redirect(if http_config.follow_redirects { + reqwest::redirect::Policy::limited(http_config.max_redirects as usize) + } else { + reqwest::redirect::Policy::none() + }) + .build() + .map_err(|e| RouterError::PluginInitError { + plugin: self.name.clone(), + reason: format!("Failed to create HTTP client: {}", e), + })?; + + self.config = Some(http_config); + self.client = Some(client); + + // Update health status + { + let mut health = self.health.write().await; + health.healthy = true; + health.message = "HTTP proxy initialized successfully".to_string(); + health.last_check = std::time::SystemTime::now(); + } + + Ok(()) + } + + async fn start(&mut self) -> Result<()> { + // HTTP proxy doesn't need to start listeners - it handles requests via handle_event + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + // Clean up resources + self.client = None; + Ok(()) + } + + async fn handle_event(&mut self, event: PluginEvent) -> Result<()> { + match event { + PluginEvent::DataReceived(packet) => { + // Parse HTTP request and proxy it + // This is a simplified implementation - in practice you'd need + // a full HTTP parser and connection handler + tracing::debug!("Received HTTP data: {} bytes", packet.data.len()); + } + PluginEvent::ConnectionEstablished(context) => { + tracing::debug!("HTTP connection established: {}", context.connection_id); + } + PluginEvent::ConnectionClosed(connection_id) => { + tracing::debug!("HTTP connection closed: {}", connection_id); + } + _ => { + // Handle other events as needed + } + } + Ok(()) + } + + async fn health(&self) -> Result { + let health = self.health.read().await.clone(); + Ok(health) + } + + async fn metrics(&self) -> Result { + let metrics = self.metrics.read().await.clone(); + Ok(metrics) + } + + async fn update_config(&mut self, config: PluginConfig) -> Result<()> { + self.initialize(config, mpsc::unbounded_channel().0).await + } + + async fn handle_command(&self, command: &str, args: serde_json::Value) -> Result { + match command { + "get_backends" => { + let backends = self.backends.read().await; + Ok(serde_json::to_value(&*backends)?) + } + "health_check_backend" => { + let backend_url = args.get("url").and_then(|v| v.as_str()) + .ok_or_else(|| RouterError::InvalidApiRequest { + reason: "Missing 'url' parameter".to_string(), + })?; + + // Perform health check + let healthy = self.check_backend_health(backend_url).await?; + Ok(serde_json::json!({ "healthy": healthy })) + } + "circuit_breaker_status" => { + let state = self.circuit_breaker.read().await; + let status = match *state { + CircuitBreakerState::Closed => "closed", + CircuitBreakerState::Open { .. } => "open", + CircuitBreakerState::HalfOpen { .. } => "half_open", + }; + Ok(serde_json::json!({ "status": status })) + } + _ => Err(RouterError::NotSupported { + operation: format!("Command: {}", command), + }), + } + } + + async fn shutdown(&mut self) -> Result<()> { + self.stop().await + } +} + +impl HttpProxyPlugin { + async fn check_backend_health(&self, backend_url: &str) -> Result { + if let Some(client) = &self.client { + match client.get(backend_url).timeout(Duration::from_secs(5)).send().await { + Ok(response) => Ok(response.status().is_success()), + Err(_) => Ok(false), + } + } else { + Ok(false) + } + } +} + +/// Plugin factory function for HTTP proxy +#[no_mangle] +pub extern "C" fn create_http_proxy_plugin() -> *mut dyn Plugin { + Box::into_raw(Box::new(HttpProxyPlugin::new())) +} + +/// Plugin info function for HTTP proxy +#[no_mangle] +pub extern "C" fn get_http_proxy_plugin_info() -> PluginInfo { + PluginInfo { + name: "http_proxy".to_string(), + version: "1.0.0".to_string(), + description: "HTTP reverse proxy with advanced features".to_string(), + author: "Router Team".to_string(), + license: "MIT".to_string(), + repository: Some("https://github.com/Harbr-Foundation/Harbr-Router".to_string()), + min_router_version: "2.0.0".to_string(), + config_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "default_backend": {"type": "string"}, + "timeout_ms": {"type": "number", "default": 30000}, + "retry_count": {"type": "number", "default": 3} + }, + "required": ["default_backend"] + })), + dependencies: vec![], + tags: vec!["http", "proxy", "reverse-proxy"].iter().map(|s| s.to_string()).collect(), + } +} \ No newline at end of file diff --git a/src/builtin_plugins/tcp_proxy.rs b/src/builtin_plugins/tcp_proxy.rs new file mode 100644 index 0000000..48f3add --- /dev/null +++ b/src/builtin_plugins/tcp_proxy.rs @@ -0,0 +1,868 @@ +// src/builtin_plugins/tcp_proxy.rs - TCP Proxy Plugin Implementation +use crate::plugin::*; +use crate::error::{RouterError, Result}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{mpsc, RwLock, Semaphore}; +use tokio::time::timeout; +use tracing::{debug, error, info, warn}; + +/// TCP Proxy Plugin - handles TCP connections with load balancing and connection pooling +#[derive(Debug)] +pub struct TcpProxyPlugin { + name: String, + config: Option, + listener: Option, + metrics: Arc>, + health: Arc>, + connection_pool: Arc, + load_balancer: Arc, + shutdown_sender: Option>, + running: Arc>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpProxyConfig { + pub listen_addresses: Vec, + pub backends: Vec, + pub load_balancing_strategy: LoadBalancingStrategy, + pub connection_timeout_ms: u64, + pub read_timeout_ms: u64, + pub write_timeout_ms: u64, + pub max_connections: usize, + pub connection_pool_size: usize, + pub health_check_interval_ms: u64, + pub buffer_size: usize, + pub tcp_nodelay: bool, + pub tcp_keepalive: Option, + pub circuit_breaker: CircuitBreakerConfig, + pub retry_policy: RetryPolicy, + pub session_affinity: SessionAffinityConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpBackend { + pub address: String, + pub weight: u32, + pub max_connections: Option, + pub health_check_port: Option, + pub metadata: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum LoadBalancingStrategy { + RoundRobin, + WeightedRoundRobin, + LeastConnections, + IpHash, + Random, + ConsistentHash, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TcpKeepalive { + pub enabled: bool, + pub idle_secs: u64, + pub interval_secs: u64, + pub retries: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CircuitBreakerConfig { + pub enabled: bool, + pub failure_threshold: u32, + pub recovery_timeout_ms: u64, + pub half_open_max_calls: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryPolicy { + pub max_retries: u32, + pub retry_delay_ms: u64, + pub exponential_backoff: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionAffinityConfig { + pub enabled: bool, + pub strategy: SessionAffinityStrategy, + pub timeout_secs: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SessionAffinityStrategy { + ClientIp, + Cookie(String), + Header(String), +} + +/// Connection pool for managing backend connections +#[derive(Debug)] +pub struct ConnectionPool { + pools: RwLock>>, + config: TcpProxyConfig, +} + +#[derive(Debug)] +pub struct BackendPool { + address: String, + connections: RwLock>, + semaphore: Semaphore, + health_status: RwLock, +} + +#[derive(Debug)] +pub struct PooledConnection { + stream: TcpStream, + created_at: Instant, + last_used: Instant, + connection_count: u64, +} + +#[derive(Debug, Clone)] +pub struct BackendHealth { + pub healthy: bool, + pub last_check: Instant, + pub consecutive_failures: u32, + pub response_time_ms: u64, +} + +/// Load balancer for selecting backend servers +#[derive(Debug)] +pub struct LoadBalancer { + strategy: LoadBalancingStrategy, + backends: RwLock>, + current_index: RwLock, + connection_counts: RwLock>, + session_store: RwLock>, // session_id -> backend_address +} + +impl TcpProxyPlugin { + pub fn new() -> Self { + Self { + name: "tcp_proxy".to_string(), + config: None, + listener: None, + metrics: Arc::new(RwLock::new(PluginMetrics { + connections_total: 0, + connections_active: 0, + bytes_sent: 0, + bytes_received: 0, + errors_total: 0, + custom_metrics: HashMap::new(), + last_updated: std::time::SystemTime::now(), + })), + health: Arc::new(RwLock::new(PluginHealth { + healthy: false, + message: "Not initialized".to_string(), + last_check: std::time::SystemTime::now(), + response_time_ms: None, + custom_health_data: HashMap::new(), + })), + connection_pool: Arc::new(ConnectionPool::new()), + load_balancer: Arc::new(LoadBalancer::new()), + shutdown_sender: None, + running: Arc::new(RwLock::new(false)), + } + } + + async fn start_listeners(&mut self) -> Result<()> { + let config = self.config.as_ref().ok_or_else(|| { + RouterError::PluginError { + plugin: self.name.clone(), + message: "Plugin not initialized".to_string(), + } + })?; + + let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1); + self.shutdown_sender = Some(shutdown_tx); + + for listen_addr in &config.listen_addresses { + let addr: SocketAddr = listen_addr.parse() + .map_err(|e| RouterError::BindError { + address: listen_addr.clone(), + reason: e.to_string(), + })?; + + let listener = TcpListener::bind(addr).await + .map_err(|e| RouterError::BindError { + address: listen_addr.clone(), + reason: e.to_string(), + })?; + + info!("TCP proxy listening on {}", addr); + + // Clone necessary components for the listener task + let connection_pool = self.connection_pool.clone(); + let load_balancer = self.load_balancer.clone(); + let metrics = self.metrics.clone(); + let running = self.running.clone(); + let config_clone = config.clone(); + let mut shutdown_rx_clone = shutdown_rx.resubscribe(); + + tokio::spawn(async move { + loop { + tokio::select! { + // Accept new connections + result = listener.accept() => { + match result { + Ok((stream, client_addr)) => { + debug!("Accepted TCP connection from {}", client_addr); + + // Update metrics + { + let mut metrics = metrics.write().await; + metrics.connections_total += 1; + metrics.connections_active += 1; + } + + // Handle connection in a separate task + let pool = connection_pool.clone(); + let lb = load_balancer.clone(); + let metrics_clone = metrics.clone(); + let config = config_clone.clone(); + + tokio::spawn(async move { + if let Err(e) = Self::handle_connection( + stream, + client_addr, + pool, + lb, + metrics_clone, + config + ).await { + error!("Connection handling error: {}", e); + } + }); + } + Err(e) => { + error!("Failed to accept connection: {}", e); + } + } + } + + // Handle shutdown + _ = shutdown_rx_clone.recv() => { + info!("Shutting down TCP listener on {}", addr); + break; + } + } + + // Check if we should continue running + if !*running.read().await { + break; + } + } + }); + } + + Ok(()) + } + + async fn handle_connection( + client_stream: TcpStream, + client_addr: SocketAddr, + connection_pool: Arc, + load_balancer: Arc, + metrics: Arc>, + config: TcpProxyConfig, + ) -> Result<()> { + let start_time = Instant::now(); + + // Select backend + let backend_addr = load_balancer.select_backend(Some(client_addr)).await + .ok_or_else(|| RouterError::LoadBalancingError { + message: "No healthy backend available".to_string(), + })?; + + // Get backend connection + let backend_stream = connection_pool.get_connection(&backend_addr, &config).await?; + + // Proxy data between client and backend + let result = Self::proxy_streams(client_stream, backend_stream, &config).await; + + // Update metrics + let duration = start_time.elapsed(); + { + let mut metrics = metrics.write().await; + metrics.connections_active = metrics.connections_active.saturating_sub(1); + + if result.is_ok() { + metrics.custom_metrics.insert( + "avg_connection_duration_ms".to_string(), + duration.as_millis() as f64, + ); + } else { + metrics.errors_total += 1; + } + } + + result + } + + async fn proxy_streams( + mut client_stream: TcpStream, + mut backend_stream: TcpStream, + config: &TcpProxyConfig, + ) -> Result<()> { + // Configure TCP options + if config.tcp_nodelay { + let _ = client_stream.set_nodelay(true); + let _ = backend_stream.set_nodelay(true); + } + + // Split streams for bidirectional proxying + let (client_read, client_write) = client_stream.split(); + let (backend_read, backend_write) = backend_stream.split(); + + let mut client_reader = BufReader::with_capacity(config.buffer_size, client_read); + let mut client_writer = BufWriter::with_capacity(config.buffer_size, client_write); + let mut backend_reader = BufReader::with_capacity(config.buffer_size, backend_read); + let mut backend_writer = BufWriter::with_capacity(config.buffer_size, backend_write); + + let client_to_backend = async { + let mut buffer = vec![0u8; config.buffer_size]; + loop { + match timeout( + Duration::from_millis(config.read_timeout_ms), + client_reader.read(&mut buffer) + ).await { + Ok(Ok(0)) => break, // Connection closed + Ok(Ok(n)) => { + if let Err(e) = timeout( + Duration::from_millis(config.write_timeout_ms), + backend_writer.write_all(&buffer[..n]) + ).await { + error!("Backend write timeout: {}", e); + break; + } + + if let Err(e) = backend_writer.flush().await { + error!("Backend flush error: {}", e); + break; + } + } + Ok(Err(e)) => { + error!("Client read error: {}", e); + break; + } + Err(_) => { + debug!("Client read timeout"); + break; + } + } + } + }; + + let backend_to_client = async { + let mut buffer = vec![0u8; config.buffer_size]; + loop { + match timeout( + Duration::from_millis(config.read_timeout_ms), + backend_reader.read(&mut buffer) + ).await { + Ok(Ok(0)) => break, // Connection closed + Ok(Ok(n)) => { + if let Err(e) = timeout( + Duration::from_millis(config.write_timeout_ms), + client_writer.write_all(&buffer[..n]) + ).await { + error!("Client write timeout: {}", e); + break; + } + + if let Err(e) = client_writer.flush().await { + error!("Client flush error: {}", e); + break; + } + } + Ok(Err(e)) => { + error!("Backend read error: {}", e); + break; + } + Err(_) => { + debug!("Backend read timeout"); + break; + } + } + } + }; + + // Run both directions concurrently + tokio::select! { + _ = client_to_backend => {}, + _ = backend_to_client => {}, + } + + debug!("TCP proxy connection completed"); + Ok(()) + } + + async fn start_health_checker(&self) -> Result<()> { + let config = self.config.as_ref().unwrap(); + let connection_pool = self.connection_pool.clone(); + let health_check_interval = Duration::from_millis(config.health_check_interval_ms); + let running = self.running.clone(); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(health_check_interval); + + while *running.read().await { + interval.tick().await; + + // Health check all backends + connection_pool.health_check_all().await; + } + }); + + Ok(()) + } +} + +#[async_trait] +impl Plugin for TcpProxyPlugin { + fn info(&self) -> PluginInfo { + PluginInfo { + name: self.name.clone(), + version: "1.0.0".to_string(), + description: "TCP proxy with load balancing, connection pooling, and health checks".to_string(), + author: "Router Team".to_string(), + license: "MIT".to_string(), + repository: Some("https://github.com/router/plugins".to_string()), + min_router_version: "2.0.0".to_string(), + config_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "listen_addresses": { + "type": "array", + "items": {"type": "string"} + }, + "backends": { + "type": "array", + "items": { + "type": "object", + "properties": { + "address": {"type": "string"}, + "weight": {"type": "number", "default": 1} + }, + "required": ["address"] + } + } + }, + "required": ["listen_addresses", "backends"] + })), + dependencies: vec![], + tags: vec!["tcp", "proxy", "load-balancer", "connection-pool"].iter().map(|s| s.to_string()).collect(), + } + } + + fn capabilities(&self) -> PluginCapabilities { + PluginCapabilities { + can_handle_tcp: true, + can_handle_udp: false, + can_handle_unix_socket: true, + requires_dedicated_port: true, + supports_hot_reload: true, + supports_load_balancing: true, + custom_protocols: vec!["TCP".to_string()], + port_requirements: vec![], + } + } + + async fn initialize(&mut self, config: PluginConfig, _event_sender: PluginSender) -> Result<()> { + let tcp_config: TcpProxyConfig = serde_json::from_value(config.config) + .map_err(|e| RouterError::ConfigError { + message: format!("Invalid TCP proxy configuration: {}", e), + })?; + + // Initialize connection pool + self.connection_pool.initialize(&tcp_config).await?; + + // Initialize load balancer + self.load_balancer.initialize(&tcp_config).await?; + + self.config = Some(tcp_config); + + // Update health status + { + let mut health = self.health.write().await; + health.healthy = true; + health.message = "TCP proxy initialized successfully".to_string(); + health.last_check = std::time::SystemTime::now(); + } + + Ok(()) + } + + async fn start(&mut self) -> Result<()> { + *self.running.write().await = true; + + // Start listeners + self.start_listeners().await?; + + // Start health checker + self.start_health_checker().await?; + + info!("TCP proxy plugin started"); + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + *self.running.write().await = false; + + // Send shutdown signal + if let Some(sender) = &self.shutdown_sender { + let _ = sender.send(()).await; + } + + info!("TCP proxy plugin stopped"); + Ok(()) + } + + async fn handle_event(&mut self, event: PluginEvent) -> Result<()> { + match event { + PluginEvent::ConnectionEstablished(context) => { + debug!("TCP connection established: {}", context.connection_id); + } + PluginEvent::ConnectionClosed(connection_id) => { + debug!("TCP connection closed: {}", connection_id); + } + _ => { + // Handle other events as needed + } + } + Ok(()) + } + + async fn health(&self) -> Result { + let health = self.health.read().await.clone(); + Ok(health) + } + + async fn metrics(&self) -> Result { + let metrics = self.metrics.read().await.clone(); + Ok(metrics) + } + + async fn update_config(&mut self, config: PluginConfig) -> Result<()> { + self.initialize(config, mpsc::unbounded_channel().0).await + } + + async fn handle_command(&self, command: &str, args: serde_json::Value) -> Result { + match command { + "get_backends" => { + let backends = self.load_balancer.get_backends().await; + Ok(serde_json::to_value(backends)?) + } + "get_pool_stats" => { + let stats = self.connection_pool.get_statistics().await; + Ok(serde_json::to_value(stats)?) + } + "drain_backend" => { + let backend_addr = args.get("address").and_then(|v| v.as_str()) + .ok_or_else(|| RouterError::InvalidApiRequest { + reason: "Missing 'address' parameter".to_string(), + })?; + + self.load_balancer.drain_backend(backend_addr).await?; + Ok(serde_json::json!({ "success": true })) + } + _ => Err(RouterError::NotSupported { + operation: format!("Command: {}", command), + }), + } + } + + async fn shutdown(&mut self) -> Result<()> { + self.stop().await + } +} + +// Implementation for ConnectionPool +impl ConnectionPool { + pub fn new() -> Self { + Self { + pools: RwLock::new(HashMap::new()), + config: TcpProxyConfig { + listen_addresses: vec![], + backends: vec![], + load_balancing_strategy: LoadBalancingStrategy::RoundRobin, + connection_timeout_ms: 30000, + read_timeout_ms: 30000, + write_timeout_ms: 30000, + max_connections: 1000, + connection_pool_size: 100, + health_check_interval_ms: 30000, + buffer_size: 8192, + tcp_nodelay: true, + tcp_keepalive: None, + circuit_breaker: CircuitBreakerConfig { + enabled: false, + failure_threshold: 5, + recovery_timeout_ms: 30000, + half_open_max_calls: 3, + }, + retry_policy: RetryPolicy { + max_retries: 3, + retry_delay_ms: 1000, + exponential_backoff: true, + }, + session_affinity: SessionAffinityConfig { + enabled: false, + strategy: SessionAffinityStrategy::ClientIp, + timeout_secs: 3600, + }, + }, + } + } + + pub async fn initialize(&self, config: &TcpProxyConfig) -> Result<()> { + // Initialize pools for each backend + let mut pools = self.pools.write().await; + pools.clear(); + + for backend in &config.backends { + let pool = Arc::new(BackendPool { + address: backend.address.clone(), + connections: RwLock::new(Vec::new()), + semaphore: Semaphore::new(config.connection_pool_size), + health_status: RwLock::new(BackendHealth { + healthy: true, + last_check: Instant::now(), + consecutive_failures: 0, + response_time_ms: 0, + }), + }); + + pools.insert(backend.address.clone(), pool); + } + + Ok(()) + } + + pub async fn get_connection(&self, backend_addr: &str, config: &TcpProxyConfig) -> Result { + let pools = self.pools.read().await; + let pool = pools.get(backend_addr) + .ok_or_else(|| RouterError::BackendConnectionError { + backend: backend_addr.to_string(), + })?; + + // Try to get a connection from the pool + { + let mut connections = pool.connections.write().await; + if let Some(mut pooled_conn) = connections.pop() { + pooled_conn.last_used = Instant::now(); + pooled_conn.connection_count += 1; + return Ok(pooled_conn.stream); + } + } + + // No pooled connection available, create a new one + let addr: SocketAddr = backend_addr.parse() + .map_err(|e| RouterError::BackendConnectionError { + backend: format!("Invalid address {}: {}", backend_addr, e), + })?; + + let stream = timeout( + Duration::from_millis(config.connection_timeout_ms), + TcpStream::connect(addr) + ).await + .map_err(|_| RouterError::NetworkTimeout { + operation: format!("connect to {}", backend_addr), + })? + .map_err(|e| RouterError::BackendConnectionError { + backend: format!("Failed to connect to {}: {}", backend_addr, e), + })?; + + Ok(stream) + } + + pub async fn health_check_all(&self) { + let pools = self.pools.read().await; + for (addr, pool) in pools.iter() { + self.health_check_backend(addr, pool).await; + } + } + + async fn health_check_backend(&self, addr: &str, pool: &BackendPool) { + let start_time = Instant::now(); + + let health_result = match timeout( + Duration::from_secs(5), + TcpStream::connect(addr) + ).await { + Ok(Ok(_)) => { + let response_time = start_time.elapsed().as_millis() as u64; + Ok(response_time) + } + Ok(Err(e)) => Err(e.to_string()), + Err(_) => Err("Health check timeout".to_string()), + }; + + let mut health = pool.health_status.write().await; + match health_result { + Ok(response_time) => { + health.healthy = true; + health.consecutive_failures = 0; + health.response_time_ms = response_time; + } + Err(_) => { + health.consecutive_failures += 1; + if health.consecutive_failures >= 3 { + health.healthy = false; + } + } + } + health.last_check = Instant::now(); + } + + pub async fn get_statistics(&self) -> HashMap { + let pools = self.pools.read().await; + let mut stats = HashMap::new(); + + for (addr, pool) in pools.iter() { + let connections = pool.connections.read().await; + let health = pool.health_status.read().await; + + stats.insert(addr.clone(), serde_json::json!({ + "connections_count": connections.len(), + "healthy": health.healthy, + "consecutive_failures": health.consecutive_failures, + "response_time_ms": health.response_time_ms, + })); + } + + stats + } +} + +// Implementation for LoadBalancer +impl LoadBalancer { + pub fn new() -> Self { + Self { + strategy: LoadBalancingStrategy::RoundRobin, + backends: RwLock::new(Vec::new()), + current_index: RwLock::new(0), + connection_counts: RwLock::new(HashMap::new()), + session_store: RwLock::new(HashMap::new()), + } + } + + pub async fn initialize(&self, config: &TcpProxyConfig) -> Result<()> { + self.strategy = config.load_balancing_strategy.clone(); + + let mut backends = self.backends.write().await; + *backends = config.backends.clone(); + + Ok(()) + } + + pub async fn select_backend(&self, client_addr: Option) -> Option { + let backends = self.backends.read().await; + if backends.is_empty() { + return None; + } + + match self.strategy { + LoadBalancingStrategy::RoundRobin => { + let mut index = self.current_index.write().await; + let backend = &backends[*index % backends.len()]; + *index = (*index + 1) % backends.len(); + Some(backend.address.clone()) + } + LoadBalancingStrategy::Random => { + use rand::Rng; + let mut rng = rand::thread_rng(); + let index = rng.gen_range(0..backends.len()); + Some(backends[index].address.clone()) + } + LoadBalancingStrategy::IpHash => { + if let Some(client) = client_addr { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + client.ip().hash(&mut hasher); + let hash = hasher.finish() as usize; + + let index = hash % backends.len(); + Some(backends[index].address.clone()) + } else { + // Fallback to round robin + let mut index = self.current_index.write().await; + let backend = &backends[*index % backends.len()]; + *index = (*index + 1) % backends.len(); + Some(backend.address.clone()) + } + } + // Add other load balancing strategies as needed + _ => { + // Default to round robin + let mut index = self.current_index.write().await; + let backend = &backends[*index % backends.len()]; + *index = (*index + 1) % backends.len(); + Some(backend.address.clone()) + } + } + } + + pub async fn get_backends(&self) -> Vec { + self.backends.read().await.clone() + } + + pub async fn drain_backend(&self, backend_addr: &str) -> Result<()> { + // Implementation would mark backend as draining + // and prevent new connections while allowing existing ones to finish + warn!("Backend draining not fully implemented: {}", backend_addr); + Ok(()) + } +} + +/// Plugin factory function for TCP proxy +#[no_mangle] +pub extern "C" fn create_tcp_proxy_plugin() -> *mut dyn Plugin { + Box::into_raw(Box::new(TcpProxyPlugin::new())) +} + +/// Plugin info function for TCP proxy +#[no_mangle] +pub extern "C" fn get_tcp_proxy_plugin_info() -> PluginInfo { + PluginInfo { + name: "tcp_proxy".to_string(), + version: "1.0.0".to_string(), + description: "TCP proxy with advanced load balancing and connection pooling".to_string(), + author: "Router Team".to_string(), + license: "MIT".to_string(), + repository: Some("https://github.com/router/plugins".to_string()), + min_router_version: "2.0.0".to_string(), + config_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "listen_addresses": { + "type": "array", + "items": {"type": "string"} + }, + "backends": { + "type": "array", + "items": { + "type": "object", + "properties": { + "address": {"type": "string"}, + "weight": {"type": "number", "default": 1} + }, + "required": ["address"] + } + } + }, + "required": ["listen_addresses", "backends"] + })), + dependencies: vec![], + tags: vec!["tcp", "proxy", "load-balancer"].iter().map(|s| s.to_string()).collect(), + } +} \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 03fe8ee..769f434 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,305 +1,752 @@ -// src/config.rs -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::fs; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ProxyConfig { - pub listen_addr: String, - pub routes: HashMap, - pub global_timeout_ms: u64, - pub max_connections: usize, - - // New TCP proxy specific configuration - #[serde(default)] - pub tcp_proxy: TcpProxyConfig, -} - -impl ProxyConfig { - pub fn new(listen_addr: &str, global_timeout_ms: u64, max_connections: usize) -> Self { - Self { - listen_addr: listen_addr.to_string(), - routes: HashMap::new(), - global_timeout_ms, - max_connections, - tcp_proxy: TcpProxyConfig::default(), - } - } - - pub fn with_route(mut self, name: &str, route: RouteConfig) -> Self { - self.routes.insert(name.to_string(), route); - self - } - - pub fn with_tcp_proxy(mut self, tcp_proxy: TcpProxyConfig) -> Self { - self.tcp_proxy = tcp_proxy; - self - } - - pub fn enable_tcp_proxy(mut self, enabled: bool) -> Self { - self.tcp_proxy.enabled = enabled; - self - } - - pub fn tcp_listen_addr(mut self, addr: &str) -> Self { - self.tcp_proxy.listen_addr = addr.to_string(); - self - } - - pub fn enable_udp_proxy(mut self, enabled: bool) -> Self { - self.tcp_proxy.udp_enabled = enabled; - self - } - - pub fn udp_listen_addr(mut self, addr: &str) -> Self { - self.tcp_proxy.udp_listen_addr = addr.to_string(); - self - } -} - -#[derive(Debug, Serialize, Deserialize, Clone, Default)] -pub struct TcpProxyConfig { - #[serde(default = "default_tcp_enabled")] - pub enabled: bool, - #[serde(default = "default_tcp_listen_addr")] - pub listen_addr: String, - #[serde(default = "default_tcp_connection_pooling")] - pub connection_pooling: bool, - #[serde(default = "default_tcp_max_idle_time_secs")] - pub max_idle_time_secs: u64, - #[serde(default = "default_udp_enabled")] - pub udp_enabled: bool, - #[serde(default = "default_udp_listen_addr")] - pub udp_listen_addr: String, -} - -impl TcpProxyConfig { - pub fn new() -> Self { - Self::default() - } - - pub fn with_enabled(mut self, enabled: bool) -> Self { - self.enabled = enabled; - self - } - - pub fn with_listen_addr(mut self, addr: &str) -> Self { - self.listen_addr = addr.to_string(); - self - } - - pub fn with_connection_pooling(mut self, enabled: bool) -> Self { - self.connection_pooling = enabled; - self - } - - pub fn with_max_idle_time(mut self, secs: u64) -> Self { - self.max_idle_time_secs = secs; - self - } - - pub fn with_udp_enabled(mut self, enabled: bool) -> Self { - self.udp_enabled = enabled; - self - } - - pub fn with_udp_listen_addr(mut self, addr: &str) -> Self { - self.udp_listen_addr = addr.to_string(); - self - } -} - -fn default_tcp_enabled() -> bool { - false -} - -fn default_tcp_listen_addr() -> String { - "0.0.0.0:9090".to_string() -} - -fn default_tcp_connection_pooling() -> bool { - true -} - -fn default_tcp_max_idle_time_secs() -> u64 { - 60 -} - -fn default_udp_enabled() -> bool { - false -} - -fn default_udp_listen_addr() -> String { - "0.0.0.0:9090".to_string() // Same port as TCP by default -} - -#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] -pub struct RouteConfig { - pub upstream: String, - pub timeout_ms: Option, - pub retry_count: Option, - #[serde(default)] - pub priority: Option, - pub preserve_host_header: Option, - - // TCP-specific configuration - #[serde(default = "default_is_tcp")] - pub is_tcp: bool, - #[serde(default = "default_tcp_port")] - pub tcp_listen_port: Option, - - // UDP-specific configuration - #[serde(default = "default_is_udp")] - pub is_udp: Option, - #[serde(default = "default_udp_port")] - pub udp_listen_port: Option, - - // Database-specific configuration - #[serde(default = "default_db_type")] - pub db_type: Option, -} - -impl RouteConfig { - pub fn new(upstream: &str) -> Self { - Self { - upstream: upstream.to_string(), - timeout_ms: None, - retry_count: None, - priority: None, - preserve_host_header: None, - is_tcp: false, - tcp_listen_port: None, - is_udp: None, - udp_listen_port: None, - db_type: None, - } - } - - pub fn with_timeout(mut self, timeout_ms: u64) -> Self { - self.timeout_ms = Some(timeout_ms); - self - } - - pub fn with_retry_count(mut self, count: u32) -> Self { - self.retry_count = Some(count); - self - } - - pub fn with_priority(mut self, priority: i32) -> Self { - self.priority = Some(priority); - self - } - - pub fn preserve_host_header(mut self, preserve: bool) -> Self { - self.preserve_host_header = Some(preserve); - self - } - - pub fn as_tcp(mut self, is_tcp: bool) -> Self { - self.is_tcp = is_tcp; - self - } - - pub fn with_tcp_listen_port(mut self, port: u16) -> Self { - self.tcp_listen_port = Some(port); - self - } - - pub fn as_udp(mut self, is_udp: bool) -> Self { - self.is_udp = Some(is_udp); - self - } - - pub fn with_udp_listen_port(mut self, port: u16) -> Self { - self.udp_listen_port = Some(port); - self - } - - pub fn with_db_type(mut self, db_type: &str) -> Self { - self.db_type = Some(db_type.to_string()); - self - } -} - -fn default_is_tcp() -> bool { - false -} - -fn default_tcp_port() -> Option { - None -} - -fn default_is_udp() -> Option { - Some(false) -} - -fn default_udp_port() -> Option { - None -} - -fn default_db_type() -> Option { - None -} - -pub fn load_config(path: &str) -> Result { - let content = fs::read_to_string(path)?; - let config: ProxyConfig = serde_yaml::from_str(&content)?; - Ok(config) -} - -// Helper function to detect if a route is likely a database -pub fn is_likely_database(route: &RouteConfig) -> bool { - // Check if explicitly marked as TCP - if route.is_tcp { - return true; - } - - // Check if db_type is specified - if route.db_type.is_some() { - return true; - } - - // Basic heuristics for common database port detection - if let Some(port) = extract_port(&route.upstream) { - match port { - 3306 | 33060 => true, // MySQL - 5432 => true, // PostgreSQL - 27017 | 27018 | 27019 => true, // MongoDB - 6379 => true, // Redis - 1521 => true, // Oracle - 1433 => true, // SQL Server - 9042 => true, // Cassandra - 5984 => true, // CouchDB - 8086 => true, // InfluxDB - 9200 | 9300 => true, // Elasticsearch - _ => false, - } - } else { - // Check for database prefixes in the upstream URL - let upstream = route.upstream.to_lowercase(); - upstream.starts_with("mysql://") - || upstream.starts_with("postgresql://") - || upstream.starts_with("mongodb://") - || upstream.starts_with("redis://") - || upstream.starts_with("oracle://") - || upstream.starts_with("sqlserver://") - || upstream.starts_with("cassandra://") - || upstream.starts_with("couchdb://") - || upstream.starts_with("influxdb://") - || upstream.starts_with("elasticsearch://") - } -} - -// Helper function to extract port from a URL -fn extract_port(url: &str) -> Option { - // Parse out protocol - let url_without_protocol = url.split("://").nth(1).unwrap_or(url); - - // Extract host:port part - let host_port = url_without_protocol.split('/').next()?; - - // Extract port - let port_str = host_port.split(':').nth(1)?; - port_str.parse::().ok() +// src/config.rs - JSON configuration structure +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; +use anyhow::{Context, Result}; + +/// Main router configuration - now JSON-based with plugins +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterConfig { + pub version: String, + pub server: ServerConfig, + pub plugins: PluginSystemConfig, + pub proxies: Vec, + pub domains: HashMap, + pub load_balancing: HashMap, + pub middleware: Vec, + pub monitoring: MonitoringConfig, + pub security: SecurityConfig, + pub logging: LoggingConfig, + + // Legacy compatibility - these will be converted to plugin instances + #[serde(default)] + pub legacy_routes: HashMap, + #[serde(default)] + pub legacy_tcp_proxy: Option, +} + +impl RouterConfig { + /// Load configuration from JSON file + pub async fn load_from_file>(path: P) -> Result { + let content = tokio::fs::read_to_string(path.as_ref()).await + .with_context(|| format!("Failed to read config file: {}", path.as_ref().display()))?; + + let mut config: RouterConfig = serde_json::from_str(&content) + .with_context(|| "Failed to parse JSON configuration")?; + + // Convert legacy configuration to plugin instances + config.convert_legacy_config().await?; + + // Validate configuration + config.validate().await?; + + Ok(config) + } + + /// Save configuration to JSON file + pub async fn save_to_file>(&self, path: P) -> Result<()> { + let json = serde_json::to_string_pretty(self) + .context("Failed to serialize configuration")?; + + tokio::fs::write(path.as_ref(), json).await + .with_context(|| format!("Failed to write config file: {}", path.as_ref().display()))?; + + Ok(()) + } + + /// Convert legacy YAML-style configuration to plugin instances + async fn convert_legacy_config(&mut self) -> Result<()> { + // Convert legacy HTTP routes to HTTP plugin instances + if !self.legacy_routes.is_empty() { + let http_plugin = ProxyInstanceConfig { + name: "legacy_http_proxy".to_string(), + plugin_type: "http_proxy".to_string(), + enabled: true, + priority: 0, + ports: vec![8080], // Default port + bind_addresses: vec!["0.0.0.0".to_string()], + domains: self.legacy_routes.keys().cloned().collect(), + plugin_config: serde_json::to_value(&self.legacy_routes)?, + middleware: vec![], + load_balancing: None, + health_check: Some(HealthCheckConfig { + enabled: true, + interval_seconds: 30, + timeout_seconds: 10, + path: Some("/health".to_string()), + expected_status: Some(200), + custom_check: None, + }), + circuit_breaker: None, + rate_limiting: None, + ssl_config: None, + }; + + self.proxies.push(http_plugin); + self.legacy_routes.clear(); + } + + // Convert legacy TCP proxy to TCP plugin instance + if let Some(tcp_config) = &self.legacy_tcp_proxy { + let tcp_plugin = ProxyInstanceConfig { + name: "legacy_tcp_proxy".to_string(), + plugin_type: "tcp_proxy".to_string(), + enabled: tcp_config.enabled, + priority: 100, + ports: vec![tcp_config.listen_port], + bind_addresses: vec![tcp_config.listen_addr.clone()], + domains: vec![], + plugin_config: serde_json::to_value(tcp_config)?, + middleware: vec![], + load_balancing: None, + health_check: Some(HealthCheckConfig { + enabled: true, + interval_seconds: 60, + timeout_seconds: 5, + path: None, + expected_status: None, + custom_check: Some(serde_json::json!({ + "type": "tcp_connect" + })), + }), + circuit_breaker: None, + rate_limiting: None, + ssl_config: None, + }; + + self.proxies.push(tcp_plugin); + self.legacy_tcp_proxy = None; + } + + Ok(()) + } + + /// Validate the configuration + async fn validate(&self) -> Result<()> { + // Validate server config + if self.server.listen_addresses.is_empty() { + return Err(anyhow::anyhow!("At least one listen address must be specified")); + } + + // Validate proxy instances + let mut used_ports = std::collections::HashSet::new(); + for proxy in &self.proxies { + for port in &proxy.ports { + if used_ports.contains(port) { + return Err(anyhow::anyhow!("Port {} is used by multiple proxies", port)); + } + used_ports.insert(*port); + } + } + + // Validate domain mappings + for (domain, config) in &self.domains { + if !self.proxies.iter().any(|p| p.domains.contains(domain)) { + tracing::warn!("Domain '{}' is not handled by any proxy", domain); + } + } + + Ok(()) + } + + /// Create a default configuration + pub fn default() -> Self { + Self { + version: "2.0.0".to_string(), + server: ServerConfig::default(), + plugins: PluginSystemConfig::default(), + proxies: vec![], + domains: HashMap::new(), + load_balancing: HashMap::new(), + middleware: vec![], + monitoring: MonitoringConfig::default(), + security: SecurityConfig::default(), + logging: LoggingConfig::default(), + legacy_routes: HashMap::new(), + legacy_tcp_proxy: None, + } + } +} + +/// Server configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + pub listen_addresses: Vec, + pub management_port: u16, + pub health_check_port: u16, + pub worker_threads: Option, + pub max_connections: usize, + pub connection_timeout_seconds: u64, + pub graceful_shutdown_timeout_seconds: u64, + pub enable_http2: bool, + pub enable_websockets: bool, + pub request_id_header: String, + pub server_tokens: bool, +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + listen_addresses: vec!["0.0.0.0:8080".to_string()], + management_port: 8081, + health_check_port: 8082, + worker_threads: None, + max_connections: 10000, + connection_timeout_seconds: 30, + graceful_shutdown_timeout_seconds: 30, + enable_http2: true, + enable_websockets: true, + request_id_header: "X-Request-ID".to_string(), + server_tokens: false, + } + } +} + +/// Plugin system configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginSystemConfig { + pub plugin_directories: Vec, + pub auto_reload: bool, + pub reload_interval_seconds: u64, + pub max_plugin_memory_mb: usize, + pub plugin_timeout_seconds: u64, + pub enable_plugin_isolation: bool, + pub allowed_plugins: Option>, + pub blocked_plugins: Vec, + pub require_signature: bool, + pub signature_key_path: Option, + pub max_concurrent_loads: usize, + pub health_check_interval_seconds: u64, + pub metrics_collection_interval_seconds: u64, + pub enable_inter_plugin_communication: bool, +} + +impl Default for PluginSystemConfig { + fn default() -> Self { + Self { + plugin_directories: vec!["./plugins".to_string()], + auto_reload: false, + reload_interval_seconds: 60, + max_plugin_memory_mb: 100, + plugin_timeout_seconds: 30, + enable_plugin_isolation: true, + allowed_plugins: None, + blocked_plugins: vec![], + require_signature: false, + signature_key_path: None, + max_concurrent_loads: 10, + health_check_interval_seconds: 30, + metrics_collection_interval_seconds: 60, + enable_inter_plugin_communication: true, + } + } +} + +/// Proxy instance configuration - represents a loaded plugin instance +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProxyInstanceConfig { + pub name: String, + pub plugin_type: String, // Which plugin to use + pub enabled: bool, + pub priority: i32, + pub ports: Vec, + pub bind_addresses: Vec, + pub domains: Vec, + pub plugin_config: serde_json::Value, // Plugin-specific configuration + pub middleware: Vec, + pub load_balancing: Option, // Reference to load_balancing config + pub health_check: Option, + pub circuit_breaker: Option, + pub rate_limiting: Option, + pub ssl_config: Option, +} + +/// Domain configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DomainConfig { + pub proxy_instance: String, // Which proxy instance handles this domain + pub backend_service: Option, // For plugin configuration + pub ssl_config: Option, + pub cors_config: Option, + pub cache_config: Option, + pub custom_headers: HashMap, + pub rewrite_rules: Vec, + pub access_control: Option, +} + +/// Load balancing configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadBalancingConfig { + pub strategy: String, // round_robin, weighted, least_connections, ip_hash, custom + pub backends: Vec, + pub session_affinity: bool, + pub health_check: HealthCheckConfig, + pub failover: FailoverConfig, + pub custom_config: serde_json::Value, // Plugin-specific LB config +} + +/// Backend server configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BackendConfig { + pub address: String, + pub weight: Option, + pub max_connections: Option, + pub backup: bool, + pub metadata: HashMap, + pub ssl_config: Option, + pub health_check_override: Option, +} + +/// Health check configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HealthCheckConfig { + pub enabled: bool, + pub interval_seconds: u64, + pub timeout_seconds: u64, + pub path: Option, // For HTTP health checks + pub expected_status: Option, + pub custom_check: Option, // Plugin-specific health check +} + +/// Circuit breaker configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CircuitBreakerConfig { + pub enabled: bool, + pub failure_threshold: u32, + pub success_threshold: u32, + pub timeout_seconds: u64, + pub half_open_max_calls: u32, + pub metrics_window_seconds: u64, +} + +/// Rate limiting configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RateLimitingConfig { + pub enabled: bool, + pub requests_per_second: u32, + pub burst_size: u32, + pub per_ip: bool, + pub per_domain: bool, + pub custom_key: Option, // Custom rate limiting key + pub whitelist: Vec, + pub blacklist: Vec, +} + +/// SSL/TLS configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SslConfig { + pub enabled: bool, + pub cert_path: String, + pub key_path: String, + pub ca_path: Option, + pub protocols: Vec, + pub ciphers: Vec, + pub client_cert_required: bool, + pub verify_client: bool, + pub sni_callback: Option, // Plugin callback for SNI +} + +/// CORS configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CorsConfig { + pub enabled: bool, + pub allowed_origins: Vec, + pub allowed_methods: Vec, + pub allowed_headers: Vec, + pub exposed_headers: Vec, + pub max_age_seconds: u64, + pub allow_credentials: bool, +} + +/// Cache configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CacheConfig { + pub enabled: bool, + pub ttl_seconds: u64, + pub max_size_mb: usize, + pub cache_key_headers: Vec, + pub vary_headers: Vec, + pub bypass_headers: Vec, + pub custom_cache_logic: Option, // Plugin callback +} + +/// URL rewrite rule +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RewriteRule { + pub pattern: String, + pub replacement: String, + pub flags: Vec, // redirect, last, etc. + pub conditions: Vec, +} + +/// Rewrite condition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RewriteCondition { + pub test_string: String, + pub condition_pattern: String, + pub flags: Vec, +} + +/// Access control configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessControlConfig { + pub allowed_ips: Vec, + pub blocked_ips: Vec, + pub allowed_countries: Vec, + pub blocked_countries: Vec, + pub custom_rules: Vec, +} + +/// Access rule +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessRule { + pub name: String, + pub condition: String, // Expression to evaluate + pub action: String, // allow, deny, redirect + pub value: Option, // For redirect actions +} + +/// Failover configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FailoverConfig { + pub enabled: bool, + pub max_failures: u32, + pub retry_interval_seconds: u64, + pub fallback_backend: Option, + pub custom_failover_logic: Option, // Plugin callback +} + +/// Middleware configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MiddlewareConfig { + pub name: String, + pub plugin_type: String, // Which plugin provides this middleware + pub enabled: bool, + pub order: u32, + pub config: serde_json::Value, + pub apply_to: Vec, // Which proxy instances to apply to +} + +/// Monitoring configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MonitoringConfig { + pub metrics_enabled: bool, + pub metrics_port: u16, + pub metrics_path: String, + pub prometheus_enabled: bool, + pub custom_metrics: Vec, + pub tracing_enabled: bool, + pub tracing_endpoint: Option, + pub log_sampling_rate: f64, + pub alert_rules: Vec, +} + +impl Default for MonitoringConfig { + fn default() -> Self { + Self { + metrics_enabled: true, + metrics_port: 9090, + metrics_path: "/metrics".to_string(), + prometheus_enabled: true, + custom_metrics: vec![], + tracing_enabled: false, + tracing_endpoint: None, + log_sampling_rate: 1.0, + alert_rules: vec![], + } + } +} + +/// Custom metric configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CustomMetricConfig { + pub name: String, + pub metric_type: String, // counter, gauge, histogram + pub description: String, + pub labels: Vec, + pub plugin_callback: Option, +} + +/// Alert rule configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AlertRule { + pub name: String, + pub condition: String, // PromQL-like expression + pub threshold: f64, + pub duration_seconds: u64, + pub severity: String, + pub action: AlertAction, +} + +/// Alert action +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AlertAction { + pub action_type: String, // webhook, email, plugin + pub endpoint: Option, + pub config: serde_json::Value, +} + +/// Security configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityConfig { + pub request_size_limit_mb: usize, + pub header_size_limit_kb: usize, + pub timeout_seconds: u64, + pub enable_request_logging: bool, + pub sensitive_headers: Vec, + pub security_headers: HashMap, + pub waf_enabled: bool, + pub waf_rules: Vec, + pub ddos_protection: DdosProtectionConfig, +} + +impl Default for SecurityConfig { + fn default() -> Self { + Self { + request_size_limit_mb: 100, + header_size_limit_kb: 64, + timeout_seconds: 30, + enable_request_logging: true, + sensitive_headers: vec![ + "Authorization".to_string(), + "Cookie".to_string(), + "X-API-Key".to_string(), + ], + security_headers: [ + ("X-Frame-Options".to_string(), "DENY".to_string()), + ("X-Content-Type-Options".to_string(), "nosniff".to_string()), + ("X-XSS-Protection".to_string(), "1; mode=block".to_string()), + ].iter().cloned().collect(), + waf_enabled: false, + waf_rules: vec![], + ddos_protection: DdosProtectionConfig::default(), + } + } +} + +/// WAF rule +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WafRule { + pub name: String, + pub pattern: String, + pub action: String, // block, log, redirect + pub severity: String, + pub enabled: bool, +} + +/// DDoS protection configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DdosProtectionConfig { + pub enabled: bool, + pub requests_per_second_per_ip: u32, + pub burst_size_per_ip: u32, + pub ban_duration_seconds: u64, + pub whitelist: Vec, + pub custom_logic: Option, // Plugin callback +} + +impl Default for DdosProtectionConfig { + fn default() -> Self { + Self { + enabled: true, + requests_per_second_per_ip: 100, + burst_size_per_ip: 200, + ban_duration_seconds: 300, + whitelist: vec![], + custom_logic: None, + } + } +} + +/// Logging configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoggingConfig { + pub level: String, + pub format: String, // json, text + pub output: Vec, + pub access_log: AccessLogConfig, + pub error_log: ErrorLogConfig, + pub audit_log: Option, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + level: "info".to_string(), + format: "json".to_string(), + output: vec![LogOutput { + output_type: "stdout".to_string(), + config: serde_json::json!({}), + }], + access_log: AccessLogConfig::default(), + error_log: ErrorLogConfig::default(), + audit_log: None, + } + } +} + +/// Log output configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LogOutput { + pub output_type: String, // stdout, file, syslog, plugin + pub config: serde_json::Value, +} + +/// Access log configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessLogConfig { + pub enabled: bool, + pub format: String, + pub output: Vec, + pub fields: Vec, + pub exclude_paths: Vec, +} + +impl Default for AccessLogConfig { + fn default() -> Self { + Self { + enabled: true, + format: "combined".to_string(), + output: vec![LogOutput { + output_type: "file".to_string(), + config: serde_json::json!({ + "path": "./logs/access.log", + "rotation": "daily" + }), + }], + fields: vec![ + "timestamp".to_string(), + "client_ip".to_string(), + "method".to_string(), + "path".to_string(), + "status".to_string(), + "response_time".to_string(), + ], + exclude_paths: vec!["/health".to_string(), "/metrics".to_string()], + } + } +} + +/// Error log configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorLogConfig { + pub enabled: bool, + pub output: Vec, + pub include_stack_trace: bool, + pub min_level: String, +} + +impl Default for ErrorLogConfig { + fn default() -> Self { + Self { + enabled: true, + output: vec![LogOutput { + output_type: "file".to_string(), + config: serde_json::json!({ + "path": "./logs/error.log", + "rotation": "daily" + }), + }], + include_stack_trace: true, + min_level: "warn".to_string(), + } + } +} + +/// Audit log configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditLogConfig { + pub enabled: bool, + pub output: Vec, + pub events: Vec, // Which events to audit + pub include_request_body: bool, + pub include_response_body: bool, +} + +// Legacy configuration types for backward compatibility +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LegacyRouteConfig { + pub upstream: String, + pub timeout_ms: Option, + pub retry_count: Option, + pub priority: Option, + pub preserve_host_header: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LegacyTcpProxyConfig { + pub enabled: bool, + pub listen_addr: String, + pub listen_port: u16, + pub connection_pooling: bool, + pub max_idle_time_secs: u64, + pub udp_enabled: bool, + pub udp_listen_addr: String, +} + +/// Configuration utilities +impl RouterConfig { + /// Get proxy instance by name + pub fn get_proxy_instance(&self, name: &str) -> Option<&ProxyInstanceConfig> { + self.proxies.iter().find(|p| p.name == name) + } + + /// Get domain configuration + pub fn get_domain_config(&self, domain: &str) -> Option<&DomainConfig> { + self.domains.get(domain) + } + + /// Get load balancing configuration + pub fn get_load_balancing_config(&self, name: &str) -> Option<&LoadBalancingConfig> { + self.load_balancing.get(name) + } + + /// Find proxy instances handling a domain + pub fn find_proxies_for_domain(&self, domain: &str) -> Vec<&ProxyInstanceConfig> { + self.proxies.iter() + .filter(|p| p.domains.iter().any(|d| self.domain_matches(d, domain))) + .collect() + } + + /// Check if domain pattern matches + fn domain_matches(&self, pattern: &str, domain: &str) -> bool { + if pattern.starts_with("*.") { + let suffix = &pattern[2..]; + domain.ends_with(suffix) + } else { + pattern == domain + } + } + + /// Merge with another configuration (for updates) + pub fn merge(&mut self, other: RouterConfig) -> Result<()> { + // Merge server config + self.server = other.server; + + // Merge plugin system config + self.plugins = other.plugins; + + // Replace proxy instances + self.proxies = other.proxies; + + // Merge domains + self.domains.extend(other.domains); + + // Merge load balancing configs + self.load_balancing.extend(other.load_balancing); + + // Replace middleware + self.middleware = other.middleware; + + // Merge monitoring + self.monitoring = other.monitoring; + + // Merge security + self.security = other.security; + + // Merge logging + self.logging = other.logging; + + Ok(()) + } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e77a244 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,668 @@ +// src/error.rs - Comprehensive error handling +use serde::{Deserialize, Serialize}; +use std::fmt; +use thiserror::Error; + +/// Main router error type +#[derive(Error, Debug, Clone, Serialize, Deserialize)] +pub enum RouterError { + // Configuration errors + #[error("Configuration error: {message}")] + ConfigError { message: String }, + + #[error("Invalid configuration file: {path}")] + InvalidConfigFile { path: String }, + + #[error("Configuration validation failed: {errors:?}")] + ConfigValidationError { errors: Vec }, + + // Plugin errors + #[error("Plugin error in '{plugin}': {message}")] + PluginError { plugin: String, message: String }, + + #[error("Plugin '{plugin}' not found")] + PluginNotFound { plugin: String }, + + #[error("Failed to load plugin '{plugin}': {reason}")] + PluginLoadError { plugin: String, reason: String }, + + #[error("Plugin '{plugin}' initialization failed: {reason}")] + PluginInitError { plugin: String, reason: String }, + + #[error("Plugin '{plugin}' is not compatible: {reason}")] + PluginCompatibilityError { plugin: String, reason: String }, + + // Server errors + #[error("Server error: {message}")] + ServerError { message: String }, + + #[error("Failed to bind to address '{address}': {reason}")] + BindError { address: String, reason: String }, + + #[error("Server is not running")] + ServerNotRunning, + + #[error("Server startup timeout")] + ServerStartupTimeout, + + #[error("Graceful shutdown failed: {reason}")] + ShutdownError { reason: String }, + + // Network/Connection errors + #[error("Connection error from {client}: {message}")] + ConnectionError { client: String, message: String }, + + #[error("Network timeout: {operation}")] + NetworkTimeout { operation: String }, + + #[error("DNS resolution failed for '{hostname}': {reason}")] + DnsError { hostname: String, reason: String }, + + #[error("SSL/TLS error: {message}")] + SslError { message: String }, + + // Request/Response errors + #[error("Request processing error: {message}")] + RequestError { message: String }, + + #[error("Invalid request: {reason}")] + InvalidRequest { reason: String }, + + #[error("Request timeout after {timeout_ms}ms")] + RequestTimeout { timeout_ms: u64 }, + + #[error("Request too large: {size} bytes (max: {max_size})")] + RequestTooLarge { size: usize, max_size: usize }, + + #[error("Response processing error: {message}")] + ResponseError { message: String }, + + #[error("Upstream error: {message}")] + UpstreamError { message: String }, + + #[error("Backend connection failed: {backend}")] + BackendConnectionError { backend: String }, + + // Routing errors + #[error("No route found for {domain}{path}")] + RouteNotFound { domain: String, path: String }, + + #[error("Routing error: {message}")] + RoutingError { message: String }, + + #[error("Load balancing error: {message}")] + LoadBalancingError { message: String }, + + // Security errors + #[error("Authentication failed: {reason}")] + AuthenticationError { reason: String }, + + #[error("Authorization denied: {reason}")] + AuthorizationError { reason: String }, + + #[error("Rate limit exceeded for {identifier}")] + RateLimitExceeded { identifier: String }, + + #[error("Security violation: {message}")] + SecurityViolation { message: String }, + + #[error("CORS error: {message}")] + CorsError { message: String }, + + // Circuit breaker errors + #[error("Circuit breaker is open for {service}")] + CircuitBreakerOpen { service: String }, + + #[error("Circuit breaker error: {message}")] + CircuitBreakerError { message: String }, + + // Health check errors + #[error("Health check failed for {service}: {reason}")] + HealthCheckFailed { service: String, reason: String }, + + #[error("Service unavailable: {service}")] + ServiceUnavailable { service: String }, + + // Resource errors + #[error("Resource exhausted: {resource}")] + ResourceExhausted { resource: String }, + + #[error("Memory limit exceeded: {current} bytes (limit: {limit})")] + MemoryLimitExceeded { current: usize, limit: usize }, + + #[error("Connection pool exhausted for {backend}")] + ConnectionPoolExhausted { backend: String }, + + // I/O errors + #[error("I/O error: {message}")] + IoError { message: String }, + + #[error("File not found: {path}")] + FileNotFound { path: String }, + + #[error("Permission denied: {path}")] + PermissionDenied { path: String }, + + // Serialization/parsing errors + #[error("JSON parsing error: {message}")] + JsonError { message: String }, + + #[error("YAML parsing error: {message}")] + YamlError { message: String }, + + #[error("URL parsing error: {url}")] + UrlParsingError { url: String }, + + #[error("Header parsing error: {header}")] + HeaderParsingError { header: String }, + + // Middleware errors + #[error("Middleware '{middleware}' error: {message}")] + MiddlewareError { middleware: String, message: String }, + + #[error("Middleware chain error: {message}")] + MiddlewareChainError { message: String }, + + // Management API errors + #[error("Management API error: {message}")] + ManagementApiError { message: String }, + + #[error("Invalid API request: {reason}")] + InvalidApiRequest { reason: String }, + + #[error("API authentication required")] + ApiAuthRequired, + + // Metrics/monitoring errors + #[error("Metrics collection error: {message}")] + MetricsError { message: String }, + + #[error("Monitoring error: {message}")] + MonitoringError { message: String }, + + // Generic/unknown errors + #[error("Internal error: {message}")] + InternalError { message: String }, + + #[error("Unknown error: {message}")] + UnknownError { message: String }, + + #[error("Operation not supported: {operation}")] + NotSupported { operation: String }, + + #[error("Feature not implemented: {feature}")] + NotImplemented { feature: String }, +} + +/// Result type alias for router operations +pub type Result = std::result::Result; + +impl RouterError { + /// Get the error category + pub fn category(&self) -> ErrorCategory { + match self { + RouterError::ConfigError { .. } | + RouterError::InvalidConfigFile { .. } | + RouterError::ConfigValidationError { .. } => ErrorCategory::Configuration, + + RouterError::PluginError { .. } | + RouterError::PluginNotFound { .. } | + RouterError::PluginLoadError { .. } | + RouterError::PluginInitError { .. } | + RouterError::PluginCompatibilityError { .. } => ErrorCategory::Plugin, + + RouterError::ServerError { .. } | + RouterError::BindError { .. } | + RouterError::ServerNotRunning | + RouterError::ServerStartupTimeout | + RouterError::ShutdownError { .. } => ErrorCategory::Server, + + RouterError::ConnectionError { .. } | + RouterError::NetworkTimeout { .. } | + RouterError::DnsError { .. } | + RouterError::SslError { .. } => ErrorCategory::Network, + + RouterError::RequestError { .. } | + RouterError::InvalidRequest { .. } | + RouterError::RequestTimeout { .. } | + RouterError::RequestTooLarge { .. } | + RouterError::ResponseError { .. } | + RouterError::UpstreamError { .. } | + RouterError::BackendConnectionError { .. } => ErrorCategory::Request, + + RouterError::RouteNotFound { .. } | + RouterError::RoutingError { .. } | + RouterError::LoadBalancingError { .. } => ErrorCategory::Routing, + + RouterError::AuthenticationError { .. } | + RouterError::AuthorizationError { .. } | + RouterError::RateLimitExceeded { .. } | + RouterError::SecurityViolation { .. } | + RouterError::CorsError { .. } => ErrorCategory::Security, + + RouterError::CircuitBreakerOpen { .. } | + RouterError::CircuitBreakerError { .. } => ErrorCategory::CircuitBreaker, + + RouterError::HealthCheckFailed { .. } | + RouterError::ServiceUnavailable { .. } => ErrorCategory::Health, + + RouterError::ResourceExhausted { .. } | + RouterError::MemoryLimitExceeded { .. } | + RouterError::ConnectionPoolExhausted { .. } => ErrorCategory::Resource, + + RouterError::IoError { .. } | + RouterError::FileNotFound { .. } | + RouterError::PermissionDenied { .. } => ErrorCategory::Io, + + RouterError::JsonError { .. } | + RouterError::YamlError { .. } | + RouterError::UrlParsingError { .. } | + RouterError::HeaderParsingError { .. } => ErrorCategory::Parsing, + + RouterError::MiddlewareError { .. } | + RouterError::MiddlewareChainError { .. } => ErrorCategory::Middleware, + + RouterError::ManagementApiError { .. } | + RouterError::InvalidApiRequest { .. } | + RouterError::ApiAuthRequired => ErrorCategory::Api, + + RouterError::MetricsError { .. } | + RouterError::MonitoringError { .. } => ErrorCategory::Monitoring, + + RouterError::InternalError { .. } | + RouterError::UnknownError { .. } | + RouterError::NotSupported { .. } | + RouterError::NotImplemented { .. } => ErrorCategory::Internal, + } + } + + /// Get the HTTP status code that should be returned for this error + pub fn http_status_code(&self) -> u16 { + match self { + RouterError::InvalidRequest { .. } | + RouterError::HeaderParsingError { .. } | + RouterError::UrlParsingError { .. } => 400, // Bad Request + + RouterError::AuthenticationError { .. } | + RouterError::ApiAuthRequired => 401, // Unauthorized + + RouterError::AuthorizationError { .. } => 403, // Forbidden + + RouterError::RouteNotFound { .. } | + RouterError::PluginNotFound { .. } | + RouterError::FileNotFound { .. } => 404, // Not Found + + RouterError::RequestTooLarge { .. } => 413, // Payload Too Large + + RouterError::RateLimitExceeded { .. } => 429, // Too Many Requests + + RouterError::InternalError { .. } | + RouterError::ServerError { .. } | + RouterError::PluginError { .. } | + RouterError::ConfigError { .. } => 500, // Internal Server Error + + RouterError::NotImplemented { .. } | + RouterError::NotSupported { .. } => 501, // Not Implemented + + RouterError::BackendConnectionError { .. } | + RouterError::UpstreamError { .. } => 502, // Bad Gateway + + RouterError::ServiceUnavailable { .. } | + RouterError::CircuitBreakerOpen { .. } | + RouterError::ServerNotRunning => 503, // Service Unavailable + + RouterError::RequestTimeout { .. } | + RouterError::NetworkTimeout { .. } => 504, // Gateway Timeout + + _ => 500, // Default to Internal Server Error + } + } + + /// Check if the error is recoverable + pub fn is_recoverable(&self) -> bool { + match self { + RouterError::NetworkTimeout { .. } | + RouterError::RequestTimeout { .. } | + RouterError::BackendConnectionError { .. } | + RouterError::CircuitBreakerOpen { .. } | + RouterError::ResourceExhausted { .. } | + RouterError::ConnectionPoolExhausted { .. } => true, + + RouterError::ConfigError { .. } | + RouterError::PluginLoadError { .. } | + RouterError::PluginInitError { .. } | + RouterError::ServerNotRunning | + RouterError::InvalidRequest { .. } | + RouterError::AuthenticationError { .. } | + RouterError::AuthorizationError { .. } | + RouterError::RouteNotFound { .. } => false, + + _ => false, // Conservative approach - assume not recoverable + } + } + + /// Check if the error should be logged + pub fn should_log(&self) -> bool { + match self { + RouterError::RouteNotFound { .. } | + RouterError::AuthenticationError { .. } | + RouterError::RateLimitExceeded { .. } => false, // These are expected and shouldn't clutter logs + + _ => true, + } + } + + /// Get the log level for this error + pub fn log_level(&self) -> LogLevel { + match self { + RouterError::ConfigError { .. } | + RouterError::PluginLoadError { .. } | + RouterError::ServerError { .. } | + RouterError::InternalError { .. } => LogLevel::Error, + + RouterError::BackendConnectionError { .. } | + RouterError::UpstreamError { .. } | + RouterError::CircuitBreakerOpen { .. } | + RouterError::HealthCheckFailed { .. } => LogLevel::Warn, + + RouterError::RequestTimeout { .. } | + RouterError::NetworkTimeout { .. } | + RouterError::RateLimitExceeded { .. } => LogLevel::Info, + + RouterError::RouteNotFound { .. } | + RouterError::InvalidRequest { .. } => LogLevel::Debug, + + _ => LogLevel::Warn, + } + } + + /// Convert to a user-friendly error message + pub fn user_message(&self) -> String { + match self { + RouterError::RouteNotFound { .. } => "The requested resource was not found".to_string(), + RouterError::ServiceUnavailable { .. } => "Service temporarily unavailable".to_string(), + RouterError::RequestTimeout { .. } => "Request timed out".to_string(), + RouterError::RateLimitExceeded { .. } => "Too many requests. Please try again later".to_string(), + RouterError::AuthenticationError { .. } => "Authentication required".to_string(), + RouterError::AuthorizationError { .. } => "Access denied".to_string(), + RouterError::RequestTooLarge { .. } => "Request entity too large".to_string(), + _ => "An error occurred while processing your request".to_string(), + } + } +} + +/// Error categories for classification +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ErrorCategory { + Configuration, + Plugin, + Server, + Network, + Request, + Routing, + Security, + CircuitBreaker, + Health, + Resource, + Io, + Parsing, + Middleware, + Api, + Monitoring, + Internal, +} + +/// Log levels +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LogLevel { + Error, + Warn, + Info, + Debug, +} + +/// Error context for debugging +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorContext { + pub error: RouterError, + pub timestamp: std::time::SystemTime, + pub request_id: Option, + pub plugin_name: Option, + pub client_ip: Option, + pub user_agent: Option, + pub additional_context: std::collections::HashMap, +} + +impl ErrorContext { + pub fn new(error: RouterError) -> Self { + Self { + error, + timestamp: std::time::SystemTime::now(), + request_id: None, + plugin_name: None, + client_ip: None, + user_agent: None, + additional_context: std::collections::HashMap::new(), + } + } + + pub fn with_request_id(mut self, request_id: String) -> Self { + self.request_id = Some(request_id); + self + } + + pub fn with_plugin(mut self, plugin_name: String) -> Self { + self.plugin_name = Some(plugin_name); + self + } + + pub fn with_client_info(mut self, client_ip: String, user_agent: Option) -> Self { + self.client_ip = Some(client_ip); + self.user_agent = user_agent; + self + } + + pub fn with_context(mut self, key: String, value: String) -> Self { + self.additional_context.insert(key, value); + self + } +} + +/// Error handler trait for customizing error handling +pub trait ErrorHandler: Send + Sync { + fn handle_error(&self, context: &ErrorContext) -> Result<()>; +} + +/// Default error handler implementation +pub struct DefaultErrorHandler; + +impl ErrorHandler for DefaultErrorHandler { + fn handle_error(&self, context: &ErrorContext) -> Result<()> { + let level = context.error.log_level(); + let should_log = context.error.should_log(); + + if should_log { + let message = format!( + "Error in {}: {} [Category: {:?}]", + context.plugin_name.as_deref().unwrap_or("system"), + context.error, + context.error.category() + ); + + match level { + LogLevel::Error => tracing::error!("{}", message), + LogLevel::Warn => tracing::warn!("{}", message), + LogLevel::Info => tracing::info!("{}", message), + LogLevel::Debug => tracing::debug!("{}", message), + } + } + + Ok(()) + } +} + +/// Conversion implementations from common error types +impl From for RouterError { + fn from(err: std::io::Error) -> Self { + match err.kind() { + std::io::ErrorKind::NotFound => RouterError::FileNotFound { + path: "unknown".to_string(), + }, + std::io::ErrorKind::PermissionDenied => RouterError::PermissionDenied { + path: "unknown".to_string(), + }, + std::io::ErrorKind::TimedOut => RouterError::NetworkTimeout { + operation: "I/O operation".to_string(), + }, + _ => RouterError::IoError { + message: err.to_string(), + }, + } + } +} + +impl From for RouterError { + fn from(err: serde_json::Error) -> Self { + RouterError::JsonError { + message: err.to_string(), + } + } +} + +impl From for RouterError { + fn from(err: url::ParseError) -> Self { + RouterError::UrlParsingError { + url: err.to_string(), + } + } +} + +impl From for RouterError { + fn from(_: tokio::time::error::Elapsed) -> Self { + RouterError::NetworkTimeout { + operation: "async operation".to_string(), + } + } +} + +/// Utility functions for error handling +pub mod utils { + use super::*; + + /// Create a configuration error + pub fn config_error>(message: S) -> RouterError { + RouterError::ConfigError { + message: message.into(), + } + } + + /// Create a plugin error + pub fn plugin_error>(plugin: S, message: S) -> RouterError { + RouterError::PluginError { + plugin: plugin.into(), + message: message.into(), + } + } + + /// Create a server error + pub fn server_error>(message: S) -> RouterError { + RouterError::ServerError { + message: message.into(), + } + } + + /// Create a request error + pub fn request_error>(message: S) -> RouterError { + RouterError::RequestError { + message: message.into(), + } + } + + /// Create a routing error + pub fn routing_error>(message: S) -> RouterError { + RouterError::RoutingError { + message: message.into(), + } + } + + /// Create an internal error + pub fn internal_error>(message: S) -> RouterError { + RouterError::InternalError { + message: message.into(), + } + } + + /// Chain errors with context + pub fn chain_error(error: E, context: &str) -> RouterError { + RouterError::InternalError { + message: format!("{}: {}", context, error), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_categorization() { + let config_error = RouterError::ConfigError { + message: "test".to_string(), + }; + assert_eq!(config_error.category(), ErrorCategory::Configuration); + + let plugin_error = RouterError::PluginError { + plugin: "test".to_string(), + message: "test".to_string(), + }; + assert_eq!(plugin_error.category(), ErrorCategory::Plugin); + } + + #[test] + fn test_http_status_codes() { + let not_found = RouterError::RouteNotFound { + domain: "example.com".to_string(), + path: "/test".to_string(), + }; + assert_eq!(not_found.http_status_code(), 404); + + let auth_error = RouterError::AuthenticationError { + reason: "invalid token".to_string(), + }; + assert_eq!(auth_error.http_status_code(), 401); + + let rate_limit = RouterError::RateLimitExceeded { + identifier: "127.0.0.1".to_string(), + }; + assert_eq!(rate_limit.http_status_code(), 429); + } + + #[test] + fn test_error_recoverability() { + let timeout_error = RouterError::RequestTimeout { timeout_ms: 5000 }; + assert!(timeout_error.is_recoverable()); + + let config_error = RouterError::ConfigError { + message: "test".to_string(), + }; + assert!(!config_error.is_recoverable()); + } + + #[test] + fn test_error_context() { + let error = RouterError::RequestError { + message: "test error".to_string(), + }; + + let context = ErrorContext::new(error) + .with_request_id("req-123".to_string()) + .with_plugin("http_proxy".to_string()) + .with_client_info("127.0.0.1".to_string(), Some("curl/7.68.0".to_string())) + .with_context("additional".to_string(), "info".to_string()); + + assert_eq!(context.request_id, Some("req-123".to_string())); + assert_eq!(context.plugin_name, Some("http_proxy".to_string())); + assert_eq!(context.client_ip, Some("127.0.0.1".to_string())); + assert_eq!(context.additional_context.get("additional"), Some(&"info".to_string())); + } +} \ No newline at end of file diff --git a/src/health.rs b/src/health.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/lib.rs b/src/lib.rs index 76402a1..47c2ce3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,261 +1,535 @@ -// src/lib.rs -use anyhow::Result; -use std::sync::Arc; -use tokio::sync::{broadcast, mpsc, RwLock}; - -pub mod client; -pub mod config; -pub mod metrics; -pub mod http_proxy; -pub mod tcp_proxy; -pub mod udp_proxy; -pub mod dynamic_config; -pub mod config_api; - -/// The main Router struct that manages all proxy services -pub struct Router { - config_manager: Arc, - shutdown_tx: Option>, -} - -impl Router { - /// Create a new Router with the provided configuration manager - pub fn new_with_manager(config_manager: Arc) -> Self { - Router { - config_manager, - shutdown_tx: None, - } - } - - /// Create a new Router with the provided configuration - pub fn new(config: config::ProxyConfig) -> Self { - let config_manager = Arc::new(DynamicConfigManager::new(config)); - Self::new_with_manager(config_manager) - } - - /// Create a new Router by loading configuration from a file - pub async fn from_file(config_path: &str) -> Result { - let config_manager = DynamicConfigManager::from_file(config_path).await?; - Ok(Router::new_with_manager(Arc::new(config_manager))) - } - - /// Get the configuration manager - pub fn config_manager(&self) -> Arc { - self.config_manager.clone() - } - - /// Start the router service with all enabled proxies - pub async fn start(&mut self) -> Result<()> { - // Initialize metrics - metrics::init_metrics()?; - - // Create a shutdown channel - let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - self.shutdown_tx = Some(shutdown_tx); - - // Set up config change listener - let mut config_rx = self.config_manager.subscribe(); - let config_manager = self.config_manager.clone(); - - // Initial configuration - let initial_config = self.config_manager.get_config().read().await.clone(); - - // Start the HTTP proxy server - let http_config = self.config_manager.get_config().clone(); - let http_handle = tokio::spawn(async move { - if let Err(e) = http_proxy::run_server(http_config).await { - tracing::error!("HTTP Server error: {}", e); - } - }); - - // Check for database routes that should be handled as TCP - let has_db_routes = initial_config.routes.iter().any(|(_, route)| { - config::is_likely_database(route) - }); - - // Check for UDP routes - let has_udp_routes = initial_config.routes.iter().any(|(_, route)| { - route.is_udp.unwrap_or(false) - }); - - // Start TCP proxy if enabled or if database routes are detected - let tcp_handle = if initial_config.tcp_proxy.enabled || has_db_routes { - tracing::info!("TCP proxy support enabled"); - let tcp_config = self.config_manager.get_config().clone(); - - let handle = tokio::spawn(async move { - let tcp_proxy = tcp_proxy::TcpProxyServer::new(tcp_config).await; - if let Err(e) = tcp_proxy.run(&initial_config.tcp_proxy.listen_addr).await { - tracing::error!("TCP proxy server error: {}", e); - } - }); - Some(handle) - } else { - None - }; - - // Start UDP proxy if enabled or if UDP routes are detected - let udp_handle = if initial_config.tcp_proxy.udp_enabled || has_udp_routes { - tracing::info!("UDP proxy support enabled"); - let udp_config = self.config_manager.get_config().clone(); - - // Use the same address as TCP proxy by default - let udp_listen_addr = initial_config.tcp_proxy.udp_listen_addr.clone(); - - let handle = tokio::spawn(async move { - let udp_proxy = udp_proxy::UdpProxyServer::new(udp_config); - if let Err(e) = udp_proxy.run(&udp_listen_addr).await { - tracing::error!("UDP proxy server error: {}", e); - } - }); - Some(handle) - } else { - None - }; - - // Set up the config API - let api_handle = { - let api_routes = config_api::config_api_routes(self.config_manager.clone()); - - // Create a separate server for the API on a different port - let api_port = 8082; // Could be configurable - let api_addr = format!("0.0.0.0:{}", api_port); - - tracing::info!("Starting configuration API server on {}", api_addr); - - tokio::spawn(async move { - warp::serve(api_routes) - .run(api_addr.parse::().unwrap()) - .await; - }) - }; - - // Listen for configuration changes and handle them - let config_change_handle = tokio::spawn(async move { - tracing::info!("Starting configuration change listener"); - - loop { - tokio::select! { - // Handle shutdown signal - _ = shutdown_rx.recv() => { - tracing::info!("Received shutdown signal, stopping config listener"); - break; - } - - // Handle configuration changes - result = config_rx.recv() => { - match result { - Ok(event) => { - tracing::info!("Received configuration change event: {:?}", event); - - match event { - ConfigEvent::RouteAdded(name, config) => { - tracing::info!("Route added: {}", name); - // No need to restart services, they'll pick up the change - } - ConfigEvent::RouteUpdated(name, config) => { - tracing::info!("Route updated: {}", name); - // No need to restart services, they'll pick up the change - } - ConfigEvent::RouteRemoved(name) => { - tracing::info!("Route removed: {}", name); - // No need to restart services, they'll pick up the change - } - ConfigEvent::TcpConfigUpdated(tcp_config) => { - tracing::warn!("TCP configuration updated - some changes may require restart"); - // Here we could potentially restart TCP services if needed - } - ConfigEvent::GlobalSettingsUpdated { .. } => { - tracing::warn!("Global settings updated - some changes may require restart"); - // Here we could potentially restart services if needed - } - ConfigEvent::FullUpdate(_) => { - tracing::warn!("Full configuration replaced - some changes may require restart"); - // Here we could potentially restart all services if needed - } - } - } - Err(e) => { - if matches!(e, tokio::sync::broadcast::error::RecvError::Lagged(_)) { - tracing::warn!("Config listener lagged and missed messages"); - } else { - tracing::error!("Error receiving config change: {}", e); - break; - } - } - } - } - } - } - }); - - // TODO: Add this - // // Optional: start file watcher if config was loaded from file - // if let Some(path) = config_manager.file_path() { - // if let Err(e) = config_manager.start_file_watcher(30).await { - // tracing::error!("Failed to start file watcher: {}", e); - // } - // } - - // Wait for Ctrl+C or other shutdown signal - tokio::signal::ctrl_c().await?; - tracing::info!("Received shutdown signal"); - - // Attempt graceful shutdown - if let Some(tx) = &self.shutdown_tx { - let _ = tx.send(()).await; - } - - Ok(()) - } - - /// Manually trigger a configuration reload from file - pub async fn reload_config(&self) -> Result<()> { - self.config_manager.reload_from_file().await - } - - /// Get the current configuration - pub async fn get_config(&self) -> config::ProxyConfig { - self.config_manager.get_config().read().await.clone() - } - - /// Update a specific route - pub async fn update_route(&self, route_name: &str, route_config: config::RouteConfig) -> Result<()> { - self.config_manager.update_route(route_name, route_config).await - } - - /// Add a new route - pub async fn add_route(&self, route_name: &str, route_config: config::RouteConfig) -> Result<()> { - self.config_manager.add_route(route_name, route_config).await - } - - /// Remove a route - pub async fn remove_route(&self, route_name: &str) -> Result<()> { - self.config_manager.remove_route(route_name).await - } - - /// Update TCP proxy configuration - pub async fn update_tcp_config(&self, tcp_config: config::TcpProxyConfig) -> Result<()> { - self.config_manager.update_tcp_config(tcp_config).await - } - - /// Update global settings - pub async fn update_global_settings( - &self, - listen_addr: Option, - global_timeout_ms: Option, - max_connections: Option, - ) -> Result<()> { - self.config_manager.update_global_settings(listen_addr, global_timeout_ms, max_connections).await - } - - /// Replace the entire configuration - pub async fn replace_config(&self, new_config: config::ProxyConfig) -> Result<()> { - self.config_manager.replace_config(new_config).await - } -} - -// Re-export types for easier usage -pub use config::{ProxyConfig, RouteConfig, TcpProxyConfig}; -pub use dynamic_config::{DynamicConfigManager, ConfigEvent}; -pub use client::ConfigClient; \ No newline at end of file +// src/lib.rs - Main library entry point for embeddable router +use anyhow::Result; +use std::sync::Arc; +use tokio::sync::{broadcast, mpsc}; + +pub mod config; +pub mod plugin; +pub mod server; +pub mod router; +pub mod middleware; +pub mod metrics; +pub mod error; +pub mod builtin_plugins; +pub mod management_api; +pub mod health; +pub mod logging; + +// Re-export main types for easy usage +pub use config::RouterConfig; +pub use plugin::{Plugin, PluginConfig, PluginInfo, PluginCapabilities, PluginEvent, RouterEvent}; +pub use server::ProxyServer; +pub use error::{RouterError, Result as RouterResult}; + +/// Main Router struct - embeddable and programmable +pub struct Router { + server: Option>, + config: Arc>, + shutdown_sender: Option>, + event_sender: Option>, + event_receiver: Option>, +} + +impl Router { + /// Create a new router with configuration + pub async fn new(config: RouterConfig) -> Result { + let server = Arc::new(ProxyServer::new(config.clone()).await?); + let (shutdown_tx, _shutdown_rx) = broadcast::channel(1); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + + Ok(Self { + server: Some(server), + config: Arc::new(tokio::sync::RwLock::new(config)), + shutdown_sender: Some(shutdown_tx), + event_sender: Some(event_tx), + event_receiver: Some(event_rx), + }) + } + + /// Create a router from JSON configuration file + pub async fn from_config_file>(path: P) -> Result { + let config = RouterConfig::load_from_file(path).await?; + Self::new(config).await + } + + /// Create a router with default configuration + pub async fn with_defaults() -> Result { + let config = RouterConfig::default(); + Self::new(config).await + } + + /// Start the router server + pub async fn start(&mut self) -> Result<()> { + if let Some(server) = &self.server { + server.start().await?; + } else { + return Err(anyhow::anyhow!("Server not initialized")); + } + Ok(()) + } + + /// Stop the router server + pub async fn stop(&mut self) -> Result<()> { + if let Some(sender) = &self.shutdown_sender { + let _ = sender.send(()); + } + + if let Some(server) = &self.server { + server.stop().await?; + } + + Ok(()) + } + + /// Get the plugin manager for programmatic plugin management + pub fn plugin_manager(&self) -> Option> { + self.server.as_ref().map(|s| s.plugin_manager()) + } + + /// Load a plugin from a file + pub async fn load_plugin>( + &self, + path: P, + config: Option, + ) -> Result { + if let Some(pm) = self.plugin_manager() { + pm.load_plugin_from_path(path, config).await + } else { + Err(anyhow::anyhow!("Plugin manager not available")) + } + } + + /// Unload a plugin + pub async fn unload_plugin(&self, name: &str) -> Result<()> { + if let Some(pm) = self.plugin_manager() { + pm.unload_plugin(name).await + } else { + Err(anyhow::anyhow!("Plugin manager not available")) + } + } + + /// Register a plugin programmatically (for embedded usage) + pub async fn register_plugin( + &self, + name: String, + plugin: Box, + config: PluginConfig, + ) -> Result<()> { + if let Some(pm) = self.plugin_manager() { + pm.registry().register_plugin(name.clone(), plugin, config).await?; + pm.registry().start_plugin(&name).await?; + Ok(()) + } else { + Err(anyhow::anyhow!("Plugin manager not available")) + } + } + + /// Add a route programmatically + pub async fn add_route(&self, domain: String, proxy_instance: String) -> Result<()> { + let mut config = self.config.write().await; + + // Find the proxy instance + if !config.proxies.iter().any(|p| p.name == proxy_instance) { + return Err(anyhow::anyhow!("Proxy instance '{}' not found", proxy_instance)); + } + + // Add domain mapping + config.domains.insert(domain.clone(), config::DomainConfig { + proxy_instance, + backend_service: None, + ssl_config: None, + cors_config: None, + cache_config: None, + custom_headers: std::collections::HashMap::new(), + rewrite_rules: vec![], + access_control: None, + }); + + // Update the proxy instance to include this domain + for proxy in &mut config.proxies { + if proxy.name == proxy_instance { + if !proxy.domains.contains(&domain) { + proxy.domains.push(domain); + } + break; + } + } + + // Notify server of configuration change + if let Some(server) = &self.server { + server.reload_config(config.clone()).await?; + } + + Ok(()) + } + + /// Remove a route + pub async fn remove_route(&self, domain: &str) -> Result<()> { + let mut config = self.config.write().await; + + // Remove from domain mappings + if let Some(domain_config) = config.domains.remove(domain) { + // Remove from proxy instance + for proxy in &mut config.proxies { + if proxy.name == domain_config.proxy_instance { + proxy.domains.retain(|d| d != domain); + break; + } + } + + // Notify server of configuration change + if let Some(server) = &self.server { + server.reload_config(config.clone()).await?; + } + } + + Ok(()) + } + + /// Update configuration + pub async fn update_config(&self, new_config: RouterConfig) -> Result<()> { + { + let mut config = self.config.write().await; + *config = new_config.clone(); + } + + if let Some(server) = &self.server { + server.reload_config(new_config).await?; + } + + Ok(()) + } + + /// Get current configuration + pub async fn get_config(&self) -> RouterConfig { + self.config.read().await.clone() + } + + /// Get plugin information + pub async fn get_plugin_info(&self, name: &str) -> Option { + self.plugin_manager()?.get_plugin_info(name).await + } + + /// List all plugins + pub fn list_plugins(&self) -> Vec { + self.plugin_manager().map(|pm| pm.list_plugins()).unwrap_or_default() + } + + /// Get plugin health + pub async fn get_plugin_health(&self, name: &str) -> Option { + self.plugin_manager()?.get_plugin_health(name).await + } + + /// Get plugin metrics + pub async fn get_plugin_metrics(&self, name: &str) -> Option { + self.plugin_manager()?.get_plugin_metrics(name).await + } + + /// Get all metrics + pub async fn get_all_metrics(&self) -> std::collections::HashMap { + self.plugin_manager().map(|pm| + futures::executor::block_on(pm.get_all_metrics()) + ).unwrap_or_default() + } + + /// Send event to plugin + pub async fn send_event_to_plugin(&self, plugin_name: &str, event: PluginEvent) -> Result<()> { + if let Some(pm) = self.plugin_manager() { + pm.send_event_to_plugin(plugin_name, event).await + } else { + Err(anyhow::anyhow!("Plugin manager not available")) + } + } + + /// Broadcast event to all plugins + pub async fn broadcast_event(&self, event: PluginEvent) { + if let Some(pm) = self.plugin_manager() { + pm.broadcast_event(event).await; + } + } + + /// Execute plugin command + pub async fn execute_plugin_command( + &self, + plugin_name: &str, + command: &str, + args: serde_json::Value, + ) -> Result { + if let Some(pm) = self.plugin_manager() { + pm.execute_plugin_command(plugin_name, command, args).await + } else { + Err(anyhow::anyhow!("Plugin manager not available")) + } + } + + /// Enable management API + pub async fn enable_management_api(&self, port: u16) -> Result<()> { + if let Some(server) = &self.server { + server.enable_management_api(port).await + } else { + Err(anyhow::anyhow!("Server not available")) + } + } + + /// Get event receiver for monitoring router events + pub fn take_event_receiver(&mut self) -> Option> { + self.event_receiver.take() + } + + /// Get event sender for sending custom events + pub fn event_sender(&self) -> Option> { + self.event_sender.clone() + } + + /// Check if router is running + pub async fn is_running(&self) -> bool { + if let Some(server) = &self.server { + server.is_running().await + } else { + false + } + } + + /// Get server statistics + pub async fn get_statistics(&self) -> Option { + if let Some(server) = &self.server { + Some(server.get_statistics().await) + } else { + None + } + } + + /// Graceful shutdown with timeout + pub async fn shutdown_with_timeout(&mut self, timeout: std::time::Duration) -> Result<()> { + let shutdown_future = self.stop(); + + match tokio::time::timeout(timeout, shutdown_future).await { + Ok(result) => result, + Err(_) => { + tracing::warn!("Graceful shutdown timed out, forcing shutdown"); + // Force shutdown if needed + Ok(()) + } + } + } + + /// Wait for router to finish (blocking) + pub async fn wait(&self) -> Result<()> { + if let Some(server) = &self.server { + server.wait().await + } else { + Ok(()) + } + } +} + +impl Drop for Router { + fn drop(&mut self) { + // Attempt graceful shutdown on drop + if let Some(sender) = &self.shutdown_sender { + let _ = sender.send(()); + } + } +} + +/// Builder pattern for Router configuration +pub struct RouterBuilder { + config: RouterConfig, +} + +impl RouterBuilder { + pub fn new() -> Self { + Self { + config: RouterConfig::default(), + } + } + + /// Set listen addresses + pub fn listen_on(mut self, addresses: Vec) -> Self { + self.config.server.listen_addresses = addresses; + self + } + + /// Add plugin directory + pub fn plugin_directory>(mut self, directory: S) -> Self { + self.config.plugins.plugin_directories.push(directory.into()); + self + } + + /// Enable auto-reload + pub fn auto_reload(mut self, enabled: bool) -> Self { + self.config.plugins.auto_reload = enabled; + self + } + + /// Set max connections + pub fn max_connections(mut self, max: usize) -> Self { + self.config.server.max_connections = max; + self + } + + /// Enable metrics + pub fn enable_metrics(mut self, port: u16) -> Self { + self.config.monitoring.metrics_enabled = true; + self.config.monitoring.metrics_port = port; + self + } + + /// Enable health checks + pub fn enable_health_checks(mut self, port: u16) -> Self { + self.config.server.health_check_port = port; + self + } + + /// Add proxy instance + pub fn add_proxy(mut self, proxy: config::ProxyInstanceConfig) -> Self { + self.config.proxies.push(proxy); + self + } + + /// Add domain mapping + pub fn add_domain(mut self, domain: String, config: config::DomainConfig) -> Self { + self.config.domains.insert(domain, config); + self + } + + /// Build the router + pub async fn build(self) -> Result { + Router::new(self.config).await + } +} + +impl Default for RouterBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Convenience functions for quick setup +pub async fn create_http_proxy( + listen_port: u16, + domain: String, + backend_url: String, +) -> Result { + let proxy_config = config::ProxyInstanceConfig { + name: "http_proxy".to_string(), + plugin_type: "http_proxy".to_string(), + enabled: true, + priority: 0, + ports: vec![listen_port], + bind_addresses: vec!["0.0.0.0".to_string()], + domains: vec![domain.clone()], + plugin_config: serde_json::json!({ + "backend_url": backend_url, + "timeout_ms": 30000, + "retry_count": 3 + }), + middleware: vec![], + load_balancing: None, + health_check: Some(config::HealthCheckConfig { + enabled: true, + interval_seconds: 30, + timeout_seconds: 10, + path: Some("/health".to_string()), + expected_status: Some(200), + custom_check: None, + }), + circuit_breaker: None, + rate_limiting: None, + ssl_config: None, + }; + + let domain_config = config::DomainConfig { + proxy_instance: "http_proxy".to_string(), + backend_service: Some(backend_url), + ssl_config: None, + cors_config: None, + cache_config: None, + custom_headers: std::collections::HashMap::new(), + rewrite_rules: vec![], + access_control: None, + }; + + RouterBuilder::new() + .listen_on(vec![format!("0.0.0.0:{}", listen_port)]) + .add_proxy(proxy_config) + .add_domain(domain, domain_config) + .build() + .await +} + +pub async fn create_tcp_proxy( + listen_port: u16, + backend_address: String, +) -> Result { + let proxy_config = config::ProxyInstanceConfig { + name: "tcp_proxy".to_string(), + plugin_type: "tcp_proxy".to_string(), + enabled: true, + priority: 0, + ports: vec![listen_port], + bind_addresses: vec!["0.0.0.0".to_string()], + domains: vec![], + plugin_config: serde_json::json!({ + "backend_address": backend_address, + "connection_pooling": true, + "max_idle_time_secs": 60 + }), + middleware: vec![], + load_balancing: None, + health_check: Some(config::HealthCheckConfig { + enabled: true, + interval_seconds: 60, + timeout_seconds: 5, + path: None, + expected_status: None, + custom_check: Some(serde_json::json!({ + "type": "tcp_connect" + })), + }), + circuit_breaker: None, + rate_limiting: None, + ssl_config: None, + }; + + RouterBuilder::new() + .listen_on(vec![format!("0.0.0.0:{}", listen_port)]) + .add_proxy(proxy_config) + .build() + .await +} + +/// Async-friendly router handle for embedding in other async applications +pub struct RouterHandle { + router: Arc>, + handle: tokio::task::JoinHandle>, +} + +impl RouterHandle { + pub async fn spawn(mut router: Router) -> Result { + let router_arc = Arc::new(tokio::sync::Mutex::new(router)); + let router_clone = router_arc.clone(); + + let handle = tokio::spawn(async move { + let mut router = router_clone.lock().await; + router.start().await?; + router.wait().await + }); + + Ok(Self { + router: router_arc, + handle, + }) + } + + pub async fn stop(&self) -> Result<()> { + let mut router = self.router.lock().await; + router.stop().await?; + self.handle.abort(); + Ok(()) + } + + pub async fn router(&self) -> tokio::sync::MutexGuard { + self.router.lock().await + } +} \ No newline at end of file diff --git a/src/logging.rs b/src/logging.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/management_api.rs b/src/management_api.rs new file mode 100644 index 0000000..7fe18d8 --- /dev/null +++ b/src/management_api.rs @@ -0,0 +1,762 @@ +// src/management_api.rs - Management REST API for the router +use crate::config::RouterConfig; +use crate::plugin::manager::PluginManager; +use crate::plugin::{PluginConfig, PluginEvent, PluginInfo, PluginHealth, PluginMetrics}; +use crate::error::{RouterError, Result}; + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; +use warp::{Filter, Rejection, Reply}; +use warp::http::StatusCode; +use tracing; + +/// Management API server +pub struct ManagementApi { + port: u16, + plugin_manager: Arc, + config: Arc>, + shutdown_sender: Option>, + authentication: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApiAuthentication { + pub api_keys: Vec, + pub basic_auth: Option, + pub jwt_secret: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BasicAuth { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiResponse { + pub success: bool, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + pub timestamp: String, +} + +impl ApiResponse { + pub fn success(message: &str, data: Option) -> Self { + Self { + success: true, + message: message.to_string(), + data, + error: None, + timestamp: chrono::Utc::now().to_rfc3339(), + } + } + + pub fn error(message: &str, error_detail: Option) -> Self { + Self { + success: false, + message: message.to_string(), + data: None, + error: error_detail, + timestamp: chrono::Utc::now().to_rfc3339(), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PluginStatusRequest { + pub enabled: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PluginConfigUpdateRequest { + pub config: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RouteCreateRequest { + pub domain: String, + pub proxy_instance: String, + pub config: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SystemStatus { + pub uptime_seconds: u64, + pub version: String, + pub total_plugins: usize, + pub running_plugins: usize, + pub total_connections: u64, + pub memory_usage_mb: f64, + pub cpu_usage_percent: f64, +} + +impl ManagementApi { + pub fn new( + port: u16, + plugin_manager: Arc, + config: Arc>, + ) -> Self { + Self { + port, + plugin_manager, + config, + shutdown_sender: None, + authentication: None, + } + } + + pub fn with_authentication(mut self, auth: ApiAuthentication) -> Self { + self.authentication = Some(auth); + self + } + + pub async fn start(&mut self) -> Result<()> { + let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1); + self.shutdown_sender = Some(shutdown_tx); + + let routes = self.create_routes(); + let addr = SocketAddr::from(([0, 0, 0, 0], self.port)); + + tracing::info!("Starting management API on port {}", self.port); + + let server = warp::serve(routes) + .bind_with_graceful_shutdown(addr, async move { + let _ = shutdown_rx.recv().await; + tracing::info!("Shutting down management API"); + }); + + tokio::spawn(server); + + Ok(()) + } + + pub async fn stop(&self) -> Result<()> { + if let Some(sender) = &self.shutdown_sender { + let _ = sender.send(()).await; + } + Ok(()) + } + + fn create_routes(&self) -> impl Filter + Clone { + let cors = warp::cors() + .allow_any_origin() + .allow_headers(vec!["content-type", "authorization", "x-api-key"]) + .allow_methods(vec!["GET", "POST", "PUT", "DELETE", "PATCH"]); + + // Create route filters + let system_routes = self.system_routes(); + let plugin_routes = self.plugin_routes(); + let config_routes = self.config_routes(); + let health_routes = self.health_routes(); + let metrics_routes = self.metrics_routes(); + + system_routes + .or(plugin_routes) + .or(config_routes) + .or(health_routes) + .or(metrics_routes) + .with(cors) + .with(warp::log("management_api")) + .recover(handle_rejection) + } + + fn system_routes(&self) -> impl Filter + Clone { + let plugin_manager = self.plugin_manager.clone(); + let config = self.config.clone(); + + // GET /api/system/status + let status = warp::path!("api" / "system" / "status") + .and(warp::get()) + .and_then(move || { + let pm = plugin_manager.clone(); + let cfg = config.clone(); + async move { + handle_get_system_status(pm, cfg).await + } + }); + + // POST /api/system/shutdown + let shutdown = warp::path!("api" / "system" / "shutdown") + .and(warp::post()) + .and_then(|| async { + handle_system_shutdown().await + }); + + // GET /api/system/version + let version = warp::path!("api" / "system" / "version") + .and(warp::get()) + .and_then(|| async { + Ok::<_, Rejection>(warp::reply::json(&ApiResponse::success( + "Version information", + Some(serde_json::json!({ + "version": env!("CARGO_PKG_VERSION"), + "build_date": env!("BUILD_DATE").unwrap_or("unknown"), + "git_hash": env!("GIT_HASH").unwrap_or("unknown"), + })) + ))) + }); + + status.or(shutdown).or(version) + } + + fn plugin_routes(&self) -> impl Filter + Clone { + let plugin_manager = self.plugin_manager.clone(); + + // GET /api/plugins + let list_plugins = warp::path!("api" / "plugins") + .and(warp::get()) + .and_then(move || { + let pm = plugin_manager.clone(); + async move { + handle_list_plugins(pm).await + } + }); + + let plugin_manager2 = self.plugin_manager.clone(); + + // GET /api/plugins/{name} + let get_plugin = warp::path!("api" / "plugins" / String) + .and(warp::get()) + .and_then(move |name: String| { + let pm = plugin_manager2.clone(); + async move { + handle_get_plugin(pm, name).await + } + }); + + let plugin_manager3 = self.plugin_manager.clone(); + + // POST /api/plugins/{name}/reload + let reload_plugin = warp::path!("api" / "plugins" / String / "reload") + .and(warp::post()) + .and_then(move |name: String| { + let pm = plugin_manager3.clone(); + async move { + handle_reload_plugin(pm, name).await + } + }); + + let plugin_manager4 = self.plugin_manager.clone(); + + // PUT /api/plugins/{name}/config + let update_plugin_config = warp::path!("api" / "plugins" / String / "config") + .and(warp::put()) + .and(warp::body::json()) + .and_then(move |name: String, req: PluginConfigUpdateRequest| { + let pm = plugin_manager4.clone(); + async move { + handle_update_plugin_config(pm, name, req).await + } + }); + + let plugin_manager5 = self.plugin_manager.clone(); + + // POST /api/plugins/{name}/command + let plugin_command = warp::path!("api" / "plugins" / String / "command") + .and(warp::post()) + .and(warp::body::json()) + .and_then(move |name: String, body: serde_json::Value| { + let pm = plugin_manager5.clone(); + async move { + handle_plugin_command(pm, name, body).await + } + }); + + list_plugins + .or(get_plugin) + .or(reload_plugin) + .or(update_plugin_config) + .or(plugin_command) + } + + fn config_routes(&self) -> impl Filter + Clone { + let config = self.config.clone(); + let plugin_manager = self.plugin_manager.clone(); + + // GET /api/config + let get_config = warp::path!("api" / "config") + .and(warp::get()) + .and_then(move || { + let cfg = config.clone(); + async move { + handle_get_config(cfg).await + } + }); + + let config2 = self.config.clone(); + + // PUT /api/config + let update_config = warp::path!("api" / "config") + .and(warp::put()) + .and(warp::body::json()) + .and_then(move |new_config: RouterConfig| { + let cfg = config2.clone(); + async move { + handle_update_config(cfg, new_config).await + } + }); + + let config3 = self.config.clone(); + + // POST /api/config/reload + let reload_config = warp::path!("api" / "config" / "reload") + .and(warp::post()) + .and_then(move || { + let cfg = config3.clone(); + async move { + handle_reload_config(cfg).await + } + }); + + let config4 = self.config.clone(); + + // POST /api/routes + let create_route = warp::path!("api" / "routes") + .and(warp::post()) + .and(warp::body::json()) + .and_then(move |req: RouteCreateRequest| { + let cfg = config4.clone(); + async move { + handle_create_route(cfg, req).await + } + }); + + let config5 = self.config.clone(); + + // DELETE /api/routes/{domain} + let delete_route = warp::path!("api" / "routes" / String) + .and(warp::delete()) + .and_then(move |domain: String| { + let cfg = config5.clone(); + async move { + handle_delete_route(cfg, domain).await + } + }); + + get_config + .or(update_config) + .or(reload_config) + .or(create_route) + .or(delete_route) + } + + fn health_routes(&self) -> impl Filter + Clone { + let plugin_manager = self.plugin_manager.clone(); + + // GET /api/health + let overall_health = warp::path!("api" / "health") + .and(warp::get()) + .and_then(move || { + let pm = plugin_manager.clone(); + async move { + handle_get_overall_health(pm).await + } + }); + + let plugin_manager2 = self.plugin_manager.clone(); + + // GET /api/health/plugins + let plugin_health = warp::path!("api" / "health" / "plugins") + .and(warp::get()) + .and_then(move || { + let pm = plugin_manager2.clone(); + async move { + handle_get_plugin_health(pm).await + } + }); + + overall_health.or(plugin_health) + } + + fn metrics_routes(&self) -> impl Filter + Clone { + let plugin_manager = self.plugin_manager.clone(); + + // GET /api/metrics + let get_metrics = warp::path!("api" / "metrics") + .and(warp::get()) + .and_then(move || { + let pm = plugin_manager.clone(); + async move { + handle_get_metrics(pm).await + } + }); + + let plugin_manager2 = self.plugin_manager.clone(); + + // GET /api/metrics/prometheus + let prometheus_metrics = warp::path!("api" / "metrics" / "prometheus") + .and(warp::get()) + .and_then(move || { + let pm = plugin_manager2.clone(); + async move { + handle_get_prometheus_metrics(pm).await + } + }); + + get_metrics.or(prometheus_metrics) + } +} + +// Handler implementations +async fn handle_get_system_status( + plugin_manager: Arc, + config: Arc>, +) -> Result { + let plugins = plugin_manager.list_plugins(); + let total_plugins = plugins.len(); + let running_plugins = total_plugins; // Simplified - in practice you'd check each plugin status + + let status = SystemStatus { + uptime_seconds: 0, // Would track actual uptime + version: env!("CARGO_PKG_VERSION").to_string(), + total_plugins, + running_plugins, + total_connections: 0, // Would get from metrics + memory_usage_mb: get_memory_usage(), + cpu_usage_percent: get_cpu_usage(), + }; + + Ok(warp::reply::json(&ApiResponse::success( + "System status retrieved", + Some(status), + ))) +} + +async fn handle_system_shutdown() -> Result { + // In a real implementation, this would trigger graceful shutdown + tracing::warn!("System shutdown requested via API"); + + Ok(warp::reply::json(&ApiResponse::<()>::success( + "Shutdown initiated", + None, + ))) +} + +async fn handle_list_plugins( + plugin_manager: Arc, +) -> Result { + let plugins = plugin_manager.list_plugins(); + let mut plugin_infos = Vec::new(); + + for plugin_name in plugins { + if let Some(info) = plugin_manager.get_plugin_info(&plugin_name).await { + plugin_infos.push(info); + } + } + + Ok(warp::reply::json(&ApiResponse::success( + "Plugins listed successfully", + Some(plugin_infos), + ))) +} + +async fn handle_get_plugin( + plugin_manager: Arc, + name: String, +) -> Result { + match plugin_manager.get_plugin_info(&name).await { + Some(info) => { + let health = plugin_manager.get_plugin_health(&name).await; + let metrics = plugin_manager.get_plugin_metrics(&name).await; + + let plugin_details = serde_json::json!({ + "info": info, + "health": health, + "metrics": metrics, + }); + + Ok(warp::reply::json(&ApiResponse::success( + &format!("Plugin '{}' details", name), + Some(plugin_details), + ))) + } + None => Ok(warp::reply::json(&ApiResponse::<()>::error( + &format!("Plugin '{}' not found", name), + None, + ))), + } +} + +async fn handle_reload_plugin( + plugin_manager: Arc, + name: String, +) -> Result { + match plugin_manager.reload_plugin(&name).await { + Ok(_) => Ok(warp::reply::json(&ApiResponse::<()>::success( + &format!("Plugin '{}' reloaded successfully", name), + None, + ))), + Err(e) => Ok(warp::reply::json(&ApiResponse::<()>::error( + &format!("Failed to reload plugin '{}'", name), + Some(e.to_string()), + ))), + } +} + +async fn handle_update_plugin_config( + plugin_manager: Arc, + name: String, + req: PluginConfigUpdateRequest, +) -> Result { + // Get current plugin config and update it + if let Some(mut current_config) = plugin_manager.registry().get_plugin_config(&name).await { + current_config.config = req.config; + + match plugin_manager.update_plugin_config(&name, current_config).await { + Ok(_) => Ok(warp::reply::json(&ApiResponse::<()>::success( + &format!("Plugin '{}' configuration updated", name), + None, + ))), + Err(e) => Ok(warp::reply::json(&ApiResponse::<()>::error( + &format!("Failed to update plugin '{}' configuration", name), + Some(e.to_string()), + ))), + } + } else { + Ok(warp::reply::json(&ApiResponse::<()>::error( + &format!("Plugin '{}' not found", name), + None, + ))) + } +} + +async fn handle_plugin_command( + plugin_manager: Arc, + name: String, + body: serde_json::Value, +) -> Result { + let command = body.get("command").and_then(|v| v.as_str()).unwrap_or("status"); + let args = body.get("args").cloned().unwrap_or(serde_json::json!({})); + + match plugin_manager.execute_plugin_command(&name, command, args).await { + Ok(result) => Ok(warp::reply::json(&ApiResponse::success( + &format!("Command '{}' executed on plugin '{}'", command, name), + Some(result), + ))), + Err(e) => Ok(warp::reply::json(&ApiResponse::<()>::error( + &format!("Failed to execute command '{}' on plugin '{}'", command, name), + Some(e.to_string()), + ))), + } +} + +async fn handle_get_config(config: Arc>) -> Result { + let cfg = config.read().await.clone(); + Ok(warp::reply::json(&ApiResponse::success( + "Configuration retrieved", + Some(cfg), + ))) +} + +async fn handle_update_config( + config: Arc>, + new_config: RouterConfig, +) -> Result { + { + let mut cfg = config.write().await; + *cfg = new_config; + } + + Ok(warp::reply::json(&ApiResponse::<()>::success( + "Configuration updated successfully", + None, + ))) +} + +async fn handle_reload_config(config: Arc>) -> Result { + // In a real implementation, this would reload from file + Ok(warp::reply::json(&ApiResponse::<()>::success( + "Configuration reloaded successfully", + None, + ))) +} + +async fn handle_create_route( + config: Arc>, + req: RouteCreateRequest, +) -> Result { + let mut cfg = config.write().await; + + let domain_config = req.config.unwrap_or(crate::config::DomainConfig { + proxy_instance: req.proxy_instance, + backend_service: None, + ssl_config: None, + cors_config: None, + cache_config: None, + custom_headers: HashMap::new(), + rewrite_rules: vec![], + access_control: None, + }); + + cfg.domains.insert(req.domain.clone(), domain_config); + + Ok(warp::reply::json(&ApiResponse::<()>::success( + &format!("Route created for domain '{}'", req.domain), + None, + ))) +} + +async fn handle_delete_route( + config: Arc>, + domain: String, +) -> Result { + let mut cfg = config.write().await; + + match cfg.domains.remove(&domain) { + Some(_) => Ok(warp::reply::json(&ApiResponse::<()>::success( + &format!("Route deleted for domain '{}'", domain), + None, + ))), + None => Ok(warp::reply::json(&ApiResponse::<()>::error( + &format!("Route not found for domain '{}'", domain), + None, + ))), + } +} + +async fn handle_get_overall_health( + plugin_manager: Arc, +) -> Result { + let plugins = plugin_manager.list_plugins(); + let mut all_healthy = true; + let mut health_details = HashMap::new(); + + for plugin_name in plugins { + if let Some(health) = plugin_manager.get_plugin_health(&plugin_name).await { + health_details.insert(plugin_name.clone(), health.healthy); + if !health.healthy { + all_healthy = false; + } + } + } + + let overall_health = serde_json::json!({ + "healthy": all_healthy, + "plugins": health_details, + "timestamp": chrono::Utc::now().to_rfc3339(), + }); + + Ok(warp::reply::json(&ApiResponse::success( + "Overall health status", + Some(overall_health), + ))) +} + +async fn handle_get_plugin_health( + plugin_manager: Arc, +) -> Result { + let plugins = plugin_manager.list_plugins(); + let mut health_status = HashMap::new(); + + for plugin_name in plugins { + if let Some(health) = plugin_manager.get_plugin_health(&plugin_name).await { + health_status.insert(plugin_name, health); + } + } + + Ok(warp::reply::json(&ApiResponse::success( + "Plugin health status", + Some(health_status), + ))) +} + +async fn handle_get_metrics( + plugin_manager: Arc, +) -> Result { + let metrics = plugin_manager.get_all_metrics().await; + + Ok(warp::reply::json(&ApiResponse::success( + "Metrics retrieved", + Some(metrics), + ))) +} + +async fn handle_get_prometheus_metrics( + plugin_manager: Arc, +) -> Result { + let metrics = plugin_manager.get_all_metrics().await; + + // Convert to Prometheus format + let mut prometheus_output = String::new(); + + for (plugin_name, plugin_metrics) in metrics { + prometheus_output.push_str(&format!( + "# HELP router_plugin_connections_total Total connections for plugin\n" + )); + prometheus_output.push_str(&format!( + "# TYPE router_plugin_connections_total counter\n" + )); + prometheus_output.push_str(&format!( + "router_plugin_connections_total{{plugin=\"{}\"}} {}\n", + plugin_name, plugin_metrics.connections_total + )); + + prometheus_output.push_str(&format!( + "router_plugin_connections_active{{plugin=\"{}\"}} {}\n", + plugin_name, plugin_metrics.connections_active + )); + + prometheus_output.push_str(&format!( + "router_plugin_bytes_sent_total{{plugin=\"{}\"}} {}\n", + plugin_name, plugin_metrics.bytes_sent + )); + + prometheus_output.push_str(&format!( + "router_plugin_bytes_received_total{{plugin=\"{}\"}} {}\n", + plugin_name, plugin_metrics.bytes_received + )); + + prometheus_output.push_str(&format!( + "router_plugin_errors_total{{plugin=\"{}\"}} {}\n", + plugin_name, plugin_metrics.errors_total + )); + } + + Ok(warp::reply::with_header( + prometheus_output, + "content-type", + "text/plain; version=0.0.4", + )) +} + +// Error handling +async fn handle_rejection(err: Rejection) -> Result { + let (code, message) = if err.is_not_found() { + (StatusCode::NOT_FOUND, "Not Found") + } else if let Some(_) = err.find::() { + (StatusCode::BAD_REQUEST, "Invalid JSON body") + } else if let Some(_) = err.find::() { + (StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed") + } else { + (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error") + }; + + let json = warp::reply::json(&ApiResponse::<()>::error(message, None)); + Ok(warp::reply::with_status(json, code)) +} + +// System utilities +fn get_memory_usage() -> f64 { + // Simplified memory usage - in practice use proper system monitoring + #[cfg(target_os = "linux")] + { + if let Ok(process) = procfs::process::Process::myself() { + if let Ok(stat) = process.stat() { + return (stat.rss * 4096) as f64 / 1024.0 / 1024.0; // Convert to MB + } + } + } + 0.0 +} + +fn get_cpu_usage() -> f64 { + // Simplified CPU usage - in practice use proper system monitoring + 0.0 +} \ No newline at end of file diff --git a/src/metrics.rs b/src/metrics.rs index b25c75d..6f51ba9 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,8 +1,740 @@ -use anyhow::Result; -use metrics_exporter_prometheus::PrometheusBuilder; - -pub fn init_metrics() -> Result<()> { - let builder = PrometheusBuilder::new(); - builder.install()?; - Ok(()) -} +// src/metrics.rs - Comprehensive metrics collection system +use crate::config::MonitoringConfig; +use crate::error::{RouterError, Result}; +use crate::plugin::{PluginMetrics, RouterEvent}; + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime}; +use tokio::sync::{mpsc, RwLock}; +use tokio::time::interval; +use tracing::{debug, error, info, warn}; + +/// Main metrics collector +#[derive(Debug)] +pub struct MetricsCollector { + config: MonitoringConfig, + registry: Arc, + exporters: Vec>, + shutdown_sender: Option>, + running: Arc>, + start_time: Instant, +} + +/// Central metrics registry +#[derive(Debug)] +pub struct MetricsRegistry { + counters: RwLock>>, + gauges: RwLock>>, + histograms: RwLock>>, + custom_metrics: RwLock>, + labels: RwLock>>, +} + +/// Counter metric - monotonically increasing value +#[derive(Debug)] +pub struct Counter { + name: String, + value: RwLock, + labels: HashMap, + created_at: SystemTime, +} + +/// Gauge metric - value that can go up and down +#[derive(Debug)] +pub struct Gauge { + name: String, + value: RwLock, + labels: HashMap, + created_at: SystemTime, +} + +/// Histogram metric - distribution of values +#[derive(Debug)] +pub struct Histogram { + name: String, + buckets: Vec, + counts: RwLock>, + sum: RwLock, + count: RwLock, + labels: HashMap, + created_at: SystemTime, +} + +/// Metrics snapshot for export +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MetricsSnapshot { + pub timestamp: SystemTime, + pub counters: HashMap, + pub gauges: HashMap, + pub histograms: HashMap, + pub custom_metrics: HashMap, + pub system_metrics: SystemMetrics, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CounterValue { + pub value: u64, + pub labels: HashMap, + pub created_at: SystemTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GaugeValue { + pub value: f64, + pub labels: HashMap, + pub created_at: SystemTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HistogramValue { + pub buckets: Vec, + pub counts: Vec, + pub sum: f64, + pub count: u64, + pub labels: HashMap, + pub created_at: SystemTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SystemMetrics { + pub uptime_seconds: u64, + pub memory_usage_bytes: u64, + pub cpu_usage_percent: f64, + pub goroutines: u64, + pub open_file_descriptors: u64, + pub network_connections: u64, +} + +/// Trait for metrics exporters +pub trait MetricsExporter: Send + Sync + std::fmt::Debug { + fn export(&self, snapshot: &MetricsSnapshot) -> Result<()>; + fn name(&self) -> &str; + fn shutdown(&self) -> Result<()>; +} + +/// Prometheus metrics exporter +#[derive(Debug)] +pub struct PrometheusExporter { + endpoint: String, + push_gateway: Option, + job_name: String, + client: reqwest::Client, +} + +/// JSON metrics exporter +#[derive(Debug)] +pub struct JsonExporter { + output_path: std::path::PathBuf, + pretty_print: bool, +} + +/// StatsD metrics exporter +#[derive(Debug)] +pub struct StatsdExporter { + address: String, + prefix: String, + client: Option, +} + +impl MetricsCollector { + pub fn new(config: MonitoringConfig) -> Self { + Self { + config, + registry: Arc::new(MetricsRegistry::new()), + exporters: Vec::new(), + shutdown_sender: None, + running: Arc::new(RwLock::new(false)), + start_time: Instant::now(), + } + } + + pub fn with_prometheus_exporter(mut self, endpoint: String) -> Self { + let exporter = PrometheusExporter::new(endpoint); + self.exporters.push(Box::new(exporter)); + self + } + + pub fn with_json_exporter(mut self, output_path: std::path::PathBuf) -> Self { + let exporter = JsonExporter::new(output_path); + self.exporters.push(Box::new(exporter)); + self + } + + pub fn with_statsd_exporter(mut self, address: String, prefix: String) -> Self { + let exporter = StatsdExporter::new(address, prefix); + self.exporters.push(Box::new(exporter)); + self + } + + pub async fn start(&self) -> Result<()> { + *self.running.write().await = true; + + let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1); + // self.shutdown_sender = Some(shutdown_tx); // This would need to be mutable + + // Register default metrics + self.register_default_metrics().await?; + + // Start collection loop + let registry = self.registry.clone(); + let exporters = self.exporters.iter().map(|e| e.name().to_string()).collect::>(); + let running = self.running.clone(); + let collection_interval = Duration::from_secs(self.config.metrics_collection_interval_seconds); + + tokio::spawn(async move { + let mut interval = interval(collection_interval); + + while *running.read().await { + tokio::select! { + _ = interval.tick() => { + if let Err(e) = Self::collect_and_export(®istry, &exporters).await { + error!("Metrics collection error: {}", e); + } + } + _ = shutdown_rx.recv() => { + info!("Shutting down metrics collector"); + break; + } + } + } + }); + + info!("Metrics collector started with {} exporters", self.exporters.len()); + Ok(()) + } + + pub async fn stop(&self) -> Result<()> { + *self.running.write().await = false; + + if let Some(sender) = &self.shutdown_sender { + let _ = sender.send(()).await; + } + + // Shutdown exporters + for exporter in &self.exporters { + if let Err(e) = exporter.shutdown() { + error!("Error shutting down exporter {}: {}", exporter.name(), e); + } + } + + info!("Metrics collector stopped"); + Ok(()) + } + + pub fn registry(&self) -> Arc { + self.registry.clone() + } + + pub async fn get_snapshot(&self) -> MetricsSnapshot { + self.registry.snapshot().await + } + + pub async fn handle_router_event(&self, event: RouterEvent) { + match event { + RouterEvent::PluginStarted(name) => { + self.registry.increment_counter(&format!("plugin.{}.started", name), None).await; + } + RouterEvent::PluginStopped(name) => { + self.registry.increment_counter(&format!("plugin.{}.stopped", name), None).await; + } + RouterEvent::PluginError { plugin_name, .. } => { + self.registry.increment_counter(&format!("plugin.{}.errors", plugin_name), None).await; + } + RouterEvent::MetricsUpdated { plugin_name, metrics } => { + self.update_plugin_metrics(&plugin_name, &metrics).await; + } + _ => {} + } + } + + async fn register_default_metrics(&self) -> Result<()> { + // System metrics + self.registry.create_gauge("system.uptime_seconds", None).await?; + self.registry.create_gauge("system.memory_usage_bytes", None).await?; + self.registry.create_gauge("system.cpu_usage_percent", None).await?; + self.registry.create_gauge("system.open_file_descriptors", None).await?; + + // Router metrics + self.registry.create_counter("router.requests_total", None).await?; + self.registry.create_counter("router.responses_total", None).await?; + self.registry.create_counter("router.errors_total", None).await?; + self.registry.create_gauge("router.active_connections", None).await?; + + // Create histogram for request duration + let buckets = vec![0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]; + self.registry.create_histogram("router.request_duration_seconds", buckets, None).await?; + + info!("Default metrics registered"); + Ok(()) + } + + async fn update_plugin_metrics(&self, plugin_name: &str, metrics: &PluginMetrics) { + let labels = Some(vec![("plugin".to_string(), plugin_name.to_string())].into_iter().collect()); + + // Update plugin-specific metrics + self.registry.set_gauge(&format!("plugin.connections_total"), metrics.connections_total as f64, labels.clone()).await; + self.registry.set_gauge(&format!("plugin.connections_active"), metrics.connections_active as f64, labels.clone()).await; + self.registry.set_gauge(&format!("plugin.bytes_sent"), metrics.bytes_sent as f64, labels.clone()).await; + self.registry.set_gauge(&format!("plugin.bytes_received"), metrics.bytes_received as f64, labels.clone()).await; + self.registry.set_gauge(&format!("plugin.errors_total"), metrics.errors_total as f64, labels.clone()).await; + + // Update custom metrics + for (key, value) in &metrics.custom_metrics { + self.registry.set_gauge(&format!("plugin.custom.{}", key), *value, labels.clone()).await; + } + } + + async fn collect_and_export( + registry: &MetricsRegistry, + _exporters: &[String], + ) -> Result<()> { + // Update system metrics + registry.set_gauge("system.uptime_seconds", get_uptime_seconds(), None).await; + registry.set_gauge("system.memory_usage_bytes", get_memory_usage_bytes(), None).await; + registry.set_gauge("system.cpu_usage_percent", get_cpu_usage_percent(), None).await; + registry.set_gauge("system.open_file_descriptors", get_open_file_descriptors(), None).await; + + // Create snapshot + let snapshot = registry.snapshot().await; + + // Export to all configured exporters + // Note: In the real implementation, you'd iterate over actual exporter instances + debug!("Collected metrics: {} counters, {} gauges, {} histograms", + snapshot.counters.len(), + snapshot.gauges.len(), + snapshot.histograms.len()); + + Ok(()) + } +} + +impl MetricsRegistry { + pub fn new() -> Self { + Self { + counters: RwLock::new(HashMap::new()), + gauges: RwLock::new(HashMap::new()), + histograms: RwLock::new(HashMap::new()), + custom_metrics: RwLock::new(HashMap::new()), + labels: RwLock::new(HashMap::new()), + } + } + + pub async fn create_counter(&self, name: &str, labels: Option>) -> Result<()> { + let counter = Arc::new(Counter::new(name, labels.unwrap_or_default())); + let mut counters = self.counters.write().await; + counters.insert(name.to_string(), counter); + Ok(()) + } + + pub async fn create_gauge(&self, name: &str, labels: Option>) -> Result<()> { + let gauge = Arc::new(Gauge::new(name, labels.unwrap_or_default())); + let mut gauges = self.gauges.write().await; + gauges.insert(name.to_string(), gauge); + Ok(()) + } + + pub async fn create_histogram(&self, name: &str, buckets: Vec, labels: Option>) -> Result<()> { + let histogram = Arc::new(Histogram::new(name, buckets, labels.unwrap_or_default())); + let mut histograms = self.histograms.write().await; + histograms.insert(name.to_string(), histogram); + Ok(()) + } + + pub async fn increment_counter(&self, name: &str, labels: Option>) { + let counters = self.counters.read().await; + if let Some(counter) = counters.get(name) { + counter.increment().await; + } else { + // Auto-create counter if it doesn't exist + drop(counters); + let _ = self.create_counter(name, labels).await; + let counters = self.counters.read().await; + if let Some(counter) = counters.get(name) { + counter.increment().await; + } + } + } + + pub async fn add_to_counter(&self, name: &str, value: u64, _labels: Option>) { + let counters = self.counters.read().await; + if let Some(counter) = counters.get(name) { + counter.add(value).await; + } + } + + pub async fn set_gauge(&self, name: &str, value: f64, labels: Option>) { + let gauges = self.gauges.read().await; + if let Some(gauge) = gauges.get(name) { + gauge.set(value).await; + } else { + // Auto-create gauge if it doesn't exist + drop(gauges); + let _ = self.create_gauge(name, labels).await; + let gauges = self.gauges.read().await; + if let Some(gauge) = gauges.get(name) { + gauge.set(value).await; + } + } + } + + pub async fn observe_histogram(&self, name: &str, value: f64, _labels: Option>) { + let histograms = self.histograms.read().await; + if let Some(histogram) = histograms.get(name) { + histogram.observe(value).await; + } + } + + pub async fn set_custom_metric(&self, name: &str, value: serde_json::Value) { + let mut custom_metrics = self.custom_metrics.write().await; + custom_metrics.insert(name.to_string(), value); + } + + pub async fn snapshot(&self) -> MetricsSnapshot { + let counters = self.counters.read().await; + let gauges = self.gauges.read().await; + let histograms = self.histograms.read().await; + let custom_metrics = self.custom_metrics.read().await; + + let mut counter_values = HashMap::new(); + for (name, counter) in counters.iter() { + counter_values.insert(name.clone(), counter.snapshot().await); + } + + let mut gauge_values = HashMap::new(); + for (name, gauge) in gauges.iter() { + gauge_values.insert(name.clone(), gauge.snapshot().await); + } + + let mut histogram_values = HashMap::new(); + for (name, histogram) in histograms.iter() { + histogram_values.insert(name.clone(), histogram.snapshot().await); + } + + MetricsSnapshot { + timestamp: SystemTime::now(), + counters: counter_values, + gauges: gauge_values, + histograms: histogram_values, + custom_metrics: custom_metrics.clone(), + system_metrics: SystemMetrics { + uptime_seconds: get_uptime_seconds() as u64, + memory_usage_bytes: get_memory_usage_bytes() as u64, + cpu_usage_percent: get_cpu_usage_percent(), + goroutines: 0, // Not applicable to Rust + open_file_descriptors: get_open_file_descriptors() as u64, + network_connections: get_network_connections(), + }, + } + } +} + +impl Counter { + fn new(name: &str, labels: HashMap) -> Self { + Self { + name: name.to_string(), + value: RwLock::new(0), + labels, + created_at: SystemTime::now(), + } + } + + pub async fn increment(&self) { + let mut value = self.value.write().await; + *value += 1; + } + + pub async fn add(&self, amount: u64) { + let mut value = self.value.write().await; + *value += amount; + } + + pub async fn get(&self) -> u64 { + *self.value.read().await + } + + pub async fn snapshot(&self) -> CounterValue { + CounterValue { + value: self.get().await, + labels: self.labels.clone(), + created_at: self.created_at, + } + } +} + +impl Gauge { + fn new(name: &str, labels: HashMap) -> Self { + Self { + name: name.to_string(), + value: RwLock::new(0.0), + labels, + created_at: SystemTime::now(), + } + } + + pub async fn set(&self, value: f64) { + let mut val = self.value.write().await; + *val = value; + } + + pub async fn add(&self, amount: f64) { + let mut val = self.value.write().await; + *val += amount; + } + + pub async fn get(&self) -> f64 { + *self.value.read().await + } + + pub async fn snapshot(&self) -> GaugeValue { + GaugeValue { + value: self.get().await, + labels: self.labels.clone(), + created_at: self.created_at, + } + } +} + +impl Histogram { + fn new(name: &str, buckets: Vec, labels: HashMap) -> Self { + let counts = vec![0u64; buckets.len() + 1]; // +1 for +Inf bucket + Self { + name: name.to_string(), + buckets, + counts: RwLock::new(counts), + sum: RwLock::new(0.0), + count: RwLock::new(0), + labels, + created_at: SystemTime::now(), + } + } + + pub async fn observe(&self, value: f64) { + let mut counts = self.counts.write().await; + let mut sum = self.sum.write().await; + let mut count = self.count.write().await; + + *sum += value; + *count += 1; + + // Find the appropriate bucket + for (i, &bucket) in self.buckets.iter().enumerate() { + if value <= bucket { + counts[i] += 1; + break; + } + } + // Also increment the +Inf bucket + if let Some(last) = counts.last_mut() { + *last += 1; + } + } + + pub async fn snapshot(&self) -> HistogramValue { + HistogramValue { + buckets: self.buckets.clone(), + counts: self.counts.read().await.clone(), + sum: *self.sum.read().await, + count: *self.count.read().await, + labels: self.labels.clone(), + created_at: self.created_at, + } + } +} + +// Exporter implementations +impl PrometheusExporter { + pub fn new(endpoint: String) -> Self { + Self { + endpoint, + push_gateway: None, + job_name: "router".to_string(), + client: reqwest::Client::new(), + } + } + + pub fn with_push_gateway(mut self, gateway: String, job_name: String) -> Self { + self.push_gateway = Some(gateway); + self.job_name = job_name; + self + } +} + +impl MetricsExporter for PrometheusExporter { + fn export(&self, snapshot: &MetricsSnapshot) -> Result<()> { + let mut output = String::new(); + + // Export counters + for (name, counter) in &snapshot.counters { + output.push_str(&format!("# TYPE {} counter\n", name)); + let labels = format_labels(&counter.labels); + output.push_str(&format!("{}{} {}\n", name, labels, counter.value)); + } + + // Export gauges + for (name, gauge) in &snapshot.gauges { + output.push_str(&format!("# TYPE {} gauge\n", name)); + let labels = format_labels(&gauge.labels); + output.push_str(&format!("{}{} {}\n", name, labels, gauge.value)); + } + + // Export histograms + for (name, histogram) in &snapshot.histograms { + output.push_str(&format!("# TYPE {} histogram\n", name)); + let labels = format_labels(&histogram.labels); + + // Bucket counts + for (i, &bucket) in histogram.buckets.iter().enumerate() { + let mut bucket_labels = histogram.labels.clone(); + bucket_labels.insert("le".to_string(), bucket.to_string()); + let bucket_labels_str = format_labels(&bucket_labels); + output.push_str(&format!("{}_bucket{} {}\n", name, bucket_labels_str, histogram.counts[i])); + } + + // +Inf bucket + let mut inf_labels = histogram.labels.clone(); + inf_labels.insert("le".to_string(), "+Inf".to_string()); + let inf_labels_str = format_labels(&inf_labels); + if let Some(&inf_count) = histogram.counts.last() { + output.push_str(&format!("{}_bucket{} {}\n", name, inf_labels_str, inf_count)); + } + + // Sum and count + output.push_str(&format!("{}_sum{} {}\n", name, labels, histogram.sum)); + output.push_str(&format!("{}_count{} {}\n", name, labels, histogram.count)); + } + + debug!("Generated Prometheus metrics: {} bytes", output.len()); + + // If push gateway is configured, push metrics + if let Some(_gateway) = &self.push_gateway { + // Implementation would push to Prometheus Push Gateway + debug!("Would push metrics to push gateway"); + } + + Ok(()) + } + + fn name(&self) -> &str { + "prometheus" + } + + fn shutdown(&self) -> Result<()> { + Ok(()) + } +} + +impl JsonExporter { + pub fn new(output_path: std::path::PathBuf) -> Self { + Self { + output_path, + pretty_print: true, + } + } +} + +impl MetricsExporter for JsonExporter { + fn export(&self, snapshot: &MetricsSnapshot) -> Result<()> { + let json = if self.pretty_print { + serde_json::to_string_pretty(snapshot)? + } else { + serde_json::to_string(snapshot)? + }; + + std::fs::write(&self.output_path, json)?; + debug!("Exported metrics to {}", self.output_path.display()); + + Ok(()) + } + + fn name(&self) -> &str { + "json" + } + + fn shutdown(&self) -> Result<()> { + Ok(()) + } +} + +impl StatsdExporter { + pub fn new(address: String, prefix: String) -> Self { + Self { + address, + prefix, + client: None, + } + } +} + +impl MetricsExporter for StatsdExporter { + fn export(&self, snapshot: &MetricsSnapshot) -> Result<()> { + // Implementation would send metrics to StatsD + debug!("Would export {} metrics to StatsD at {}", + snapshot.counters.len() + snapshot.gauges.len(), + self.address); + Ok(()) + } + + fn name(&self) -> &str { + "statsd" + } + + fn shutdown(&self) -> Result<()> { + Ok(()) + } +} + +// Utility functions +fn format_labels(labels: &HashMap) -> String { + if labels.is_empty() { + return String::new(); + } + + let mut parts = Vec::new(); + for (key, value) in labels { + parts.push(format!("{}=\"{}\"", key, value)); + } + + format!("{{{}}}", parts.join(",")) +} + +// System metrics collection +fn get_uptime_seconds() -> f64 { + // Implementation would get actual system uptime + 0.0 +} + +fn get_memory_usage_bytes() -> f64 { + #[cfg(target_os = "linux")] + { + if let Ok(process) = procfs::process::Process::myself() { + if let Ok(stat) = process.stat() { + return (stat.rss * 4096) as f64; // Convert pages to bytes + } + } + } + 0.0 +} + +fn get_cpu_usage_percent() -> f64 { + // Implementation would calculate actual CPU usage + 0.0 +} + +fn get_open_file_descriptors() -> f64 { + #[cfg(target_os = "linux")] + { + if let Ok(process) = procfs::process::Process::myself() { + if let Ok(fd_count) = process.fd_count() { + return fd_count as f64; + } + } + } + 0.0 +} + +fn get_network_connections() -> u64 { + // Implementation would count network connections + 0 +} \ No newline at end of file diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/plugin/loader.rs b/src/plugin/loader.rs new file mode 100644 index 0000000..13c7df8 --- /dev/null +++ b/src/plugin/loader.rs @@ -0,0 +1,507 @@ +// src/plugin/loader.rs - Plugin dynamic loader using libloading +use super::*; +use anyhow::{Context, Result}; +use libloading::{Library, Symbol}; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Plugin loader configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoaderConfig { + pub plugin_directories: Vec, + pub auto_reload: bool, + pub reload_interval_seconds: u64, + pub max_plugin_memory_mb: usize, + pub plugin_timeout_seconds: u64, + pub allowed_plugins: Option>, // If None, all plugins allowed + pub blocked_plugins: Vec, + pub require_signature: bool, + pub signature_key_path: Option, +} + +impl Default for LoaderConfig { + fn default() -> Self { + Self { + plugin_directories: vec!["./plugins".to_string()], + auto_reload: false, + reload_interval_seconds: 60, + max_plugin_memory_mb: 100, + plugin_timeout_seconds: 30, + allowed_plugins: None, + blocked_plugins: Vec::new(), + require_signature: false, + signature_key_path: None, + } + } +} + +/// Plugin load information +#[derive(Debug, Clone)] +pub struct LoadedPluginInfo { + pub info: PluginInfo, + pub path: PathBuf, + pub loaded_at: std::time::SystemTime, + pub file_size: u64, + pub file_hash: String, + pub library_handle: String, // Internal reference +} + +/// Plugin loader for dynamic loading from shared libraries +pub struct PluginLoader { + config: LoaderConfig, + loaded_libraries: Arc>>, + loaded_plugins: Arc>>, + file_watchers: Arc>>>, +} + +impl PluginLoader { + pub fn new(config: LoaderConfig) -> Self { + Self { + config, + loaded_libraries: Arc::new(RwLock::new(HashMap::new())), + loaded_plugins: Arc::new(RwLock::new(HashMap::new())), + file_watchers: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Load a plugin from a shared library file + pub async fn load_plugin>(&self, path: P) -> Result<(String, Box)> { + let path = path.as_ref(); + + // Validate plugin file + self.validate_plugin_file(path).await?; + + // Load the library + let lib = unsafe { + Library::new(path) + .with_context(|| format!("Failed to load library: {}", path.display()))? + }; + + // Get plugin info first + let get_info: Symbol = unsafe { + lib.get(b"get_plugin_info") + .with_context(|| "Plugin must export 'get_plugin_info' function")? + }; + + let plugin_info = unsafe { get_info() }; + + // Validate plugin info + self.validate_plugin_info(&plugin_info)?; + + // Check if plugin is allowed + if !self.is_plugin_allowed(&plugin_info.name) { + return Err(anyhow::anyhow!("Plugin '{}' is not allowed", plugin_info.name)); + } + + // Get the plugin factory function + let create_plugin: Symbol = unsafe { + lib.get(b"create_plugin") + .with_context(|| "Plugin must export 'create_plugin' function")? + }; + + // Create the plugin instance + let plugin_ptr = unsafe { create_plugin() }; + let plugin = unsafe { Box::from_raw(plugin_ptr) }; + + // Calculate file hash for change detection + let file_hash = self.calculate_file_hash(path).await?; + + // Store library to prevent unloading + let library_handle = format!("{}_{}", plugin_info.name, plugin_info.version); + { + let mut libraries = self.loaded_libraries.write().await; + libraries.insert(library_handle.clone(), lib); + } + + // Store plugin info + let loaded_info = LoadedPluginInfo { + info: plugin_info.clone(), + path: path.to_path_buf(), + loaded_at: std::time::SystemTime::now(), + file_size: std::fs::metadata(path)?.len(), + file_hash, + library_handle, + }; + + { + let mut loaded_plugins = self.loaded_plugins.write().await; + loaded_plugins.insert(plugin_info.name.clone(), loaded_info); + } + + tracing::info!( + "Loaded plugin '{}' v{} from {}", + plugin_info.name, + plugin_info.version, + path.display() + ); + + Ok((plugin_info.name, plugin)) + } + + /// Unload a plugin + pub async fn unload_plugin(&self, plugin_name: &str) -> Result<()> { + let loaded_info = { + let mut loaded_plugins = self.loaded_plugins.write().await; + loaded_plugins.remove(plugin_name) + .ok_or_else(|| anyhow::anyhow!("Plugin '{}' is not loaded", plugin_name))? + }; + + // Remove library + { + let mut libraries = self.loaded_libraries.write().await; + libraries.remove(&loaded_info.library_handle); + } + + // Stop file watcher if exists + { + let mut watchers = self.file_watchers.write().await; + if let Some(handle) = watchers.remove(&loaded_info.path) { + handle.abort(); + } + } + + tracing::info!("Unloaded plugin '{}'", plugin_name); + Ok(()) + } + + /// Load all plugins from configured directories + pub async fn load_all_plugins(&self) -> Result>> { + let mut loaded_plugins = HashMap::new(); + + for directory in &self.config.plugin_directories { + let dir_path = Path::new(directory); + + if !dir_path.exists() { + tracing::warn!("Plugin directory does not exist: {}", directory); + continue; + } + + let mut entries = tokio::fs::read_dir(dir_path).await + .with_context(|| format!("Failed to read plugin directory: {}", directory))?; + + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + + // Check if it's a shared library + if self.is_plugin_file(&path) { + match self.load_plugin(&path).await { + Ok((name, plugin)) => { + loaded_plugins.insert(name, plugin); + } + Err(e) => { + tracing::error!("Failed to load plugin from {}: {}", path.display(), e); + } + } + } + } + } + + if self.config.auto_reload { + self.start_auto_reload().await?; + } + + tracing::info!("Loaded {} plugins", loaded_plugins.len()); + Ok(loaded_plugins) + } + + /// Start auto-reload watcher for plugin files + pub async fn start_auto_reload(&self) -> Result<()> { + if !self.config.auto_reload { + return Ok(()); + } + + let loaded_plugins = self.loaded_plugins.read().await.clone(); + let reload_interval = self.config.reload_interval_seconds; + + for (_name, loaded_info) in loaded_plugins { + let path = loaded_info.path.clone(); + let current_hash = loaded_info.file_hash.clone(); + let loader = self.clone_for_reload(); + + let handle = tokio::spawn(async move { + let mut interval = tokio::time::interval( + std::time::Duration::from_secs(reload_interval) + ); + + let mut last_hash = current_hash; + + loop { + interval.tick().await; + + // Check if file has changed + match loader.calculate_file_hash(&path).await { + Ok(new_hash) => { + if new_hash != last_hash { + tracing::info!("Plugin file changed, reloading: {}", path.display()); + + // Reload plugin (this would need to be implemented) + // For now, just log the change + last_hash = new_hash; + } + } + Err(e) => { + tracing::error!("Error checking plugin file {}: {}", path.display(), e); + break; + } + } + } + }); + + self.file_watchers.write().await.insert(path, handle); + } + + tracing::info!("Started auto-reload for {} plugins", loaded_plugins.len()); + Ok(()) + } + + /// Stop auto-reload watchers + pub async fn stop_auto_reload(&self) { + let mut watchers = self.file_watchers.write().await; + for (_path, handle) in watchers.drain() { + handle.abort(); + } + tracing::info!("Stopped auto-reload watchers"); + } + + /// Get information about loaded plugins + pub async fn get_loaded_plugins(&self) -> HashMap { + self.loaded_plugins.read().await.clone() + } + + /// Check if a plugin file is valid + async fn validate_plugin_file>(&self, path: P) -> Result<()> { + let path = path.as_ref(); + + // Check file exists + if !path.exists() { + return Err(anyhow::anyhow!("Plugin file does not exist: {}", path.display())); + } + + // Check file extension + if !self.is_plugin_file(path) { + return Err(anyhow::anyhow!("Invalid plugin file extension: {}", path.display())); + } + + // Check file size + let metadata = std::fs::metadata(path) + .with_context(|| format!("Failed to get metadata for: {}", path.display()))?; + + let max_size = (self.config.max_plugin_memory_mb * 1024 * 1024) as u64; + if metadata.len() > max_size { + return Err(anyhow::anyhow!( + "Plugin file too large: {} bytes (max: {} bytes)", + metadata.len(), + max_size + )); + } + + // Check file permissions + if metadata.permissions().readonly() { + tracing::warn!("Plugin file is read-only: {}", path.display()); + } + + // TODO: Check digital signature if required + if self.config.require_signature { + self.verify_plugin_signature(path).await?; + } + + Ok(()) + } + + /// Validate plugin information + fn validate_plugin_info(&self, info: &PluginInfo) -> Result<()> { + if info.name.is_empty() { + return Err(anyhow::anyhow!("Plugin name cannot be empty")); + } + + if info.version.is_empty() { + return Err(anyhow::anyhow!("Plugin version cannot be empty")); + } + + // Validate version format (basic semver check) + if !info.version.chars().any(|c| c.is_ascii_digit()) { + return Err(anyhow::anyhow!("Invalid version format: {}", info.version)); + } + + // Check minimum router version compatibility + // TODO: Implement version comparison + + Ok(()) + } + + /// Check if plugin is allowed to load + fn is_plugin_allowed(&self, plugin_name: &str) -> bool { + // Check blocked list first + if self.config.blocked_plugins.contains(&plugin_name.to_string()) { + return false; + } + + // Check allowed list if specified + if let Some(ref allowed) = self.config.allowed_plugins { + allowed.contains(&plugin_name.to_string()) + } else { + true // All plugins allowed if no whitelist + } + } + + /// Check if file is a plugin file based on extension + fn is_plugin_file>(&self, path: P) -> bool { + path.as_ref() + .extension() + .and_then(|ext| ext.to_str()) + .map(|ext| ext == "so" || ext == "dylib" || ext == "dll") + .unwrap_or(false) + } + + /// Calculate file hash for change detection + async fn calculate_file_hash>(&self, path: P) -> Result { + use sha2::{Sha256, Digest}; + + let content = tokio::fs::read(path.as_ref()).await + .with_context(|| format!("Failed to read file: {}", path.as_ref().display()))?; + + let mut hasher = Sha256::new(); + hasher.update(&content); + let hash = hasher.finalize(); + + Ok(format!("{:x}", hash)) + } + + /// Verify plugin digital signature + async fn verify_plugin_signature>(&self, _path: P) -> Result<()> { + // TODO: Implement digital signature verification + // This would typically involve: + // 1. Reading the signature from the plugin file or separate .sig file + // 2. Verifying the signature against the plugin content + // 3. Checking against a trusted public key + + tracing::warn!("Plugin signature verification not implemented"); + Ok(()) + } + + /// Create a clone for reload operations (simplified) + fn clone_for_reload(&self) -> PluginLoader { + PluginLoader { + config: self.config.clone(), + loaded_libraries: self.loaded_libraries.clone(), + loaded_plugins: self.loaded_plugins.clone(), + file_watchers: self.file_watchers.clone(), + } + } + + /// Get statistics about loaded plugins + pub async fn get_statistics(&self) -> PluginLoaderStatistics { + let loaded_plugins = self.loaded_plugins.read().await; + let libraries = self.loaded_libraries.read().await; + let watchers = self.file_watchers.read().await; + + let total_file_size: u64 = loaded_plugins.values() + .map(|info| info.file_size) + .sum(); + + PluginLoaderStatistics { + total_plugins: loaded_plugins.len(), + total_libraries: libraries.len(), + active_watchers: watchers.len(), + total_file_size_bytes: total_file_size, + directories_watched: self.config.plugin_directories.len(), + auto_reload_enabled: self.config.auto_reload, + } + } + + /// Cleanup and shutdown + pub async fn shutdown(&self) -> Result<()> { + tracing::info!("Shutting down plugin loader..."); + + // Stop auto-reload + self.stop_auto_reload().await; + + // Unload all libraries + { + let mut libraries = self.loaded_libraries.write().await; + libraries.clear(); + } + + // Clear loaded plugins info + { + let mut loaded_plugins = self.loaded_plugins.write().await; + loaded_plugins.clear(); + } + + tracing::info!("Plugin loader shut down"); + Ok(()) + } +} + +/// Statistics about the plugin loader +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginLoaderStatistics { + pub total_plugins: usize, + pub total_libraries: usize, + pub active_watchers: usize, + pub total_file_size_bytes: u64, + pub directories_watched: usize, + pub auto_reload_enabled: bool, +} + +/// Plugin scanner for discovering plugins +pub struct PluginScanner { + directories: Vec, +} + +impl PluginScanner { + pub fn new(directories: Vec) -> Self { + Self { + directories: directories.into_iter().map(PathBuf::from).collect(), + } + } + + /// Scan for plugin files + pub async fn scan(&self) -> Result> { + let mut plugin_files = Vec::new(); + + for directory in &self.directories { + if !directory.exists() { + continue; + } + + let mut entries = tokio::fs::read_dir(directory).await?; + + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + + if self.is_plugin_file(&path) { + plugin_files.push(path); + } + } + } + + Ok(plugin_files) + } + + /// Get plugin metadata without loading + pub async fn get_plugin_metadata>(&self, path: P) -> Result { + let path = path.as_ref(); + + // This is a simplified version - in reality you'd need to load the library + // temporarily just to get the metadata + let lib = unsafe { Library::new(path)? }; + + let get_info: Symbol = unsafe { + lib.get(b"get_plugin_info")? + }; + + let info = unsafe { get_info() }; + Ok(info) + } + + fn is_plugin_file>(&self, path: P) -> bool { + path.as_ref() + .extension() + .and_then(|ext| ext.to_str()) + .map(|ext| ext == "so" || ext == "dylib" || ext == "dll") + .unwrap_or(false) + } +} \ No newline at end of file diff --git a/src/plugin/manager.rs b/src/plugin/manager.rs new file mode 100644 index 0000000..056a35a --- /dev/null +++ b/src/plugin/manager.rs @@ -0,0 +1,563 @@ +// src/plugin/manager.rs - Plugin manager coordinating loader and registry +use super::*; +use super::loader::{PluginLoader, LoaderConfig}; +use super::registry::PluginRegistry; +use anyhow::{Context, Result}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; + +/// Plugin manager configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginManagerConfig { + pub loader: LoaderConfig, + pub enable_inter_plugin_communication: bool, + pub max_concurrent_loads: usize, + pub plugin_startup_timeout_seconds: u64, + pub health_check_interval_seconds: u64, + pub metrics_collection_interval_seconds: u64, + pub enable_plugin_isolation: bool, + pub default_plugin_config: serde_json::Value, +} + +impl Default for PluginManagerConfig { + fn default() -> Self { + Self { + loader: LoaderConfig::default(), + enable_inter_plugin_communication: true, + max_concurrent_loads: 10, + plugin_startup_timeout_seconds: 30, + health_check_interval_seconds: 30, + metrics_collection_interval_seconds: 60, + enable_plugin_isolation: true, + default_plugin_config: serde_json::json!({}), + } + } +} + +/// Main plugin manager that coordinates everything +pub struct PluginManager { + config: PluginManagerConfig, + loader: Arc, + registry: Arc, + router_event_sender: mpsc::UnboundedSender, + router_event_receiver: Arc>>>, + health_monitor: Arc, + metrics_collector: Arc, + event_bus: Arc, + running: Arc>, +} + +impl PluginManager { + pub fn new(config: PluginManagerConfig) -> Self { + let (router_tx, router_rx) = mpsc::unbounded_channel(); + + let loader = Arc::new(PluginLoader::new(config.loader.clone())); + let registry = Arc::new(PluginRegistry::new(router_tx.clone())); + let health_monitor = Arc::new(HealthMonitor::new( + config.health_check_interval_seconds + )); + let metrics_collector = Arc::new(MetricsCollector::new( + config.metrics_collection_interval_seconds + )); + let event_bus = Arc::new(EventBus::new()); + + Self { + config, + loader, + registry, + router_event_sender: router_tx, + router_event_receiver: Arc::new(RwLock::new(Some(router_rx))), + health_monitor, + metrics_collector, + event_bus, + running: Arc::new(RwLock::new(false)), + } + } + + /// Start the plugin manager + pub async fn start(&self) -> Result<()> { + *self.running.write().await = true; + + tracing::info!("Starting plugin manager..."); + + // Start event processing + self.start_event_processor().await?; + + // Start health monitoring + self.health_monitor.start(self.registry.clone()).await; + + // Start metrics collection + self.metrics_collector.start(self.registry.clone()).await; + + // Load all plugins from directories + self.load_all_plugins().await?; + + tracing::info!("Plugin manager started successfully"); + Ok(()) + } + + /// Stop the plugin manager + pub async fn stop(&self) -> Result<()> { + *self.running.write().await = false; + + tracing::info!("Stopping plugin manager..."); + + // Stop all plugins + self.registry.shutdown_all().await?; + + // Stop health monitor + self.health_monitor.stop().await; + + // Stop metrics collector + self.metrics_collector.stop().await; + + // Shutdown loader + self.loader.shutdown().await?; + + tracing::info!("Plugin manager stopped"); + Ok(()) + } + + /// Load and register all plugins from configured directories + pub async fn load_all_plugins(&self) -> Result<()> { + tracing::info!("Loading all plugins..."); + + let loaded_plugins = self.loader.load_all_plugins().await?; + + // Register each loaded plugin + for (name, plugin) in loaded_plugins { + let config = self.create_default_plugin_config(&name).await; + + match self.registry.register_plugin(name.clone(), plugin, config).await { + Ok(_) => { + // Start the plugin + if let Err(e) = self.registry.start_plugin(&name).await { + tracing::error!("Failed to start plugin '{}': {}", name, e); + } + } + Err(e) => { + tracing::error!("Failed to register plugin '{}': {}", name, e); + } + } + } + + tracing::info!("Finished loading plugins"); + Ok(()) + } + + /// Load a specific plugin by path + pub async fn load_plugin_from_path>( + &self, + path: P, + config: Option, + ) -> Result { + let (name, plugin) = self.loader.load_plugin(path).await?; + + let plugin_config = config.unwrap_or_else(|| { + futures::executor::block_on(self.create_default_plugin_config(&name)) + }); + + self.registry.register_plugin(name.clone(), plugin, plugin_config).await?; + self.registry.start_plugin(&name).await?; + + tracing::info!("Successfully loaded and started plugin: {}", name); + Ok(name) + } + + /// Unload a plugin + pub async fn unload_plugin(&self, name: &str) -> Result<()> { + // Stop and unregister from registry + self.registry.stop_plugin(name).await?; + self.registry.unregister_plugin(name).await?; + + // Unload from loader + self.loader.unload_plugin(name).await?; + + tracing::info!("Successfully unloaded plugin: {}", name); + Ok(()) + } + + /// Reload a plugin + pub async fn reload_plugin(&self, name: &str) -> Result<()> { + tracing::info!("Reloading plugin: {}", name); + + // Get current config before unloading + let current_config = self.registry.get_plugin_config(name).await; + + // Unload the plugin + self.unload_plugin(name).await?; + + // Find the plugin file and reload it + let loaded_plugins = self.loader.get_loaded_plugins().await; + if let Some(loaded_info) = loaded_plugins.get(name) { + let path = loaded_info.path.clone(); + self.load_plugin_from_path(path, current_config).await?; + } else { + return Err(anyhow::anyhow!("Cannot find plugin file for '{}'", name)); + } + + tracing::info!("Successfully reloaded plugin: {}", name); + Ok(()) + } + + /// Get plugin registry + pub fn registry(&self) -> Arc { + self.registry.clone() + } + + /// Get plugin loader + pub fn loader(&self) -> Arc { + self.loader.clone() + } + + /// Route a request to appropriate plugin(s) + pub async fn route_request( + &self, + domain: &str, + path: &str, + metadata: HashMap, + ) -> Result> { + let plugins = self.registry.find_plugins_for_domain(domain, path).await; + + if plugins.is_empty() { + return Err(anyhow::anyhow!("No plugin found for domain: {}", domain)); + } + + Ok(plugins) + } + + /// Send event to plugin + pub async fn send_event_to_plugin( + &self, + plugin_name: &str, + event: PluginEvent, + ) -> Result<()> { + self.registry.send_event_to_plugin(plugin_name, event).await + } + + /// Broadcast event to all plugins + pub async fn broadcast_event(&self, event: PluginEvent) { + self.registry.broadcast_event(event).await; + } + + /// Get plugin health status + pub async fn get_plugin_health(&self, name: &str) -> Option { + self.registry.get_plugin_health(name).await + } + + /// Get plugin metrics + pub async fn get_plugin_metrics(&self, name: &str) -> Option { + self.registry.get_plugin_metrics(name).await + } + + /// Get all plugin metrics + pub async fn get_all_metrics(&self) -> HashMap { + self.registry.get_aggregated_metrics().await + } + + /// List all loaded plugins + pub fn list_plugins(&self) -> Vec { + self.registry.list_plugins() + } + + /// Get plugin information + pub async fn get_plugin_info(&self, name: &str) -> Option { + self.registry.get_plugin_info(name).await + } + + /// Update plugin configuration + pub async fn update_plugin_config( + &self, + name: &str, + config: PluginConfig, + ) -> Result<()> { + self.registry.update_plugin_config(name, config).await + } + + /// Execute plugin command + pub async fn execute_plugin_command( + &self, + plugin_name: &str, + command: &str, + args: serde_json::Value, + ) -> Result { + if let Some(wrapper) = self.registry.get_plugin(plugin_name) { + let plugin = wrapper.read().await; + plugin.plugin.handle_command(command, args).await + } else { + Err(anyhow::anyhow!("Plugin '{}' not found", plugin_name)) + } + } + + /// Start event processor + async fn start_event_processor(&self) -> Result<()> { + let mut receiver = { + let mut rx_guard = self.router_event_receiver.write().await; + rx_guard.take().ok_or_else(|| anyhow::anyhow!("Event processor already started"))? + }; + + let running = self.running.clone(); + let event_bus = self.event_bus.clone(); + + tokio::spawn(async move { + while *running.read().await { + tokio::select! { + Some(event) = receiver.recv() => { + if let Err(e) = Self::handle_router_event(event, event_bus.clone()).await { + tracing::error!("Error handling router event: {}", e); + } + } + _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { + // Periodic check + } + } + } + }); + + Ok(()) + } + + /// Handle router events + async fn handle_router_event( + event: RouterEvent, + event_bus: Arc, + ) -> Result<()> { + match event { + RouterEvent::PluginStarted(name) => { + tracing::info!("Plugin started: {}", name); + event_bus.publish("plugin.started", serde_json::json!({ "name": name })).await; + } + RouterEvent::PluginStopped(name) => { + tracing::info!("Plugin stopped: {}", name); + event_bus.publish("plugin.stopped", serde_json::json!({ "name": name })).await; + } + RouterEvent::PluginError { plugin_name, error } => { + tracing::error!("Plugin '{}' error: {}", plugin_name, error); + event_bus.publish("plugin.error", serde_json::json!({ + "plugin": plugin_name, + "error": error + })).await; + } + RouterEvent::MetricsUpdated { plugin_name, metrics } => { + event_bus.publish("plugin.metrics", serde_json::json!({ + "plugin": plugin_name, + "metrics": metrics + })).await; + } + RouterEvent::HealthUpdated { plugin_name, health } => { + event_bus.publish("plugin.health", serde_json::json!({ + "plugin": plugin_name, + "health": health + })).await; + } + RouterEvent::CustomEvent { plugin_name, event } => { + event_bus.publish(&format!("plugin.{}.custom", plugin_name), event).await; + } + RouterEvent::LogMessage { plugin_name, level, message } => { + match level.as_str() { + "error" => tracing::error!("[{}] {}", plugin_name, message), + "warn" => tracing::warn!("[{}] {}", plugin_name, message), + "info" => tracing::info!("[{}] {}", plugin_name, message), + "debug" => tracing::debug!("[{}] {}", plugin_name, message), + _ => tracing::info!("[{}] {}", plugin_name, message), + } + } + } + + Ok(()) + } + + /// Create default plugin configuration + async fn create_default_plugin_config(&self, plugin_name: &str) -> PluginConfig { + PluginConfig { + name: plugin_name.to_string(), + enabled: true, + priority: 0, + bind_ports: Vec::new(), + bind_addresses: vec!["0.0.0.0".to_string()], + domains: Vec::new(), + config: self.config.default_plugin_config.clone(), + middleware_chain: Vec::new(), + load_balancing: None, + health_check: Some(HealthCheckConfig { + enabled: true, + interval_seconds: self.config.health_check_interval_seconds, + timeout_seconds: 10, + custom_check: None, + }), + } + } + + /// Get manager statistics + pub async fn get_statistics(&self) -> PluginManagerStatistics { + let loader_stats = self.loader.get_statistics().await; + let plugins = self.registry.list_plugins(); + let metrics = self.registry.get_aggregated_metrics().await; + + PluginManagerStatistics { + total_plugins: plugins.len(), + running_plugins: plugins.len(), // Simplified + loader_statistics: loader_stats, + total_connections: metrics.values().map(|m| m.connections_active).sum(), + total_requests: metrics.values().map(|m| m.connections_total).sum(), + total_errors: metrics.values().map(|m| m.errors_total).sum(), + uptime_seconds: 0, // Would need to track start time + } + } +} + +/// Health monitor for plugins +pub struct HealthMonitor { + interval_seconds: u64, + running: Arc>, + handle: Arc>>>, +} + +impl HealthMonitor { + pub fn new(interval_seconds: u64) -> Self { + Self { + interval_seconds, + running: Arc::new(RwLock::new(false)), + handle: Arc::new(RwLock::new(None)), + } + } + + pub async fn start(&self, registry: Arc) { + *self.running.write().await = true; + + let running = self.running.clone(); + let interval = self.interval_seconds; + + let handle = tokio::spawn(async move { + let mut interval_timer = tokio::time::interval( + std::time::Duration::from_secs(interval) + ); + + while *running.read().await { + interval_timer.tick().await; + + let plugins = registry.list_plugins(); + for plugin_name in plugins { + if let Some(health) = registry.get_plugin_health(&plugin_name).await { + if !health.healthy { + tracing::warn!("Plugin '{}' is unhealthy: {}", plugin_name, health.message); + } + } + } + } + }); + + *self.handle.write().await = Some(handle); + } + + pub async fn stop(&self) { + *self.running.write().await = false; + + if let Some(handle) = self.handle.write().await.take() { + handle.abort(); + } + } +} + +/// Metrics collector for plugins +pub struct MetricsCollector { + interval_seconds: u64, + running: Arc>, + handle: Arc>>>, +} + +impl MetricsCollector { + pub fn new(interval_seconds: u64) -> Self { + Self { + interval_seconds, + running: Arc::new(RwLock::new(false)), + handle: Arc::new(RwLock::new(None)), + } + } + + pub async fn start(&self, registry: Arc) { + *self.running.write().await = true; + + let running = self.running.clone(); + let interval = self.interval_seconds; + + let handle = tokio::spawn(async move { + let mut interval_timer = tokio::time::interval( + std::time::Duration::from_secs(interval) + ); + + while *running.read().await { + interval_timer.tick().await; + + let _metrics = registry.get_aggregated_metrics().await; + // Could store metrics in a time series database here + } + }); + + *self.handle.write().await = Some(handle); + } + + pub async fn stop(&self) { + *self.running.write().await = false; + + if let Some(handle) = self.handle.write().await.take() { + handle.abort(); + } + } +} + +/// Event bus for inter-plugin communication +pub struct EventBus { + subscribers: Arc>>>>, +} + +impl EventBus { + pub fn new() -> Self { + Self { + subscribers: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn subscribe(&self, event_type: &str) -> mpsc::UnboundedReceiver { + let (tx, rx) = mpsc::unbounded_channel(); + + let mut subscribers = self.subscribers.write().await; + subscribers + .entry(event_type.to_string()) + .or_insert_with(Vec::new) + .push(tx); + + rx + } + + pub async fn publish(&self, event_type: &str, data: serde_json::Value) { + let subscribers = self.subscribers.read().await; + + if let Some(senders) = subscribers.get(event_type) { + for sender in senders { + let _ = sender.send(data.clone()); + } + } + } +} + +/// Plugin manager statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginManagerStatistics { + pub total_plugins: usize, + pub running_plugins: usize, + pub loader_statistics: super::loader::PluginLoaderStatistics, + pub total_connections: u64, + pub total_requests: u64, + pub total_errors: u64, + pub uptime_seconds: u64, +} + +// Extension to registry for additional methods +impl PluginRegistry { + pub async fn get_plugin_config(&self, name: &str) -> Option { + let configs = self.plugin_configs.read().await; + configs.get(name).cloned() + } +} \ No newline at end of file diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs new file mode 100644 index 0000000..c86865b --- /dev/null +++ b/src/plugin/mod.rs @@ -0,0 +1,405 @@ +// src/plugin/mod.rs - Generic plugin interface +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt::Debug; +use std::net::SocketAddr; +use std::sync::Arc; +use anyhow::Result; +use tokio::net::{TcpListener, UdpSocket}; +use tokio::sync::mpsc; + +pub mod registry; +pub mod loader; +pub mod manager; + +/// Generic connection context - plugins define what this means +#[derive(Debug, Clone)] +pub struct ConnectionContext { + pub connection_id: String, + pub client_addr: SocketAddr, + pub server_addr: SocketAddr, + pub metadata: HashMap, + pub created_at: std::time::SystemTime, +} + +/// Generic data packet - completely opaque to the router +#[derive(Debug, Clone)] +pub struct DataPacket { + pub data: Vec, + pub metadata: HashMap, + pub context: ConnectionContext, +} + +/// Plugin capabilities - what the plugin can handle +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginCapabilities { + pub can_handle_tcp: bool, + pub can_handle_udp: bool, + pub can_handle_unix_socket: bool, + pub requires_dedicated_port: bool, + pub supports_hot_reload: bool, + pub supports_load_balancing: bool, + pub custom_protocols: Vec, + pub port_requirements: Vec, // Specific ports this plugin needs +} + +/// Plugin configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginConfig { + pub name: String, + pub enabled: bool, + pub priority: i32, + pub bind_ports: Vec, + pub bind_addresses: Vec, + pub domains: Vec, // Domains this plugin should handle + pub config: serde_json::Value, // Plugin-specific configuration + pub middleware_chain: Vec, + pub load_balancing: Option, + pub health_check: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadBalancingConfig { + pub strategy: String, // Plugin defines what strategies it supports + pub backends: Vec, + pub session_affinity: bool, + pub health_check_path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BackendConfig { + pub address: String, + pub weight: Option, + pub metadata: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HealthCheckConfig { + pub enabled: bool, + pub interval_seconds: u64, + pub timeout_seconds: u64, + pub custom_check: Option, +} + +/// Plugin metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginMetrics { + pub connections_total: u64, + pub connections_active: u64, + pub bytes_sent: u64, + pub bytes_received: u64, + pub errors_total: u64, + pub custom_metrics: HashMap, + pub last_updated: std::time::SystemTime, +} + +/// Plugin health status +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginHealth { + pub healthy: bool, + pub message: String, + pub last_check: std::time::SystemTime, + pub response_time_ms: Option, + pub custom_health_data: HashMap, +} + +/// Events that plugins can emit or receive +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PluginEvent { + ConnectionEstablished(ConnectionContext), + ConnectionClosed(String), // connection_id + DataReceived(DataPacket), + ConfigurationChanged(serde_json::Value), + HealthCheckRequested, + MetricsRequested, + ShutdownRequested, + Custom { + event_type: String, + data: serde_json::Value, + }, +} + +/// Communication channel for plugins +pub type PluginSender = mpsc::UnboundedSender; +pub type PluginReceiver = mpsc::UnboundedReceiver; + +/// Main plugin trait - completely generic +#[async_trait] +pub trait Plugin: Send + Sync + Debug { + /// Plugin identification + fn info(&self) -> &PluginInfo; + + /// What this plugin can do + fn capabilities(&self) -> &PluginCapabilities; + + /// Initialize the plugin + async fn initialize( + &mut self, + config: PluginConfig, + event_sender: PluginSender, + ) -> Result<()>; + + /// Start the plugin - it manages its own listeners/connections + async fn start(&mut self) -> Result<()>; + + /// Stop the plugin + async fn stop(&mut self) -> Result<()>; + + /// Handle events from the router or other plugins + async fn handle_event(&mut self, event: PluginEvent) -> Result<()>; + + /// Get current health status + async fn health(&self) -> Result; + + /// Get current metrics + async fn metrics(&self) -> Result; + + /// Update configuration at runtime + async fn update_config(&mut self, config: PluginConfig) -> Result<()>; + + /// Custom command handler for management + async fn handle_command( + &self, + command: &str, + args: serde_json::Value, + ) -> Result; + + /// Graceful shutdown + async fn shutdown(&mut self) -> Result<()>; +} + +/// Plugin metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginInfo { + pub name: String, + pub version: String, + pub description: String, + pub author: String, + pub license: String, + pub repository: Option, + pub min_router_version: String, + pub config_schema: Option, + pub dependencies: Vec, + pub tags: Vec, +} + +/// Plugin factory function signature +pub type PluginFactory = unsafe extern "C" fn() -> *mut dyn Plugin; + +/// Plugin registration function signature +pub type PluginRegisterFn = unsafe extern "C" fn() -> PluginInfo; + +/// Domain matcher trait - plugins can implement custom domain matching +pub trait DomainMatcher: Send + Sync + Debug { + fn matches(&self, domain: &str, path: &str) -> bool; + fn priority(&self) -> i32; +} + +/// Simple domain matcher implementation +#[derive(Debug, Clone)] +pub struct SimpleDomainMatcher { + pub patterns: Vec, + pub priority: i32, +} + +impl DomainMatcher for SimpleDomainMatcher { + fn matches(&self, domain: &str, _path: &str) -> bool { + self.patterns.iter().any(|pattern| { + if pattern.starts_with("*.") { + let suffix = &pattern[2..]; + domain.ends_with(suffix) + } else { + pattern == domain + } + }) + } + + fn priority(&self) -> i32 { + self.priority + } +} + +/// Plugin context - shared state between router and plugin +#[derive(Debug)] +pub struct PluginContext { + pub plugin_name: String, + pub config: PluginConfig, + pub event_sender: PluginSender, + pub router_sender: mpsc::UnboundedSender, + pub metrics: Arc>, + pub health: Arc>, +} + +/// Events that plugins can send to the router +#[derive(Debug, Clone)] +pub enum RouterEvent { + PluginStarted(String), + PluginStopped(String), + PluginError { plugin_name: String, error: String }, + MetricsUpdated { plugin_name: String, metrics: PluginMetrics }, + HealthUpdated { plugin_name: String, health: PluginHealth }, + CustomEvent { plugin_name: String, event: serde_json::Value }, + LogMessage { plugin_name: String, level: String, message: String }, +} + +/// Plugin wrapper that handles the lifecycle +pub struct PluginWrapper { + plugin: Box, + context: PluginContext, + event_receiver: PluginReceiver, + running: Arc>, +} + +impl PluginWrapper { + pub fn new(plugin: Box, context: PluginContext) -> (Self, PluginReceiver) { + let (tx, rx) = mpsc::unbounded_channel(); + let context = PluginContext { + event_sender: tx, + ..context + }; + + let (_event_tx, event_rx) = mpsc::unbounded_channel(); + + ( + Self { + plugin, + context, + event_receiver: event_rx, + running: Arc::new(tokio::sync::RwLock::new(false)), + }, + rx, + ) + } + + pub async fn start(&mut self) -> Result<()> { + *self.running.write().await = true; + + // Initialize the plugin + self.plugin.initialize( + self.context.config.clone(), + self.context.event_sender.clone(), + ).await?; + + // Start the plugin + self.plugin.start().await?; + + // Notify router + let _ = self.context.router_sender.send(RouterEvent::PluginStarted( + self.context.plugin_name.clone() + )); + + Ok(()) + } + + pub async fn stop(&mut self) -> Result<()> { + *self.running.write().await = false; + + self.plugin.stop().await?; + + let _ = self.context.router_sender.send(RouterEvent::PluginStopped( + self.context.plugin_name.clone() + )); + + Ok(()) + } + + pub async fn run_event_loop(&mut self) -> Result<()> { + while *self.running.read().await { + tokio::select! { + Some(event) = self.event_receiver.recv() => { + if let Err(e) = self.plugin.handle_event(event).await { + let _ = self.context.router_sender.send(RouterEvent::PluginError { + plugin_name: self.context.plugin_name.clone(), + error: e.to_string(), + }); + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + // Periodic health check and metrics update + if let Ok(health) = self.plugin.health().await { + *self.context.health.write().await = health.clone(); + let _ = self.context.router_sender.send(RouterEvent::HealthUpdated { + plugin_name: self.context.plugin_name.clone(), + health, + }); + } + + if let Ok(metrics) = self.plugin.metrics().await { + *self.context.metrics.write().await = metrics.clone(); + let _ = self.context.router_sender.send(RouterEvent::MetricsUpdated { + plugin_name: self.context.plugin_name.clone(), + metrics, + }); + } + } + } + } + + Ok(()) + } + + pub fn info(&self) -> &PluginInfo { + self.plugin.info() + } + + pub fn capabilities(&self) -> &PluginCapabilities { + self.plugin.capabilities() + } +} + +/// Utility functions for plugins +pub mod utils { + use super::*; + + /// Extract domain from various headers + pub fn extract_domain(headers: &HashMap) -> Option { + headers.get("host") + .or_else(|| headers.get("Host")) + .or_else(|| headers.get("SERVER_NAME")) + .map(|host| { + // Remove port if present + host.split(':').next().unwrap_or(host).to_string() + }) + } + + /// Parse query parameters from URL + pub fn parse_query_params(query: &str) -> HashMap { + let mut params = HashMap::new(); + + for pair in query.split('&') { + if let Some(eq_pos) = pair.find('=') { + let key = &pair[..eq_pos]; + let value = &pair[eq_pos + 1..]; + params.insert( + urlencoding::decode(key).unwrap_or_default().into_owned(), + urlencoding::decode(value).unwrap_or_default().into_owned(), + ); + } + } + + params + } + + /// Generate unique connection ID + pub fn generate_connection_id() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + format!("conn-{}-{}", now.as_secs(), now.subsec_nanos()) + } + + /// Create connection context + pub fn create_connection_context( + client_addr: SocketAddr, + server_addr: SocketAddr, + metadata: HashMap, + ) -> ConnectionContext { + ConnectionContext { + connection_id: generate_connection_id(), + client_addr, + server_addr, + metadata, + created_at: SystemTime::now(), + } + } +} \ No newline at end of file diff --git a/src/plugin/registry.rs b/src/plugin/registry.rs new file mode 100644 index 0000000..cdd746c --- /dev/null +++ b/src/plugin/registry.rs @@ -0,0 +1,450 @@ +// src/plugin/registry.rs - Plugin registry for managing loaded plugins +use super::*; +use anyhow::{Context, Result}; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; + +/// Plugin registry manages all loaded plugins +pub struct PluginRegistry { + pub plugins: Arc>>>, + pub plugin_configs: Arc>>, + pub domain_mappings: Arc>>>, // domain -> plugin names + pub port_mappings: Arc>>, // port -> plugin name + pub event_router: Arc, + pub router_event_sender: mpsc::UnboundedSender, + pub metrics_aggregator: Arc, +} + +impl PluginRegistry { + pub fn new(router_event_sender: mpsc::UnboundedSender) -> Self { + Self { + plugins: Arc::new(DashMap::new()), + plugin_configs: Arc::new(RwLock::new(HashMap::new())), + domain_mappings: Arc::new(RwLock::new(HashMap::new())), + port_mappings: Arc::new(RwLock::new(HashMap::new())), + event_router: Arc::new(EventRouter::new()), + router_event_sender, + metrics_aggregator: Arc::new(MetricsAggregator::new()), + } + } + + /// Register a new plugin + pub async fn register_plugin( + &self, + name: String, + plugin: Box, + config: PluginConfig, + ) -> Result<()> { + // Check if plugin already exists + if self.plugins.contains_key(&name) { + return Err(anyhow::anyhow!("Plugin '{}' is already registered", name)); + } + + // Validate configuration + self.validate_plugin_config(&name, &config).await?; + + // Check port conflicts + for port in &config.bind_ports { + if let Some(existing_plugin) = self.port_mappings.read().await.get(port) { + return Err(anyhow::anyhow!( + "Port {} is already in use by plugin '{}'", + port, + existing_plugin + )); + } + } + + // Create plugin context + let context = PluginContext { + plugin_name: name.clone(), + config: config.clone(), + event_sender: mpsc::unbounded_channel().0, // Placeholder, will be replaced + router_sender: self.router_event_sender.clone(), + metrics: Arc::new(RwLock::new(PluginMetrics { + connections_total: 0, + connections_active: 0, + bytes_sent: 0, + bytes_received: 0, + errors_total: 0, + custom_metrics: HashMap::new(), + last_updated: std::time::SystemTime::now(), + })), + health: Arc::new(RwLock::new(PluginHealth { + healthy: false, + message: "Initializing".to_string(), + last_check: std::time::SystemTime::now(), + response_time_ms: None, + custom_health_data: HashMap::new(), + })), + }; + + // Create plugin wrapper + let (wrapper, _event_rx) = PluginWrapper::new(plugin, context); + let wrapper = Arc::new(RwLock::new(wrapper)); + + // Register domain mappings + { + let mut domain_mappings = self.domain_mappings.write().await; + for domain in &config.domains { + domain_mappings + .entry(domain.clone()) + .or_insert_with(Vec::new) + .push(name.clone()); + } + } + + // Register port mappings + { + let mut port_mappings = self.port_mappings.write().await; + for port in &config.bind_ports { + port_mappings.insert(*port, name.clone()); + } + } + + // Store plugin configuration + { + let mut configs = self.plugin_configs.write().await; + configs.insert(name.clone(), config); + } + + // Add to registry + self.plugins.insert(name.clone(), wrapper); + + tracing::info!("Plugin '{}' registered successfully", name); + Ok(()) + } + + /// Unregister a plugin + pub async fn unregister_plugin(&self, name: &str) -> Result<()> { + let wrapper = self.plugins.get(name) + .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", name))? + .clone(); + + // Stop the plugin + { + let mut plugin = wrapper.write().await; + plugin.stop().await?; + } + + // Remove from registries + self.plugins.remove(name); + + // Clean up domain mappings + { + let mut domain_mappings = self.domain_mappings.write().await; + domain_mappings.retain(|_domain, plugins| { + plugins.retain(|plugin_name| plugin_name != name); + !plugins.is_empty() + }); + } + + // Clean up port mappings + { + let mut port_mappings = self.port_mappings.write().await; + port_mappings.retain(|_port, plugin_name| plugin_name != name); + } + + // Remove configuration + { + let mut configs = self.plugin_configs.write().await; + configs.remove(name); + } + + tracing::info!("Plugin '{}' unregistered successfully", name); + Ok(()) + } + + /// Start a plugin + pub async fn start_plugin(&self, name: &str) -> Result<()> { + let wrapper = self.plugins.get(name) + .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", name))? + .clone(); + + let mut plugin = wrapper.write().await; + plugin.start().await?; + + // Start event loop + let wrapper_clone = wrapper.clone(); + tokio::spawn(async move { + let mut plugin = wrapper_clone.write().await; + if let Err(e) = plugin.run_event_loop().await { + tracing::error!("Plugin event loop error: {}", e); + } + }); + + tracing::info!("Plugin '{}' started successfully", name); + Ok(()) + } + + /// Stop a plugin + pub async fn stop_plugin(&self, name: &str) -> Result<()> { + let wrapper = self.plugins.get(name) + .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", name))? + .clone(); + + let mut plugin = wrapper.write().await; + plugin.stop().await?; + + tracing::info!("Plugin '{}' stopped successfully", name); + Ok(()) + } + + /// Get plugin by name + pub fn get_plugin(&self, name: &str) -> Option>> { + self.plugins.get(name).map(|p| p.clone()) + } + + /// Find plugins that can handle a domain + pub async fn find_plugins_for_domain(&self, domain: &str, path: &str) -> Vec { + let domain_mappings = self.domain_mappings.read().await; + let mut matching_plugins = Vec::new(); + + // Exact match first + if let Some(plugins) = domain_mappings.get(domain) { + matching_plugins.extend(plugins.clone()); + } + + // Wildcard matches + for (pattern, plugins) in domain_mappings.iter() { + if pattern.starts_with("*.") { + let suffix = &pattern[2..]; + if domain.ends_with(suffix) { + matching_plugins.extend(plugins.clone()); + } + } + } + + // Sort by priority (would need to get plugin configs for this) + matching_plugins.sort(); + matching_plugins.dedup(); + + matching_plugins + } + + /// Find plugin that handles a specific port + pub async fn find_plugin_for_port(&self, port: u16) -> Option { + self.port_mappings.read().await.get(&port).cloned() + } + + /// List all plugins + pub fn list_plugins(&self) -> Vec { + self.plugins.iter().map(|entry| entry.key().clone()).collect() + } + + /// Get plugin info + pub async fn get_plugin_info(&self, name: &str) -> Option { + if let Some(wrapper) = self.plugins.get(name) { + let plugin = wrapper.read().await; + Some(plugin.info()) + } else { + None + } + } + + /// Get plugin capabilities + pub async fn get_plugin_capabilities(&self, name: &str) -> Option { + if let Some(wrapper) = self.plugins.get(name) { + let plugin = wrapper.read().await; + Some(plugin.capabilities()) + } else { + None + } + } + + /// Get plugin health + pub async fn get_plugin_health(&self, name: &str) -> Option { + if let Some(wrapper) = self.plugins.get(name) { + let plugin = wrapper.read().await; + plugin.plugin.health().await.ok() + } else { + None + } + } + + /// Get plugin metrics + pub async fn get_plugin_metrics(&self, name: &str) -> Option { + if let Some(wrapper) = self.plugins.get(name) { + let plugin = wrapper.read().await; + plugin.plugin.metrics().await.ok() + } else { + None + } + } + + /// Update plugin configuration + pub async fn update_plugin_config(&self, name: &str, config: PluginConfig) -> Result<()> { + let wrapper = self.plugins.get(name) + .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", name))? + .clone(); + + // Validate new configuration + self.validate_plugin_config(name, &config).await?; + + // Update plugin + { + let mut plugin = wrapper.write().await; + plugin.plugin.update_config(config.clone()).await?; + } + + // Update stored configuration + { + let mut configs = self.plugin_configs.write().await; + configs.insert(name.to_string(), config); + } + + tracing::info!("Plugin '{}' configuration updated", name); + Ok(()) + } + + /// Send event to specific plugin + pub async fn send_event_to_plugin(&self, plugin_name: &str, event: PluginEvent) -> Result<()> { + if let Some(wrapper) = self.plugins.get(plugin_name) { + let plugin = wrapper.read().await; + plugin.context.event_sender.send(event) + .map_err(|e| anyhow::anyhow!("Failed to send event: {}", e))?; + Ok(()) + } else { + Err(anyhow::anyhow!("Plugin '{}' not found", plugin_name)) + } + } + + /// Broadcast event to all plugins + pub async fn broadcast_event(&self, event: PluginEvent) { + for plugin_entry in self.plugins.iter() { + let wrapper = plugin_entry.value().clone(); + let plugin = wrapper.read().await; + let _ = plugin.context.event_sender.send(event.clone()); + } + } + + /// Get aggregated metrics from all plugins + pub async fn get_aggregated_metrics(&self) -> HashMap { + let mut all_metrics = HashMap::new(); + + for plugin_entry in self.plugins.iter() { + let name = plugin_entry.key().clone(); + let wrapper = plugin_entry.value().clone(); + let plugin = wrapper.read().await; + + if let Ok(metrics) = plugin.plugin.metrics().await { + all_metrics.insert(name, metrics); + } + } + + all_metrics + } + + /// Validate plugin configuration + async fn validate_plugin_config(&self, name: &str, config: &PluginConfig) -> Result<()> { + // Basic validation + if config.name != name { + return Err(anyhow::anyhow!("Plugin name mismatch")); + } + + if config.bind_ports.is_empty() && config.domains.is_empty() { + return Err(anyhow::anyhow!("Plugin must bind to at least one port or handle at least one domain")); + } + + // Validate port ranges + for port in &config.bind_ports { + if *port == 0 || *port > 65535 { + return Err(anyhow::anyhow!("Invalid port: {}", port)); + } + } + + // Validate domain patterns + for domain in &config.domains { + if domain.is_empty() { + return Err(anyhow::anyhow!("Empty domain pattern")); + } + } + + Ok(()) + } + + /// Graceful shutdown of all plugins + pub async fn shutdown_all(&self) -> Result<()> { + tracing::info!("Shutting down all plugins..."); + + let plugin_names: Vec = self.plugins.iter() + .map(|entry| entry.key().clone()) + .collect(); + + for name in plugin_names { + if let Err(e) = self.stop_plugin(&name).await { + tracing::error!("Error stopping plugin '{}': {}", name, e); + } + } + + tracing::info!("All plugins shut down"); + Ok(()) + } +} + +/// Event router for inter-plugin communication +pub struct EventRouter { + subscribers: Arc>>>, // event_type -> plugin_names +} + +impl EventRouter { + pub fn new() -> Self { + Self { + subscribers: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn subscribe(&self, plugin_name: String, event_type: String) { + let mut subscribers = self.subscribers.write().await; + subscribers + .entry(event_type) + .or_insert_with(Vec::new) + .push(plugin_name); + } + + pub async fn unsubscribe(&self, plugin_name: &str, event_type: &str) { + let mut subscribers = self.subscribers.write().await; + if let Some(plugin_list) = subscribers.get_mut(event_type) { + plugin_list.retain(|name| name != plugin_name); + } + } + + pub async fn get_subscribers(&self, event_type: &str) -> Vec { + let subscribers = self.subscribers.read().await; + subscribers.get(event_type).cloned().unwrap_or_default() + } +} + +/// Metrics aggregator +pub struct MetricsAggregator { + // Could store historical metrics, calculate averages, etc. +} + +impl MetricsAggregator { + pub fn new() -> Self { + Self {} + } + + pub async fn aggregate_metrics(&self, plugin_metrics: &HashMap) -> serde_json::Value { + let mut total_connections = 0u64; + let mut total_bytes_sent = 0u64; + let mut total_bytes_received = 0u64; + let mut total_errors = 0u64; + + for (_plugin_name, metrics) in plugin_metrics { + total_connections += metrics.connections_total; + total_bytes_sent += metrics.bytes_sent; + total_bytes_received += metrics.bytes_received; + total_errors += metrics.errors_total; + } + + serde_json::json!({ + "total_connections": total_connections, + "total_bytes_sent": total_bytes_sent, + "total_bytes_received": total_bytes_received, + "total_errors": total_errors, + "plugin_count": plugin_metrics.len(), + "plugins": plugin_metrics + }) + } +} \ No newline at end of file diff --git a/src/router.rs b/src/router.rs new file mode 100644 index 0000000..06a9cc3 --- /dev/null +++ b/src/router.rs @@ -0,0 +1,680 @@ +// src/router.rs - Request routing logic +use crate::config::{DomainConfig, ProxyInstanceConfig}; +use crate::plugin::PluginCapabilities; +use anyhow::Result; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, warn}; + +/// Request router that determines which plugin should handle a request +pub struct RequestRouter { + domain_mappings: Arc>>, + proxy_instances: Arc>>, + port_mappings: Arc>>, // port -> proxy instance name + protocol_mappings: Arc>>>, // protocol -> plugin names + routing_cache: Arc>>>, // cache for performance +} + +#[derive(Debug, Clone)] +pub struct RoutingDecision { + pub plugin_names: Vec, + pub proxy_instance: String, + pub domain_config: Option, + pub routing_metadata: HashMap, + pub priority: i32, + pub load_balancing_config: Option, +} + +impl RequestRouter { + /// Create a new request router + pub fn new( + domains: HashMap, + proxies: Vec, + ) -> Self { + let mut port_mappings = HashMap::new(); + let mut proxy_instances = HashMap::new(); + + // Build proxy instance mappings + for proxy in proxies { + // Map ports to proxy instances + for port in &proxy.ports { + port_mappings.insert(*port, proxy.name.clone()); + } + + proxy_instances.insert(proxy.name.clone(), proxy); + } + + Self { + domain_mappings: Arc::new(RwLock::new(domains)), + proxy_instances: Arc::new(RwLock::new(proxy_instances)), + port_mappings: Arc::new(RwLock::new(port_mappings)), + protocol_mappings: Arc::new(RwLock::new(HashMap::new())), + routing_cache: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Find plugins that can handle a request for a specific domain and path + pub async fn find_plugins_for_request( + &self, + domain: &str, + path: &str, + protocol: &str, + ) -> Vec { + // Try cache first + let cache_key = format!("{}:{}:{}", domain, path, protocol); + { + let cache = self.routing_cache.read().await; + if let Some(cached_plugins) = cache.get(&cache_key) { + debug!("Cache hit for routing decision: {}", cache_key); + return cached_plugins.clone(); + } + } + + let mut matching_plugins = Vec::new(); + + // Find domain configuration + let domain_config = self.find_domain_config(domain).await; + + if let Some(domain_cfg) = &domain_config { + // Get the proxy instance for this domain + let proxy_instances = self.proxy_instances.read().await; + if let Some(proxy) = proxy_instances.get(&domain_cfg.proxy_instance) { + // Check if proxy is enabled and handles this protocol + if proxy.enabled { + matching_plugins.push(proxy.plugin_type.clone()); + } + } + } else { + // No specific domain config, try to find plugins by protocol + let protocol_mappings = self.protocol_mappings.read().await; + if let Some(plugins) = protocol_mappings.get(protocol) { + matching_plugins.extend(plugins.clone()); + } + } + + // Sort by priority + matching_plugins.sort(); + matching_plugins.dedup(); + + // Cache the result + { + let mut cache = self.routing_cache.write().await; + cache.insert(cache_key, matching_plugins.clone()); + + // Limit cache size + if cache.len() > 10000 { + cache.clear(); + } + } + + debug!("Found {} plugins for {}:{} ({})", matching_plugins.len(), domain, path, protocol); + matching_plugins + } + + /// Find plugins that can handle a specific protocol + pub async fn find_plugins_for_protocol(&self, protocol: &str) -> Vec { + let protocol_mappings = self.protocol_mappings.read().await; + protocol_mappings.get(protocol).cloned().unwrap_or_default() + } + + /// Find plugins that can handle a specific port + pub async fn find_plugins_for_port(&self, port: u16) -> Vec { + let port_mappings = self.port_mappings.read().await; + let proxy_instances = self.proxy_instances.read().await; + + if let Some(proxy_name) = port_mappings.get(&port) { + if let Some(proxy) = proxy_instances.get(proxy_name) { + if proxy.enabled { + return vec![proxy.plugin_type.clone()]; + } + } + } + + Vec::new() + } + + /// Get routing decision with full context + pub async fn get_routing_decision( + &self, + domain: &str, + path: &str, + protocol: &str, + port: Option, + ) -> Option { + let plugins = self.find_plugins_for_request(domain, path, protocol).await; + + if plugins.is_empty() { + return None; + } + + // Find domain config + let domain_config = self.find_domain_config(domain).await; + let proxy_instance_name = domain_config.as_ref() + .map(|dc| dc.proxy_instance.clone()) + .unwrap_or_else(|| "default".to_string()); + + // Get proxy instance details + let proxy_instances = self.proxy_instances.read().await; + let proxy_instance = proxy_instances.get(&proxy_instance_name); + + let mut routing_metadata = HashMap::new(); + routing_metadata.insert("domain".to_string(), domain.to_string()); + routing_metadata.insert("path".to_string(), path.to_string()); + routing_metadata.insert("protocol".to_string(), protocol.to_string()); + + if let Some(port) = port { + routing_metadata.insert("port".to_string(), port.to_string()); + } + + Some(RoutingDecision { + plugin_names: plugins, + proxy_instance: proxy_instance_name, + domain_config: domain_config.clone(), + routing_metadata, + priority: proxy_instance.map(|p| p.priority).unwrap_or(0), + load_balancing_config: proxy_instance.and_then(|p| p.load_balancing.clone()), + }) + } + + /// Find domain configuration (including wildcard matching) + async fn find_domain_config(&self, domain: &str) -> Option { + let domain_mappings = self.domain_mappings.read().await; + + // Try exact match first + if let Some(config) = domain_mappings.get(domain) { + return Some(config.clone()); + } + + // Try wildcard matching + for (pattern, config) in domain_mappings.iter() { + if self.domain_matches_pattern(pattern, domain) { + return Some(config.clone()); + } + } + + None + } + + /// Check if domain matches pattern (supports wildcards) + fn domain_matches_pattern(&self, pattern: &str, domain: &str) -> bool { + if pattern == "*" { + return true; // Match all + } + + if pattern.starts_with("*.") { + let suffix = &pattern[2..]; + return domain.ends_with(suffix); + } + + if pattern.ends_with(".*") { + let prefix = &pattern[..pattern.len() - 2]; + return domain.starts_with(prefix); + } + + // Support regex patterns + if pattern.starts_with("~") { + let regex_pattern = &pattern[1..]; + if let Ok(regex) = regex::Regex::new(regex_pattern) { + return regex.is_match(domain); + } + } + + pattern == domain + } + + /// Register plugin capabilities for routing + pub async fn register_plugin_capabilities( + &self, + plugin_name: String, + capabilities: PluginCapabilities, + ) { + let mut protocol_mappings = self.protocol_mappings.write().await; + + // Register standard protocol capabilities + if capabilities.can_handle_tcp { + protocol_mappings + .entry("TCP".to_string()) + .or_insert_with(Vec::new) + .push(plugin_name.clone()); + } + + if capabilities.can_handle_udp { + protocol_mappings + .entry("UDP".to_string()) + .or_insert_with(Vec::new) + .push(plugin_name.clone()); + } + + // Register custom protocols + for protocol in capabilities.custom_protocols { + protocol_mappings + .entry(protocol.to_uppercase()) + .or_insert_with(Vec::new) + .push(plugin_name.clone()); + } + + debug!("Registered plugin '{}' capabilities", plugin_name); + } + + /// Unregister plugin capabilities + pub async fn unregister_plugin_capabilities(&self, plugin_name: &str) { + let mut protocol_mappings = self.protocol_mappings.write().await; + + for (_, plugins) in protocol_mappings.iter_mut() { + plugins.retain(|name| name != plugin_name); + } + + // Clear cache to force re-evaluation + { + let mut cache = self.routing_cache.write().await; + cache.clear(); + } + + debug!("Unregistered plugin '{}' capabilities", plugin_name); + } + + /// Update configuration + pub async fn update_config( + &self, + domains: HashMap, + proxies: Vec, + ) { + // Update domain mappings + { + let mut domain_mappings = self.domain_mappings.write().await; + *domain_mappings = domains; + } + + // Update proxy instances and port mappings + { + let mut proxy_instances = self.proxy_instances.write().await; + let mut port_mappings = self.port_mappings.write().await; + + // Clear existing mappings + proxy_instances.clear(); + port_mappings.clear(); + + // Rebuild mappings + for proxy in proxies { + // Map ports to proxy instances + for port in &proxy.ports { + port_mappings.insert(*port, proxy.name.clone()); + } + + proxy_instances.insert(proxy.name.clone(), proxy); + } + } + + // Clear cache to force re-evaluation + { + let mut cache = self.routing_cache.write().await; + cache.clear(); + } + + debug!("Updated router configuration"); + } + + /// Add domain mapping + pub async fn add_domain(&self, domain: String, config: DomainConfig) { + let mut domain_mappings = self.domain_mappings.write().await; + domain_mappings.insert(domain.clone(), config); + + // Clear relevant cache entries + { + let mut cache = self.routing_cache.write().await; + cache.retain(|key, _| !key.starts_with(&format!("{}:", domain))); + } + + debug!("Added domain mapping: {}", domain); + } + + /// Remove domain mapping + pub async fn remove_domain(&self, domain: &str) { + let mut domain_mappings = self.domain_mappings.write().await; + domain_mappings.remove(domain); + + // Clear relevant cache entries + { + let mut cache = self.routing_cache.write().await; + cache.retain(|key, _| !key.starts_with(&format!("{}:", domain))); + } + + debug!("Removed domain mapping: {}", domain); + } + + /// Add proxy instance + pub async fn add_proxy_instance(&self, proxy: ProxyInstanceConfig) { + let proxy_name = proxy.name.clone(); + + // Update port mappings + { + let mut port_mappings = self.port_mappings.write().await; + for port in &proxy.ports { + port_mappings.insert(*port, proxy_name.clone()); + } + } + + // Update proxy instances + { + let mut proxy_instances = self.proxy_instances.write().await; + proxy_instances.insert(proxy_name.clone(), proxy); + } + + // Clear cache + { + let mut cache = self.routing_cache.write().await; + cache.clear(); + } + + debug!("Added proxy instance: {}", proxy_name); + } + + /// Remove proxy instance + pub async fn remove_proxy_instance(&self, proxy_name: &str) { + // Remove from proxy instances + let removed_proxy = { + let mut proxy_instances = self.proxy_instances.write().await; + proxy_instances.remove(proxy_name) + }; + + // Remove port mappings + if let Some(proxy) = removed_proxy { + let mut port_mappings = self.port_mappings.write().await; + for port in &proxy.ports { + port_mappings.remove(port); + } + } + + // Clear cache + { + let mut cache = self.routing_cache.write().await; + cache.clear(); + } + + debug!("Removed proxy instance: {}", proxy_name); + } + + /// Get routing statistics + pub async fn get_statistics(&self) -> RoutingStatistics { + let domain_mappings = self.domain_mappings.read().await; + let proxy_instances = self.proxy_instances.read().await; + let port_mappings = self.port_mappings.read().await; + let protocol_mappings = self.protocol_mappings.read().await; + let cache = self.routing_cache.read().await; + + RoutingStatistics { + total_domains: domain_mappings.len(), + total_proxy_instances: proxy_instances.len(), + total_port_mappings: port_mappings.len(), + total_protocol_mappings: protocol_mappings.len(), + cache_size: cache.len(), + enabled_proxies: proxy_instances.values().filter(|p| p.enabled).count(), + disabled_proxies: proxy_instances.values().filter(|p| !p.enabled).count(), + } + } + + /// Validate routing configuration + pub async fn validate_configuration(&self) -> Vec { + let mut errors = Vec::new(); + + let domain_mappings = self.domain_mappings.read().await; + let proxy_instances = self.proxy_instances.read().await; + + // Check that all domain configs reference valid proxy instances + for (domain, config) in domain_mappings.iter() { + if !proxy_instances.contains_key(&config.proxy_instance) { + errors.push(RoutingValidationError { + error_type: "missing_proxy_instance".to_string(), + domain: Some(domain.clone()), + proxy_instance: Some(config.proxy_instance.clone()), + message: format!("Domain '{}' references non-existent proxy instance '{}'", + domain, config.proxy_instance), + }); + } + } + + // Check for port conflicts + let mut port_usage = HashMap::new(); + for (name, proxy) in proxy_instances.iter() { + for port in &proxy.ports { + if let Some(existing) = port_usage.insert(*port, name.clone()) { + errors.push(RoutingValidationError { + error_type: "port_conflict".to_string(), + domain: None, + proxy_instance: Some(name.clone()), + message: format!("Port {} is used by both '{}' and '{}'", + port, existing, name), + }); + } + } + } + + errors + } + + /// Clear routing cache + pub async fn clear_cache(&self) { + let mut cache = self.routing_cache.write().await; + cache.clear(); + debug!("Cleared routing cache"); + } + + /// Get cache statistics + pub async fn get_cache_statistics(&self) -> CacheStatistics { + let cache = self.routing_cache.read().await; + + CacheStatistics { + total_entries: cache.len(), + memory_usage_estimate: cache.len() * 100, // Rough estimate + } + } +} + +/// Routing statistics +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RoutingStatistics { + pub total_domains: usize, + pub total_proxy_instances: usize, + pub total_port_mappings: usize, + pub total_protocol_mappings: usize, + pub cache_size: usize, + pub enabled_proxies: usize, + pub disabled_proxies: usize, +} + +/// Routing validation error +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RoutingValidationError { + pub error_type: String, + pub domain: Option, + pub proxy_instance: Option, + pub message: String, +} + +/// Cache statistics +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct CacheStatistics { + pub total_entries: usize, + pub memory_usage_estimate: usize, +} + +/// Advanced routing features +impl RequestRouter { + /// Route with load balancing consideration + pub async fn route_with_load_balancing( + &self, + domain: &str, + path: &str, + protocol: &str, + client_info: &ClientInfo, + ) -> Option { + let mut decision = self.get_routing_decision(domain, path, protocol, None).await?; + + // Apply load balancing logic if configured + if let Some(lb_config) = &decision.load_balancing_config { + // Implementation would depend on load balancing strategy + decision.routing_metadata.insert( + "load_balancing_strategy".to_string(), + lb_config.clone(), + ); + decision.routing_metadata.insert( + "client_id".to_string(), + client_info.id.clone(), + ); + } + + Some(decision) + } + + /// Route with session affinity + pub async fn route_with_session_affinity( + &self, + domain: &str, + path: &str, + protocol: &str, + session_id: &str, + ) -> Option { + let mut decision = self.get_routing_decision(domain, path, protocol, None).await?; + + // Add session affinity metadata + decision.routing_metadata.insert( + "session_id".to_string(), + session_id.to_string(), + ); + decision.routing_metadata.insert( + "session_affinity".to_string(), + "enabled".to_string(), + ); + + Some(decision) + } + + /// Route with A/B testing support + pub async fn route_with_ab_testing( + &self, + domain: &str, + path: &str, + protocol: &str, + user_id: &str, + ab_test_config: &ABTestConfig, + ) -> Option { + let mut decision = self.get_routing_decision(domain, path, protocol, None).await?; + + // Determine which variant the user should see + let variant = self.determine_ab_variant(user_id, ab_test_config); + + decision.routing_metadata.insert( + "ab_test_variant".to_string(), + variant, + ); + + Some(decision) + } + + /// Determine A/B test variant + fn determine_ab_variant(&self, user_id: &str, config: &ABTestConfig) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + user_id.hash(&mut hasher); + let hash = hasher.finish(); + + let variant_index = (hash % 100) as u8; + + let mut cumulative_percentage = 0; + for (variant, percentage) in &config.variants { + cumulative_percentage += percentage; + if variant_index < cumulative_percentage { + return variant.clone(); + } + } + + config.variants.first().map(|(v, _)| v.clone()).unwrap_or_else(|| "default".to_string()) + } +} + +/// Client information for routing decisions +#[derive(Debug, Clone)] +pub struct ClientInfo { + pub id: String, + pub ip_address: std::net::IpAddr, + pub user_agent: Option, + pub headers: HashMap, +} + +/// A/B testing configuration +#[derive(Debug, Clone)] +pub struct ABTestConfig { + pub test_name: String, + pub variants: Vec<(String, u8)>, // (variant_name, percentage) + pub enabled: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{DomainConfig, ProxyInstanceConfig}; + + #[tokio::test] + async fn test_domain_matching() { + let router = RequestRouter::new(HashMap::new(), Vec::new()); + + // Test exact match + assert!(router.domain_matches_pattern("example.com", "example.com")); + assert!(!router.domain_matches_pattern("example.com", "test.com")); + + // Test wildcard match + assert!(router.domain_matches_pattern("*.example.com", "api.example.com")); + assert!(router.domain_matches_pattern("*.example.com", "www.example.com")); + assert!(!router.domain_matches_pattern("*.example.com", "example.com")); + + // Test catch-all + assert!(router.domain_matches_pattern("*", "anything.com")); + } + + #[tokio::test] + async fn test_routing_decision() { + let mut domains = HashMap::new(); + domains.insert("example.com".to_string(), DomainConfig { + proxy_instance: "http_proxy".to_string(), + backend_service: Some("http://backend:8080".to_string()), + ssl_config: None, + cors_config: None, + cache_config: None, + custom_headers: HashMap::new(), + rewrite_rules: vec![], + access_control: None, + }); + + let proxies = vec![ProxyInstanceConfig { + name: "http_proxy".to_string(), + plugin_type: "http_proxy".to_string(), + enabled: true, + priority: 0, + ports: vec![8080], + bind_addresses: vec!["0.0.0.0".to_string()], + domains: vec!["example.com".to_string()], + plugin_config: serde_json::json!({}), + middleware: vec![], + load_balancing: None, + health_check: None, + circuit_breaker: None, + rate_limiting: None, + ssl_config: None, + }]; + + let router = RequestRouter::new(domains, proxies); + + let plugins = router.find_plugins_for_request("example.com", "/api", "HTTP").await; + assert_eq!(plugins, vec!["http_proxy"]); + + let decision = router.get_routing_decision("example.com", "/api", "HTTP", None).await; + assert!(decision.is_some()); + + let decision = decision.unwrap(); + assert_eq!(decision.proxy_instance, "http_proxy"); + assert_eq!(decision.plugin_names, vec!["http_proxy"]); + } +} \ No newline at end of file diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..e636abf --- /dev/null +++ b/src/server.rs @@ -0,0 +1,733 @@ +// src/server.rs - Main server implementation +use crate::config::RouterConfig; +use crate::plugin::manager::{PluginManager, PluginManagerConfig}; +use crate::plugin::{PluginEvent, RouterEvent}; +use crate::router::RequestRouter; +use crate::middleware::MiddlewareStack; +use crate::metrics::MetricsCollector; +use crate::health::HealthChecker; +use crate::management_api::ManagementApi; +use crate::error::{RouterError, Result}; + +use anyhow::Context; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::net::TcpListener; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::time::timeout; +use tracing::{debug, error, info, warn}; + +/// Main proxy server that coordinates all components +pub struct ProxyServer { + config: Arc>, + plugin_manager: Arc, + request_router: Arc, + middleware_stack: Arc, + metrics_collector: Arc, + health_checker: Arc, + management_api: Option>, + + // Server state + running: Arc>, + start_time: Instant, + shutdown_sender: Option>, + event_router: Arc, + + // Statistics + statistics: Arc>, +} + +#[derive(Debug, Clone, Default)] +pub struct ServerStatistics { + pub start_time: std::time::SystemTime, + pub uptime_seconds: u64, + pub total_connections: u64, + pub active_connections: u64, + pub total_requests: u64, + pub successful_requests: u64, + pub failed_requests: u64, + pub bytes_sent: u64, + pub bytes_received: u64, + pub plugin_count: usize, + pub average_response_time_ms: f64, + pub last_error: Option, + pub memory_usage_bytes: u64, +} + +impl ProxyServer { + /// Create a new proxy server with the given configuration + pub async fn new(config: RouterConfig) -> Result { + info!("Initializing proxy server..."); + + // Initialize plugin manager + let plugin_manager_config = PluginManagerConfig { + loader: crate::plugin::loader::LoaderConfig { + plugin_directories: config.plugins.plugin_directories.clone(), + auto_reload: config.plugins.auto_reload, + reload_interval_seconds: config.plugins.reload_interval_seconds, + max_plugin_memory_mb: config.plugins.max_plugin_memory_mb, + plugin_timeout_seconds: config.plugins.plugin_timeout_seconds, + allowed_plugins: config.plugins.allowed_plugins.clone(), + blocked_plugins: config.plugins.blocked_plugins.clone(), + require_signature: config.plugins.require_signature, + signature_key_path: config.plugins.signature_key_path.clone(), + }, + enable_inter_plugin_communication: config.plugins.enable_inter_plugin_communication, + max_concurrent_loads: config.plugins.max_concurrent_loads, + plugin_startup_timeout_seconds: config.plugins.plugin_timeout_seconds, + health_check_interval_seconds: config.plugins.health_check_interval_seconds, + metrics_collection_interval_seconds: config.plugins.metrics_collection_interval_seconds, + enable_plugin_isolation: config.plugins.enable_plugin_isolation, + default_plugin_config: serde_json::json!({}), + }; + + let plugin_manager = Arc::new(PluginManager::new(plugin_manager_config)); + + // Initialize request router + let request_router = Arc::new(RequestRouter::new( + config.domains.clone(), + config.proxies.clone(), + )); + + // Initialize middleware stack + let middleware_stack = Arc::new(MiddlewareStack::new(config.middleware.clone())); + + // Initialize metrics collector + let metrics_collector = Arc::new(MetricsCollector::new( + config.monitoring.clone(), + )); + + // Initialize health checker + let health_checker = Arc::new(HealthChecker::new( + config.server.health_check_port, + )); + + // Create event router + let event_router = Arc::new(EventRouter::new()); + + // Create shutdown channel + let (shutdown_tx, _) = broadcast::channel(1); + + let server = Self { + config: Arc::new(RwLock::new(config)), + plugin_manager, + request_router, + middleware_stack, + metrics_collector, + health_checker, + management_api: None, + running: Arc::new(RwLock::new(false)), + start_time: Instant::now(), + shutdown_sender: Some(shutdown_tx), + event_router, + statistics: Arc::new(RwLock::new(ServerStatistics { + start_time: std::time::SystemTime::now(), + ..Default::default() + })), + }; + + info!("Proxy server initialized successfully"); + Ok(server) + } + + /// Start the proxy server + pub async fn start(&self) -> Result<()> { + info!("Starting proxy server..."); + + // Mark as running + *self.running.write().await = true; + + // Start plugin manager + self.plugin_manager.start().await + .context("Failed to start plugin manager")?; + + // Load built-in plugins first + self.load_builtin_plugins().await?; + + // Start metrics collector + self.metrics_collector.start().await + .context("Failed to start metrics collector")?; + + // Start health checker + self.health_checker.start(self.plugin_manager.clone()).await + .context("Failed to start health checker")?; + + // Start event processing + self.start_event_processing().await?; + + // Start listeners for each configured address + let config = self.config.read().await; + let mut listeners = Vec::new(); + + for listen_addr in &config.server.listen_addresses { + let addr: SocketAddr = listen_addr.parse() + .with_context(|| format!("Invalid listen address: {}", listen_addr))?; + + let listener = TcpListener::bind(addr).await + .with_context(|| format!("Failed to bind to {}", addr))?; + + info!("Listening on {}", addr); + listeners.push((listener, addr)); + } + + // Start connection handlers + for (listener, addr) in listeners { + let server = Arc::new(self.clone_for_handler()); + + tokio::spawn(async move { + if let Err(e) = server.handle_connections(listener, addr).await { + error!("Connection handler error on {}: {}", addr, e); + } + }); + } + + info!("Proxy server started successfully"); + Ok(()) + } + + /// Stop the proxy server + pub async fn stop(&self) -> Result<()> { + info!("Stopping proxy server..."); + + // Mark as not running + *self.running.write().await = false; + + // Send shutdown signal + if let Some(sender) = &self.shutdown_sender { + let _ = sender.send(()); + } + + // Stop components in reverse order + if let Some(api) = &self.management_api { + api.stop().await?; + } + + self.health_checker.stop().await?; + self.metrics_collector.stop().await?; + self.plugin_manager.stop().await?; + + info!("Proxy server stopped"); + Ok(()) + } + + /// Handle incoming connections + async fn handle_connections(&self, listener: TcpListener, addr: SocketAddr) -> Result<()> { + let mut shutdown_rx = self.shutdown_sender.as_ref().unwrap().subscribe(); + + loop { + tokio::select! { + // Handle new connections + connection = listener.accept() => { + match connection { + Ok((stream, client_addr)) => { + debug!("Accepted connection from {} on {}", client_addr, addr); + + // Update statistics + { + let mut stats = self.statistics.write().await; + stats.total_connections += 1; + stats.active_connections += 1; + } + + // Spawn connection handler + let server = Arc::new(self.clone_for_handler()); + tokio::spawn(async move { + if let Err(e) = server.handle_single_connection(stream, client_addr).await { + error!("Connection error from {}: {}", client_addr, e); + } + + // Update statistics + { + let mut stats = server.statistics.write().await; + stats.active_connections = stats.active_connections.saturating_sub(1); + } + }); + } + Err(e) => { + error!("Failed to accept connection on {}: {}", addr, e); + } + } + } + + // Handle shutdown + _ = shutdown_rx.recv() => { + info!("Shutting down listener on {}", addr); + break; + } + } + } + + Ok(()) + } + + /// Handle a single connection + async fn handle_single_connection( + &self, + stream: tokio::net::TcpStream, + client_addr: SocketAddr, + ) -> Result<()> { + let start_time = Instant::now(); + + // Create connection context + let connection_context = crate::plugin::ConnectionContext { + connection_id: crate::plugin::utils::generate_connection_id(), + client_addr, + server_addr: stream.local_addr().unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap()), + metadata: HashMap::new(), + created_at: std::time::SystemTime::now(), + }; + + // Try to determine which plugin should handle this connection + // This is where the routing logic comes in + + // For now, we'll implement a simple protocol detection + let mut buffer = [0u8; 1024]; + let mut stream = stream; + + // Peek at the first few bytes to determine protocol + stream.readable().await?; + let n = stream.try_read(&mut buffer)?; + + if n == 0 { + return Ok(()); // Connection closed immediately + } + + // Determine protocol and route to appropriate plugin + let protocol_hint = self.detect_protocol(&buffer[..n])?; + + match protocol_hint { + ProtocolHint::Http => { + self.handle_http_connection(stream, client_addr, &buffer[..n]).await?; + } + ProtocolHint::Tcp => { + self.handle_tcp_connection(stream, client_addr, connection_context).await?; + } + ProtocolHint::Unknown => { + // Try to find a plugin that can handle unknown protocols + self.handle_unknown_connection(stream, client_addr, &buffer[..n]).await?; + } + } + + // Update statistics + let duration = start_time.elapsed(); + { + let mut stats = self.statistics.write().await; + stats.total_requests += 1; + stats.successful_requests += 1; + + // Update average response time (simple moving average) + let duration_ms = duration.as_millis() as f64; + if stats.average_response_time_ms == 0.0 { + stats.average_response_time_ms = duration_ms; + } else { + stats.average_response_time_ms = (stats.average_response_time_ms * 0.9) + (duration_ms * 0.1); + } + } + + Ok(()) + } + + /// Handle HTTP connection + async fn handle_http_connection( + &self, + stream: tokio::net::TcpStream, + client_addr: SocketAddr, + initial_data: &[u8], + ) -> Result<()> { + // Parse HTTP request from initial data + let request_info = self.parse_http_request(initial_data)?; + + // Route to appropriate plugin based on domain + let plugins = self.request_router.find_plugins_for_request( + &request_info.domain, + &request_info.path, + "HTTP", + ).await; + + if plugins.is_empty() { + warn!("No plugin found for HTTP request to {}{}", request_info.domain, request_info.path); + return self.send_http_error(stream, 404, "Not Found").await; + } + + // Use the first matching plugin + let plugin_name = &plugins[0]; + + // Create request metadata + let metadata = crate::plugin::RequestMetadata { + domain: request_info.domain.clone(), + path: request_info.path.clone(), + method: request_info.method.clone(), + headers: request_info.headers.clone(), + query_params: crate::plugin::utils::parse_query_params(&request_info.query), + client_addr, + protocol: crate::plugin::ProtocolType::Http, + route_name: Some(plugin_name.clone()), + upstream_url: None, + request_id: crate::plugin::utils::generate_connection_id(), + timestamp: std::time::SystemTime::now(), + }; + + // Send event to plugin + let event = PluginEvent::DataReceived(crate::plugin::DataPacket { + data: initial_data.to_vec(), + metadata: HashMap::new(), + context: crate::plugin::ConnectionContext { + connection_id: metadata.request_id.clone(), + client_addr, + server_addr: stream.local_addr().unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap()), + metadata: HashMap::new(), + created_at: std::time::SystemTime::now(), + }, + }); + + self.plugin_manager.send_event_to_plugin(plugin_name, event).await?; + + Ok(()) + } + + /// Handle TCP connection + async fn handle_tcp_connection( + &self, + stream: tokio::net::TcpStream, + client_addr: SocketAddr, + context: crate::plugin::ConnectionContext, + ) -> Result<()> { + // Find TCP plugins + let plugins = self.request_router.find_plugins_for_protocol("TCP").await; + + if plugins.is_empty() { + warn!("No TCP plugin available for connection from {}", client_addr); + return Ok(()); + } + + // Use the first TCP plugin + let plugin_name = &plugins[0]; + + // Create connection event + let event = PluginEvent::ConnectionEstablished(context); + + self.plugin_manager.send_event_to_plugin(plugin_name, event).await?; + + Ok(()) + } + + /// Handle unknown protocol connection + async fn handle_unknown_connection( + &self, + stream: tokio::net::TcpStream, + client_addr: SocketAddr, + initial_data: &[u8], + ) -> Result<()> { + // Try to find a plugin that accepts unknown protocols + let plugins = self.plugin_manager.list_plugins(); + + for plugin_name in plugins { + if let Some(capabilities) = self.plugin_manager.registry().get_plugin_capabilities(&plugin_name).await { + if capabilities.custom_protocols.contains(&"unknown".to_string()) { + let event = PluginEvent::DataReceived(crate::plugin::DataPacket { + data: initial_data.to_vec(), + metadata: HashMap::new(), + context: crate::plugin::ConnectionContext { + connection_id: crate::plugin::utils::generate_connection_id(), + client_addr, + server_addr: stream.local_addr().unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap()), + metadata: HashMap::new(), + created_at: std::time::SystemTime::now(), + }, + }); + + self.plugin_manager.send_event_to_plugin(&plugin_name, event).await?; + return Ok(()); + } + } + } + + debug!("No plugin found for unknown protocol from {}", client_addr); + Ok(()) + } + + /// Detect protocol from initial data + fn detect_protocol(&self, data: &[u8]) -> Result { + if data.is_empty() { + return Ok(ProtocolHint::Unknown); + } + + // Check for HTTP + if data.starts_with(b"GET ") || + data.starts_with(b"POST ") || + data.starts_with(b"PUT ") || + data.starts_with(b"DELETE ") || + data.starts_with(b"HEAD ") || + data.starts_with(b"OPTIONS ") || + data.starts_with(b"PATCH ") { + return Ok(ProtocolHint::Http); + } + + // Check for TLS + if data.len() >= 6 && data[0] == 0x16 && data[1] == 0x03 { + return Ok(ProtocolHint::Http); // Assume HTTPS + } + + // Default to TCP for other protocols + Ok(ProtocolHint::Tcp) + } + + /// Parse HTTP request from data + fn parse_http_request(&self, data: &[u8]) -> Result { + let request_str = std::str::from_utf8(data) + .map_err(|_| RouterError::RequestError("Invalid UTF-8 in request".to_string()))?; + + let lines: Vec<&str> = request_str.split("\r\n").collect(); + if lines.is_empty() { + return Err(RouterError::RequestError("Empty request".to_string())); + } + + // Parse request line + let request_line_parts: Vec<&str> = lines[0].split_whitespace().collect(); + if request_line_parts.len() < 3 { + return Err(RouterError::RequestError("Invalid request line".to_string())); + } + + let method = request_line_parts[0].to_string(); + let uri = request_line_parts[1]; + + // Parse URI + let (path, query) = if let Some(query_pos) = uri.find('?') { + (uri[..query_pos].to_string(), uri[query_pos + 1..].to_string()) + } else { + (uri.to_string(), String::new()) + }; + + // Parse headers + let mut headers = HashMap::new(); + let mut domain = String::new(); + + for line in lines.iter().skip(1) { + if line.is_empty() { + break; + } + + if let Some(colon_pos) = line.find(':') { + let key = line[..colon_pos].trim().to_string(); + let value = line[colon_pos + 1..].trim().to_string(); + + if key.to_lowercase() == "host" { + domain = value.split(':').next().unwrap_or(&value).to_string(); + } + + headers.insert(key, value); + } + } + + Ok(HttpRequestInfo { + method, + path, + query, + domain, + headers, + }) + } + + /// Send HTTP error response + async fn send_http_error(&self, mut stream: tokio::net::TcpStream, status: u16, message: &str) -> Result<()> { + use tokio::io::AsyncWriteExt; + + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\n\r\n{}", + status, message, message.len(), message + ); + + stream.write_all(response.as_bytes()).await?; + stream.flush().await?; + + Ok(()) + } + + /// Load built-in plugins + async fn load_builtin_plugins(&self) -> Result<()> { + info!("Loading built-in plugins..."); + + // Load HTTP proxy plugin + let http_plugin = Box::new(crate::builtin_plugins::HttpProxyPlugin::new()); + let http_config = crate::plugin::PluginConfig { + name: "http_proxy".to_string(), + enabled: true, + priority: 0, + bind_ports: vec![], + bind_addresses: vec![], + domains: vec![], + config: serde_json::json!({}), + middleware_chain: vec![], + load_balancing: None, + health_check: None, + }; + + self.plugin_manager.registry().register_plugin( + "http_proxy".to_string(), + http_plugin, + http_config, + ).await?; + + // Load TCP proxy plugin + let tcp_plugin = Box::new(crate::builtin_plugins::TcpProxyPlugin::new()); + let tcp_config = crate::plugin::PluginConfig { + name: "tcp_proxy".to_string(), + enabled: true, + priority: 0, + bind_ports: vec![], + bind_addresses: vec![], + domains: vec![], + config: serde_json::json!({}), + middleware_chain: vec![], + load_balancing: None, + health_check: None, + }; + + self.plugin_manager.registry().register_plugin( + "tcp_proxy".to_string(), + tcp_plugin, + tcp_config, + ).await?; + + info!("Built-in plugins loaded successfully"); + Ok(()) + } + + /// Start event processing + async fn start_event_processing(&self) -> Result<()> { + let event_router = self.event_router.clone(); + let plugin_manager = self.plugin_manager.clone(); + + tokio::spawn(async move { + // Event processing loop would go here + // This would handle inter-plugin communication + }); + + Ok(()) + } + + /// Clone for use in connection handlers + fn clone_for_handler(&self) -> Self { + Self { + config: self.config.clone(), + plugin_manager: self.plugin_manager.clone(), + request_router: self.request_router.clone(), + middleware_stack: self.middleware_stack.clone(), + metrics_collector: self.metrics_collector.clone(), + health_checker: self.health_checker.clone(), + management_api: self.management_api.clone(), + running: self.running.clone(), + start_time: self.start_time, + shutdown_sender: self.shutdown_sender.clone(), + event_router: self.event_router.clone(), + statistics: self.statistics.clone(), + } + } + + /// Get plugin manager + pub fn plugin_manager(&self) -> Arc { + self.plugin_manager.clone() + } + + /// Reload configuration + pub async fn reload_config(&self, new_config: RouterConfig) -> Result<()> { + info!("Reloading configuration..."); + + { + let mut config = self.config.write().await; + *config = new_config.clone(); + } + + // Update router + self.request_router.update_config( + new_config.domains.clone(), + new_config.proxies.clone(), + ).await; + + // Update middleware stack + self.middleware_stack.update_config(new_config.middleware.clone()).await; + + info!("Configuration reloaded successfully"); + Ok(()) + } + + /// Enable management API + pub async fn enable_management_api(&self, port: u16) -> Result<()> { + if self.management_api.is_some() { + return Ok(()); // Already enabled + } + + let api = Arc::new(ManagementApi::new( + port, + self.plugin_manager.clone(), + self.config.clone(), + )); + + api.start().await?; + + // Store reference (this would need to be mutable in real implementation) + // self.management_api = Some(api); + + Ok(()) + } + + /// Check if server is running + pub async fn is_running(&self) -> bool { + *self.running.read().await + } + + /// Get server statistics + pub async fn get_statistics(&self) -> ServerStatistics { + let mut stats = self.statistics.read().await.clone(); + stats.uptime_seconds = self.start_time.elapsed().as_secs(); + stats.plugin_count = self.plugin_manager.list_plugins().len(); + + // Get memory usage (simplified) + #[cfg(target_os = "linux")] + { + if let Ok(info) = procfs::process::Process::myself() { + if let Ok(stat) = info.stat() { + stats.memory_usage_bytes = stat.rss * 4096; // Convert pages to bytes + } + } + } + + stats + } + + /// Wait for server to finish + pub async fn wait(&self) -> Result<()> { + let mut shutdown_rx = self.shutdown_sender.as_ref().unwrap().subscribe(); + let _ = shutdown_rx.recv().await; + Ok(()) + } +} + +#[derive(Debug)] +enum ProtocolHint { + Http, + Tcp, + Unknown, +} + +#[derive(Debug)] +struct HttpRequestInfo { + method: String, + path: String, + query: String, + domain: String, + headers: HashMap, +} + +/// Event router for inter-plugin communication +pub struct EventRouter { + // Event routing implementation would go here +} + +impl EventRouter { + pub fn new() -> Self { + Self {} + } +} \ No newline at end of file