Skip to content

Commit 7f83936

Browse files
authored
Add HNSW support for Clickhouse client (#500)
* feat: add hnsw support * refactor: minor fixes * feat: reformat code * fix: remove sql injections, reformat code
1 parent e42845f commit 7f83936

File tree

2 files changed

+162
-57
lines changed

2 files changed

+162
-57
lines changed

vectordb_bench/backend/clients/clickhouse/clickhouse.py

Lines changed: 123 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from typing import Any
66

77
import clickhouse_connect
8+
from clickhouse_connect.driver import Client
89

9-
from ..api import DBCaseConfig, VectorDB
10+
from .. import IndexType
11+
from ..api import VectorDB
12+
from .config import ClickhouseConfigDict, ClickhouseIndexConfig
1013

1114
log = logging.getLogger(__name__)
1215

@@ -17,8 +20,8 @@ class Clickhouse(VectorDB):
1720
def __init__(
1821
self,
1922
dim: int,
20-
db_config: dict,
21-
db_case_config: DBCaseConfig,
23+
db_config: ClickhouseConfigDict,
24+
db_case_config: ClickhouseIndexConfig,
2225
collection_name: str = "CHVectorCollection",
2326
drop_old: bool = False,
2427
**kwargs,
@@ -28,84 +31,130 @@ def __init__(
2831
self.table_name = collection_name
2932
self.dim = dim
3033

34+
self.index_param = self.case_config.index_param()
35+
self.search_param = self.case_config.search_param()
36+
self.session_param = self.case_config.session_param()
37+
3138
self._index_name = "clickhouse_index"
3239
self._primary_field = "id"
3340
self._vector_field = "embedding"
3441

3542
# construct basic units
36-
self.conn = clickhouse_connect.get_client(
37-
host=self.db_config["host"],
38-
port=self.db_config["port"],
39-
username=self.db_config["user"],
40-
password=self.db_config["password"],
41-
database=self.db_config["dbname"],
42-
)
43+
self.conn = self._create_connection(**self.db_config, settings=self.session_param)
4344

4445
if drop_old:
4546
log.info(f"Clickhouse client drop table : {self.table_name}")
4647
self._drop_table()
4748
self._create_table(dim)
49+
if self.case_config.create_index_before_load:
50+
self._create_index()
4851

4952
self.conn.close()
5053
self.conn = None
5154

5255
@contextmanager
53-
def init(self):
56+
def init(self) -> None:
5457
"""
5558
Examples:
5659
>>> with self.init():
5760
>>> self.insert_embeddings()
5861
>>> self.search_embedding()
5962
"""
6063

61-
self.conn = clickhouse_connect.get_client(
62-
host=self.db_config["host"],
63-
port=self.db_config["port"],
64-
username=self.db_config["user"],
65-
password=self.db_config["password"],
66-
database=self.db_config["dbname"],
67-
)
64+
self.conn = self._create_connection(**self.db_config, settings=self.session_param)
6865

6966
try:
7067
yield
7168
finally:
7269
self.conn.close()
7370
self.conn = None
7471

72+
def _create_connection(self, settings: dict | None, **kwargs) -> Client:
73+
return clickhouse_connect.get_client(**self.db_config, settings=settings)
74+
75+
def _drop_index(self):
76+
assert self.conn is not None, "Connection is not initialized"
77+
try:
78+
self.conn.command(
79+
f'ALTER TABLE {self.db_config["database"]}.{self.table_name} DROP INDEX {self._index_name}'
80+
)
81+
except Exception as e:
82+
log.warning(f"Failed to drop index on table {self.db_config['database']}.{self.table_name}: {e}")
83+
raise e from None
84+
7585
def _drop_table(self):
7686
assert self.conn is not None, "Connection is not initialized"
7787

78-
self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}')
88+
try:
89+
self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["database"]}.{self.table_name}')
90+
except Exception as e:
91+
log.warning(f"Failed to drop table {self.db_config['database']}.{self.table_name}: {e}")
92+
raise e from None
93+
94+
def _perfomance_tuning(self):
95+
self.conn.command("SET materialize_skip_indexes_on_insert = 1")
96+
97+
def _create_index(self):
98+
assert self.conn is not None, "Connection is not initialized"
99+
try:
100+
if self.index_param["index_type"] == IndexType.HNSW.value:
101+
if (
102+
self.index_param["quantization"]
103+
and self.index_param["params"]["M"]
104+
and self.index_param["params"]["efConstruction"]
105+
):
106+
query = f"""
107+
ALTER TABLE {self.db_config["database"]}.{self.table_name}
108+
ADD INDEX {self._index_name} {self._vector_field}
109+
TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}',
110+
'{self.index_param["quantization"]}',
111+
{self.index_param["params"]["M"]}, {self.index_param["params"]["efConstruction"]})
112+
GRANULARITY {self.index_param["granularity"]}
113+
"""
114+
else:
115+
query = f"""
116+
ALTER TABLE {self.db_config["database"]}.{self.table_name}
117+
ADD INDEX {self._index_name} {self._vector_field}
118+
TYPE vector_similarity('hnsw', '{self.index_param["metric_type"]}')
119+
GRANULARITY {self.index_param["granularity"]}
120+
"""
121+
self.conn.command(cmd=query)
122+
else:
123+
log.warning("HNSW is only avaliable method in clickhouse now")
124+
except Exception as e:
125+
log.warning(f"Failed to create Clickhouse vector index on table: {self.table_name} error: {e}")
126+
raise e from None
79127

80128
def _create_table(self, dim: int):
81129
assert self.conn is not None, "Connection is not initialized"
82130

83131
try:
84132
# create table
85133
self.conn.command(
86-
f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \
87-
(id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
134+
f'CREATE TABLE IF NOT EXISTS {self.db_config["database"]}.{self.table_name} '
135+
f"({self._primary_field} UInt32, "
136+
f'{self._vector_field} Array({self.index_param["vector_data_type"]}) CODEC(NONE), '
137+
f"CONSTRAINT same_length CHECK length(embedding) = {dim}) "
138+
f"ENGINE = MergeTree() "
139+
f"ORDER BY {self._primary_field}"
88140
)
89141

90142
except Exception as e:
91143
log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}")
92144
raise e from None
93145

94-
def ready_to_load(self):
95-
pass
96-
97146
def optimize(self, data_size: int | None = None):
98147
pass
99148

100-
def ready_to_search(self):
149+
def _post_insert(self):
101150
pass
102151

103152
def insert_embeddings(
104153
self,
105154
embeddings: list[list[float]],
106155
metadata: list[int],
107156
**kwargs: Any,
108-
) -> tuple[int, Exception]:
157+
) -> (int, Exception):
109158
assert self.conn is not None, "Connection is not initialized"
110159

111160
try:
@@ -116,7 +165,7 @@ def insert_embeddings(
116165
table=self.table_name,
117166
data=items,
118167
column_names=["id", "embedding"],
119-
column_type_names=["UInt32", "Array(Float64)"],
168+
column_type_names=["UInt32", f'Array({self.index_param["vector_data_type"]})'],
120169
column_oriented=True,
121170
)
122171
return len(metadata), None
@@ -132,25 +181,52 @@ def search_embedding(
132181
timeout: int | None = None,
133182
) -> list[int]:
134183
assert self.conn is not None, "Connection is not initialized"
135-
136-
index_param = self.case_config.index_param() # noqa: F841
137-
search_param = self.case_config.search_param()
138-
139-
if filters:
140-
gt = filters.get("id")
141-
filter_sql = (
142-
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
143-
f'FROM {self.db_config["dbname"]}.{self.table_name} '
144-
f"WHERE id > {gt} "
145-
f"ORDER BY score LIMIT {k};"
146-
)
147-
result = self.conn.query(filter_sql).result_rows
184+
parameters = {
185+
"primary_field": self._primary_field,
186+
"vector_field": self._vector_field,
187+
"schema": self.db_config["database"],
188+
"table": self.table_name,
189+
"gt": filters.get("id"),
190+
"k": k,
191+
"metric_type": self.search_param["metric_type"],
192+
"query": query,
193+
}
194+
if self.case_config.metric_type == "COSINE":
195+
if filters:
196+
result = self.conn.query(
197+
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
198+
"FROM {schema:Identifier}.{table:Identifier} "
199+
"WHERE {primary_field:Identifier} > {gt:UInt32} "
200+
"ORDER BY cosineDistance(embedding,{query:Array(Float64)}) "
201+
"LIMIT {k:UInt32}",
202+
parameters=parameters,
203+
).result_rows
204+
return [int(row[0]) for row in result]
205+
206+
result = self.conn.query(
207+
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
208+
"FROM {schema:Identifier}.{table:Identifier} "
209+
"ORDER BY cosineDistance(embedding,{query:Array(Float64)}) "
210+
"LIMIT {k:UInt32}",
211+
parameters=parameters,
212+
).result_rows
148213
return [int(row[0]) for row in result]
149-
else: # noqa: RET505
150-
select_sql = (
151-
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
152-
f'FROM {self.db_config["dbname"]}.{self.table_name} '
153-
f"ORDER BY score LIMIT {k};"
154-
)
155-
result = self.conn.query(select_sql).result_rows
214+
if filters:
215+
result = self.conn.query(
216+
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
217+
"FROM {schema:Identifier}.{table:Identifier} "
218+
"WHERE {primary_field:Identifier} > {gt:UInt32} "
219+
"ORDER BY L2Distance(embedding,{query:Array(Float64)}) "
220+
"LIMIT {k:UInt32}",
221+
parameters=parameters,
222+
).result_rows
156223
return [int(row[0]) for row in result]
224+
225+
result = self.conn.query(
226+
"SELECT {primary_field:Identifier}, {vector_field:Identifier} "
227+
"FROM {schema:Identifier}.{table:Identifier} "
228+
"ORDER BY L2Distance(embedding,{query:Array(Float64)}) "
229+
"LIMIT {k:UInt32}",
230+
parameters=parameters,
231+
).result_rows
232+
return [int(row[0]) for row in result]
Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,46 @@
1+
from abc import abstractmethod
2+
from typing import TypedDict
3+
14
from pydantic import BaseModel, SecretStr
25

36
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
47

58

9+
class ClickhouseConfigDict(TypedDict):
10+
user: str
11+
password: str
12+
host: str
13+
port: int
14+
database: str
15+
secure: bool
16+
17+
618
class ClickhouseConfig(DBConfig):
719
user_name: str = "clickhouse"
820
password: SecretStr
921
host: str = "localhost"
1022
port: int = 8123
1123
db_name: str = "default"
24+
secure: bool = False
1225

13-
def to_dict(self) -> dict:
26+
def to_dict(self) -> ClickhouseConfigDict:
1427
pwd_str = self.password.get_secret_value()
1528
return {
1629
"host": self.host,
1730
"port": self.port,
18-
"dbname": self.db_name,
31+
"database": self.db_name,
1932
"user": self.user_name,
2033
"password": pwd_str,
34+
"secure": self.secure,
2135
}
2236

2337

24-
class ClickhouseIndexConfig(BaseModel):
38+
class ClickhouseIndexConfig(BaseModel, DBCaseConfig):
2539

2640
metric_type: MetricType | None = None
41+
vector_data_type: str | None = "Float32" # Data type of vectors. Can be Float32 or Float64 or BFloat16
42+
create_index_before_load: bool = True
43+
create_index_after_load: bool = False
2744

2845
def parse_metric(self) -> str:
2946
if not self.metric_type:
@@ -35,26 +52,38 @@ def parse_metric_str(self) -> str:
3552
return "L2Distance"
3653
if self.metric_type == MetricType.COSINE:
3754
return "cosineDistance"
38-
msg = f"Not Support for {self.metric_type}"
39-
raise RuntimeError(msg)
40-
return None
55+
return "cosineDistance"
56+
57+
@abstractmethod
58+
def session_param(self):
59+
pass
4160

4261

43-
class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig):
44-
M: int | None
45-
efConstruction: int | None
62+
class ClickhouseHNSWConfig(ClickhouseIndexConfig):
63+
M: int | None # Default in clickhouse in 32
64+
efConstruction: int | None # Default in clickhouse in 128
4665
ef: int | None = None
4766
index: IndexType = IndexType.HNSW
67+
quantization: str | None = "bf16" # Default is bf16. Possible values are f64, f32, f16, bf16, or i8
68+
granularity: int | None = 10_000_000 # Size of the index granules. By default, in CH it's equal 10.000.000
4869

4970
def index_param(self) -> dict:
5071
return {
72+
"vector_data_type": self.vector_data_type,
5173
"metric_type": self.parse_metric_str(),
5274
"index_type": self.index.value,
75+
"quantization": self.quantization,
76+
"granularity": self.granularity,
5377
"params": {"M": self.M, "efConstruction": self.efConstruction},
5478
}
5579

5680
def search_param(self) -> dict:
5781
return {
58-
"met˝ric_type": self.parse_metric_str(),
82+
"metric_type": self.parse_metric_str(),
5983
"params": {"ef": self.ef},
6084
}
85+
86+
def session_param(self) -> dict:
87+
return {
88+
"allow_experimental_vector_similarity_index": 1,
89+
}

0 commit comments

Comments
 (0)