Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
repos:
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.6.4
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
hooks:
- id: isort
args: ['--multi-line=3', '--trailing-comma', '--force-grid-wrap=0', '--use-parentheses', '--line-width=88']
- id: isort
args:
[
"--multi-line=3",
"--trailing-comma",
"--force-grid-wrap=0",
"--use-parentheses",
"--line-width=88",
]


- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.3.0
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
- id: check-yaml
- id: mixed-line-ending
args: ['--fix=lf']
- id: trailing-whitespace
- id: check-added-large-files
- id: check-yaml
- id: mixed-line-ending
args: ["--fix=lf"]

- repo: https://github.com/humitos/mirrors-autoflake.git
rev: v1.1
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
args: ['--in-place', '--remove-all-unused-imports']
- id: autoflake
args: ["--in-place", "--remove-all-unused-imports"]

- repo: https://github.com/ambv/black
rev: 22.10.0
- repo: https://github.com/ambv/black
rev: 25.1.0
hooks:
- id: black
language_version: python3.9
- id: black
language_version: python3.9
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[pytest]
addopts = src/tests
asyncio_mode = auto
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def read_package_variable(key, filename="__init__.py"):
"pgvector==0.3.*",
"sqlalchemy==2.*",
"psycopg2-binary==2.9.*",
"asyncpg==0.29.*",
"greenlet>=1.0.0",
"flupy==1.*",
"deprecated==1.2.*",
]
Expand Down Expand Up @@ -75,7 +77,7 @@ def read_package_variable(key, filename="__init__.py"):
],
install_requires=REQUIRES,
extras_require={
"dev": ["pytest", "parse", "numpy", "pytest-cov"],
"dev": ["pytest", "parse", "numpy", "pytest-cov", "pytest-asyncio"],
"docs": [
"mkdocs",
"pygments",
Expand Down
8 changes: 8 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Generator

import pytest
import pytest_asyncio
from parse import parse
from sqlalchemy import create_engine, text

Expand Down Expand Up @@ -103,3 +104,10 @@ def clean_db(maybe_start_pg: None) -> Generator[str, None, None]:
def client(clean_db: str) -> Generator[vecs.Client, None, None]:
client_ = vecs.create_client(clean_db)
yield client_


@pytest_asyncio.fixture
async def async_client(clean_db: str):
"""Create an async client for testing"""
client_ = await vecs.create_async_client(clean_db)
yield client_
161 changes: 161 additions & 0 deletions src/tests/test_async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import pytest

import vecs


@pytest.mark.asyncio
async def test_create_async_client(clean_db: str):
"""Test creating an async client"""
# Convert regular connection to async
async_db = clean_db.replace("postgresql://", "postgresql+asyncpg://")
client = await vecs.create_async_client(async_db)
assert isinstance(client, vecs.AsyncClient)
await client.disconnect()


@pytest.mark.asyncio
async def test_async_client_context_manager(clean_db: str):
"""Test async client as context manager"""
async_db = clean_db.replace("postgresql://", "postgresql+asyncpg://")
client = await vecs.create_async_client(async_db)
async with client:
assert isinstance(client, vecs.AsyncClient)
assert client.vector_version is not None


@pytest.mark.asyncio
async def test_async_collection_create_and_upsert(async_client: vecs.AsyncClient):
"""Test async collection creation and upsert"""
# Create collection
collection = await async_client.get_or_create_collection(
"test_collection", dimension=3
)
assert collection.name == "test_collection"
assert collection.dimension == 3

# Test upsert
records = [
("id1", [1.0, 2.0, 3.0], {"type": "test"}),
("id2", [4.0, 5.0, 6.0], {"type": "test"}),
]
await collection.upsert(records)

# Test collection length
length = await collection.__len__()
assert length == 2


@pytest.mark.asyncio
async def test_async_collection_query(async_client: vecs.AsyncClient):
"""Test async collection query"""
# Create collection and add data
collection = await async_client.get_or_create_collection("test_query", dimension=3)

records = [
("id1", [1.0, 2.0, 3.0], {"type": "test"}),
("id2", [4.0, 5.0, 6.0], {"type": "test"}),
("id3", [7.0, 8.0, 9.0], {"type": "test"}),
]
await collection.upsert(records)

# Test query
results = await collection.query([1.0, 2.0, 3.0], limit=2)
assert len(results) == 2
assert results[0] == "id1" # Should be closest match

# Test query with metadata
results_with_meta = await collection.query(
[1.0, 2.0, 3.0], limit=2, include_metadata=True
)
assert len(results_with_meta) == 2
assert results_with_meta[0][0] == "id1"
assert results_with_meta[0][1] == {"type": "test"}


@pytest.mark.asyncio
async def test_async_collection_fetch_and_delete(async_client: vecs.AsyncClient):
"""Test async collection fetch and delete"""
# Create collection and add data
collection = await async_client.get_or_create_collection("test_fetch", dimension=3)

records = [
("id1", [1.0, 2.0, 3.0], {"type": "test"}),
("id2", [4.0, 5.0, 6.0], {"type": "test"}),
]
await collection.upsert(records)

# Test fetch
fetched = await collection.fetch(["id1", "id2"])
assert len(fetched) == 2

# Test delete
await collection.delete(["id1"])
length = await collection.__len__()
assert length == 1

# Test fetch after delete
fetched_after_delete = await collection.fetch(["id1", "id2"])
assert len(fetched_after_delete) == 1


@pytest.mark.asyncio
async def test_async_list_collections(async_client: vecs.AsyncClient):
"""Test async list collections"""
# Create multiple collections
collection1 = await async_client.get_or_create_collection(
"collection1", dimension=3
)
collection2 = await async_client.get_or_create_collection(
"collection2", dimension=4
)

# List collections
collections = await async_client.list_collections()
collection_names = [c.name for c in collections]

assert "collection1" in collection_names
assert "collection2" in collection_names
assert len(collections) >= 2


@pytest.mark.asyncio
async def test_async_delete_collection(async_client: vecs.AsyncClient):
"""Test async delete collection"""
# Create collection
collection = await async_client.get_or_create_collection("to_delete", dimension=3)

# Add some data
records = [("id1", [1.0, 2.0, 3.0], {"type": "test"})]
await collection.upsert(records)

# Delete collection
await async_client.delete_collection("to_delete")

# Try to get deleted collection - should raise error
with pytest.raises(vecs.exc.CollectionNotFound):
await async_client.get_collection("to_delete")


@pytest.mark.asyncio
async def test_async_collection_create_index(async_client: vecs.AsyncClient):
"""Test async collection index creation"""
# Create collection and add enough data for indexing
collection = await async_client.get_or_create_collection("test_index", dimension=3)

# Add data (need enough for index to be created)
records = [
(f"id{i}", [float(i), float(i + 1), float(i + 2)], {"i": i})
for i in range(1100)
]
await collection.upsert(records)

# Create index
await collection.create_index()

# Check if index was created
index_name = await collection.index()
assert index_name is not None

# Test querying with index
results = await collection.query([1.0, 2.0, 3.0], limit=5)
assert len(results) == 5
Loading