@@ -64,8 +64,14 @@ def __init__( # noqa: PLR0915
6464 "user" : db_config .get ("user_name" , "root" ),
6565 "password" : db_config .get ("password" , "" ),
6666 }
67- # Add sslmode if specified, otherwise default to disable for local dev
67+ # Add SSL configuration if specified
6868 conn_params ["sslmode" ] = db_config .get ("sslmode" , "disable" )
69+ if db_config .get ("sslrootcert" ):
70+ conn_params ["sslrootcert" ] = db_config ["sslrootcert" ]
71+ if db_config .get ("sslcert" ):
72+ conn_params ["sslcert" ] = db_config ["sslcert" ]
73+ if db_config .get ("sslkey" ):
74+ conn_params ["sslkey" ] = db_config ["sslkey" ]
6975
7076 self .connect_config = conn_params
7177 self .pool_size = db_config .get ("pool_size" , 100 )
@@ -117,8 +123,7 @@ def __init__( # noqa: PLR0915
117123 cursor = conn .cursor ()
118124
119125 if drop_old :
120- if self .case_config is not None :
121- self ._drop_index () # Use SQLAlchemy
126+ # DROP TABLE CASCADE will automatically drop indexes, no need to drop separately
122127 self ._drop_table (cursor , conn )
123128 self ._create_table (cursor , conn , dim )
124129 if self .case_config is not None and self .case_config .create_index_before_load :
@@ -139,6 +144,12 @@ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
139144 def _create_connection_pool (self ) -> ConnectionPool :
140145 """Create connection pool with production settings."""
141146 # Build connection info without 'options' parameter (not supported by psycopg_pool)
147+ # Include vector_search_beam_size in connection options for performance
148+ beam_size = 32
149+ if self .case_config is not None :
150+ search_param = self .case_config .search_param ()
151+ beam_size = search_param .get ("vector_search_beam_size" , 32 )
152+
142153 conninfo = (
143154 f"host={ self .connect_config ['host' ]} "
144155 f"port={ self .connect_config ['port' ]} "
@@ -147,23 +158,23 @@ def _create_connection_pool(self) -> ConnectionPool:
147158 f"password={ self .connect_config ['password' ]} "
148159 )
149160
150- # Add sslmode if present
161+ # Add SSL configuration if present
151162 if "sslmode" in self .connect_config :
152163 conninfo += f" sslmode={ self .connect_config ['sslmode' ]} "
164+ if "sslrootcert" in self .connect_config :
165+ conninfo += f" sslrootcert={ self .connect_config ['sslrootcert' ]} "
166+ if "sslcert" in self .connect_config :
167+ conninfo += f" sslcert={ self .connect_config ['sslcert' ]} "
168+ if "sslkey" in self .connect_config :
169+ conninfo += f" sslkey={ self .connect_config ['sslkey' ]} "
153170
154- # Add statement timeout for long-running vector index operations
155- conninfo += " options='-c statement_timeout=600s'"
171+ # Add all settings in connection options to avoid per-connection overhead
172+ conninfo += f " options='-c statement_timeout=600s -c vector_search_beam_size= { beam_size } '"
156173
157- # Configure each connection with vector support and search parameters
174+ # Configure each connection with vector support (lightweight operation)
158175 def configure_connection (conn : Connection ) -> None :
159176 register_vector (conn )
160- # Set vector_search_beam_size on every connection for index usage
161- if self .case_config is not None :
162- search_param = self .case_config .search_param ()
163- beam_size = search_param .get ("vector_search_beam_size" , 32 )
164- with conn .cursor () as cur :
165- cur .execute (f"SET vector_search_beam_size = { beam_size } " )
166- conn .commit ()
177+ # No need to set beam_size here - it's in connection options
167178
168179 return ConnectionPool (
169180 conninfo = conninfo ,
@@ -211,9 +222,50 @@ def init(self) -> Generator[None, None, None]:
211222 self .conn = None
212223 self .pool = None
213224
225+ def _cancel_running_schema_jobs (self ):
226+ """Cancel any running schema change jobs for this table.
227+ CockroachDB-specific: Running CREATE INDEX jobs block DROP TABLE."""
228+ import psycopg
229+
230+ try :
231+ conn = psycopg .connect (** self .connect_config )
232+ conn .autocommit = True
233+ cursor = conn .cursor ()
234+
235+ # Find running schema change jobs for our table
236+ cursor .execute (
237+ """
238+ SELECT job_id
239+ FROM [SHOW JOBS]
240+ WHERE status IN ('running', 'pending')
241+ AND job_type = 'NEW SCHEMA CHANGE'
242+ AND description LIKE %s
243+ """ ,
244+ (f"%{ self .table_name } %" ,),
245+ )
246+ jobs = cursor .fetchall ()
247+
248+ for job in jobs :
249+ job_id = job [0 ]
250+ log .warning (f"{ self .name } canceling schema job { job_id } before dropping table" )
251+ try :
252+ cursor .execute (f"CANCEL JOB { job_id } " )
253+ log .info (f"Canceled job { job_id } " )
254+ except Exception as e :
255+ log .warning (f"Failed to cancel job { job_id } : { e } " )
256+
257+ cursor .close ()
258+ conn .close ()
259+ except Exception as e :
260+ log .warning (f"Failed to check/cancel running jobs: { e } " )
261+
214262 @db_retry (max_attempts = 3 , initial_delay = 0.5 , backoff_factor = 2.0 )
215263 def _drop_table (self , cursor : Cursor , conn : Connection ):
216- """Drop table with retry logic."""
264+ """Drop table with retry logic.
265+ Note: CockroachDB-specific - must cancel running schema jobs first."""
266+ # Cancel any running schema change jobs that would block DROP
267+ self ._cancel_running_schema_jobs ()
268+
217269 log .info (f"{ self .name } dropping table: { self .table_name } " )
218270 cursor .execute (
219271 sql .SQL ("DROP TABLE IF EXISTS {table_name} CASCADE" ).format (
@@ -223,7 +275,8 @@ def _drop_table(self, cursor: Cursor, conn: Connection):
223275 conn .commit ()
224276
225277 def _drop_index (self ):
226- """Drop CockroachDB vector index if it exists (DDL with autocommit)."""
278+ """Drop CockroachDB vector index if it exists (DDL with autocommit).
279+ Note: This is typically not needed as DROP TABLE CASCADE handles it."""
227280 log .info (f"{ self .name } dropping index: { self ._index_name } " )
228281 conn = psycopg .connect (** self .connect_config )
229282 conn .autocommit = True
@@ -397,25 +450,46 @@ def optimize(self, data_size: int | None = None):
397450 start_time = time .time ()
398451 connection_closed = False
399452
400- # Try to create index
453+ # Try to create index - use a connection without statement_timeout
401454 try :
402- with self .pool .connection () as conn :
455+ # Create connection without statement_timeout for long-running index creation
456+ import psycopg
457+
458+ conninfo_no_timeout = (
459+ f"host={ self .connect_config ['host' ]} "
460+ f"port={ self .connect_config ['port' ]} "
461+ f"dbname={ self .connect_config ['dbname' ]} "
462+ f"user={ self .connect_config ['user' ]} "
463+ f"password={ self .connect_config ['password' ]} "
464+ )
465+ if "sslmode" in self .connect_config :
466+ conninfo_no_timeout += f" sslmode={ self .connect_config ['sslmode' ]} "
467+ if "sslrootcert" in self .connect_config :
468+ conninfo_no_timeout += f" sslrootcert={ self .connect_config ['sslrootcert' ]} "
469+ if "sslcert" in self .connect_config :
470+ conninfo_no_timeout += f" sslcert={ self .connect_config ['sslcert' ]} "
471+ if "sslkey" in self .connect_config :
472+ conninfo_no_timeout += f" sslkey={ self .connect_config ['sslkey' ]} "
473+
474+ with psycopg .connect (conninfo_no_timeout , autocommit = True ) as conn :
403475 register_vector (conn )
404- conn .autocommit = True
405- cursor = conn .cursor ()
406- try :
476+ with conn .cursor () as cursor :
407477 cursor .execute (sql_str )
408478 elapsed = time .time () - start_time
409479 log .info (f"{ self .name } index created successfully in { elapsed :.1f} s" )
410480 return # Success!
411- finally :
412- cursor .close ()
413481 except Exception as e :
414482 elapsed = time .time () - start_time
415- # Check if this is the expected 30s timeout on multi-node clusters
416- if "server closed the connection" in str (e ) or "connection" in str (e ).lower ():
417- log .warning (f"Connection closed after { elapsed :.1f} s during index creation: { e } " )
418- log .info ("This is expected on multi-node clusters - checking if index was created..." )
483+ error_msg = str (e )
484+ # Check for timeout or connection issues
485+ if (
486+ "server closed the connection" in error_msg
487+ or "statement timeout" in error_msg .lower ()
488+ or "query execution canceled" in error_msg .lower ()
489+ or "connection" in error_msg .lower ()
490+ ):
491+ log .warning (f"Timeout/connection issue after { elapsed :.1f} s during index creation: { e } " )
492+ log .info ("This is expected on large datasets - checking if index is being created in background..." )
419493 connection_closed = True
420494 else :
421495 # Unexpected error, re-raise
@@ -495,8 +569,7 @@ def search_embedding(
495569 ** kwargs : Any ,
496570 ) -> list [int ]:
497571 """Search for k nearest neighbors using vector index."""
498- assert self .conn is not None , "Connection is not initialized"
499- assert self .cursor is not None , "Cursor is not initialized"
572+ assert self .pool is not None , "Connection pool is not initialized"
500573
501574 # Use default L2 distance if no case_config provided
502575 if self .case_config is not None :
@@ -520,5 +593,16 @@ def search_embedding(
520593 metric_op = sql .SQL (metric_op ),
521594 )
522595
523- result = self .cursor .execute (full_sql , (q , k ), prepare = True , binary = True )
524- return [int (i [0 ]) for i in result .fetchall ()]
596+ # Get a connection from the pool for this query (enables true concurrency)
597+ # Pool returns already-configured connections with vector support
598+ with self .pool .connection () as conn , conn .cursor () as cursor :
599+ try :
600+ result = cursor .execute (full_sql , (q , k ), prepare = True , binary = True )
601+ return [int (i [0 ]) for i in result .fetchall ()]
602+ except Exception as e :
603+ # If transaction is aborted, rollback and retry
604+ if "transaction is aborted" in str (e ).lower ():
605+ conn .rollback ()
606+ result = cursor .execute (full_sql , (q , k ), prepare = True , binary = True )
607+ return [int (i [0 ]) for i in result .fetchall ()]
608+ raise
0 commit comments