diff --git a/mindsdb/integrations/handlers/valkey_handler/README.md b/mindsdb/integrations/handlers/valkey_handler/README.md new file mode 100644 index 00000000000..1a17b7b89a6 --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/README.md @@ -0,0 +1,143 @@ +# Valkey Handler + +This is the implementation of the Valkey handler for MindsDB, providing vector store capabilities using the [Valkey Search](https://valkey.io/) module. + +## Prerequisites + +- **Valkey Server** 9.0+ with the Search module enabled (e.g., `valkey/valkey-bundle` Docker image) +- **Python packages**: `valkey-glide>=2.4.0`, `numpy>=1.21.0` + +## Installation + +Install the handler dependencies: + +```bash +pip install valkey-glide numpy +``` + +Or install via MindsDB extras: + +```bash +pip install mindsdb[valkey] +``` + +## Connection + +Create a Valkey vector store connection in MindsDB: + +```sql +CREATE DATABASE my_valkey +WITH ENGINE = 'valkey', +PARAMETERS = { + "host": "localhost", -- Valkey server hostname (default: localhost) + "port": 6379, -- Valkey server port (default: 6379) + "password": "", -- Authentication password (optional) + "db": 0, -- Database number 0-15 (default: 0) + "vector_dimension": 384, -- Default embedding dimension (default: 384) + "distance_metric": "COSINE", -- COSINE, L2, or IP (default: COSINE) + "prefix": "doc:" -- Key prefix for document hashes (default: "doc:") +}; +``` + +## Usage + +### Create a Table (Vector Index) + +```sql +CREATE TABLE my_valkey.my_collection +(SELECT * FROM my_model + WHERE content = 'sample text'); +``` + +Or use the knowledge base pattern: + +```sql +CREATE KNOWLEDGE BASE my_kb +USING + VECTOR STORE = my_valkey, + MODEL = my_embedding_model; +``` + +### Insert Data + +```sql +INSERT INTO my_valkey.my_collection (id, content, embeddings, metadata) +VALUES ('doc1', 'Hello world', '[0.1, 0.2, ...]', '{"source": "web"}'); +``` + +### Vector Similarity Search (KNN) + +```sql +SELECT id, content, distance +FROM my_valkey.my_collection +WHERE search_vector = (SELECT embeddings FROM my_model WHERE content = 'query text') +LIMIT 5; +``` + +### Select by ID + +```sql +SELECT id, content, metadata +FROM my_valkey.my_collection +WHERE id = 'doc1'; +``` + +### Delete Documents + +```sql +DELETE FROM my_valkey.my_collection +WHERE id = 'doc1'; + +DELETE FROM my_valkey.my_collection +WHERE id IN ('doc1', 'doc2', 'doc3'); +``` + +### Drop Table (Index) + +```sql +DROP TABLE my_valkey.my_collection; +``` + +## Configuration Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `host` | str | `localhost` | Valkey server hostname | +| `port` | int | `6379` | Valkey server port | +| `password` | str | `None` | Authentication password | +| `db` | int | `0` | Database number (0-15) | +| `vector_dimension` | int | `384` | Default embedding dimension for new indexes | +| `distance_metric` | str | `COSINE` | Distance metric: `COSINE`, `L2`, or `IP` | +| `prefix` | str | `doc:` | Key prefix for document hash keys | + +## Distance Metrics + +| Metric | Description | Use Case | +|--------|-------------|----------| +| `COSINE` | Cosine similarity (1 - cos_sim) | Text embeddings, normalized vectors | +| `L2` | Euclidean distance | Image embeddings, spatial data | +| `IP` | Inner product (negative dot product) | Pre-normalized vectors, recommendation | + +## Architecture + +The handler uses: +- **valkey-glide** (async client) with a synchronous wrapper (`asyncio.run_until_complete`) +- **HASH-based storage**: Documents are stored as Redis hashes with the key pattern `{prefix}{table}:{id}` +- **FT.CREATE**: Creates HNSW vector indexes with configurable dimensions and distance metrics +- **FT.SEARCH**: Executes KNN vector similarity queries with optional pre-filtering + +## Running Tests + +```bash +# Unit tests (no Valkey required) +pytest mindsdb/integrations/handlers/valkey_handler/tests/test_valkey_handler.py -v -k "Unit" + +# Integration tests (requires running Valkey with Search module) +VALKEY_HOST=localhost VALKEY_PORT=6379 pytest mindsdb/integrations/handlers/valkey_handler/tests/test_valkey_handler.py -v -k "Integration" +``` + +### Running Valkey for Tests + +```bash +docker run -d --name valkey-test -p 6379:6379 valkey/valkey-bundle:9.1 +``` diff --git a/mindsdb/integrations/handlers/valkey_handler/__about__.py b/mindsdb/integrations/handlers/valkey_handler/__about__.py new file mode 100644 index 00000000000..85e1ecf5d2e --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/__about__.py @@ -0,0 +1,9 @@ +__title__ = "MindsDB Valkey handler" +__package_name__ = "mindsdb_valkey_handler" +__version__ = "0.0.1" +__description__ = "MindsDB handler for Valkey Vector Store (via valkey-glide)" +__author__ = "Daria Korenieva" +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2026 - mindsdb" diff --git a/mindsdb/integrations/handlers/valkey_handler/__init__.py b/mindsdb/integrations/handlers/valkey_handler/__init__.py new file mode 100644 index 00000000000..d2d6227d161 --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/__init__.py @@ -0,0 +1,33 @@ +from mindsdb.integrations.libs.const import HANDLER_SUPPORT_LEVEL, HANDLER_TYPE + +from .__about__ import __description__ as description +from .__about__ import __version__ as version +from .connection_args import connection_args, connection_args_example + +try: + from .valkey_handler import ValkeyHandler as Handler + + import_error = None +except Exception as e: + Handler = None + import_error = e + +title = "Valkey" +name = "valkey" +type = HANDLER_TYPE.DATA +support_level = HANDLER_SUPPORT_LEVEL.COMMUNITY +icon_path = "icon.svg" + +__all__ = [ + "Handler", + "version", + "name", + "type", + "title", + "description", + "support_level", + "connection_args", + "connection_args_example", + "import_error", + "icon_path", +] diff --git a/mindsdb/integrations/handlers/valkey_handler/connection_args.py b/mindsdb/integrations/handlers/valkey_handler/connection_args.py new file mode 100644 index 00000000000..d995b46cfa3 --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/connection_args.py @@ -0,0 +1,64 @@ +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + +connection_args = OrderedDict( + host={ + "type": ARG_TYPE.STR, + "description": "Valkey server hostname", + "required": False, + }, + port={ + "type": ARG_TYPE.INT, + "description": "Valkey server port", + "required": False, + }, + password={ + "type": ARG_TYPE.PWD, + "description": "Valkey authentication password", + "required": False, + "secret": True, + }, + db={ + "type": ARG_TYPE.INT, + "description": "Valkey database number (0-15)", + "required": False, + }, + vector_dimension={ + "type": ARG_TYPE.INT, + "description": "Default vector dimension for new indexes", + "required": False, + }, + distance_metric={ + "type": ARG_TYPE.STR, + "description": "Distance metric: COSINE, L2, or IP", + "required": False, + }, + index_algorithm={ + "type": ARG_TYPE.STR, + "description": "Vector index algorithm: HNSW (default) or FLAT", + "required": False, + }, + prefix={ + "type": ARG_TYPE.STR, + "description": "Key prefix for document hashes", + "required": False, + }, + use_tls={ + "type": ARG_TYPE.BOOL, + "description": "Enable TLS/SSL connection (required for AWS ElastiCache, MemoryDB)", + "required": False, + }, + request_timeout={ + "type": ARG_TYPE.INT, + "description": "Request timeout in milliseconds (default: 5000)", + "required": False, + }, +) + +connection_args_example = OrderedDict( + host="localhost", + port=6379, + vector_dimension=384, + distance_metric="COSINE", +) diff --git a/mindsdb/integrations/handlers/valkey_handler/icon.svg b/mindsdb/integrations/handlers/valkey_handler/icon.svg new file mode 100644 index 00000000000..81d8d3b7806 --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mindsdb/integrations/handlers/valkey_handler/requirements.txt b/mindsdb/integrations/handlers/valkey_handler/requirements.txt new file mode 100644 index 00000000000..88e74013f2b --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/requirements.txt @@ -0,0 +1,2 @@ +valkey-glide>=2.4.0 +numpy>=1.21.0,<3 diff --git a/mindsdb/integrations/handlers/valkey_handler/tests/__init__.py b/mindsdb/integrations/handlers/valkey_handler/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindsdb/integrations/handlers/valkey_handler/tests/test_valkey_handler.py b/mindsdb/integrations/handlers/valkey_handler/tests/test_valkey_handler.py new file mode 100644 index 00000000000..bdb165ff887 --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/tests/test_valkey_handler.py @@ -0,0 +1,946 @@ +""" +Unit and integration tests for the Valkey vector store handler. + +Unit tests: Always run (no external dependencies, uses mocks). +Integration tests: Require a running Valkey instance with Search module. + - Set VALKEY_HOST and VALKEY_PORT environment variables to configure. + - Tests skip gracefully if Valkey is unavailable. +""" + +from __future__ import annotations + +import os +import struct +import time +import uuid +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + +from mindsdb.integrations.handlers.valkey_handler.valkey_handler import ( + DEFAULT_DB, + DEFAULT_DISTANCE_METRIC, + DEFAULT_HOST, + DEFAULT_PORT, + DEFAULT_PREFIX, + DEFAULT_VECTOR_DIMENSION, + VECTOR_FIELD_NAME, + ValkeyHandler, +) +from mindsdb.integrations.libs.response import RESPONSE_TYPE +from mindsdb.integrations.libs.vectordatabase_handler import ( + FilterCondition, + FilterOperator, +) + +from glide import RequestError + + +def _wait_for_indexing(handler, table_name: str, expected_count: int, timeout: float = 5.0): + """Poll FT.INFO until the index reports the expected number of documents. + + Falls back to a short sleep if FT.INFO is unavailable or returns + unexpected data. + """ + from glide import ft as _ft + + deadline = time.time() + timeout + while time.time() < deadline: + try: + info = handler._run(_ft.info(handler._client, table_name)) + # info is a mapping; num_docs may be bytes or str + num_docs = info.get(b"num_docs") or info.get("num_docs") + if num_docs is not None: + count = int(num_docs.decode() if isinstance(num_docs, bytes) else num_docs) + if count >= expected_count: + return + except Exception: + pass + time.sleep(0.1) + # Final fallback — give indexing a moment + time.sleep(0.3) + + +# ============================================================================= +# Unit Tests +# ============================================================================= + + +class TestValkeyHandlerUnit: + """Unit tests that do not require a running Valkey instance.""" + + def _make_handler(self, connection_data=None): + """Create a handler instance with mock storage.""" + return ValkeyHandler( + "test", + connection_data=connection_data or {}, + handler_storage=MagicMock(), + ) + + def test_validate_connection_defaults(self): + """Handler uses correct defaults when no connection_data provided.""" + h = self._make_handler() + assert h._host == DEFAULT_HOST + assert h._port == DEFAULT_PORT + assert h._db == DEFAULT_DB + assert h._password is None + assert h._vector_dimension == DEFAULT_VECTOR_DIMENSION + assert h._distance_metric == DEFAULT_DISTANCE_METRIC + assert h._prefix == DEFAULT_PREFIX + + def test_validate_connection_custom(self): + """Handler correctly stores custom connection parameters.""" + h = self._make_handler( + { + "host": "valkey.io", + "port": 6380, + "password": "secret", + "db": 2, + "vector_dimension": 768, + "distance_metric": "L2", + "prefix": "vec:", + } + ) + assert h._host == "valkey.io" + assert h._port == 6380 + assert h._password == "secret" + assert h._db == 2 + assert h._vector_dimension == 768 + assert h._distance_metric == "L2" + assert h._prefix == "vec:" + + def test_parse_doc_fields_basic(self): + """Correctly parses raw Valkey hash fields into Python types.""" + h = self._make_handler() + fields = { + b"id": b"doc1", + b"content": b"hello world", + b"embeddings": struct.pack("4f", 0.1, 0.2, 0.3, 0.4), + b"metadata": b'{"source": "web"}', + } + row = h._parse_doc_fields(fields, include_score=False) + assert row["id"] == "doc1" + assert row["content"] == "hello world" + assert len(row["embeddings"]) == 4 + assert abs(row["embeddings"][0] - 0.1) < 1e-6 + assert row["metadata"] == {"source": "web"} + assert "distance" not in row + + def test_parse_doc_fields_with_score(self): + """Correctly extracts distance score when present.""" + h = self._make_handler() + fields = { + b"id": b"doc1", + b"content": b"hello", + b"embeddings": struct.pack("4f", 0.1, 0.2, 0.3, 0.4), + b"metadata": b"{}", + b"__embeddings_score": b"0.123", + } + row = h._parse_doc_fields(fields, include_score=True) + assert abs(row["distance"] - 0.123) < 1e-6 + + def test_parse_doc_fields_empty_metadata(self): + """Handles empty or invalid metadata gracefully.""" + h = self._make_handler() + fields = { + b"id": b"doc1", + b"content": b"", + b"embeddings": b"", + b"metadata": b"not-json", + } + row = h._parse_doc_fields(fields, include_score=False) + assert row["metadata"] == {} + assert row["embeddings"] == [] + + def test_build_filter_expression_empty(self): + """Returns '*' when no filters provided.""" + h = self._make_handler() + expr = h._build_filter_expression([], []) + assert expr == "*" + + def test_build_filter_expression_id_equal(self): + """Builds correct TAG filter for single ID equality.""" + h = self._make_handler() + cond = FilterCondition("id", FilterOperator.EQUAL, "doc1") + expr = h._build_filter_expression([cond], []) + assert "@id:{doc1}" in expr + + def test_build_filter_expression_id_in(self): + """Builds correct TAG filter for ID IN list.""" + h = self._make_handler() + cond = FilterCondition("id", FilterOperator.IN, ["d1", "d2"]) + expr = h._build_filter_expression([cond], []) + assert "@id:{d1|d2}" in expr + + def test_build_filter_expression_id_not_equal(self): + """Builds correct negation filter for ID != value.""" + h = self._make_handler() + cond = FilterCondition("id", FilterOperator.NOT_EQUAL, "doc1") + expr = h._build_filter_expression([cond], []) + assert "-@id:{doc1}" in expr + + def test_escape_tag_special_chars(self): + """Escapes special characters in TAG values.""" + h = self._make_handler() + result = h._escape_tag("hello world!@#") + assert "\\ " in result or "\\!" in result + # Spaces and special chars should be escaped + assert "hello" in result + assert "world" in result + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_get_tables_returns_dataframe(self, mock_ft): + """get_tables returns a DataFrame with table_name column.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + async def mock_list(client): + return [b"idx1", b"idx2"] + + mock_ft.list = mock_list + + result = h.get_tables() + assert result.resp_type == RESPONSE_TYPE.TABLE + assert list(result.data_frame["table_name"]) == ["idx1", "idx2"] + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_select_builds_knn_query(self, mock_ft): + """select with search_vector builds correct KNN query string.""" + h = self._make_handler({"vector_dimension": 4}) + h.is_connected = True + h._client = MagicMock() + + search_result = [ + 1, + { + b"doc:t:d1": { + b"id": b"d1", + b"content": b"hi", + b"embeddings": struct.pack("4f", 0.1, 0.2, 0.3, 0.4), + b"metadata": b"{}", + b"__embeddings_score": b"0.05", + } + }, + ] + + captured_args = {} + + async def mock_search(client, index_name, query, options=None): + captured_args["query"] = query + captured_args["options"] = options + return search_result + + mock_ft.search = mock_search + + cond = FilterCondition("search_vector", FilterOperator.EQUAL, [0.1, 0.2, 0.3, 0.4]) + result = h.select("t", columns=["id", "content", "distance"], conditions=[cond], limit=5) + + assert "KNN 5" in captured_args["query"] + assert f"@{VECTOR_FIELD_NAME} $query_vec" in captured_args["query"] + assert len(result) == 1 + assert result.iloc[0]["id"] == "d1" + + def test_insert_serializes_embeddings(self): + """insert correctly serializes embeddings to float32 bytes.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + captured_calls = [] + + async def mock_hset(key, field_map): + captured_calls.append((key, field_map)) + return 1 + + h._client.hset = mock_hset + + df = pd.DataFrame( + { + "id": ["doc1"], + "content": ["hello world"], + "embeddings": [[0.1, 0.2, 0.3, 0.4]], + "metadata": [{"source": "web"}], + } + ) + h.insert("table", df) + + assert len(captured_calls) == 1 + key, field_map = captured_calls[0] + assert key == "doc:table:doc1" + expected_bytes = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32).tobytes() + assert field_map["embeddings"] == expected_bytes + assert field_map["content"] == "hello world" + assert field_map["metadata"] == '{"source": "web"}' + + def test_delete_by_id(self): + """delete correctly builds keys and calls unlink.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + captured_keys = [] + + async def mock_unlink(keys): + captured_keys.extend(keys) + return len(keys) + + h._client.unlink = mock_unlink + + conditions = [FilterCondition("id", FilterOperator.IN, ["d1", "d2"])] + h.delete("t", conditions) + + assert "doc:t:d1" in captured_keys + assert "doc:t:d2" in captured_keys + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_drop_table_calls_dropindex(self, mock_ft): + """drop_table calls ft.dropindex and cleans up keys.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + dropindex_called = [] + + async def mock_dropindex(client, index_name): + dropindex_called.append(index_name) + return "OK" + + mock_ft.dropindex = mock_dropindex + + # Mock scan to return no keys (empty cleanup) + async def mock_scan(cursor, match=None, count=None): + return [b"0", []] + + h._client.scan = mock_scan + + h.drop_table("my_index") + assert "my_index" in dropindex_called + + def test_build_filter_expression_id_not_in(self): + """Builds correct negation filter for ID NOT_IN list.""" + h = self._make_handler() + cond = FilterCondition("id", FilterOperator.NOT_IN, ["d1", "d2"]) + expr = h._build_filter_expression([cond], []) + assert "-@id:{d1|d2}" in expr + + def test_build_filter_expression_metadata_equal(self): + """Builds correct phrase search for metadata EQUAL condition.""" + h = self._make_handler() + cond = FilterCondition("metadata.source", FilterOperator.EQUAL, "web") + expr = h._build_filter_expression([], [cond]) + assert '@metadata:("web")' in expr + + def test_build_filter_expression_metadata_not_equal(self): + """Builds correct negation phrase search for metadata NOT_EQUAL.""" + h = self._make_handler() + cond = FilterCondition("metadata.source", FilterOperator.NOT_EQUAL, "web") + expr = h._build_filter_expression([], [cond]) + assert '-@metadata:("web")' in expr + + def test_build_filter_expression_metadata_escapes_quotes(self): + """Metadata filter properly escapes special characters in values.""" + h = self._make_handler() + cond = FilterCondition("metadata.desc", FilterOperator.EQUAL, 'say "hello"') + expr = h._build_filter_expression([], [cond]) + assert '\\"hello\\"' in expr + assert "@metadata:(" in expr + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_select_id_not_equal_falls_through_to_search(self, mock_ft): + """NOT_EQUAL on id falls through to FT.SEARCH instead of returning empty.""" + h = self._make_handler({"vector_dimension": 4}) + h.is_connected = True + h._client = MagicMock() + + search_called = {} + + async def mock_search(client, index_name, query, options=None): + search_called["query"] = query + return [0, {}] + + mock_ft.search = mock_search + + cond = FilterCondition("id", FilterOperator.NOT_EQUAL, "doc1") + h.select("t", columns=["id"], conditions=[cond]) + + # Should have called ft.search with negation filter, not returned empty + assert "query" in search_called + assert "-@id:{doc1}" in search_called["query"] + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_select_id_not_in_falls_through_to_search(self, mock_ft): + """NOT_IN on id falls through to FT.SEARCH instead of returning empty.""" + h = self._make_handler({"vector_dimension": 4}) + h.is_connected = True + h._client = MagicMock() + + search_called = {} + + async def mock_search(client, index_name, query, options=None): + search_called["query"] = query + return [0, {}] + + mock_ft.search = mock_search + + cond = FilterCondition("id", FilterOperator.NOT_IN, ["doc1", "doc2"]) + h.select("t", columns=["id"], conditions=[cond]) + + assert "query" in search_called + assert "-@id:{doc1|doc2}" in search_called["query"] + + def test_connect_logs_error_on_failure(self): + """connect logs error when connection fails.""" + h = self._make_handler({"host": "nonexistent.invalid", "port": 9999}) + with patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.logger") as mock_logger: + with pytest.raises(Exception): + h.connect() + mock_logger.error.assert_called_once() + # Lazy formatting: args are (format_str, host, port, exception) + call_args = mock_logger.error.call_args[0] + assert "nonexistent.invalid" in str(call_args) + + def test_disconnect_logs_debug_on_error(self): + """disconnect logs at debug level when close raises.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + async def mock_close(): + raise RuntimeError("close failed") + + h._client.close = mock_close + + with patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.logger") as mock_logger: + h.disconnect() + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args[0] + assert "close failed" in str(call_args) + + assert h.is_connected is False + assert h._client is None + + def test_check_connection_failure(self): + """check_connection returns failure status on connection error.""" + h = self._make_handler({"host": "nonexistent.invalid", "port": 9999}) + status = h.check_connection() + assert status.success is False + assert status.error_message is not None + + def test_delete_no_conditions_raises(self): + """delete raises exception when no conditions provided.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + with pytest.raises(Exception, match="Delete requires at least one condition"): + h.delete("table", conditions=None) + + def test_parse_search_result_empty(self): + """_parse_search_result returns empty DataFrame on zero results.""" + h = self._make_handler() + result = [0, {}] + df = h._parse_search_result(result, columns=["id", "content"], include_score=False) + assert len(df) == 0 + assert "id" in df.columns + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_select_metadata_filter_uses_search(self, mock_ft): + """Select with metadata filter uses FT.SEARCH with correct expression.""" + h = self._make_handler({"vector_dimension": 4}) + h.is_connected = True + h._client = MagicMock() + + search_called = {} + + async def mock_search(client, index_name, query, options=None): + search_called["query"] = query + return [0, {}] + + mock_ft.search = mock_search + + cond = FilterCondition("metadata.source", FilterOperator.EQUAL, "web") + h.select("t", columns=["id"], conditions=[cond]) + + assert "query" in search_called + assert '@metadata:("web")' in search_called["query"] + + @patch("mindsdb.integrations.handlers.valkey_handler.valkey_handler.ft") + def test_get_columns_nonexistent_table(self, mock_ft): + """get_columns returns error response for non-existent table.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + async def mock_info(client, index_name): + raise RequestError("Unknown Index name: not found") + + mock_ft.info = mock_info + + result = h.get_columns("no_such_table") + assert result.resp_type == RESPONSE_TYPE.ERROR + assert "does not exist" in result.error_message + + def test_insert_handles_none_content_and_metadata(self): + """insert sets empty defaults when content is None and metadata is not a dict.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + captured_calls = [] + + async def mock_hset(key, field_map): + captured_calls.append((key, field_map)) + return 1 + + h._client.hset = mock_hset + + df = pd.DataFrame( + { + "id": ["doc1"], + "content": [None], + "embeddings": [[0.1, 0.2, 0.3, 0.4]], + "metadata": ["not_a_dict"], + } + ) + h.insert("table", df) + + assert len(captured_calls) == 1 + _, field_map = captured_calls[0] + assert field_map["content"] == "" + assert field_map["metadata"] == "{}" + + def test_scan_all_docs_with_offset_and_limit(self): + """_scan_all_docs respects offset and limit parameters.""" + h = self._make_handler() + h.is_connected = True + h._client = MagicMock() + + # Mock scan to return 5 keys + async def mock_scan(cursor, match=None, count=None): + if cursor == b"0": + return [b"0", [f"doc:t:doc{i}".encode() for i in range(5)]] + return [b"0", []] + + h._client.scan = mock_scan + + # Mock client.exec to return hgetall results for each key in the batch + async def mock_exec(batch, **kwargs): + # Return one hgetall result per command added to the batch. + # The batch contains hgetall for selected_keys (after offset/limit). + # With offset=1, limit=2 from 5 keys, selected_keys = [doc1, doc2] + results = [] + for i in range(1, 3): # doc1, doc2 + results.append({ + b"id": f"doc{i}".encode(), + b"content": b"text", + b"embeddings": np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32).tobytes(), + b"metadata": b"{}", + }) + return results + + h._client.exec = mock_exec + + # Request offset=1, limit=2 + df = h._scan_all_docs("t", columns=["id"], offset=1, limit=2) + assert len(df) == 2 + assert df.iloc[0]["id"] == "doc1" + assert df.iloc[1]["id"] == "doc2" + + +# ============================================================================= +# Integration Tests +# ============================================================================= + +VALKEY_HOST = os.environ.get("VALKEY_HOST", "localhost") +VALKEY_PORT = int(os.environ.get("VALKEY_PORT", "6379")) +VALKEY_PASSWORD = os.environ.get("VALKEY_PASSWORD", None) +VECTOR_DIM = 4 # Small dimension for fast tests + + +@pytest.fixture(scope="class") +def handler(): + """Create handler and skip if Valkey not available.""" + h = ValkeyHandler( + "test_valkey", + connection_data={ + "host": VALKEY_HOST, + "port": VALKEY_PORT, + "password": VALKEY_PASSWORD, + "vector_dimension": VECTOR_DIM, + "distance_metric": "COSINE", + }, + handler_storage=MagicMock(), + ) + status = h.check_connection() + if not status.success: + pytest.skip(f"Valkey not available: {status.error_message}") + yield h + h.disconnect() + + +@pytest.fixture +def unique_table(): + """Generate unique table name to avoid test collisions.""" + return f"test_{uuid.uuid4().hex[:8]}" + + +def _make_test_df(num_docs=3): + """Create a test DataFrame with sample documents.""" + embeddings = [np.random.rand(VECTOR_DIM).tolist() for _ in range(num_docs)] + return pd.DataFrame( + { + "id": [f"doc{i}" for i in range(1, num_docs + 1)], + "content": [f"content_{i}" for i in range(1, num_docs + 1)], + "embeddings": embeddings, + "metadata": [{"src": f"source_{i}"} for i in range(1, num_docs + 1)], + } + ) + + +class TestValkeyHandlerIntegration: + """Integration tests requiring a running Valkey instance with Search module.""" + + def test_check_connection(self, handler): + """Verify connection to Valkey is successful.""" + status = handler.check_connection() + assert status.success is True + + def test_create_table(self, handler, unique_table): + """Create an index and verify it appears in table list.""" + try: + handler.create_table(unique_table) + tables = handler.get_tables() + assert unique_table in tables.data_frame["table_name"].values + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_create_table_if_not_exists(self, handler, unique_table): + """Creating an existing table with if_not_exists=True does not raise.""" + try: + handler.create_table(unique_table) + # Should not raise + handler.create_table(unique_table, if_not_exists=True) + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_insert_and_select_all(self, handler, unique_table): + """Insert documents and select all to verify they are stored.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": ["doc1", "doc2", "doc3"], + "content": ["hello", "world", "foo"], + "embeddings": [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.1, 0.2, 0.3], + ], + "metadata": [{"src": "a"}, {"src": "b"}, {"src": "c"}], + } + ) + handler.insert(unique_table, df) + # Allow indexing time + _wait_for_indexing(handler, unique_table, 3) + + result = handler.select(unique_table, columns=["id", "content"]) + assert len(result) == 3 + assert set(result["id"].tolist()) == {"doc1", "doc2", "doc3"} + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_select_knn(self, handler, unique_table): + """KNN vector search returns nearest neighbors sorted by distance.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": ["doc1", "doc2", "doc3"], + "content": ["hello", "world", "foo"], + "embeddings": [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.1, 0.2, 0.3], + ], + "metadata": [{"src": "a"}, {"src": "b"}, {"src": "c"}], + } + ) + handler.insert(unique_table, df) + _wait_for_indexing(handler, unique_table, 3) + + result = handler.select( + unique_table, + columns=["id", "content", "distance"], + conditions=[ + FilterCondition( + "search_vector", + FilterOperator.EQUAL, + [0.1, 0.2, 0.3, 0.4], + ) + ], + limit=2, + ) + assert len(result) == 2 + assert "distance" in result.columns + # doc1 should be closest (identical vector) + assert result.iloc[0]["id"] == "doc1" + assert result.iloc[0]["distance"] <= result.iloc[1]["distance"] + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_select_by_id(self, handler, unique_table): + """Select a document by its ID.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": ["doc1", "doc2"], + "content": ["hello", "world"], + "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], + "metadata": [{}, {}], + } + ) + handler.insert(unique_table, df) + _wait_for_indexing(handler, unique_table, 2) + + result = handler.select( + unique_table, + columns=["id", "content"], + conditions=[FilterCondition("id", FilterOperator.EQUAL, "doc2")], + ) + assert len(result) == 1 + assert result.iloc[0]["id"] == "doc2" + assert result.iloc[0]["content"] == "world" + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_select_by_id_in(self, handler, unique_table): + """Select multiple documents by ID list.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": ["doc1", "doc2", "doc3"], + "content": ["a", "b", "c"], + "embeddings": [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.1, 0.2, 0.3], + ], + "metadata": [{}, {}, {}], + } + ) + handler.insert(unique_table, df) + _wait_for_indexing(handler, unique_table, 3) + + result = handler.select( + unique_table, + columns=["id"], + conditions=[FilterCondition("id", FilterOperator.IN, ["doc1", "doc3"])], + ) + assert len(result) == 2 + assert set(result["id"].tolist()) == {"doc1", "doc3"} + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_select_with_offset_limit(self, handler, unique_table): + """Select with offset and limit restricts results.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": [f"doc{i}" for i in range(5)], + "content": [f"c{i}" for i in range(5)], + "embeddings": [np.random.rand(VECTOR_DIM).tolist() for _ in range(5)], + "metadata": [{} for _ in range(5)], + } + ) + handler.insert(unique_table, df) + _wait_for_indexing(handler, unique_table, 5) + + result = handler.select( + unique_table, + columns=["id"], + offset=1, + limit=2, + ) + assert len(result) <= 2 + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_delete_by_id(self, handler, unique_table): + """Delete a single document by ID.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": ["doc1", "doc2", "doc3"], + "content": ["a", "b", "c"], + "embeddings": [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.1, 0.2, 0.3], + ], + "metadata": [{}, {}, {}], + } + ) + handler.insert(unique_table, df) + _wait_for_indexing(handler, unique_table, 3) + + handler.delete( + unique_table, + [FilterCondition("id", FilterOperator.EQUAL, "doc2")], + ) + # Deletion is synchronous (UNLINK), brief pause for index update + time.sleep(0.2) + + result = handler.select(unique_table, columns=["id"]) + assert "doc2" not in result["id"].tolist() + assert len(result) == 2 + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_delete_by_id_in(self, handler, unique_table): + """Delete multiple documents by ID list.""" + try: + handler.create_table(unique_table) + df = pd.DataFrame( + { + "id": ["doc1", "doc2", "doc3"], + "content": ["a", "b", "c"], + "embeddings": [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.1, 0.2, 0.3], + ], + "metadata": [{}, {}, {}], + } + ) + handler.insert(unique_table, df) + _wait_for_indexing(handler, unique_table, 3) + + handler.delete( + unique_table, + [FilterCondition("id", FilterOperator.IN, ["doc1", "doc3"])], + ) + # Deletion is synchronous (UNLINK), brief pause for index update + time.sleep(0.2) + + result = handler.select(unique_table, columns=["id"]) + assert set(result["id"].tolist()) == {"doc2"} + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_drop_table(self, handler, unique_table): + """Drop table removes index from table list.""" + handler.create_table(unique_table) + handler.drop_table(unique_table) + + tables = handler.get_tables() + assert unique_table not in tables.data_frame["table_name"].values + + def test_drop_table_if_exists_nonexistent(self, handler): + """Dropping a non-existent table with if_exists=True does not raise.""" + # Should not raise + handler.drop_table("nonexistent_table_xyz_999", if_exists=True) + + def test_get_columns(self, handler, unique_table): + """get_columns returns expected schema.""" + try: + handler.create_table(unique_table) + result = handler.get_columns(unique_table) + cols = result.data_frame["COLUMN_NAME"].tolist() + assert "id" in cols + assert "content" in cols + assert "embeddings" in cols + assert "metadata" in cols + assert "distance" in cols + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_upsert_existing_doc(self, handler, unique_table): + """Re-inserting with same ID overwrites the document (upsert).""" + try: + handler.create_table(unique_table) + df1 = pd.DataFrame( + { + "id": ["d1"], + "content": ["original"], + "embeddings": [[0.1, 0.2, 0.3, 0.4]], + "metadata": [{}], + } + ) + handler.insert(unique_table, df1) + _wait_for_indexing(handler, unique_table, 1) + + df2 = pd.DataFrame( + { + "id": ["d1"], + "content": ["updated"], + "embeddings": [[0.9, 0.8, 0.7, 0.6]], + "metadata": [{"new": "meta"}], + } + ) + handler.insert(unique_table, df2) + _wait_for_indexing(handler, unique_table, 1) + + result = handler.select( + unique_table, + columns=["id", "content"], + conditions=[FilterCondition("id", FilterOperator.EQUAL, "d1")], + ) + assert len(result) == 1 + assert result.iloc[0]["content"] == "updated" + finally: + handler.drop_table(unique_table, if_exists=True) + + def test_large_vectors(self, unique_table): + """Test with high-dimensional vectors (768 dims). + + Note: Uses its own handler instance because it needs a different + vector_dimension than the class-level fixture. + """ + dim = 768 + h = ValkeyHandler( + "test_large", + connection_data={ + "host": VALKEY_HOST, + "port": VALKEY_PORT, + "password": VALKEY_PASSWORD, + "vector_dimension": dim, + "distance_metric": "COSINE", + }, + handler_storage=MagicMock(), + ) + try: + status = h.check_connection() + if not status.success: + pytest.skip(f"Valkey not available: {status.error_message}") + + h.create_table(unique_table) + vec = np.random.rand(dim).tolist() + df = pd.DataFrame( + { + "id": ["bigvec1"], + "content": ["large vector test"], + "embeddings": [vec], + "metadata": [{}], + } + ) + h.insert(unique_table, df) + _wait_for_indexing(h, unique_table, 1) + + result = h.select( + unique_table, + columns=["id", "embeddings"], + conditions=[FilterCondition("search_vector", FilterOperator.EQUAL, vec)], + limit=1, + ) + assert len(result) == 1 + assert result.iloc[0]["id"] == "bigvec1" + assert len(result.iloc[0]["embeddings"]) == dim + finally: + h.drop_table(unique_table, if_exists=True) + h.disconnect() diff --git a/mindsdb/integrations/handlers/valkey_handler/valkey_handler.py b/mindsdb/integrations/handlers/valkey_handler/valkey_handler.py new file mode 100644 index 00000000000..b2e7375c3ee --- /dev/null +++ b/mindsdb/integrations/handlers/valkey_handler/valkey_handler.py @@ -0,0 +1,803 @@ +from __future__ import annotations + +import json +import asyncio +import concurrent.futures + +import numpy as np +import pandas as pd + +from glide import ( + Batch, + GlideClient, + GlideClientConfiguration, + NodeAddress, + ServerCredentials, + ft, + FtCreateOptions, + FtSearchOptions, + FtSearchLimit, + DataType, + VectorField, + TextField, + TagField, + VectorAlgorithm, + VectorFieldAttributesFlat, + VectorFieldAttributesHnsw, + DistanceMetricType, + VectorType, + RequestError, +) + +from mindsdb.integrations.libs.response import RESPONSE_TYPE +from mindsdb.integrations.libs.response import HandlerResponse as Response +from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse +from mindsdb.integrations.libs.vectordatabase_handler import ( + FilterCondition, + FilterOperator, + TableField, + VectorStoreHandler, +) +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + +# Constants +DEFAULT_HOST = "localhost" +DEFAULT_PORT = 6379 +DEFAULT_DB = 0 +DEFAULT_VECTOR_DIMENSION = 384 +DEFAULT_DISTANCE_METRIC = "COSINE" +DEFAULT_INDEX_ALGORITHM = "HNSW" +DEFAULT_PREFIX = "doc:" +VECTOR_FIELD_NAME = "embeddings" +SCORE_FIELD_NAME = f"__{VECTOR_FIELD_NAME}_score" # "__embeddings_score" +ID_FIELD_NAME = "id" +CONTENT_FIELD_NAME = "content" +METADATA_FIELD_NAME = "metadata" + +# Safety limits for SCAN operations +_MAX_SCAN_ITERATIONS = 100_000 +_OP_BATCH_SIZE = 100 +# Maximum documents returned by FT.SEARCH for metadata-based delete +_DELETE_SEARCH_LIMIT = 10_000 +# Default request timeout in milliseconds (Glide default is 250ms which is +# too short for vector search on larger indexes) +_DEFAULT_REQUEST_TIMEOUT_MS = 5000 + + +class ValkeyHandler(VectorStoreHandler): + """MindsDB handler for Valkey Vector Store using valkey-glide client.""" + + name = "valkey" + + def __init__(self, name: str, **kwargs): + super().__init__(name) + self.handler_storage = kwargs.get("handler_storage") + connection_data = kwargs.get("connection_data", {}) + + self._host = connection_data.get("host", DEFAULT_HOST) + self._port = int(connection_data.get("port", DEFAULT_PORT)) + self._password = connection_data.get("password", None) + self._db = int(connection_data.get("db", DEFAULT_DB)) + self._vector_dimension = int(connection_data.get("vector_dimension", DEFAULT_VECTOR_DIMENSION)) + self._distance_metric = connection_data.get("distance_metric", DEFAULT_DISTANCE_METRIC).upper() + self._index_algorithm = connection_data.get("index_algorithm", DEFAULT_INDEX_ALGORITHM).upper() + self._prefix = connection_data.get("prefix", DEFAULT_PREFIX) + self._use_tls = bool(connection_data.get("use_tls", False)) + self._request_timeout = int(connection_data.get("request_timeout", _DEFAULT_REQUEST_TIMEOUT_MS)) + + self._client: GlideClient | None = None + self.is_connected = False + self._loop: asyncio.AbstractEventLoop | None = None + self._executor: concurrent.futures.ThreadPoolExecutor | None = None + + def _get_loop(self) -> asyncio.AbstractEventLoop: + """Get or create an event loop for running async operations.""" + if self._loop is None or self._loop.is_closed(): + self._loop = asyncio.new_event_loop() + return self._loop + + def _run(self, coro): + """Execute an async coroutine synchronously. + + If called from within a running event loop (e.g. MindsDB async internals), + offloads execution to a dedicated thread with a persistent event loop to + avoid RuntimeError and ensure the Glide client always operates on the + same loop it was created on. + """ + try: + asyncio.get_running_loop() + # Already inside an event loop — offload to background thread with + # a persistent loop so the Glide client stays on a single loop. + if self._executor is None: + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + loop = self._get_loop() + return self._executor.submit(loop.run_until_complete, coro).result() + except RuntimeError: + # No running loop — use our own. + loop = self._get_loop() + return loop.run_until_complete(coro) + + def _get_distance_metric(self) -> DistanceMetricType: + """Map string distance metric to DistanceMetricType enum.""" + mapping = { + "COSINE": DistanceMetricType.COSINE, + "L2": DistanceMetricType.L2, + "IP": DistanceMetricType.IP, + } + return mapping.get(self._distance_metric, DistanceMetricType.COSINE) + + def connect(self) -> GlideClient: + """Connect to Valkey server and return the client instance.""" + if self.is_connected and self._client is not None: + return self._client + + try: + credentials = ServerCredentials(password=self._password) if self._password else None + config = GlideClientConfiguration( + addresses=[NodeAddress(host=self._host, port=self._port)], + credentials=credentials, + database_id=self._db, + client_name="mindsdb_valkey_handler", + use_tls=self._use_tls, + request_timeout=self._request_timeout, + ) + self._client = self._run(GlideClient.create(config)) + self.is_connected = True + return self._client + except Exception as e: + logger.error("Error connecting to Valkey at %s:%s: %s", self._host, self._port, e) + self.is_connected = False + raise + + def disconnect(self): + """Disconnect from Valkey server.""" + if self.is_connected and self._client is not None: + try: + self._run(self._client.close()) + except Exception as e: + logger.debug("Error during Valkey disconnect: %s", e) + self._client = None + self.is_connected = False + if self._executor is not None: + self._executor.shutdown(wait=False) + self._executor = None + if self._loop is not None and not self._loop.is_closed(): + self._loop.close() + self._loop = None + + def check_connection(self) -> StatusResponse: + """Check connectivity to Valkey server.""" + response = StatusResponse(False) + need_to_close = not self.is_connected + try: + client = self.connect() + self._run(client.ping()) + response.success = True + except Exception as e: + logger.error("Error connecting to Valkey: %s", e) + response.error_message = f"Failed to connect to Valkey at {self._host}:{self._port}" + finally: + if response.success and need_to_close: + self.disconnect() + if not response.success and self.is_connected: + self.disconnect() + return response + + def create_table(self, table_name: str, if_not_exists: bool = True): + """Create a vector index in Valkey. + + Args: + table_name: Name of the index to create. + if_not_exists: If True, silently skip if index already exists. + """ + self.connect() + + schema = [ + TextField(CONTENT_FIELD_NAME), + TagField(ID_FIELD_NAME), + VectorField( + VECTOR_FIELD_NAME, + VectorAlgorithm.FLAT if self._index_algorithm == "FLAT" else VectorAlgorithm.HNSW, + VectorFieldAttributesFlat( + dimensions=self._vector_dimension, + distance_metric=self._get_distance_metric(), + type=VectorType.FLOAT32, + ) + if self._index_algorithm == "FLAT" + else VectorFieldAttributesHnsw( + dimensions=self._vector_dimension, + distance_metric=self._get_distance_metric(), + type=VectorType.FLOAT32, + ), + ), + TextField(METADATA_FIELD_NAME), + ] + + options = FtCreateOptions( + data_type=DataType.HASH, + prefixes=[f"{self._prefix}{table_name}:"], + ) + + try: + self._run(ft.create(self._client, table_name, schema, options)) + except RequestError as e: + if "already exists" in str(e).lower() and if_not_exists: + return + raise + + def drop_table(self, table_name: str, if_exists: bool = True): + """Drop a vector index and its associated data. + + Args: + table_name: Name of the index to drop. + if_exists: If True, silently skip if index doesn't exist. + """ + self.connect() + + try: + self._run(ft.dropindex(self._client, table_name)) + except RequestError as e: + if "not found" not in str(e).lower() or not if_exists: + raise + + # Always clean up hash keys with the table's prefix (handles orphans + # from prior partial failures where the index was dropped but keys remain) + cursor = b"0" + iterations = 0 + while True: + result = self._run( + self._client.scan( + cursor, + match=f"{self._prefix}{table_name}:*", + count=1000, + ) + ) + cursor = result[0] + keys = result[1] + if keys: + self._run(self._client.unlink(keys)) + iterations += 1 + if cursor == b"0" or iterations >= _MAX_SCAN_ITERATIONS: + if iterations >= _MAX_SCAN_ITERATIONS: + logger.warning( + "drop_table SCAN hit iteration limit (%d) for table %s", + _MAX_SCAN_ITERATIONS, + table_name, + ) + break + + def insert(self, table_name: str, data: pd.DataFrame): + """Insert documents with embeddings into the vector store. + + Args: + table_name: Name of the target index. + data: DataFrame with columns: id, content, embeddings, metadata. + """ + self.connect() + + async def _batch_insert(): + tasks = [] + failed: list[str] = [] + for _, row in data.iterrows(): + doc_id = str(row[TableField.ID.value]) + key = f"{self._prefix}{table_name}:{doc_id}" + + # Serialize embeddings to float32 bytes + embeddings = row[TableField.EMBEDDINGS.value] + emb_bytes = np.array(embeddings, dtype=np.float32).tobytes() + + field_map = { + ID_FIELD_NAME: doc_id, + VECTOR_FIELD_NAME: emb_bytes, + } + + # Content field + content = row.get(TableField.CONTENT.value) + if content is not None and pd.notna(content): + field_map[CONTENT_FIELD_NAME] = str(content) + else: + field_map[CONTENT_FIELD_NAME] = "" + + # Metadata field + metadata = row.get(TableField.METADATA.value) + if metadata is not None and isinstance(metadata, dict): + field_map[METADATA_FIELD_NAME] = json.dumps(metadata) + else: + field_map[METADATA_FIELD_NAME] = "{}" + + # Note: hset() coroutines are created eagerly — no code between + # here and asyncio.gather should raise, or these become unawaited. + tasks.append((doc_id, self._client.hset(key, field_map))) + + # Flush in batches to avoid excessive memory usage + if len(tasks) >= _OP_BATCH_SIZE: + coros = [t[1] for t in tasks] + ids = [t[0] for t in tasks] + results = await asyncio.gather(*coros, return_exceptions=True) + for doc_id_r, result in zip(ids, results): + if isinstance(result, Exception): + logger.error( + "Error inserting document %s into %s: %s", + doc_id_r, + table_name, + result, + ) + failed.append(doc_id_r) + tasks = [] + + # Flush remaining + if tasks: + coros = [t[1] for t in tasks] + ids = [t[0] for t in tasks] + results = await asyncio.gather(*coros, return_exceptions=True) + for doc_id_r, result in zip(ids, results): + if isinstance(result, Exception): + logger.error( + "Error inserting document %s into %s: %s", + doc_id_r, + table_name, + result, + ) + failed.append(doc_id_r) + + if failed: + raise Exception( + f"Failed to insert {len(failed)}/{len(data)} documents into {table_name}: {failed[:10]}" + ) + + self._run(_batch_insert()) + + def select( + self, + table_name: str, + columns: list[str] | None = None, + conditions: list[FilterCondition] | None = None, + offset: int | None = None, + limit: int | None = None, + ) -> pd.DataFrame: + """Select documents from the vector store. + + Args: + table_name: Name of the index to query. + columns: List of columns to return. + conditions: Filter conditions (search_vector for KNN, id for lookup). + offset: Result offset. + limit: Maximum number of results (K for KNN). + + Returns: + DataFrame with requested columns. + """ + self.connect() + + # Separate conditions by type + search_vector = None + id_filters: list[FilterCondition] = [] + metadata_filters: list[FilterCondition] = [] + + if conditions: + for cond in conditions: + if cond.column == TableField.SEARCH_VECTOR.value: + search_vector = cond.value + elif cond.column == TableField.ID.value: + id_filters.append(cond) + elif cond.column.startswith(TableField.METADATA.value): + metadata_filters.append(cond) + + # Case A: KNN Vector Search + if search_vector is not None: + k = limit if limit else 10 + query_vec = np.array(search_vector, dtype=np.float32).tobytes() + + filter_expr = self._build_filter_expression(id_filters, metadata_filters) + query_str = f"{filter_expr}=>[KNN {k} @{VECTOR_FIELD_NAME} $query_vec]" + + options = FtSearchOptions( + params={"query_vec": query_vec}, + limit=FtSearchLimit(offset or 0, k), + dialect=2, + ) + result = self._run(ft.search(self._client, table_name, query_str, options)) + return self._parse_search_result(result, columns, include_score=True) + + # Case B: ID-only lookup (direct hash access for EQUAL/IN only) + if id_filters and not metadata_filters: + # Check if any filter uses NOT_EQUAL or NOT_IN — these cannot be + # resolved via direct hash lookup and need the FT.SEARCH path. + has_negation = any(cond.op in (FilterOperator.NOT_EQUAL, FilterOperator.NOT_IN) for cond in id_filters) + if not has_negation: + all_ids: list[str] = [] + for cond in id_filters: + if cond.op == FilterOperator.EQUAL: + all_ids.append(str(cond.value)) + elif cond.op == FilterOperator.IN: + all_ids.extend(str(v) for v in cond.value) + + # Use GLIDE Batch (pipeline) for batched HGETALL — sends all + # commands in a single network round-trip per chunk. + keys = [f"{self._prefix}{table_name}:{doc_id}" for doc_id in all_ids] + + async def _batch_hgetall(): + results = [] + for i in range(0, len(keys), _OP_BATCH_SIZE): + chunk = keys[i : i + _OP_BATCH_SIZE] + batch = Batch(is_atomic=False) + for key in chunk: + batch.hgetall(key) + batch_results = await self._client.exec(batch, raise_on_error=False) + results.extend(batch_results) + return results + + results = self._run(_batch_hgetall()) + docs = [r for r in results if r] + + rows = [self._parse_doc_fields(f, include_score=False) for f in docs] + df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns or [c["name"] for c in self.SCHEMA]) + if columns and not df.empty: + available = [c for c in columns if c in df.columns] + df = df[available] + return df + # Fall through to Case C for negation filters + + # Case C: Full scan / metadata filter + filter_expr = self._build_filter_expression(id_filters, metadata_filters) + + if filter_expr == "*": + # Valkey Search does not support "*" as a match-all query. + # Use SCAN to fetch all keys with the table prefix instead. + return self._scan_all_docs(table_name, columns, offset, limit) + + options = FtSearchOptions( + limit=FtSearchLimit(offset or 0, limit or 100), + dialect=2, + ) + result = self._run(ft.search(self._client, table_name, filter_expr, options)) + return self._parse_search_result(result, columns, include_score=False) + + def delete(self, table_name: str, conditions: list[FilterCondition] | None = None): + """Delete documents from the vector store. + + Args: + table_name: Name of the index. + conditions: Conditions specifying which documents to delete. + """ + self.connect() + + if not conditions: + raise Exception("Delete requires at least one condition") + + ids_to_delete: list[str] = [] + + for cond in conditions: + if cond.column == TableField.ID.value: + if cond.op == FilterOperator.EQUAL: + ids_to_delete.append(str(cond.value)) + elif cond.op == FilterOperator.IN: + ids_to_delete.extend(str(v) for v in cond.value) + elif cond.op in (FilterOperator.NOT_EQUAL, FilterOperator.NOT_IN): + # Use FT.SEARCH with negation filter to find matching docs + filter_expr = self._build_filter_expression([cond], []) + options = FtSearchOptions(limit=FtSearchLimit(0, _DELETE_SEARCH_LIMIT), dialect=2) + result = self._run(ft.search(self._client, table_name, filter_expr, options)) + if result[0] > _DELETE_SEARCH_LIMIT: + logger.warning( + "Delete matched %d docs but only removing %d (limit: %d) in table %s", + result[0], + _DELETE_SEARCH_LIMIT, + _DELETE_SEARCH_LIMIT, + table_name, + ) + ids_to_delete.extend(self._extract_doc_ids_from_search_result(result, table_name)) + elif cond.column.startswith(TableField.METADATA.value): + # Search for matching docs to get their IDs + filter_expr = self._build_filter_expression([], [cond]) + options = FtSearchOptions(limit=FtSearchLimit(0, _DELETE_SEARCH_LIMIT), dialect=2) + result = self._run(ft.search(self._client, table_name, filter_expr, options)) + if result[0] > _DELETE_SEARCH_LIMIT: + logger.warning( + "Delete matched %d docs but only removing %d (limit: %d) in table %s", + result[0], + _DELETE_SEARCH_LIMIT, + _DELETE_SEARCH_LIMIT, + table_name, + ) + ids_to_delete.extend(self._extract_doc_ids_from_search_result(result, table_name)) + + if ids_to_delete: + keys = [f"{self._prefix}{table_name}:{doc_id}" for doc_id in ids_to_delete] + self._run(self._client.unlink(keys)) + + def get_tables(self) -> Response: + """List all vector indexes. + + Returns: + HandlerResponse with DataFrame containing table_name column. + """ + self.connect() + indexes = self._run(ft.list(self._client)) + index_names = [idx.decode() if isinstance(idx, bytes) else str(idx) for idx in indexes] + df = pd.DataFrame({"table_name": index_names}) + return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=df) + + def get_columns(self, table_name: str): + """Get columns for a table (index). + + Args: + table_name: Name of the index. + + Returns: + TableResponse with column schema. + """ + self.connect() + try: + self._run(ft.info(self._client, table_name)) + except RequestError as e: + if "not found" in str(e).lower(): + return Response( + resp_type=RESPONSE_TYPE.ERROR, + error_message=f"Table {table_name} does not exist!", + ) + raise + return super().get_columns(table_name) + + # ------------------------------------------------------------------------- + # Private helper methods + # ------------------------------------------------------------------------- + + def _scan_all_docs( + self, + table_name: str, + columns: list[str] | None, + offset: int | None, + limit: int | None, + ) -> pd.DataFrame: + """Scan all documents for a table using SCAN + HGETALL. + + Valkey Search does not support '*' as a match-all query, so this method + iterates over all hash keys matching the table prefix and returns their + fields as a DataFrame. + + Note: SCAN does not guarantee key ordering. Offset-based pagination + may return inconsistent results across calls if keys are being + added or removed concurrently. This is a known Valkey/Redis limitation. + + Args: + table_name: Name of the index/table. + columns: Columns to include in output. + offset: Number of rows to skip. + limit: Maximum number of rows to return. + + Returns: + DataFrame with document data. + """ + prefix = f"{self._prefix}{table_name}:" + all_keys: list[str] = [] + cursor = b"0" + target_count = (offset or 0) + (limit or 100) + iterations = 0 + while True: + result = self._run(self._client.scan(cursor, match=f"{prefix}*", count=1000)) + cursor = result[0] + keys = result[1] + if keys: + all_keys.extend(k.decode() if isinstance(k, bytes) else k for k in keys) + iterations += 1 + # Stop early once we have enough keys or hit safety limit + if len(all_keys) >= target_count: + break + if cursor == b"0" or iterations >= _MAX_SCAN_ITERATIONS: + break + + # Apply offset and limit + start = offset or 0 + end = start + (limit or 100) + selected_keys = all_keys[start:end] + + if not selected_keys: + return pd.DataFrame(columns=columns or [c["name"] for c in self.SCHEMA]) + + rows = [] + + async def _batch_scan_hgetall(): + results = [] + for i in range(0, len(selected_keys), _OP_BATCH_SIZE): + chunk = selected_keys[i : i + _OP_BATCH_SIZE] + batch = Batch(is_atomic=False) + for key in chunk: + batch.hgetall(key) + batch_results = await self._client.exec(batch, raise_on_error=False) + results.extend(batch_results) + return results + + all_fields = self._run(_batch_scan_hgetall()) + for fields in all_fields: + if fields: + rows.append(self._parse_doc_fields(fields, include_score=False)) + + df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns or [c["name"] for c in self.SCHEMA]) + if columns and not df.empty: + available = [c for c in columns if c in df.columns] + df = df[available] + return df + + def _parse_search_result(self, result, columns: list[str] | None, include_score: bool) -> pd.DataFrame: + """Parse ft.search result into a DataFrame. + + Args: + result: Raw result from ft.search [total_count, {doc_key: {fields}}]. + columns: Columns to include in output. + include_score: Whether to include distance/score column. + + Returns: + DataFrame with parsed results. + """ + total_count = result[0] + if total_count == 0 or len(result) < 2: + return pd.DataFrame(columns=columns or [c["name"] for c in self.SCHEMA]) + + docs_map = result[1] + rows = [] + for doc_key, fields in docs_map.items(): + row = self._parse_doc_fields(fields, include_score=include_score) + rows.append(row) + + df = pd.DataFrame(rows) + if columns and not df.empty: + available = [c for c in columns if c in df.columns] + df = df[available] + return df + + def _parse_doc_fields(self, fields: dict, include_score: bool = False) -> dict: + """Convert raw bytes fields from Valkey hash into Python types. + + Args: + fields: Raw field dict from Valkey (keys and values may be bytes). + include_score: Whether to extract the distance score. + + Returns: + Dict with parsed field values. + """ + # Decode bytes keys to str + decoded = {(k.decode() if isinstance(k, bytes) else k): v for k, v in fields.items()} + + row = {} + row[TableField.ID.value] = self._decode_value(decoded.get(ID_FIELD_NAME, b"")) + row[TableField.CONTENT.value] = self._decode_value(decoded.get(CONTENT_FIELD_NAME, b"")) + + # Embeddings: raw float32 bytes → list of floats + raw_emb = decoded.get(VECTOR_FIELD_NAME, b"") + if isinstance(raw_emb, bytes) and len(raw_emb) > 0: + row[TableField.EMBEDDINGS.value] = np.frombuffer(raw_emb, dtype=np.float32).tolist() + else: + row[TableField.EMBEDDINGS.value] = [] + + # Metadata: JSON string → dict + raw_meta = self._decode_value(decoded.get(METADATA_FIELD_NAME, b"{}")) + try: + row[TableField.METADATA.value] = json.loads(raw_meta) if raw_meta else {} + except (json.JSONDecodeError, TypeError): + row[TableField.METADATA.value] = {} + + # Distance/score (only from KNN search results) + if include_score and SCORE_FIELD_NAME in decoded: + score_val = self._decode_value(decoded[SCORE_FIELD_NAME]) + try: + row[TableField.DISTANCE.value] = float(score_val) + except (ValueError, TypeError): + row[TableField.DISTANCE.value] = 0.0 + + return row + + def _decode_value(self, val) -> str: + """Decode a bytes value to string.""" + if isinstance(val, bytes): + return val.decode("utf-8", errors="replace") + return str(val) if val is not None else "" + + def _extract_doc_ids_from_search_result(self, result, table_name: str) -> list[str]: + """Extract document IDs from an FT.SEARCH result. + + Args: + result: Raw result from ft.search [total_count, {doc_key: {fields}}]. + table_name: Table name used to strip the key prefix. + + Returns: + List of document ID strings. + """ + ids: list[str] = [] + if result[0] > 0 and len(result) > 1: + prefix_str = f"{self._prefix}{table_name}:" + for doc_key in result[1].keys(): + key_str = doc_key.decode() if isinstance(doc_key, bytes) else doc_key + if key_str.startswith(prefix_str): + ids.append(key_str[len(prefix_str) :]) + else: + ids.append(key_str.split(":", 2)[-1]) + return ids + + def _build_filter_expression( + self, + id_filters: list[FilterCondition], + metadata_filters: list[FilterCondition], + ) -> str: + """Build Valkey Search query filter string from FilterCondition objects. + + Args: + id_filters: Conditions on the id field. + metadata_filters: Conditions on metadata fields. + + Returns: + Filter expression string for ft.search query. + """ + parts = [] + + for cond in id_filters: + if cond.op == FilterOperator.EQUAL: + parts.append(f"@{ID_FIELD_NAME}:{{{self._escape_tag(str(cond.value))}}}") + elif cond.op == FilterOperator.IN: + escaped = "|".join(self._escape_tag(str(v)) for v in cond.value) + parts.append(f"@{ID_FIELD_NAME}:{{{escaped}}}") + elif cond.op == FilterOperator.NOT_EQUAL: + parts.append(f"-@{ID_FIELD_NAME}:{{{self._escape_tag(str(cond.value))}}}") + elif cond.op == FilterOperator.NOT_IN: + escaped = "|".join(self._escape_tag(str(v)) for v in cond.value) + parts.append(f"-@{ID_FIELD_NAME}:{{{escaped}}}") + + for cond in metadata_filters: + # Metadata is stored as a flat JSON string in a TextField. + # Full-text search on a TextField cannot reliably filter by JSON + # sub-keys. We support basic substring matching for simple cases, + # but this is a best-effort approach with known limitations. + # For precise metadata filtering, consider using dedicated fields. + if cond.op == FilterOperator.EQUAL: + escaped_value = self._escape_phrase(str(cond.value)) + parts.append(f'@{METADATA_FIELD_NAME}:("{escaped_value}")') + elif cond.op == FilterOperator.NOT_EQUAL: + escaped_value = self._escape_phrase(str(cond.value)) + parts.append(f'-@{METADATA_FIELD_NAME}:("{escaped_value}")') + + if not parts: + return "*" + return " ".join(parts) + + def _escape_phrase(self, value: str) -> str: + """Escape special characters for FT.SEARCH phrase queries in TextField. + + All characters that have special meaning in FT.SEARCH query syntax + are escaped with a backslash to prevent query injection. + + Args: + value: Raw string value. + + Returns: + Escaped string safe for use inside double-quoted phrase queries. + """ + # Characters with special meaning in FT.SEARCH query syntax + special = r',.<>{}[]"\';:!@#$%^&*()-+=~/| ' + result = [] + for ch in value: + if ch in special: + result.append(f"\\{ch}") + else: + result.append(ch) + return "".join(result) + + def _escape_tag(self, value: str) -> str: + """Escape special characters for Valkey Search TAG field queries. + + Args: + value: Raw tag value. + + Returns: + Escaped string safe for use in TAG filter expressions. + """ + # | is the TAG union operator and must be escaped + special = r',.<>{}[]"\';:!@#$%^&*()-+=~/| ' + result = [] + for ch in str(value): + if ch in special: + result.append(f"\\{ch}") + else: + result.append(ch) + return "".join(result)