Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/vecs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy import Engine

from vecs import exc
from vecs.client import Client
from vecs.collection import (
Expand All @@ -23,6 +25,6 @@
]


def create_client(connection_string: str) -> Client:
def create_client(connection_string: str = None, engine: Engine = None) -> Client:
"""Creates a client from a Postgres connection string"""
return Client(connection_string)
return Client(connection_string=connection_string, engine=engine)
21 changes: 16 additions & 5 deletions src/vecs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, List, Optional

from deprecated import deprecated
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import Engine, MetaData, create_engine, text
from sqlalchemy.orm import sessionmaker

from vecs.adapter import Adapter
Expand Down Expand Up @@ -47,17 +47,28 @@ class Client:
vx.disconnect()
"""

def __init__(self, connection_string: str):
def __init__(self, connection_string: str = None, engine: Engine = None):
"""
Initialize a Client instance.

Args:
connection_string (str): A string representing the database connection information.
connection_string (str, optional): Database connection string. Required if engine is not provided.
engine (Engine, optional): Pre-created SQLAlchemy engine. If provided, connection_string is ignored.

Returns:
None
Raises:
ValueError: If neither connection_string nor engine is provided.
"""
self.engine = create_engine(connection_string)
if engine is not None:
self.engine = engine
elif connection_string is not None:
self.engine = create_engine(connection_string)
else:
raise ValueError(
"Either a connection_string or an engine must be provided."
)

self.meta = MetaData(schema="vecs")
self.Session = sessionmaker(self.engine)

Expand Down Expand Up @@ -153,7 +164,7 @@ def get_collection(self, name: str) -> Collection:
from vecs.collection import Collection

query = text(
f"""
"""
select
relname as table_name,
atttypmod as embedding_dim
Expand Down