diff --git a/datastore/providers/azurecosmosdb_datastore.py b/datastore/providers/azurecosmosdb_datastore.py index f9d3507d8..6879a19ed 100644 --- a/datastore/providers/azurecosmosdb_datastore.py +++ b/datastore/providers/azurecosmosdb_datastore.py @@ -27,6 +27,8 @@ AZCOSMOS_CONNSTR = os.environ.get("AZCOSMOS_CONNSTR") AZCOSMOS_DATABASE_NAME = os.environ.get("AZCOSMOS_DATABASE_NAME") AZCOSMOS_CONTAINER_NAME = os.environ.get("AZCOSMOS_CONTAINER_NAME") +AZCOSMOS_SIMILARITY = os.environ.get("AZCOSMOS_SIMILARITY", "COS") +AZCOSMOS_NUM_LISTS = os.environ.get("AZCOSMOS_NUM_LISTS", 100) assert AZCOSMOS_API is not None assert AZCOSMOS_CONNSTR is not None assert AZCOSMOS_DATABASE_NAME is not None @@ -201,7 +203,7 @@ def __init__(self, cosmosStore: AzureCosmosDBStoreApi): """ @staticmethod - async def create(num_lists, similarity) -> DataStore: + async def create(num_lists: int=AZCOSMOS_NUM_LISTS, similarity: str=AZCOSMOS_SIMILARITY) -> DataStore: # Create underlying data store based on the API definition. # Right now this only supports Mongo, but set up to support more. @@ -211,6 +213,11 @@ async def create(num_lists, similarity) -> DataStore: apiStore = MongoStoreApi(mongoClient) else: raise NotImplementedError + if similarity not in ["COS", "L2", "IP"]: + raise ValueError( + f"Similarity {similarity} is not supported." + "Supported similarity metrics are COS, L2, and IP." + ) await apiStore.ensure(num_lists, similarity) store = AzureCosmosDBDataStore(apiStore) diff --git a/docs/providers/azurecosmosdb/setup.md b/docs/providers/azurecosmosdb/setup.md index 33d8caa22..f0df196db 100644 --- a/docs/providers/azurecosmosdb/setup.md +++ b/docs/providers/azurecosmosdb/setup.md @@ -15,6 +15,8 @@ Learn more about Azure Cosmos DB for MongoDB vCore [here](https://learn.microsof | `AZCOSMOS_CONNSTR` | Yes | The connection string to your account. | | | `AZCOSMOS_DATABASE_NAME` | Yes | The database where the data is stored/queried | | | `AZCOSMOS_CONTAINER_NAME` | Yes | The container where the data is stored/queried | | +| `AZCOSMOS_SIMILARITY` | No | The similarity metric used by the vector database (allowed values are `COS`, `IP`, `L2`). Default value is `COS`. +| `AZCOSMOS_NUM_LISTS` | No | "This integer is the number of clusters that the inverted file (IVF) index uses to group the vector data.". Default value is `100`. See [vector-search](https://learn.microsoft.com/en-us/azure/cosmos-db/mongodb/vcore/vector-search) for more information. ## Indexing On first insert, the datastore will create the collection and index if necessary on the field `embedding`. Currently hybrid search is not yet supported. diff --git a/tests/datastore/providers/azurecosmosdb/test_azurecosmosdb_datastore.py b/tests/datastore/providers/azurecosmosdb/test_azurecosmosdb_datastore.py index 7b238e4d5..68b697200 100644 --- a/tests/datastore/providers/azurecosmosdb/test_azurecosmosdb_datastore.py +++ b/tests/datastore/providers/azurecosmosdb/test_azurecosmosdb_datastore.py @@ -79,6 +79,10 @@ def queries() -> List[QueryWithEmbedding]: async def azurecosmosdb_datastore() -> DataStore: return await AzureCosmosDBDataStore.create(num_lists=num_lists, similarity=similarity) +@pytest.mark.asyncio +async def test_invalid_similarity() -> None: + with pytest.raises(ValueError): + await AzureCosmosDBDataStore.create(num_lists=num_lists, similarity="INVALID") @pytest.mark.asyncio async def test_upsert(