Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 8 additions & 34 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ serde = "1.0.219"
serde_json = "1.0"
sqlparser = "0.51"
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "uuid", "time"] }
tracing = "0.1"
tracing-log = "0.1"
tracing-subscriber = "0.3.20"
thiserror = "2.0.12"
tiktoken-rs = "0.7.0"
tokio = { version = "1.0", features = ["full"] }
Expand Down
29 changes: 27 additions & 2 deletions extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use vectorize_core::types::{JobParams, Model};

use anyhow::Result;
use pgrx::prelude::*;
use std::collections::BTreeMap;
use vectorize_core::query::{FilterValue, FilterValueType};

#[pg_extern]
fn chunk_table(
Expand Down Expand Up @@ -122,13 +124,25 @@ fn search(
num_results: default!(i32, 10),
where_sql: default!(Option<String>, "NULL"),
) -> Result<TableIterator<'static, (name!(search_results, pgrx::JsonB),)>> {
let filters = where_sql
.map(|s| {
BTreeMap::from_iter(vec![(
"where_clause".to_string(),
FilterValue {
operator: vectorize_core::query::FilterOperator::Equal,
value: FilterValueType::String(s),
},
)])
})
.unwrap_or_default();

let search_results = search::search(
&job_name,
&query,
api_key,
return_columns,
num_results,
where_sql,
&filters,
)?;
Ok(TableIterator::new(search_results.into_iter().map(|r| (r,))))
}
Expand All @@ -145,13 +159,24 @@ fn hybrid_search(
num_results: default!(i32, 10),
where_sql: default!(Option<String>, "NULL"),
) -> Result<TableIterator<'static, (name!(search_results, pgrx::JsonB),)>> {
let parsed_filters = where_sql
.map(|s| {
BTreeMap::from_iter(vec![(
"where_clause".to_string(),
FilterValue {
operator: vectorize_core::query::FilterOperator::Equal,
value: FilterValueType::String(s),
},
)])
})
.unwrap_or_default();
let search_results = search::hybrid_search(
&job_name,
&query,
api_key,
return_columns,
num_results,
where_sql,
&parsed_filters,
)?;
Ok(TableIterator::new(search_results.into_iter().map(|r| (r,))))
}
Expand Down
10 changes: 9 additions & 1 deletion extension/src/chat/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::util::get_vectorize_meta_spi;
use anyhow::{anyhow, Result};
use handlebars::Handlebars;
use pgrx::prelude::*;
use std::collections::BTreeMap;
use vectorize_core::guc::ModelGucConfig;
use vectorize_core::transformers::providers::ollama::OllamaProvider;
use vectorize_core::transformers::providers::openai::OpenAIProvider;
Expand Down Expand Up @@ -57,7 +58,14 @@ pub fn call_chat(
let pk = job_params.primary_key;
let columns = vec![pk.clone(), content_column.clone()];

let raw_search = search::search(job_name, query, api_key.clone(), columns, num_context, None)?;
let raw_search = search::search(
job_name,
query,
api_key.clone(),
columns,
num_context,
&BTreeMap::new(),
)?;

let mut search_results: Vec<ContextualSearch> = Vec::new();
for s in raw_search {
Expand Down
2 changes: 1 addition & 1 deletion extension/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn batch_texts(
return TableIterator::new(vec![record_ids].into_iter().map(|arr| (arr,)));
}

let num_batches = (total_records + batch_size - 1) / batch_size;
let num_batches = total_records.div_ceil(batch_size);

let mut batches = Vec::with_capacity(num_batches);

Expand Down
1 change: 0 additions & 1 deletion extension/src/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ pub fn get_guc(guc: VectorizeGuc) -> Option<String> {
}
}

#[allow(dead_code)]
fn handle_cstr(cstr: &CStr) -> Result<String> {
if let Ok(s) = cstr.to_str() {
Ok(s.to_owned())
Expand Down
40 changes: 24 additions & 16 deletions extension/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ use anyhow::{Context, Result};
use pgrx::prelude::*;
use pgrx::JsonB;
use serde_json::Value;
use std::collections::BTreeMap;
use std::collections::HashMap;
use vectorize_core::guc::VectorizeGuc;
use vectorize_core::query;
use vectorize_core::query::{create_event_trigger, create_trigger_handler};
use vectorize_core::query::{create_event_trigger, create_trigger_handler, FilterValue};
use vectorize_core::transformers::providers::get_provider;
use vectorize_core::transformers::providers::ollama::check_model_host;
use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta};
Expand Down Expand Up @@ -282,7 +283,7 @@ pub fn hybrid_search(
api_key: Option<String>,
return_columns: Vec<String>,
num_results: i32,
where_clause: Option<String>,
filters: &BTreeMap<String, FilterValue>,
) -> Result<Vec<JsonB>> {
let semantic_weight: i32 = guc::SEMANTIC_WEIGHT.get();

Expand All @@ -296,7 +297,7 @@ pub fn hybrid_search(
api_key,
return_columns,
num_results * 2,
where_clause,
filters,
)?;

// Use a HashMap with serde_json::Value as the key
Expand Down Expand Up @@ -374,7 +375,7 @@ pub fn search(
api_key: Option<String>,
return_columns: Vec<String>,
num_results: i32,
where_clause: Option<String>,
filters: &BTreeMap<String, FilterValue>,
) -> Result<Vec<JsonB>> {
let project_meta: VectorizeMeta = util::get_vectorize_meta_spi(job_name)?;
let proj_params: types::JobParams = serde_json::from_value(
Expand Down Expand Up @@ -402,7 +403,7 @@ pub fn search(
&return_columns,
num_results,
&embeddings[0],
where_clause,
filters,
)
}
}
Expand All @@ -414,7 +415,7 @@ pub fn cosine_similarity_search(
return_columns: &[String],
num_results: i32,
embeddings: &[f64],
where_clause: Option<String>,
filters: &BTreeMap<String, FilterValue>,
) -> Result<Vec<JsonB>> {
let schema = job_params.schema.clone();
let table = job_params.relation.clone();
Expand All @@ -427,7 +428,7 @@ pub fn cosine_similarity_search(
&table,
return_columns,
num_results,
where_clause,
filters,
),
TableMethod::join => query::join_table_cosine_similarity(
project,
Expand All @@ -436,11 +437,14 @@ pub fn cosine_similarity_search(
&job_params.primary_key,
return_columns,
num_results,
where_clause,
filters,
),
};
Spi::connect(|client| {
let mut results: Vec<JsonB> = Vec::new();

// For now, we'll use the original approach with embeddings only
// TODO: Implement proper filter value binding
let tup_table = client.select(&query, None, &[embeddings.into()])?;
for row in tup_table {
match row["results"].value()? {
Expand All @@ -458,13 +462,18 @@ fn single_table_cosine_similarity(
table: &str,
return_columns: &[String],
num_results: i32,
where_clause: Option<String>,
filters: &BTreeMap<String, FilterValue>,
) -> String {
let where_str = if let Some(w) = where_clause {
format!("AND {}", w)
} else {
"".to_string()
};
let mut bind_value_counter: i16 = 2; // Start at $2 since $1 is the vector
let mut where_filter = format!("WHERE {project}_updated_at is NOT NULL");

for (column, filter_value) in filters.iter() {
let operator = filter_value.operator.to_sql();
let filt = format!(" AND \"{column}\" {operator} ${bind_value_counter}");
where_filter.push_str(&filt);
bind_value_counter += 1;
}

format!(
"
SELECT to_jsonb(t) as results
Expand All @@ -473,8 +482,7 @@ fn single_table_cosine_similarity(
1 - ({project}_embeddings <=> $1::vector) AS similarity_score,
{cols}
FROM {schema}.{table}
WHERE {project}_updated_at is NOT NULL
{where_str}
{where_filter}
ORDER BY similarity_score DESC
LIMIT {num_results}
) t
Expand Down
2 changes: 0 additions & 2 deletions extension/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ pub mod common {
use sqlx::{Pool, Postgres, Row};
use url::{ParseError, Url};

#[allow(dead_code)]
#[derive(FromRow, Debug, serde::Deserialize)]
pub struct SearchResult {
pub product_id: i32,
Expand All @@ -16,7 +15,6 @@ pub mod common {
pub similarity_score: f64,
}

#[allow(dead_code)]
#[derive(FromRow, Debug, Serialize)]
pub struct SearchJSON {
pub search_results: serde_json::Value,
Expand Down
3 changes: 3 additions & 0 deletions proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ serde_json = { workspace = true }
sqlx = { workspace = true}
thiserror = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
url = { workspace = true }

pgwire = { version = "0.30", features = ["server-api-aws-lc-rs"] }
Loading
Loading