55from typing import Any
66
77import clickhouse_connect
8+ from clickhouse_connect .driver import Client
89
9- from ..api import DBCaseConfig , VectorDB
10+ from .. import IndexType
11+ from ..api import VectorDB
12+ from .config import ClickhouseConfigDict , ClickhouseIndexConfig
1013
1114log = logging .getLogger (__name__ )
1215
@@ -17,8 +20,8 @@ class Clickhouse(VectorDB):
1720 def __init__ (
1821 self ,
1922 dim : int ,
20- db_config : dict ,
21- db_case_config : DBCaseConfig ,
23+ db_config : ClickhouseConfigDict ,
24+ db_case_config : ClickhouseIndexConfig ,
2225 collection_name : str = "CHVectorCollection" ,
2326 drop_old : bool = False ,
2427 ** kwargs ,
@@ -28,84 +31,130 @@ def __init__(
2831 self .table_name = collection_name
2932 self .dim = dim
3033
34+ self .index_param = self .case_config .index_param ()
35+ self .search_param = self .case_config .search_param ()
36+ self .session_param = self .case_config .session_param ()
37+
3138 self ._index_name = "clickhouse_index"
3239 self ._primary_field = "id"
3340 self ._vector_field = "embedding"
3441
3542 # construct basic units
36- self .conn = clickhouse_connect .get_client (
37- host = self .db_config ["host" ],
38- port = self .db_config ["port" ],
39- username = self .db_config ["user" ],
40- password = self .db_config ["password" ],
41- database = self .db_config ["dbname" ],
42- )
43+ self .conn = self ._create_connection (** self .db_config , settings = self .session_param )
4344
4445 if drop_old :
4546 log .info (f"Clickhouse client drop table : { self .table_name } " )
4647 self ._drop_table ()
4748 self ._create_table (dim )
49+ if self .case_config .create_index_before_load :
50+ self ._create_index ()
4851
4952 self .conn .close ()
5053 self .conn = None
5154
5255 @contextmanager
53- def init (self ):
56+ def init (self ) -> None :
5457 """
5558 Examples:
5659 >>> with self.init():
5760 >>> self.insert_embeddings()
5861 >>> self.search_embedding()
5962 """
6063
61- self .conn = clickhouse_connect .get_client (
62- host = self .db_config ["host" ],
63- port = self .db_config ["port" ],
64- username = self .db_config ["user" ],
65- password = self .db_config ["password" ],
66- database = self .db_config ["dbname" ],
67- )
64+ self .conn = self ._create_connection (** self .db_config , settings = self .session_param )
6865
6966 try :
7067 yield
7168 finally :
7269 self .conn .close ()
7370 self .conn = None
7471
72+ def _create_connection (self , settings : dict | None , ** kwargs ) -> Client :
73+ return clickhouse_connect .get_client (** self .db_config , settings = settings )
74+
75+ def _drop_index (self ):
76+ assert self .conn is not None , "Connection is not initialized"
77+ try :
78+ self .conn .command (
79+ f'ALTER TABLE { self .db_config ["database" ]} .{ self .table_name } DROP INDEX { self ._index_name } '
80+ )
81+ except Exception as e :
82+ log .warning (f"Failed to drop index on table { self .db_config ['database' ]} .{ self .table_name } : { e } " )
83+ raise e from None
84+
7585 def _drop_table (self ):
7686 assert self .conn is not None , "Connection is not initialized"
7787
78- self .conn .command (f'DROP TABLE IF EXISTS { self .db_config ["dbname" ]} .{ self .table_name } ' )
88+ try :
89+ self .conn .command (f'DROP TABLE IF EXISTS { self .db_config ["database" ]} .{ self .table_name } ' )
90+ except Exception as e :
91+ log .warning (f"Failed to drop table { self .db_config ['database' ]} .{ self .table_name } : { e } " )
92+ raise e from None
93+
94+ def _perfomance_tuning (self ):
95+ self .conn .command ("SET materialize_skip_indexes_on_insert = 1" )
96+
97+ def _create_index (self ):
98+ assert self .conn is not None , "Connection is not initialized"
99+ try :
100+ if self .index_param ["index_type" ] == IndexType .HNSW .value :
101+ if (
102+ self .index_param ["quantization" ]
103+ and self .index_param ["params" ]["M" ]
104+ and self .index_param ["params" ]["efConstruction" ]
105+ ):
106+ query = f"""
107+ ALTER TABLE { self .db_config ["database" ]} .{ self .table_name }
108+ ADD INDEX { self ._index_name } { self ._vector_field }
109+ TYPE vector_similarity('hnsw', '{ self .index_param ["metric_type" ]} ',
110+ '{ self .index_param ["quantization" ]} ',
111+ { self .index_param ["params" ]["M" ]} , { self .index_param ["params" ]["efConstruction" ]} )
112+ GRANULARITY { self .index_param ["granularity" ]}
113+ """
114+ else :
115+ query = f"""
116+ ALTER TABLE { self .db_config ["database" ]} .{ self .table_name }
117+ ADD INDEX { self ._index_name } { self ._vector_field }
118+ TYPE vector_similarity('hnsw', '{ self .index_param ["metric_type" ]} ')
119+ GRANULARITY { self .index_param ["granularity" ]}
120+ """
121+ self .conn .command (cmd = query )
122+ else :
123+ log .warning ("HNSW is only avaliable method in clickhouse now" )
124+ except Exception as e :
125+ log .warning (f"Failed to create Clickhouse vector index on table: { self .table_name } error: { e } " )
126+ raise e from None
79127
80128 def _create_table (self , dim : int ):
81129 assert self .conn is not None , "Connection is not initialized"
82130
83131 try :
84132 # create table
85133 self .conn .command (
86- f'CREATE TABLE IF NOT EXISTS { self .db_config ["dbname" ]} .{ self .table_name } \
87- (id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
134+ f'CREATE TABLE IF NOT EXISTS { self .db_config ["database" ]} .{ self .table_name } '
135+ f"({ self ._primary_field } UInt32, "
136+ f'{ self ._vector_field } Array({ self .index_param ["vector_data_type" ]} ) CODEC(NONE), '
137+ f"CONSTRAINT same_length CHECK length(embedding) = { dim } ) "
138+ f"ENGINE = MergeTree() "
139+ f"ORDER BY { self ._primary_field } "
88140 )
89141
90142 except Exception as e :
91143 log .warning (f"Failed to create Clickhouse table: { self .table_name } error: { e } " )
92144 raise e from None
93145
94- def ready_to_load (self ):
95- pass
96-
97146 def optimize (self , data_size : int | None = None ):
98147 pass
99148
100- def ready_to_search (self ):
149+ def _post_insert (self ):
101150 pass
102151
103152 def insert_embeddings (
104153 self ,
105154 embeddings : list [list [float ]],
106155 metadata : list [int ],
107156 ** kwargs : Any ,
108- ) -> tuple [ int , Exception ] :
157+ ) -> ( int , Exception ) :
109158 assert self .conn is not None , "Connection is not initialized"
110159
111160 try :
@@ -116,7 +165,7 @@ def insert_embeddings(
116165 table = self .table_name ,
117166 data = items ,
118167 column_names = ["id" , "embedding" ],
119- column_type_names = ["UInt32" , " Array(Float64)" ],
168+ column_type_names = ["UInt32" , f' Array({ self . index_param [ "vector_data_type" ] } )' ],
120169 column_oriented = True ,
121170 )
122171 return len (metadata ), None
@@ -132,25 +181,52 @@ def search_embedding(
132181 timeout : int | None = None ,
133182 ) -> list [int ]:
134183 assert self .conn is not None , "Connection is not initialized"
135-
136- index_param = self .case_config .index_param () # noqa: F841
137- search_param = self .case_config .search_param ()
138-
139- if filters :
140- gt = filters .get ("id" )
141- filter_sql = (
142- f'SELECT id, { search_param ["metric_type" ]} (embedding,{ query } ) AS score ' # noqa: S608
143- f'FROM { self .db_config ["dbname" ]} .{ self .table_name } '
144- f"WHERE id > { gt } "
145- f"ORDER BY score LIMIT { k } ;"
146- )
147- result = self .conn .query (filter_sql ).result_rows
184+ parameters = {
185+ "primary_field" : self ._primary_field ,
186+ "vector_field" : self ._vector_field ,
187+ "schema" : self .db_config ["database" ],
188+ "table" : self .table_name ,
189+ "gt" : filters .get ("id" ),
190+ "k" : k ,
191+ "metric_type" : self .search_param ["metric_type" ],
192+ "query" : query ,
193+ }
194+ if self .case_config .metric_type == "COSINE" :
195+ if filters :
196+ result = self .conn .query (
197+ "SELECT {primary_field:Identifier}, {vector_field:Identifier} "
198+ "FROM {schema:Identifier}.{table:Identifier} "
199+ "WHERE {primary_field:Identifier} > {gt:UInt32} "
200+ "ORDER BY cosineDistance(embedding,{query:Array(Float64)}) "
201+ "LIMIT {k:UInt32}" ,
202+ parameters = parameters ,
203+ ).result_rows
204+ return [int (row [0 ]) for row in result ]
205+
206+ result = self .conn .query (
207+ "SELECT {primary_field:Identifier}, {vector_field:Identifier} "
208+ "FROM {schema:Identifier}.{table:Identifier} "
209+ "ORDER BY cosineDistance(embedding,{query:Array(Float64)}) "
210+ "LIMIT {k:UInt32}" ,
211+ parameters = parameters ,
212+ ).result_rows
148213 return [int (row [0 ]) for row in result ]
149- else : # noqa: RET505
150- select_sql = (
151- f'SELECT id, { search_param ["metric_type" ]} (embedding,{ query } ) AS score ' # noqa: S608
152- f'FROM { self .db_config ["dbname" ]} .{ self .table_name } '
153- f"ORDER BY score LIMIT { k } ;"
154- )
155- result = self .conn .query (select_sql ).result_rows
214+ if filters :
215+ result = self .conn .query (
216+ "SELECT {primary_field:Identifier}, {vector_field:Identifier} "
217+ "FROM {schema:Identifier}.{table:Identifier} "
218+ "WHERE {primary_field:Identifier} > {gt:UInt32} "
219+ "ORDER BY L2Distance(embedding,{query:Array(Float64)}) "
220+ "LIMIT {k:UInt32}" ,
221+ parameters = parameters ,
222+ ).result_rows
156223 return [int (row [0 ]) for row in result ]
224+
225+ result = self .conn .query (
226+ "SELECT {primary_field:Identifier}, {vector_field:Identifier} "
227+ "FROM {schema:Identifier}.{table:Identifier} "
228+ "ORDER BY L2Distance(embedding,{query:Array(Float64)}) "
229+ "LIMIT {k:UInt32}" ,
230+ parameters = parameters ,
231+ ).result_rows
232+ return [int (row [0 ]) for row in result ]
0 commit comments