Skip to content
1 change: 1 addition & 0 deletions src/memos/configs/vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class QdrantVecDBConfig(BaseVecDBConfig):
host: str | None = Field(default=None, description="Host for Qdrant")
port: int | None = Field(default=None, description="Port for Qdrant")
path: str | None = Field(default=None, description="Path for Qdrant")
api_key: str | None = Field(default=None, description="Optional API key for Qdrant authentication")

@model_validator(mode="after")
def set_default_path(self):
Expand Down
12 changes: 9 additions & 3 deletions src/memos/vec_dbs/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ def __init__(self, config: QdrantVecDBConfig):
"(e.g., via Docker: https://qdrant.tech/documentation/quickstart/)."
)

self.client = QdrantClient(
host=self.config.host, port=self.config.port, path=self.config.path
)
client_kwargs = {
"host": self.config.host,
"port": self.config.port,
"path": self.config.path,
}
if self.config.api_key:
client_kwargs["api_key"] = self.config.api_key

self.client = QdrantClient(**client_kwargs)
self.create_collection()

def create_collection(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/configs/test_vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_qdrant_vec_db_config():
required_fields=[
"collection_name",
],
optional_fields=["vector_dimension", "distance_metric", "host", "port", "path"],
optional_fields=["vector_dimension", "distance_metric", "host", "port", "path", "api_key"],
)

check_config_instantiation_valid(
Expand Down
28 changes: 28 additions & 0 deletions tests/vec_dbs/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,31 @@ def test_get_all(vec_db):
results = vec_db.get_all()
assert len(results) == 1
assert isinstance(results[0], VecDBItem)


def test_client_receives_api_key():
from unittest.mock import patch
from memos import settings
from memos.configs.vec_db import VectorDBConfigFactory
from memos.vec_dbs.factory import VecDBFactory

api_key = "your_secure_api_key_here_change_in_production"
with patch("qdrant_client.QdrantClient") as mock_client:
cfg = VectorDBConfigFactory.model_validate(
{
"backend": "qdrant",
"config": {
"collection_name": "test_collection",
"vector_dimension": 4,
"distance_metric": "cosine",
"path": str(settings.MEMOS_DIR / "qdrant"),
"api_key": api_key,
},
}
)
_ = VecDBFactory.from_config(cfg)

# Assert that QdrantClient was called with api_key keyword argument
assert mock_client.called
kwargs = mock_client.call_args.kwargs
assert kwargs.get("api_key") == api_key