diff --git a/Cargo.lock b/Cargo.lock index c8e8924d..737bfaeb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2030,12 +2030,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "overload", - "winapi", + "windows-sys 0.59.0", ] [[package]] @@ -2201,12 +2200,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking" version = "2.2.1" @@ -3896,9 +3889,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -4143,6 +4136,9 @@ dependencies = [ "sqlx", "thiserror 2.0.12", "tokio", + "tracing", + "tracing-subscriber", + "url", "vectorize-core", ] @@ -4386,22 +4382,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - [[package]] name = "winapi-util" version = "0.1.9" @@ -4411,12 +4391,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-core" version = "0.61.2" diff --git a/Cargo.toml b/Cargo.toml index a9099651..c327e557 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,9 @@ serde = "1.0.219" serde_json = "1.0" sqlparser = "0.51" sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "uuid", "time"] } +tracing = "0.1" +tracing-log = "0.1" +tracing-subscriber = "0.3.20" thiserror = "2.0.12" tiktoken-rs = "0.7.0" tokio = { version = "1.0", features = ["full"] } diff --git a/extension/Makefile b/extension/Makefile index a0e97d8c..6246d05f 100644 --- a/extension/Makefile +++ b/extension/Makefile @@ -42,6 +42,7 @@ clean: setup.dependencies: install-pg_cron install-pgvector install-pgmq install-vectorscale setup.shared_preload_libraries: echo "shared_preload_libraries = 'pg_cron, vectorize'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf + echo "cron.database_name = 'postgres'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf setup.urls: echo "vectorize.embedding_service_url = 'http://localhost:3000/v1'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf echo "vectorize.ollama_service_url = 'http://localhost:3001'" >> ~/.pgrx/data-${PG_VERSION}/postgresql.conf @@ -93,7 +94,7 @@ test-integration: cargo test ${TEST_NAME} -- --ignored --test-threads=1 --nocapture test-unit: - cargo test ${TEST_NAME} -- --test-threads=1 + cargo test ${TEST_NAME} -- --test-threads=1 --nocapture test-version: git fetch --tags diff --git a/extension/src/api.rs b/extension/src/api.rs index 74ba036a..c442f2d2 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -176,40 +176,6 @@ fn encode( Ok(transform(input, &model, api_key).remove(0)) } -#[allow(clippy::too_many_arguments)] -#[deprecated(since = "0.22.0", note = "Please use vectorize.table() instead")] -#[pg_extern] -fn init_rag( - agent_name: &str, - table_name: &str, - unique_record_id: &str, - // column that have data we want to be able to chat with - column: &str, - schema: default!(&str, "'public'"), - index_dist_type: default!(types::IndexDist, "'pgv_hnsw_cosine'"), - // transformer model to use in vector-search - transformer: default!(&str, "'sentence-transformers/all-MiniLM-L6-v2'"), - table_method: default!(types::TableMethod, "'join'"), - schedule: default!(&str, "'* * * * *'"), -) -> Result { - pgrx::warning!("DEPRECATED: vectorize.init_rag() will be removed in a future version. Please use vectorize.table() instead."); - // chat only supports single columns transform - let columns = vec![column.to_string()]; - let transformer_model = Model::new(transformer)?; - init_table( - agent_name, - schema, - table_name, - columns, - unique_record_id, - None, - index_dist_type.into(), - &transformer_model, - table_method.into(), - schedule, - ) -} - /// creates a table indexed with embeddings for chat completion workloads #[pg_extern] fn rag( diff --git a/extension/src/executor.rs b/extension/src/executor.rs index 2031c002..dc491167 100644 --- a/extension/src/executor.rs +++ b/extension/src/executor.rs @@ -25,7 +25,7 @@ pub fn batch_texts( return TableIterator::new(vec![record_ids].into_iter().map(|arr| (arr,))); } - let num_batches = (total_records + batch_size - 1) / batch_size; + let num_batches = total_records.div_ceil(batch_size); let mut batches = Vec::with_capacity(num_batches); diff --git a/extension/src/guc.rs b/extension/src/guc.rs index 863f8ed9..9cc1fb21 100644 --- a/extension/src/guc.rs +++ b/extension/src/guc.rs @@ -250,7 +250,6 @@ pub fn get_guc(guc: VectorizeGuc) -> Option { } } -#[allow(dead_code)] fn handle_cstr(cstr: &CStr) -> Result { if let Ok(s) = cstr.to_str() { Ok(s.to_owned()) diff --git a/extension/src/search.rs b/extension/src/search.rs index 2fc5282b..e0d306b9 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -429,7 +429,7 @@ pub fn cosine_similarity_search( num_results, where_clause, ), - TableMethod::join => query::join_table_cosine_similarity( + TableMethod::join => join_table_cosine_similarity( project, &job_params.schema, &job_params.relation, @@ -452,6 +452,52 @@ pub fn cosine_similarity_search( }) } +pub fn join_table_cosine_similarity( + project: &str, + schema: &str, + table: &str, + join_key: &str, + return_columns: &[String], + num_results: i32, + where_clause: Option, +) -> String { + let cols = &return_columns + .iter() + .map(|s| format!("t0.{s}")) + .collect::>() + .join(","); + let where_str = if let Some(w) = where_clause { + prepare_filter(&w, join_key) + } else { + "".to_string() + }; + let inner_query = format!( + " + SELECT + {join_key}, + 1 - (embeddings <=> $1::vector) AS similarity_score + FROM vectorize._embeddings_{project} + ORDER BY similarity_score DESC + " + ); + format!( + " + SELECT to_jsonb(t) as results + FROM ( + SELECT {cols}, t1.similarity_score + FROM + ( + {inner_query} + ) t1 + INNER JOIN {schema}.{table} t0 on t0.{join_key} = t1.{join_key} + {where_str} + ) t + ORDER BY t.similarity_score DESC + LIMIT {num_results}; + " + ) +} + fn single_table_cosine_similarity( project: &str, schema: &str, @@ -482,3 +528,9 @@ fn single_table_cosine_similarity( cols = return_columns.join(", "), ) } + +// transform user's where_sql into the format search query expects +fn prepare_filter(filter: &str, pkey: &str) -> String { + let wc = filter.replace(pkey, &format!("t0.{pkey}")); + format!("AND {wc}") +} diff --git a/extension/tests/util.rs b/extension/tests/util.rs index a81406bd..83b6f08d 100644 --- a/extension/tests/util.rs +++ b/extension/tests/util.rs @@ -7,7 +7,6 @@ pub mod common { use sqlx::{Pool, Postgres, Row}; use url::{ParseError, Url}; - #[allow(dead_code)] #[derive(FromRow, Debug, serde::Deserialize)] pub struct SearchResult { pub product_id: i32, @@ -16,7 +15,6 @@ pub mod common { pub similarity_score: f64, } - #[allow(dead_code)] #[derive(FromRow, Debug, Serialize)] pub struct SearchJSON { pub search_results: serde_json::Value, diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index db3487eb..8af81032 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -13,5 +13,8 @@ serde_json = { workspace = true } sqlx = { workspace = true} thiserror = { workspace = true } tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +url = { workspace = true } pgwire = { version = "0.30", features = ["server-api-aws-lc-rs"] } \ No newline at end of file diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 62830a56..9e6e405f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,8 +1,15 @@ -use log::{error, info}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::net::ToSocketAddrs; use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::TcpStream; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::RwLock; use tokio::time::timeout; +use tracing::{error, info}; +use url::Url; +use vectorize_core::types::VectorizeJob; use super::message_parser::{log_message_processing, try_parse_complete_message}; use super::protocol::{BUFFER_SIZE, ProxyConfig, WireProxyError}; @@ -129,3 +136,55 @@ where info!("Standard proxy stream closed: {total_bytes} bytes transferred"); Ok(()) } + +pub async fn start_postgres_proxy( + proxy_port: u16, + database_url: String, + job_cache: Arc>>, + db_pool: sqlx::PgPool, +) -> Result<(), Box> { + let bind_address = "0.0.0.0"; + let timeout = 30; + + let listen_addr: SocketAddr = format!("{}:{}", bind_address, proxy_port).parse()?; + + let url = Url::parse(&database_url)?; + let postgres_host = url.host_str().unwrap(); + let postgres_port = url.port().unwrap(); + + let postgres_addr: SocketAddr = format!("{postgres_host}:{postgres_port}") + .to_socket_addrs()? + .next() + .ok_or("Failed to resolve PostgreSQL host address")?; + + let config = Arc::new(ProxyConfig { + postgres_addr, + timeout: Duration::from_secs(timeout), + jobmap: job_cache, + db_pool, + prepared_statements: Arc::new(RwLock::new(HashMap::new())), + }); + + info!("Proxy listening on: {listen_addr}"); + info!("Forwarding to PostgreSQL at: {postgres_addr}"); + + let listener = TcpListener::bind(listen_addr).await?; + + loop { + match listener.accept().await { + Ok((client_stream, client_addr)) => { + info!("New proxy connection from: {client_addr}"); + + let config = Arc::clone(&config); + tokio::spawn(async move { + if let Err(e) = handle_connection_with_timeout(client_stream, config).await { + error!("Proxy connection error from {client_addr}: {e}"); + } + }); + } + Err(e) => { + error!("Failed to accept proxy connection: {e}"); + } + } + } +} diff --git a/server/Cargo.toml b/server/Cargo.toml index 45f53706..bf13d921 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,6 +14,9 @@ vectorize_core = { package = "vectorize-core", path = "../core" } vectorize_worker = { package = "vectorize-worker", path = "../worker" } vectorize_proxy = { package = "vectorize-proxy", path = "../proxy" } pgmq = { workspace = true } +tracing-subscriber = { workspace = true } +tracing = { workspace = true } +tracing-log = { workspace = true } actix-cors = "0.7.1" actix-http = "3.11.0" @@ -26,8 +29,6 @@ bytes = "1.10.1" chrono = {version = "0.4.41", features = ["serde"] } clap = { version = "4.0", features = ["derive"] } env = "1.0.1" -tracing = "0.1" -tracing-log = "0.1" fallible-iterator = "0.3.0" futures = "0.3.31" lazy_static = "1.5.0" @@ -46,7 +47,6 @@ thiserror = "2.0.12" tiktoken-rs = "0.7.0" tokio = { version = "1.0", features = ["full"] } tokio-postgres = "0.7" -tracing-subscriber = "0.3" url = "2.2" utoipa = { version = "4", features = ["actix_extras", "chrono", "uuid"] } utoipa-swagger-ui = { version = "7", features = ["actix-web"] } diff --git a/server/src/app_state.rs b/server/src/app_state.rs new file mode 100644 index 00000000..ed7c7bd3 --- /dev/null +++ b/server/src/app_state.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::error; +use vectorize_core::config::Config; +use vectorize_core::types::VectorizeJob; +use vectorize_worker::WorkerHealth; + +use crate::cache; + +#[derive(Debug, thiserror::Error)] +pub enum AppStateError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + #[error("Connection timeout")] + Timeout, +} + +#[derive(Clone)] +pub struct AppState { + pub config: Config, + pub db_pool: sqlx::PgPool, + pub cache_pool: sqlx::PgPool, + /// in-memory cache of existing vectorize jobs and their metadata + pub job_cache: Arc>>, + /// worker health monitoring data + pub worker_health: Arc>, +} + +impl AppState { + pub async fn new(config: Config) -> Result> { + let db_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(config.database_pool_max) + .connect(&config.database_url) + .await?; + + let cache_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(config.database_cache_pool_max) + .connect(&config.database_url) + .await?; + + vectorize_core::init::init_project(&db_pool) + .await + .map_err(|e| format!("Failed to initialize project: {e}"))?; + + // load initial job cache + let job_cache = cache::load_initial_job_cache(&db_pool) + .await + .map_err(|e| format!("Failed to load initial job cache: {e}"))?; + let job_cache = Arc::new(RwLock::new(job_cache)); + + // listen for job change notifications + if let Err(e) = cache::setup_job_change_notifications(&db_pool).await { + tracing::warn!("Failed to setup job change notifications: {e}"); + } + Self::start_cache_sync_listener_task(&cache_pool, &job_cache).await; + + let worker_health = Arc::new(RwLock::new(WorkerHealth { + status: vectorize_worker::WorkerStatus::Starting, + last_heartbeat: std::time::SystemTime::now(), + jobs_processed: 0, + uptime: std::time::Duration::from_secs(0), + restart_count: 0, + last_error: None, + })); + + Ok(AppState { + config, + db_pool, + cache_pool, + job_cache, + worker_health, + }) + } + + async fn start_cache_sync_listener_task( + cache_pool: &sqlx::PgPool, + job_cache: &Arc>>, + ) { + let cache_pool_for_sync = cache_pool.clone(); + let jobmap_for_sync = job_cache.clone(); + + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + if let Err(e) = + cache::start_cache_sync_listener(cache_pool_for_sync, jobmap_for_sync).await + { + error!("Cache synchronization error: {e}"); + } + }); + } +} diff --git a/server/src/cache.rs b/server/src/cache.rs new file mode 100644 index 00000000..8e42f10e --- /dev/null +++ b/server/src/cache.rs @@ -0,0 +1,166 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{error, info}; +use vectorize_core::types::VectorizeJob; + +/// Cache sync functions for job change notifications +pub async fn setup_job_change_notifications( + pool: &sqlx::PgPool, +) -> Result<(), Box> { + let mut tx = pool.begin().await?; + + let create_notify_function = r#" + CREATE OR REPLACE FUNCTION vectorize.notify_job_change() + RETURNS TRIGGER AS $$ + BEGIN + IF TG_OP = 'DELETE' THEN + PERFORM pg_notify('vectorize_job_changes', + json_build_object( + 'operation', TG_OP, + 'job_name', OLD.job_name + )::text + ); + RETURN OLD; + ELSE + PERFORM pg_notify('vectorize_job_changes', + json_build_object( + 'operation', TG_OP, + 'job_name', NEW.job_name + )::text + ); + RETURN NEW; + END IF; + END; + $$ LANGUAGE plpgsql; + "#; + + sqlx::query("DROP TRIGGER IF EXISTS job_change_trigger ON vectorize.job;") + .execute(&mut *tx) + .await?; + + let create_trigger = r#" + CREATE TRIGGER job_change_trigger + AFTER INSERT OR UPDATE OR DELETE ON vectorize.job + FOR EACH ROW EXECUTE FUNCTION vectorize.notify_job_change(); + "#; + + sqlx::query(create_notify_function) + .execute(&mut *tx) + .await?; + sqlx::query(create_trigger).execute(&mut *tx).await?; + + tx.commit().await?; + info!("Database trigger for job changes setup successfully"); + Ok(()) +} + +pub async fn start_cache_sync_listener( + db_pool: sqlx::PgPool, + job_cache: Arc>>, +) -> Result<(), Box> { + let mut retry_delay = std::time::Duration::from_secs(1); + let max_retry_delay = std::time::Duration::from_secs(60); + + loop { + match try_listen_for_changes(&db_pool, &job_cache).await { + Ok(_) => retry_delay = std::time::Duration::from_secs(1), + Err(e) => { + error!("Cache sync listener error: {e}. Retrying in {retry_delay:?}"); + tokio::time::sleep(retry_delay).await; + retry_delay = std::cmp::min(retry_delay * 2, max_retry_delay); + } + } + } +} + +async fn try_listen_for_changes( + db_pool: &sqlx::PgPool, + job_cache: &Arc>>, +) -> Result<(), Box> { + let mut listener = sqlx::postgres::PgListener::connect_with(db_pool).await?; + listener.listen("vectorize_job_changes").await?; + + info!("Connected and listening for vectorize job changes"); + + loop { + match listener.recv().await { + Ok(notification) => { + info!( + "Received job change notification: {}", + notification.payload() + ); + + if let Ok(payload) = + serde_json::from_str::(notification.payload()) + { + let operation = payload.get("operation").and_then(|v| v.as_str()); + let job_name = payload.get("job_name").and_then(|v| v.as_str()); + info!( + "Job change detected - Operation: {}, Job: {}", + operation.unwrap_or("unknown"), + job_name.unwrap_or("unknown") + ); + } + + if let Err(e) = refresh_job_cache(db_pool, job_cache).await { + error!("Failed to refresh job cache: {e}"); + } else { + info!("Job cache refreshed successfully"); + } + } + Err(e) => { + error!("Error receiving notification: {e}"); + return Err(e.into()); + } + } + } +} + +pub async fn refresh_job_cache( + db_pool: &sqlx::PgPool, + job_cache: &Arc>>, +) -> Result<(), Box> { + let all_jobs: Vec = sqlx::query_as( + "SELECT job_name, src_table, src_schema, src_columns, primary_key, update_time_col, model FROM vectorize.job", + ) + .fetch_all(db_pool) + .await?; + + let jobmap: HashMap = all_jobs + .into_iter() + .map(|mut item| { + let key = std::mem::take(&mut item.job_name); + (key, item) + }) + .collect(); + + { + let mut jobmap_write = job_cache.write().await; + *jobmap_write = jobmap; + info!("Updated job cache with {} jobs", jobmap_write.len()); + } + + Ok(()) +} + +pub async fn load_initial_job_cache( + pool: &sqlx::PgPool, +) -> Result, crate::app_state::AppStateError> { + let all_jobs: Vec = sqlx::query_as( + "SELECT job_name, src_table, src_schema, src_columns, primary_key, update_time_col, model FROM vectorize.job", + ) + .fetch_all(pool) + .await + .map_err(crate::app_state::AppStateError::Database)?; + + let jobmap: HashMap = all_jobs + .into_iter() + .map(|mut item| { + let key = std::mem::take(&mut item.job_name); + (key, item) + }) + .collect(); + + Ok(jobmap) +} diff --git a/server/src/lib.rs b/server/src/lib.rs index 92cd0cb3..ed541e74 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -1,3 +1,5 @@ +pub mod app_state; +pub mod cache; pub mod errors; pub mod routes; pub mod server; diff --git a/server/src/main.rs b/server/src/main.rs index 0609ba55..6b9199fb 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,98 +1,65 @@ use actix_cors::Cors; use actix_web::{App, HttpServer, middleware, web}; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::net::ToSocketAddrs; -use std::sync::Arc; use std::time::Duration; -use tokio::net::TcpListener; -use tokio::sync::RwLock; -use tracing::{error, info, warn}; -use url::Url; +use tracing::error; use vectorize_core::config::Config; -use vectorize_core::init; -use vectorize_core::types::VectorizeJob; -use vectorize_proxy::{ - ProxyConfig, handle_connection_with_timeout, load_initial_job_cache, - setup_job_change_notifications, start_cache_sync_listener, -}; +use vectorize_proxy::start_postgres_proxy; +use vectorize_server::app_state::AppState; use vectorize_worker::{WorkerHealthMonitor, start_vectorize_worker_with_monitoring}; #[actix_web::main] async fn main() { - // Initialize tracing subscriber (simple default formatter) tracing_subscriber::fmt().with_target(false).init(); let cfg = Config::from_env(); - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(cfg.database_pool_max) - .connect(&cfg.database_url) - .await - .expect("unable to connect to postgres"); - // Create a separate connection pool for cache refresher - let cache_pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(cfg.database_cache_pool_max) - .connect(&cfg.database_url) - .await - .expect("unable to connect to postgres for cache refresher"); - let server_port = cfg.webserver_port; - let server_workers = cfg.num_server_workers; - init::init_project(&pool) + let app_state = AppState::new(cfg) .await - .expect("Failed to initialize project"); + .expect("Failed to initialize application state"); - // Load initial job cache and setup job change notifications - let jobcache = load_initial_job_cache(&pool) - .await - .expect("Failed to load initial job cache"); - let jobcache = Arc::new(RwLock::new(jobcache)); + // start the PostgreSQL proxy if enabled + if app_state.config.proxy_enabled { + let proxy_port = app_state.config.vectorize_proxy_port; + let database_url = app_state.config.database_url.clone(); + let job_cache = app_state.job_cache.clone(); + let db_pool = app_state.db_pool.clone(); - if let Err(e) = setup_job_change_notifications(&pool).await { - warn!("Failed to setup job change notifications: {e}"); - } - - // Start the PostgreSQL proxy if enabled - if cfg.proxy_enabled { - let proxy_pool = pool.clone(); - let proxy_cfg = cfg.clone(); - let proxy_jobcache = Arc::clone(&jobcache); - let proxy_cache_pool = cache_pool.clone(); tokio::spawn(async move { - if let Err(e) = - start_postgres_proxy(proxy_cfg, proxy_pool, proxy_jobcache, proxy_cache_pool).await + if let Err(e) = start_postgres_proxy(proxy_port, database_url, job_cache, db_pool).await { error!("Failed to start PostgreSQL proxy: {e}"); } }); } - // Start the vectorize worker with health monitoring - let worker_pool = pool.clone(); - let worker_cfg = cfg.clone(); + // start the vectorize worker with health monitoring + let worker_state = app_state.clone(); let worker_health_monitor = WorkerHealthMonitor::new(); - let worker_health_for_routes = worker_health_monitor.get_arc_clone(); tokio::spawn(async move { - if let Err(e) = - start_vectorize_worker_with_monitoring(worker_cfg, worker_pool, worker_health_monitor) - .await + if let Err(e) = start_vectorize_worker_with_monitoring( + worker_state.config.clone(), + worker_state.db_pool.clone(), + worker_health_monitor, + ) + .await { error!("Failed to start vectorize worker: {e}"); } }); + // store values before moving app_state + let server_workers = app_state.config.num_server_workers; + let server_port = app_state.config.webserver_port; + let _ = HttpServer::new(move || { let cors = Cors::permissive(); App::new() .wrap(cors) .wrap(middleware::Logger::default()) - .app_data(web::Data::new(cfg.clone())) - .app_data(web::Data::new(pool.clone())) - .app_data(web::Data::new(worker_health_for_routes.clone())) - .app_data(web::Data::new(jobcache.clone())) + .app_data(web::Data::new(app_state.clone())) .configure(vectorize_server::server::route_config) .configure(vectorize_server::routes::health::configure_health_routes) }) @@ -103,73 +70,3 @@ async fn main() { .run() .await; } - -async fn start_postgres_proxy( - cfg: Config, - pool: sqlx::PgPool, - jobmap: Arc>>, - cache_pool: sqlx::PgPool, -) -> Result<(), Box> { - let bind_address = "0.0.0.0"; - let timeout = 30; - - let listen_addr: SocketAddr = - format!("{}:{}", bind_address, cfg.vectorize_proxy_port).parse()?; - - let url = Url::parse(&cfg.database_url)?; - let postgres_host = url.host_str().unwrap(); - let postgres_port = url.port().unwrap(); - - let postgres_addr: SocketAddr = format!("{postgres_host}:{postgres_port}") - .to_socket_addrs()? - .next() - .ok_or("Failed to resolve PostgreSQL host address")?; - - let config = Arc::new(ProxyConfig { - postgres_addr, - timeout: Duration::from_secs(timeout), - jobmap: Arc::clone(&jobmap), - db_pool: pool.clone(), - prepared_statements: Arc::new(RwLock::new(HashMap::new())), - }); - - info!("Proxy listening on: {listen_addr}"); - info!("Forwarding to PostgreSQL at: {postgres_addr}"); - - // Start cache sync listener with its own connection pool - let cache_pool_for_sync = cache_pool.clone(); - let jobmap_for_sync = Arc::clone(&jobmap); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(1)).await; - let sync_config = Arc::new(ProxyConfig { - postgres_addr, - timeout: Duration::from_secs(timeout), - jobmap: jobmap_for_sync, - db_pool: cache_pool_for_sync, - prepared_statements: Arc::new(RwLock::new(HashMap::new())), - }); - if let Err(e) = start_cache_sync_listener(sync_config).await { - error!("Cache synchronization error: {e}"); - } - }); - - let listener = TcpListener::bind(listen_addr).await?; - - loop { - match listener.accept().await { - Ok((client_stream, client_addr)) => { - info!("New proxy connection from: {client_addr}"); - - let config = Arc::clone(&config); - tokio::spawn(async move { - if let Err(e) = handle_connection_with_timeout(client_stream, config).await { - error!("Proxy connection error from {client_addr}: {e}"); - } - }); - } - Err(e) => { - error!("Failed to accept proxy connection: {e}"); - } - } - } -} diff --git a/server/src/routes/health.rs b/server/src/routes/health.rs index 2fb0cf1e..09204582 100644 --- a/server/src/routes/health.rs +++ b/server/src/routes/health.rs @@ -1,15 +1,10 @@ +use crate::app_state::AppState; use actix_web::{HttpResponse, Result, web}; use serde_json::json; -use std::sync::Arc; use std::time::SystemTime; -use tokio::sync::RwLock; -use vectorize_worker::WorkerHealth; - -pub async fn health_check( - worker_health: web::Data>>, -) -> Result { - let health = worker_health.read().await; +pub async fn health_check(app_state: web::Data) -> Result { + let health = app_state.worker_health.read().await; let is_healthy = match &health.status { vectorize_worker::WorkerStatus::Healthy => true, vectorize_worker::WorkerStatus::Starting => { @@ -59,10 +54,8 @@ pub async fn liveness_check() -> Result { }))) } -pub async fn readiness_check( - worker_health: web::Data>>, -) -> Result { - let health = worker_health.read().await; +pub async fn readiness_check(app_state: web::Data) -> Result { + let health = app_state.worker_health.read().await; let is_ready = matches!(health.status, vectorize_worker::WorkerStatus::Healthy); let response = json!({ diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 18c9ab4e..7b880c6d 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -1,11 +1,10 @@ +use crate::app_state::AppState; use crate::errors::ServerError; use actix_web::{HttpResponse, get, web}; use serde::{Deserialize, Serialize}; -use sqlx::{PgPool, Row, prelude::FromRow}; -use std::collections::{BTreeMap, HashMap}; +use sqlx::{Row, prelude::FromRow}; +use std::collections::BTreeMap; -use std::sync::Arc; -use tokio::sync::RwLock; use utoipa::ToSchema; use uuid::Uuid; use vectorize_core::query::{self, FilterValue}; @@ -77,8 +76,7 @@ pub struct SearchResponse { )] #[get("/search")] pub async fn search( - pool: web::Data, - jobmap: web::Data>>>, + app_state: web::Data, payload: web::Query, ) -> Result { let payload = payload.into_inner(); @@ -96,7 +94,7 @@ pub async fn search( // Try to get job info from cache first, fallback to database with write-through on miss let vectorizejob = { if let Some(job_info) = { - let job_cache = jobmap.read().await; + let job_cache = app_state.job_cache.read().await; job_cache.get(&payload.job_name).cloned() } { job_info @@ -105,8 +103,8 @@ pub async fn search( "Job not found in cache, querying database for job: {}", payload.job_name ); - let job = get_vectorize_job(&pool, &payload.job_name).await?; - let mut job_cache = jobmap.write().await; + let job = get_vectorize_job(&app_state.db_pool, &payload.job_name).await?; + let mut job_cache = app_state.job_cache.write().await; job_cache.insert(payload.job_name.clone(), job.clone()); job } @@ -156,7 +154,7 @@ pub async fn search( }; } - let results = prepared_query.fetch_all(&**pool).await?; + let results = prepared_query.fetch_all(&app_state.db_pool).await?; let json_results: Vec = results .iter() @@ -166,7 +164,10 @@ pub async fn search( Ok(HttpResponse::Ok().json(json_results)) } -async fn get_vectorize_job(pool: &PgPool, job_name: &str) -> Result { +async fn get_vectorize_job( + pool: &sqlx::PgPool, + job_name: &str, +) -> Result { // Changed return type match sqlx::query( "SELECT job_name, src_table, src_schema, src_columns, primary_key, update_time_col, model diff --git a/server/src/routes/table.rs b/server/src/routes/table.rs index e5108ee9..773b214f 100644 --- a/server/src/routes/table.rs +++ b/server/src/routes/table.rs @@ -1,10 +1,7 @@ +use crate::app_state::AppState; use crate::errors::ServerError; use actix_web::{HttpResponse, post, web}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; use utoipa::ToSchema; use uuid::Uuid; use vectorize_core::init::{self, get_column_datatype}; @@ -27,15 +24,14 @@ pub struct JobResponse { )] #[post("/table")] pub async fn table( - dbclient: web::Data, - jobmap: web::Data>>>, + app_state: web::Data, payload: web::Json, ) -> Result { let payload = payload.into_inner(); // validate update_time_col is timestamptz let datatype = get_column_datatype( - &dbclient, + &app_state.db_pool, &payload.src_schema, &payload.src_table, &payload.update_time_col, @@ -52,11 +48,11 @@ pub async fn table( ))); } - let job_id = init::initialize_job(&dbclient, &payload).await?; + let job_id = init::initialize_job(&app_state.db_pool, &payload).await?; // Update the job cache with the new job information { - let mut job_cache = jobmap.write().await; + let mut job_cache = app_state.job_cache.write().await; job_cache.insert(payload.job_name.clone(), payload.clone()); }