diff --git a/demos/autocomplete-01-basics.gif b/demos/autocomplete-01-basics.gif new file mode 100644 index 00000000..c42e4564 Binary files /dev/null and b/demos/autocomplete-01-basics.gif differ diff --git a/demos/autocomplete-02-select-from.gif b/demos/autocomplete-02-select-from.gif new file mode 100644 index 00000000..78c0d90a Binary files /dev/null and b/demos/autocomplete-02-select-from.gif differ diff --git a/demos/autocomplete-03-joins.gif b/demos/autocomplete-03-joins.gif new file mode 100644 index 00000000..fb15aeb2 Binary files /dev/null and b/demos/autocomplete-03-joins.gif differ diff --git a/demos/autocomplete-04-clauses.gif b/demos/autocomplete-04-clauses.gif new file mode 100644 index 00000000..06f604c0 Binary files /dev/null and b/demos/autocomplete-04-clauses.gif differ diff --git a/pyproject.toml b/pyproject.toml index 2fcd844a..7247a59a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "pyperclip>=1.8.2", "keyring>=24.0.0", "docker>=7.0.0", + "sqlparse>=0.5.0", ] dynamic = ["version"] diff --git a/sqlit/app.py b/sqlit/app.py index 5d7de511..41c8205f 100644 --- a/sqlit/app.py +++ b/sqlit/app.py @@ -17,7 +17,9 @@ from textual.lazy import Lazy from textual.screen import ModalScreen from textual.timer import Timer -from textual.widgets import Static, TextArea, Tree +from textual.widgets import Static, Tree + +from .widgets import QueryTextArea from textual.worker import Worker from .config import ( @@ -51,6 +53,7 @@ TreeFilterInput, VimMode, ) +from .idle_scheduler import init_idle_scheduler, on_user_activity, IdleScheduler class SSMSTUI( @@ -208,6 +211,17 @@ class SSMSTUI( padding: 0 1; } + #idle-scheduler-bar { + height: 1; + background: $primary-darken-3; + padding: 0 1; + display: none; + } + + #idle-scheduler-bar.visible { + display: block; + } + #sidebar, #query-area, #results-area { @@ -260,7 +274,6 @@ class SSMSTUI( Binding("z", "collapse_tree", "Collapse", show=False), Binding("j", "tree_cursor_down", "Down", show=False), Binding("k", "tree_cursor_up", "Up", show=False), - Binding("u", "use_database", "Use as default", show=False), Binding("v", "view_cell", "View cell", show=False), Binding("u", "edit_cell", "Update cell", show=False), Binding("h", "results_cursor_left", "Left", show=False), @@ -297,6 +310,7 @@ def __init__( self._startup_connection = startup_connection self._startup_connect_config: ConnectionConfig | None = None self._debug_mode = os.environ.get("SQLIT_DEBUG") == "1" + self._debug_idle_scheduler = os.environ.get("SQLIT_DEBUG_IDLE_SCHEDULER") == "1" self._startup_profile = os.environ.get("SQLIT_PROFILE_STARTUP") == "1" self._startup_mark = self._parse_startup_mark(os.environ.get("SQLIT_STARTUP_MARK")) self._startup_init_time = time.perf_counter() @@ -355,7 +369,9 @@ def __init__( self._state_machine = UIStateMachine() self._session_factory: Any | None = None self._last_query_table: dict | None = None - # Omarchy theme sync state + self._query_target_database: str | None = None # Target DB for auto-generated queries + # Idle scheduler for background work + self._idle_scheduler: IdleScheduler | None = None if mock_profile: self._session_factory = self._create_mock_session_factory(mock_profile) @@ -387,8 +403,8 @@ def object_tree(self) -> Tree: return self.query_one("#object-tree", Tree) @property - def query_input(self) -> TextArea: - return self.query_one("#query-input", TextArea) + def query_input(self) -> QueryTextArea: + return self.query_one("#query-input", QueryTextArea) @property def results_table(self) -> SqlitDataTable: @@ -416,6 +432,10 @@ def results_area(self) -> Any: def status_bar(self) -> Static: return self.query_one("#status-bar", Static) + @property + def idle_scheduler_bar(self) -> Static: + return self.query_one("#idle-scheduler-bar", Static) + @property def autocomplete_dropdown(self) -> Any: from .widgets import AutocompleteDropdown @@ -509,7 +529,7 @@ def compose(self) -> ComposeResult: with Vertical(id="main-panel"): with Container(id="query-area"): - yield TextArea( + yield QueryTextArea( "", language="sql", id="query-input", @@ -521,6 +541,7 @@ def compose(self) -> ComposeResult: yield ResultsFilterInput(id="results-filter") yield Lazy(SqlitDataTable(id="results-table", zebra_stripes=True, show_header=False)) + yield Static("", id="idle-scheduler-bar") yield Static("Not connected", id="status-bar") yield ContextFooter() @@ -531,6 +552,15 @@ def on_mount(self) -> None: self._startup_stamp("on_mount_start") self._restart_argv = self._compute_restart_argv() + # Initialize and start idle scheduler + self._idle_scheduler = init_idle_scheduler(self) + self._idle_scheduler.start() + + # Show idle scheduler debug bar if enabled + if self._debug_idle_scheduler: + self.idle_scheduler_bar.add_class("visible") + self._idle_scheduler_bar_timer = self.set_interval(0.1, self._update_idle_scheduler_bar) + self._theme_manager.register_builtin_themes() self._theme_manager.register_textarea_themes() diff --git a/sqlit/cli.py b/sqlit/cli.py index 4564c585..2324798c 100644 --- a/sqlit/cli.py +++ b/sqlit/cli.py @@ -161,6 +161,11 @@ def main() -> int: action="store_true", help="Show startup timing in the status bar.", ) + parser.add_argument( + "--debug-idle-scheduler", + action="store_true", + help="Show idle scheduler status in the status bar.", + ) subparsers = parser.add_subparsers(dest="command", help="Available commands") @@ -284,6 +289,10 @@ def main() -> int: os.environ["SQLIT_DEBUG"] = "1" else: os.environ.pop("SQLIT_DEBUG", None) + if args.debug_idle_scheduler: + os.environ["SQLIT_DEBUG_IDLE_SCHEDULER"] = "1" + else: + os.environ.pop("SQLIT_DEBUG_IDLE_SCHEDULER", None) if args.profile_startup or args.debug: os.environ["SQLIT_STARTUP_MARK"] = str(startup_mark) else: diff --git a/sqlit/db/__init__.py b/sqlit/db/__init__.py index 4a4eba06..29cee140 100644 --- a/sqlit/db/__init__.py +++ b/sqlit/db/__init__.py @@ -88,6 +88,7 @@ def get_all_schemas() -> Any: "SchemaField": ("sqlit.db.schema", "SchemaField"), "SelectOption": ("sqlit.db.schema", "SelectOption"), # Adapters (through sqlit.db.adapters, which itself lazy-loads) + "AzureSQLAdapter": ("sqlit.db.adapters", "AzureSQLAdapter"), "CockroachDBAdapter": ("sqlit.db.adapters", "CockroachDBAdapter"), "DuckDBAdapter": ("sqlit.db.adapters", "DuckDBAdapter"), "FirebirdAdapter": ("sqlit.db.adapters", "FirebirdAdapter"), diff --git a/sqlit/db/adapters/base.py b/sqlit/db/adapters/base.py index a4604f9c..381403b0 100644 --- a/sqlit/db/adapters/base.py +++ b/sqlit/db/adapters/base.py @@ -182,12 +182,37 @@ def supports_multiple_databases(self) -> bool: """Whether this database type supports multiple databases.""" pass + @property + def supports_cross_database_queries(self) -> bool: + """Whether this database supports cross-database queries. + + When True, queries can reference tables in other databases using + fully qualified names (e.g., [db].[schema].[table] in SQL Server). + + When False, each database is isolated and a specific database must + be selected before querying. Connection creation will require a + database to be specified. + + Defaults to True. Override in subclasses for databases like PostgreSQL + where each database is isolated. + """ + return True + @property @abstractmethod def supports_stored_procedures(self) -> bool: """Whether this database type supports stored procedures.""" pass + @property + def system_databases(self) -> frozenset[str]: + """Set of system database names to exclude from user listings. + + Override in subclasses for database-specific system databases. + Returns lowercase names for case-insensitive comparison. + """ + return frozenset() + @property def default_schema(self) -> str: """The default schema for this database type. @@ -250,6 +275,10 @@ def validate_config(self, config: ConnectionConfig) -> None: """Validate provider-specific config values.""" return None + def detect_capabilities(self, conn: Any, config: ConnectionConfig) -> None: + """Detect runtime capabilities after establishing a connection.""" + return None + def get_auth_type(self, config: ConnectionConfig) -> Any | None: """Return the provider-specific auth type, if applicable.""" return None @@ -567,6 +596,10 @@ class MySQLBaseAdapter(CursorBasedAdapter): def supports_multiple_databases(self) -> bool: return True + @property + def system_databases(self) -> frozenset[str]: + return frozenset({"mysql", "information_schema", "performance_schema", "sys"}) + @property def supports_stored_procedures(self) -> bool: return True @@ -821,6 +854,15 @@ def supports_multiple_databases(self) -> bool: def supports_stored_procedures(self) -> bool: return True + @property + def system_databases(self) -> frozenset[str]: + return frozenset({"template0", "template1"}) + + @property + def supports_cross_database_queries(self) -> bool: + """PostgreSQL databases are isolated; cross-database queries not supported.""" + return False + @property def default_schema(self) -> str: return "public" diff --git a/sqlit/db/adapters/clickhouse.py b/sqlit/db/adapters/clickhouse.py index 8f845681..edf17fda 100644 --- a/sqlit/db/adapters/clickhouse.py +++ b/sqlit/db/adapters/clickhouse.py @@ -61,6 +61,10 @@ def driver_import_names(self) -> tuple[str, ...]: def supports_multiple_databases(self) -> bool: return True + @property + def system_databases(self) -> frozenset[str]: + return frozenset({"system", "information_schema", "INFORMATION_SCHEMA"}) + @property def supports_stored_procedures(self) -> bool: # ClickHouse doesn't have traditional stored procedures diff --git a/sqlit/db/adapters/d1.py b/sqlit/db/adapters/d1.py index b1e45aef..c01da365 100644 --- a/sqlit/db/adapters/d1.py +++ b/sqlit/db/adapters/d1.py @@ -60,6 +60,11 @@ def supports_stored_procedures(self) -> bool: """D1 is SQLite-based and does not support stored procedures.""" return False + @property + def supports_cross_database_queries(self) -> bool: + """D1 databases are isolated; cross-database queries not supported.""" + return False + def connect(self, config: ConnectionConfig) -> D1Connection: """Establishes a 'connection' to D1 by preparing authenticated session.""" requests = import_driver_module( diff --git a/sqlit/db/adapters/mssql.py b/sqlit/db/adapters/mssql.py index 5ff45c04..42ed9f98 100644 --- a/sqlit/db/adapters/mssql.py +++ b/sqlit/db/adapters/mssql.py @@ -13,6 +13,9 @@ class SQLServerAdapter(DatabaseAdapter): """Adapter for Microsoft SQL Server using the mssql-python driver.""" + def __init__(self) -> None: + self._supports_cross_database_queries_override: bool | None = None + @classmethod def badge_label(cls) -> str: return "MSSQL" @@ -43,10 +46,20 @@ def driver_import_names(self) -> tuple[str, ...]: def supports_multiple_databases(self) -> bool: return True + @property + def supports_cross_database_queries(self) -> bool: + if self._supports_cross_database_queries_override is not None: + return self._supports_cross_database_queries_override + return True + @property def supports_stored_procedures(self) -> bool: return True + @property + def system_databases(self) -> frozenset[str]: + return frozenset({"master", "tempdb", "model", "msdb"}) + def build_connection_string(self, config: ConnectionConfig) -> str: return self._build_connection_string(config) @@ -81,7 +94,7 @@ def driver_setup_kind(self) -> str | None: @classmethod def docker_image_patterns(cls) -> tuple[str, ...]: - return ("mcr.microsoft.com/mssql", "mcr.microsoft.com/azure-sql-edge") + return ("mcr.microsoft.com/mssql",) @classmethod def docker_env_vars(cls) -> dict[str, tuple[str, ...]]: @@ -113,6 +126,19 @@ def normalize_config(self, config: ConnectionConfig) -> ConnectionConfig: return config + def detect_capabilities(self, conn: Any, config: ConnectionConfig) -> None: + """Detect Azure SQL variants that don't support cross-database queries.""" + try: + cursor = conn.cursor() + cursor.execute("SELECT CAST(SERVERPROPERTY('EngineEdition') AS int)") + row = cursor.fetchone() + if row: + engine_edition = int(row[0]) + if engine_edition in (5, 6): + self._supports_cross_database_queries_override = False + except Exception: + pass + def _build_connection_string(self, config: ConnectionConfig) -> str: """Build mssql-python connection string from config. @@ -159,9 +185,7 @@ def connect(self, config: ConnectionConfig) -> Any: ) conn_str = self._build_connection_string(config) - conn = mssql_python.connect(conn_str) - - return conn + return mssql_python.connect(conn_str) def get_databases(self, conn: Any) -> list[str]: """Get list of databases from SQL Server.""" @@ -169,174 +193,112 @@ def get_databases(self, conn: Any) -> list[str]: cursor.execute("SELECT name FROM sys.databases ORDER BY name") return [row[0] for row in cursor.fetchall()] - def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: - """Get list of tables with schema from SQL Server.""" + def _get_cursor_for_database(self, conn: Any, database: str | None) -> Any: + """Get a cursor for the specified database using USE statement.""" cursor = conn.cursor() if database: - cursor.execute( - f"SELECT TABLE_SCHEMA, TABLE_NAME FROM [{database}].INFORMATION_SCHEMA.TABLES " - f"WHERE TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_SCHEMA, TABLE_NAME" - ) - else: - cursor.execute( - "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " - "WHERE TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_SCHEMA, TABLE_NAME" - ) + cursor.execute(f"USE [{database}]") + return cursor + + def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: + """Get list of tables with schema from SQL Server.""" + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE = 'BASE TABLE' ORDER BY TABLE_SCHEMA, TABLE_NAME" + ) return [(row[0], row[1]) for row in cursor.fetchall()] def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: """Get list of views with schema from SQL Server.""" - cursor = conn.cursor() - if database: - cursor.execute( - f"SELECT TABLE_SCHEMA, TABLE_NAME FROM [{database}].INFORMATION_SCHEMA.VIEWS " - f"ORDER BY TABLE_SCHEMA, TABLE_NAME" - ) - else: - cursor.execute( - "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.VIEWS " "ORDER BY TABLE_SCHEMA, TABLE_NAME" - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.VIEWS " + "ORDER BY TABLE_SCHEMA, TABLE_NAME" + ) return [(row[0], row[1]) for row in cursor.fetchall()] def get_columns( self, conn: Any, table: str, database: str | None = None, schema: str | None = None ) -> list[ColumnInfo]: """Get columns for a table from SQL Server.""" - cursor = conn.cursor() + cursor = self._get_cursor_for_database(conn, database) schema = schema or "dbo" # Get primary key columns - if database: - cursor.execute( - f"SELECT kcu.COLUMN_NAME " - f"FROM [{database}].INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc " - f"JOIN [{database}].INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu " - f" ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME " - f" AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA " - f"WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' " - f"AND tc.TABLE_SCHEMA = ? AND tc.TABLE_NAME = ?", - (schema, table), - ) - else: - cursor.execute( - "SELECT kcu.COLUMN_NAME " - "FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc " - "JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu " - " ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME " - " AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA " - "WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' " - "AND tc.TABLE_SCHEMA = ? AND tc.TABLE_NAME = ?", - (schema, table), - ) + cursor.execute( + "SELECT kcu.COLUMN_NAME " + "FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc " + "JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu " + " ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME " + " AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA " + "WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' " + "AND tc.TABLE_SCHEMA = ? AND tc.TABLE_NAME = ?", + (schema, table), + ) pk_columns = {row[0] for row in cursor.fetchall()} # Get all columns - if database: - cursor.execute( - f"SELECT COLUMN_NAME, DATA_TYPE FROM [{database}].INFORMATION_SCHEMA.COLUMNS " - f"WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION", - (schema, table), - ) - else: - cursor.execute( - "SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " - "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION", - (schema, table), - ) + cursor.execute( + "SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION", + (schema, table), + ) return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Get stored procedures from SQL Server.""" - cursor = conn.cursor() - if database: - cursor.execute( - f"SELECT ROUTINE_NAME FROM [{database}].INFORMATION_SCHEMA.ROUTINES " - f"WHERE ROUTINE_TYPE = 'PROCEDURE' ORDER BY ROUTINE_NAME" - ) - else: - cursor.execute( - "SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES " - "WHERE ROUTINE_TYPE = 'PROCEDURE' ORDER BY ROUTINE_NAME" - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES " + "WHERE ROUTINE_TYPE = 'PROCEDURE' ORDER BY ROUTINE_NAME" + ) return [row[0] for row in cursor.fetchall()] def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: """Get indexes from SQL Server.""" - cursor = conn.cursor() - if database: - cursor.execute( - f"SELECT i.name, t.name, i.is_unique " - f"FROM [{database}].sys.indexes i " - f"JOIN [{database}].sys.tables t ON i.object_id = t.object_id " - f"WHERE i.name IS NOT NULL AND i.type > 0 AND i.is_primary_key = 0 " - f"ORDER BY t.name, i.name" - ) - else: - cursor.execute( - "SELECT i.name, t.name, i.is_unique " - "FROM sys.indexes i " - "JOIN sys.tables t ON i.object_id = t.object_id " - "WHERE i.name IS NOT NULL AND i.type > 0 AND i.is_primary_key = 0 " - "ORDER BY t.name, i.name" - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT i.name, t.name, i.is_unique " + "FROM sys.indexes i " + "JOIN sys.tables t ON i.object_id = t.object_id " + "WHERE i.name IS NOT NULL AND i.type > 0 AND i.is_primary_key = 0 " + "ORDER BY t.name, i.name" + ) return [IndexInfo(name=row[0], table_name=row[1], is_unique=row[2]) for row in cursor.fetchall()] def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: """Get triggers from SQL Server.""" - cursor = conn.cursor() - if database: - cursor.execute( - f"SELECT tr.name, OBJECT_NAME(tr.parent_id, DB_ID('{database}')) " - f"FROM [{database}].sys.triggers tr " - f"WHERE tr.is_ms_shipped = 0 AND tr.parent_id > 0 " - f"ORDER BY OBJECT_NAME(tr.parent_id, DB_ID('{database}')), tr.name" - ) - else: - cursor.execute( - "SELECT tr.name, OBJECT_NAME(tr.parent_id) " - "FROM sys.triggers tr " - "WHERE tr.is_ms_shipped = 0 AND tr.parent_id > 0 " - "ORDER BY OBJECT_NAME(tr.parent_id), tr.name" - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT tr.name, OBJECT_NAME(tr.parent_id) " + "FROM sys.triggers tr " + "WHERE tr.is_ms_shipped = 0 AND tr.parent_id > 0 " + "ORDER BY OBJECT_NAME(tr.parent_id), tr.name" + ) return [TriggerInfo(name=row[0], table_name=row[1] or "") for row in cursor.fetchall()] def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: """Get sequences from SQL Server (2012+).""" - cursor = conn.cursor() - if database: - cursor.execute(f"SELECT name FROM [{database}].sys.sequences ORDER BY name") - else: - cursor.execute("SELECT name FROM sys.sequences ORDER BY name") + cursor = self._get_cursor_for_database(conn, database) + cursor.execute("SELECT name FROM sys.sequences ORDER BY name") return [SequenceInfo(name=row[0]) for row in cursor.fetchall()] def get_index_definition( self, conn: Any, index_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: """Get detailed information about a SQL Server index.""" - cursor = conn.cursor() - # Get index info and columns - if database: - cursor.execute( - f"SELECT i.is_unique, i.type_desc, c.name " - f"FROM [{database}].sys.indexes i " - f"JOIN [{database}].sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id " - f"JOIN [{database}].sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id " - f"JOIN [{database}].sys.tables t ON i.object_id = t.object_id " - f"WHERE i.name = ? AND t.name = ? " - f"ORDER BY ic.key_ordinal", - (index_name, table_name), - ) - else: - cursor.execute( - "SELECT i.is_unique, i.type_desc, c.name " - "FROM sys.indexes i " - "JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id " - "JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id " - "JOIN sys.tables t ON i.object_id = t.object_id " - "WHERE i.name = ? AND t.name = ? " - "ORDER BY ic.key_ordinal", - (index_name, table_name), - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT i.is_unique, i.type_desc, c.name " + "FROM sys.indexes i " + "JOIN sys.index_columns ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id " + "JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id " + "JOIN sys.tables t ON i.object_id = t.object_id " + "WHERE i.name = ? AND t.name = ? " + "ORDER BY ic.key_ordinal", + (index_name, table_name), + ) rows = cursor.fetchall() is_unique = rows[0][0] if rows else False index_type = rows[0][1] if rows else "NONCLUSTERED" @@ -355,28 +317,16 @@ def get_trigger_definition( self, conn: Any, trigger_name: str, table_name: str, database: str | None = None ) -> dict[str, Any]: """Get detailed information about a SQL Server trigger.""" - cursor = conn.cursor() - # Get trigger definition using OBJECT_DEFINITION - if database: - cursor.execute( - f"SELECT OBJECT_DEFINITION(tr.object_id), " - f" CASE WHEN tr.is_instead_of_trigger = 1 THEN 'INSTEAD OF' " - f" ELSE 'AFTER' END as timing " - f"FROM [{database}].sys.triggers tr " - f"JOIN [{database}].sys.tables t ON tr.parent_id = t.object_id " - f"WHERE tr.name = ? AND t.name = ?", - (trigger_name, table_name), - ) - else: - cursor.execute( - "SELECT OBJECT_DEFINITION(tr.object_id), " - " CASE WHEN tr.is_instead_of_trigger = 1 THEN 'INSTEAD OF' " - " ELSE 'AFTER' END as timing " - "FROM sys.triggers tr " - "JOIN sys.tables t ON tr.parent_id = t.object_id " - "WHERE tr.name = ? AND t.name = ?", - (trigger_name, table_name), - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT OBJECT_DEFINITION(tr.object_id), " + " CASE WHEN tr.is_instead_of_trigger = 1 THEN 'INSTEAD OF' " + " ELSE 'AFTER' END as timing " + "FROM sys.triggers tr " + "JOIN sys.tables t ON tr.parent_id = t.object_id " + "WHERE tr.name = ? AND t.name = ?", + (trigger_name, table_name), + ) row = cursor.fetchone() if row: definition = row[0] @@ -412,22 +362,13 @@ def get_sequence_definition( self, conn: Any, sequence_name: str, database: str | None = None ) -> dict[str, Any]: """Get detailed information about a SQL Server sequence.""" - cursor = conn.cursor() - # Cast sql_variant columns to BIGINT to avoid driver type conversion errors - if database: - cursor.execute( - f"SELECT CAST(start_value AS BIGINT), CAST(increment AS BIGINT), " - f"CAST(minimum_value AS BIGINT), CAST(maximum_value AS BIGINT), is_cycling " - f"FROM [{database}].sys.sequences WHERE name = ?", - (sequence_name,), - ) - else: - cursor.execute( - "SELECT CAST(start_value AS BIGINT), CAST(increment AS BIGINT), " - "CAST(minimum_value AS BIGINT), CAST(maximum_value AS BIGINT), is_cycling " - "FROM sys.sequences WHERE name = ?", - (sequence_name,), - ) + cursor = self._get_cursor_for_database(conn, database) + cursor.execute( + "SELECT CAST(start_value AS BIGINT), CAST(increment AS BIGINT), " + "CAST(minimum_value AS BIGINT), CAST(maximum_value AS BIGINT), is_cycling " + "FROM sys.sequences WHERE name = ?", + (sequence_name,), + ) row = cursor.fetchone() if row: return { @@ -456,10 +397,13 @@ def quote_identifier(self, name: str) -> str: return f"[{escaped}]" def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: - """Build SELECT TOP query for SQL Server.""" + """Build SELECT TOP query for SQL Server. + + Note: Does not include database prefix as Azure SQL Database doesn't + support cross-database references. The caller should ensure the + connection is to the correct database. + """ schema = schema or "dbo" - if database: - return f"SELECT TOP {limit} * FROM [{database}].[{schema}].[{table}]" return f"SELECT TOP {limit} * FROM [{schema}].[{table}]" def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: diff --git a/sqlit/db/adapters/snowflake.py b/sqlit/db/adapters/snowflake.py index ca96657e..61eda6b2 100644 --- a/sqlit/db/adapters/snowflake.py +++ b/sqlit/db/adapters/snowflake.py @@ -37,6 +37,10 @@ def driver_import_names(self) -> tuple[str, ...]: def supports_multiple_databases(self) -> bool: return True + @property + def system_databases(self) -> frozenset[str]: + return frozenset({"SNOWFLAKE", "snowflake"}) + @property def supports_stored_procedures(self) -> bool: return True diff --git a/sqlit/db/providers.py b/sqlit/db/providers.py index 14dc6c5a..d94ff2aa 100644 --- a/sqlit/db/providers.py +++ b/sqlit/db/providers.py @@ -262,3 +262,30 @@ def requires_auth(db_type: str) -> bool: """Check if this database type requires authentication.""" spec = PROVIDERS.get(db_type) return spec.schema.requires_auth if spec else True + + +def requires_database_selection(db_type: str) -> bool: + """Check if this database type requires a database to be specified. + + Returns True for databases that don't support cross-database queries + (e.g., PostgreSQL, CockroachDB, D1) where each database is isolated. + """ + try: + adapter = get_adapter_class(db_type)() + return not adapter.supports_cross_database_queries + except (ValueError, ImportError): + return False + + +def validate_database_required(db_type: str, database: str | None) -> None: + """Validate that a database is specified when required. + + Raises ValueError if the database type requires a database selection + but none is provided. + """ + if requires_database_selection(db_type) and not database: + display_name = get_display_name(db_type) + raise ValueError( + f"{display_name} requires a database to be specified. " + f"Each database is isolated and cross-database queries are not supported." + ) diff --git a/sqlit/idle_scheduler.py b/sqlit/idle_scheduler.py new file mode 100644 index 00000000..c18507f0 --- /dev/null +++ b/sqlit/idle_scheduler.py @@ -0,0 +1,310 @@ +"""Idle scheduler - execute work when the user isn't interacting. + +Inspired by browser's requestIdleCallback API. Queues up work and executes it +during idle periods to avoid UI hiccups. +""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any, Callable, Coroutine, TYPE_CHECKING + +if TYPE_CHECKING: + from textual.app import App + + +class Priority(Enum): + """Job priority levels.""" + LOW = auto() # Can wait indefinitely + NORMAL = auto() # Should run soon-ish + HIGH = auto() # Run at next idle opportunity + + +@dataclass +class IdleJob: + """A unit of work to execute during idle time.""" + callback: Callable[[], Any] | Callable[[], Coroutine[Any, Any, Any]] + priority: Priority = Priority.NORMAL + is_async: bool = False + name: str = "" + created_at: float = field(default_factory=time.time) + + def __lt__(self, other: "IdleJob") -> bool: + # Higher priority first, then older jobs first + if self.priority != other.priority: + return self.priority.value > other.priority.value + return self.created_at < other.created_at + + +class IdleScheduler: + """Schedules work to run during user idle periods. + + Usage: + scheduler = IdleScheduler(app) + scheduler.start() + + # Queue work + scheduler.request_idle_callback(lambda: print("doing work")) + + # Or with async + scheduler.request_idle_callback(async_func, is_async=True) + + # Call this on any user interaction + scheduler.on_user_activity() + """ + + def __init__( + self, + app: "App[Any]", + idle_threshold_ms: float = 500, # Consider idle after 500ms of no activity + max_work_chunk_ms: float = 16, # Max time to work before checking for activity (~1 frame) + check_interval_ms: float = 150, # How often to check if we should work + max_queue_size: int = 1000, # Prevent unbounded growth + ) -> None: + self.app = app + self.idle_threshold_ms = idle_threshold_ms + self.max_work_chunk_ms = max_work_chunk_ms + self.check_interval_ms = check_interval_ms + self.max_queue_size = max_queue_size + + self._queue: deque[IdleJob] = deque() + self._last_activity_time: float = time.time() + self._running = False + self._timer: Any = None + self._paused = False + + # Stats for debugging + self._jobs_completed = 0 + self._jobs_dropped = 0 + self._total_work_time_ms = 0 + + @property + def is_idle(self) -> bool: + """Check if user is considered idle.""" + elapsed_ms = (time.time() - self._last_activity_time) * 1000 + return elapsed_ms >= self.idle_threshold_ms + + @property + def time_until_idle_ms(self) -> float: + """Time remaining until user is considered idle.""" + elapsed_ms = (time.time() - self._last_activity_time) * 1000 + return max(0, self.idle_threshold_ms - elapsed_ms) + + @property + def pending_jobs(self) -> int: + """Number of jobs waiting to be executed.""" + return len(self._queue) + + def on_user_activity(self) -> None: + """Call this whenever the user interacts with the app. + + Should be hooked into key presses, mouse events, etc. + """ + self._last_activity_time = time.time() + + def request_idle_callback( + self, + callback: Callable[[], Any] | Callable[[], Coroutine[Any, Any, Any]], + priority: Priority = Priority.NORMAL, + is_async: bool = False, + name: str = "", + ) -> bool: + """Queue a callback to run during idle time. + + Args: + callback: Function to execute (sync or async) + priority: Job priority (HIGH runs first) + is_async: True if callback is an async function + name: Optional name for debugging + + Returns: + True if queued, False if queue is full + """ + if len(self._queue) >= self.max_queue_size: + self._jobs_dropped += 1 + return False + + job = IdleJob( + callback=callback, + priority=priority, + is_async=is_async, + name=name, + ) + + # Insert maintaining priority order + # For simplicity, just append and sort when executing + self._queue.append(job) + return True + + def cancel_all(self, name: str | None = None) -> int: + """Cancel pending jobs. + + Args: + name: If provided, only cancel jobs with this name. + If None, cancel all jobs. + + Returns: + Number of jobs cancelled + """ + if name is None: + count = len(self._queue) + self._queue.clear() + return count + + original_len = len(self._queue) + self._queue = deque(job for job in self._queue if job.name != name) + return original_len - len(self._queue) + + def start(self) -> None: + """Start the idle scheduler.""" + if self._running: + return + self._running = True + self._schedule_check() + + def stop(self) -> None: + """Stop the idle scheduler.""" + self._running = False + if self._timer: + self._timer.stop() + self._timer = None + + def pause(self) -> None: + """Temporarily pause processing (queue still accepts jobs).""" + self._paused = True + + def resume(self) -> None: + """Resume processing after pause.""" + self._paused = False + + def _schedule_check(self) -> None: + """Schedule the next idle check.""" + if not self._running: + return + + # Use Textual's timer + delay = self.check_interval_ms / 1000 + self._timer = self.app.set_timer(delay, self._check_and_work) + + def _check_and_work(self) -> None: + """Check if idle and do work if so.""" + if not self._running or self._paused: + self._schedule_check() + return + + if not self._queue: + self._schedule_check() + return + + if not self.is_idle: + # User is active, wait until they're idle + self._schedule_check() + return + + # We're idle! Do some work + self._do_work_chunk() + + # Schedule next check + self._schedule_check() + + def _do_work_chunk(self) -> None: + """Execute jobs for up to max_work_chunk_ms.""" + start_time = time.time() + max_time = self.max_work_chunk_ms / 1000 + + # Sort queue by priority (do this lazily) + sorted_jobs = sorted(self._queue) + self._queue = deque(sorted_jobs) + + while self._queue: + # Check if we've exceeded our time budget + elapsed = time.time() - start_time + if elapsed >= max_time: + break + + # Check if user became active + if not self.is_idle: + break + + # Execute next job + job = self._queue.popleft() + try: + if job.is_async: + # Schedule async job to run + self.app.call_later(self._run_async_job, job) + else: + job.callback() + self._jobs_completed += 1 + except Exception as e: + # Log but don't crash + self.app.log.error(f"IdleScheduler job failed: {job.name or 'unnamed'}: {e}") + + # Track stats + self._total_work_time_ms += (time.time() - start_time) * 1000 + + # Refresh status bar if debug mode is on + if hasattr(self.app, "_debug_idle_scheduler") and self.app._debug_idle_scheduler: + if hasattr(self.app, "_update_status_bar"): + self.app._update_status_bar() + + async def _run_async_job(self, job: IdleJob) -> None: + """Run an async job.""" + try: + await job.callback() # type: ignore + self._jobs_completed += 1 + except Exception as e: + self.app.log.error(f"IdleScheduler async job failed: {job.name or 'unnamed'}: {e}") + + def get_stats(self) -> dict[str, Any]: + """Get scheduler statistics for debugging.""" + return { + "pending_jobs": len(self._queue), + "jobs_completed": self._jobs_completed, + "jobs_dropped": self._jobs_dropped, + "total_work_time_ms": round(self._total_work_time_ms, 2), + "is_idle": self.is_idle, + "time_until_idle_ms": round(self.time_until_idle_ms, 2), + "is_running": self._running, + "is_paused": self._paused, + } + + +# Convenience function for simple usage +_global_scheduler: IdleScheduler | None = None + + +def get_idle_scheduler() -> IdleScheduler | None: + """Get the global idle scheduler instance.""" + return _global_scheduler + + +def init_idle_scheduler(app: "App[Any]", **kwargs: Any) -> IdleScheduler: + """Initialize the global idle scheduler.""" + global _global_scheduler + _global_scheduler = IdleScheduler(app, **kwargs) + return _global_scheduler + + +def request_idle_callback( + callback: Callable[[], Any], + priority: Priority = Priority.NORMAL, + is_async: bool = False, + name: str = "", +) -> bool: + """Queue a callback to run during idle time (uses global scheduler). + + Returns False if no scheduler is initialized or queue is full. + """ + if _global_scheduler is None: + return False + return _global_scheduler.request_idle_callback(callback, priority, is_async, name) + + +def on_user_activity() -> None: + """Signal user activity (uses global scheduler).""" + if _global_scheduler: + _global_scheduler.on_user_activity() diff --git a/sqlit/services/session.py b/sqlit/services/session.py index 39962721..bf862b32 100644 --- a/sqlit/services/session.py +++ b/sqlit/services/session.py @@ -106,6 +106,10 @@ def create( # Get adapter and connect adapter = get_adapter_fn(config.db_type) connection = adapter.connect(connect_config) + try: + adapter.detect_capabilities(connection, config) + except Exception: + pass return cls(connection, adapter, config, tunnel) @@ -160,6 +164,57 @@ def executor(self) -> DatabaseExecutor: self._executor = DatabaseExecutor(self) return self._executor + def switch_database(self, database: str) -> None: + """Switch to a different database without recreating the session. + + This is used for databases like PostgreSQL that don't support + cross-database queries. It closes the current connection and + opens a new one to the specified database, reusing the SSH tunnel. + + Args: + database: The database name to switch to. + + Raises: + RuntimeError: If the session has been closed. + Any database-specific connection errors. + """ + if self._closed: + raise RuntimeError("Cannot switch database on closed session") + + # Create new config with the database + new_config = replace(self._config, database=database) + + # Determine connection config (use tunnel if present) + if self._tunnel: + # Reuse tunnel - get local bind address + local_host, local_port = self._tunnel.local_bind_address + connect_config = replace(new_config, server=local_host, port=str(local_port)) + else: + connect_config = new_config + + # Close old connection (but keep tunnel) + if self._connection is not None: + try: + close_fn = getattr(self._connection, "close", None) + if callable(close_fn): + close_fn() + except Exception: + pass + + # Open new connection + self._connection = self._adapter.connect(connect_config) + self._config = new_config + + @connection.setter + def connection(self, value: Any) -> None: + """Set the raw database connection object.""" + self._connection = value + + @config.setter + def config(self, value: ConnectionConfig) -> None: + """Set the connection configuration.""" + self._config = value + def close(self) -> None: """Close the session and release all resources. diff --git a/sqlit/sql_completion/__init__.py b/sqlit/sql_completion/__init__.py new file mode 100644 index 00000000..609b4a8b --- /dev/null +++ b/sqlit/sql_completion/__init__.py @@ -0,0 +1,118 @@ +"""SQL completion engine. + +Provides intelligent SQL autocompletion with: +- Context-aware suggestions (tables after FROM, columns after SELECT, etc.) +- Alias recognition (FROM users u -> u.id suggests users columns) +- Fuzzy matching +- SQL keywords and common functions +- Statement-specific handling (INSERT, UPDATE, DELETE) +""" + +from .completion import get_completions, get_context +from .core import ( + RESERVED_WORDS, + SQL_FUNCTIONS, + SQL_KEYWORDS, + SQL_OPERATORS, + Suggestion, + SuggestionType, + TableRef, + build_alias_map, + extract_cte_names, + extract_table_refs, + find_context_keyword, + find_current_clause, + find_last_keyword, + fuzzy_match, + get_all_functions, + get_all_keywords, + get_current_word, + is_inside_string, + remove_comments, + remove_string_literals, +) +from .alter_table import ALTER_OPERATIONS, get_alter_table_completions, get_alter_table_context +from .create_index import get_create_index_completions +from .create_table import ( + SQL_CONSTRAINTS, + SQL_DATA_TYPES, + SQL_TABLE_CONSTRAINTS, + get_create_table_completions, + get_create_table_context, +) +from .create_view import get_create_view_completions +from .delete import extract_delete_table_refs, get_delete_context +from .drop import DROP_OBJECTS, get_drop_completions, get_drop_context +from .insert import get_insert_context +from .truncate import get_truncate_completions +from .update import get_update_context + +# Backwards compatibility aliases for private functions +_is_inside_string = is_inside_string +_get_last_token_info = None # Internal to completion.py +_remove_string_literals = remove_string_literals +_remove_comments = remove_comments +_find_context_keyword = find_context_keyword +_find_last_keyword = find_last_keyword +_find_current_clause = find_current_clause +_get_current_word = get_current_word +_build_alias_map = build_alias_map + +__all__ = [ + # Main API + "get_completions", + "get_context", + # Types + "Suggestion", + "SuggestionType", + "TableRef", + # Constants + "SQL_KEYWORDS", + "SQL_FUNCTIONS", + "SQL_OPERATORS", + "SQL_DATA_TYPES", + "SQL_CONSTRAINTS", + "SQL_TABLE_CONSTRAINTS", + "ALTER_OPERATIONS", + "DROP_OBJECTS", + "RESERVED_WORDS", + # Utilities + "fuzzy_match", + "extract_table_refs", + "extract_cte_names", + "get_all_keywords", + "get_all_functions", + # Statement-specific - DML + "get_insert_context", + "get_update_context", + "get_delete_context", + "extract_delete_table_refs", + # Statement-specific - DDL + "get_create_table_context", + "get_create_table_completions", + "get_alter_table_context", + "get_alter_table_completions", + "get_drop_context", + "get_drop_completions", + "get_create_index_completions", + "get_create_view_completions", + "get_truncate_completions", + # Helper functions (public) + "is_inside_string", + "remove_string_literals", + "remove_comments", + "find_context_keyword", + "find_last_keyword", + "find_current_clause", + "get_current_word", + "build_alias_map", + # Backwards compatibility (private names) + "_is_inside_string", + "_remove_string_literals", + "_remove_comments", + "_find_context_keyword", + "_find_last_keyword", + "_find_current_clause", + "_get_current_word", + "_build_alias_map", +] diff --git a/sqlit/sql_completion/alter_table.py b/sqlit/sql_completion/alter_table.py new file mode 100644 index 00000000..fd24c504 --- /dev/null +++ b/sqlit/sql_completion/alter_table.py @@ -0,0 +1,189 @@ +"""ALTER TABLE statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType +from .create_table import SQL_CONSTRAINTS, SQL_DATA_TYPES + +# ALTER TABLE operations +ALTER_OPERATIONS = [ + "ADD", + "ADD COLUMN", + "DROP", + "DROP COLUMN", + "ALTER", + "ALTER COLUMN", + "MODIFY", + "MODIFY COLUMN", + "RENAME", + "RENAME COLUMN", + "RENAME TO", + "ADD CONSTRAINT", + "DROP CONSTRAINT", + "ADD PRIMARY KEY", + "DROP PRIMARY KEY", + "ADD FOREIGN KEY", + "ADD INDEX", + "DROP INDEX", + "ADD UNIQUE", + "SET DEFAULT", + "DROP DEFAULT", + "SET NOT NULL", + "DROP NOT NULL", +] + + +def get_alter_table_context(before_cursor: str) -> list[Suggestion] | None: + """Detect ALTER TABLE-specific context and return suggestions. + + Handles: + - ALTER TABLE → table names + - ALTER TABLE name → operations (ADD, DROP, etc.) + - ALTER TABLE name ADD → column name (user types) then data type + - ALTER TABLE name DROP COLUMN → existing columns + - ALTER TABLE name ALTER COLUMN → existing columns + - ALTER TABLE name ADD FOREIGN KEY ... REFERENCES → tables + + Args: + before_cursor: SQL text before cursor position + + Returns: + List of suggestions if in ALTER TABLE context, None otherwise + """ + # Check for ALTER TABLE pattern + alter_match = re.search(r"\bALTER\s+TABLE\s+", before_cursor, re.IGNORECASE) + if not alter_match: + return None + + after_alter_table = before_cursor[alter_match.end():] + + # If no table name yet, suggest tables + if not after_alter_table.strip() or re.match(r"^\w*$", after_alter_table.strip()): + return [Suggestion(type=SuggestionType.TABLE)] + + # Extract table name + table_match = re.match(r"^(\w+)\s*", after_alter_table) + if not table_match: + return None + + table_name = table_match.group(1) + after_table = after_alter_table[table_match.end():] + + # If nothing after table name, suggest operations + if not after_table.strip() or re.match(r"^\w*$", after_table.strip()): + return [Suggestion(type=SuggestionType.KEYWORD)] # Will return ALTER_OPERATIONS + + # Check for DROP COLUMN → suggest columns + if re.search(r"\bDROP\s+(?:COLUMN\s+)?\w*$", after_table, re.IGNORECASE): + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + # Check for ALTER/MODIFY COLUMN → suggest columns + if re.search(r"\b(?:ALTER|MODIFY)\s+(?:COLUMN\s+)?\w*$", after_table, re.IGNORECASE): + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + # Check for RENAME COLUMN → suggest columns + if re.search(r"\bRENAME\s+(?:COLUMN\s+)?\w*$", after_table, re.IGNORECASE): + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + # Check for ADD COLUMN name → suggest data types + if re.search(r"\bADD\s+(?:COLUMN\s+)?\w+\s+\w*$", after_table, re.IGNORECASE): + return [Suggestion(type=SuggestionType.KEYWORD)] # Will return SQL_DATA_TYPES + + # Check for REFERENCES → suggest tables + if re.search(r"\bREFERENCES\s+\w*$", after_table, re.IGNORECASE): + return [Suggestion(type=SuggestionType.TABLE)] + + # Check for REFERENCES table ( → suggest columns + ref_match = re.search(r"\bREFERENCES\s+(\w+)\s*\(\s*\w*$", after_table, re.IGNORECASE) + if ref_match: + ref_table = ref_match.group(1) + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=ref_table)] + + return None + + +def get_alter_table_completions(before_cursor: str, tables: list[str], columns: dict[str, list[str]]) -> list[str] | None: + """Get completions specific to ALTER TABLE context. + + Args: + before_cursor: SQL text before cursor position + tables: List of available table names + columns: Dict mapping table names to column lists + + Returns: + List of completions, or None if not in ALTER TABLE context + """ + alter_match = re.search(r"\bALTER\s+TABLE\s+", before_cursor, re.IGNORECASE) + if not alter_match: + return None + + after_alter_table = before_cursor[alter_match.end():] + + # If no table name yet or still typing table name, suggest tables + # Only suggest tables if there's no trailing whitespace (user still typing) + if not after_alter_table.strip(): + return tables + # Check if it's just a partial table name (no whitespace after the word) + if re.match(r"^\w+$", after_alter_table): + return tables + + # Extract table name + table_match = re.match(r"^(\w+)\s*", after_alter_table) + if not table_match: + return None + + table_name = table_match.group(1).lower() + after_table = after_alter_table[table_match.end():] + + # Check for DROP COLUMN → suggest columns (must check before generic pattern) + # Pattern requires whitespace after DROP or COLUMN to avoid matching "DROP" as partial + if re.search(r"\bDROP\s+COLUMN\s+\w*$", after_table, re.IGNORECASE): + if table_name in columns: + return columns[table_name] + return [] + + # DROP without COLUMN - only match if there's whitespace after DROP + if re.search(r"\bDROP\s+(?!COLUMN)(?!CONSTRAINT)(?!PRIMARY)(?!INDEX)\w*$", after_table, re.IGNORECASE): + if table_name in columns: + return columns[table_name] + return [] + + # If nothing after table name or just typing operation, suggest operations + if not after_table.strip() or re.match(r"^\w*$", after_table.strip()): + return ALTER_OPERATIONS + + # Check for ALTER/MODIFY COLUMN → suggest columns + if re.search(r"\b(?:ALTER|MODIFY)\s+(?:COLUMN\s+)?\w*$", after_table, re.IGNORECASE): + if table_name in columns: + return columns[table_name] + return [] + + # Check for RENAME COLUMN → suggest columns + if re.search(r"\bRENAME\s+(?:COLUMN\s+)?\w*$", after_table, re.IGNORECASE): + if table_name in columns: + return columns[table_name] + return [] + + # Check for ADD COLUMN name → suggest data types + if re.search(r"\bADD\s+(?:COLUMN\s+)?\w+\s+\w*$", after_table, re.IGNORECASE): + return SQL_DATA_TYPES + + # Check for data type followed by space → suggest constraints + if re.search(r"\b(?:" + "|".join(SQL_DATA_TYPES) + r")(?:\s*\([^)]*\))?\s+\w*$", after_table, re.IGNORECASE): + return SQL_CONSTRAINTS + + # Check for REFERENCES → suggest tables + if re.search(r"\bREFERENCES\s+\w*$", after_table, re.IGNORECASE): + return tables + + # Check for REFERENCES table ( → suggest columns + ref_match = re.search(r"\bREFERENCES\s+(\w+)\s*\(\s*\w*$", after_table, re.IGNORECASE) + if ref_match: + ref_table = ref_match.group(1).lower() + if ref_table in columns: + return columns[ref_table] + return [] + + return None diff --git a/sqlit/sql_completion/completion.py b/sqlit/sql_completion/completion.py new file mode 100644 index 00000000..5897126a --- /dev/null +++ b/sqlit/sql_completion/completion.py @@ -0,0 +1,380 @@ +"""Main SQL completion engine. + +Orchestrates context detection and completion generation. +""" + +from __future__ import annotations + +import re + +from .core import ( + SQL_OPERATORS, + Suggestion, + SuggestionType, + build_alias_map, + extract_cte_names, + extract_table_refs, + find_context_keyword, + find_current_clause, + fuzzy_match, + get_all_functions, + get_all_keywords, + get_current_word, + get_last_token_info, + is_inside_string, + remove_comments, + remove_string_literals, +) + +# Special keywords for SELECT clause (before FROM) +SELECT_CLAUSE_KEYWORDS = ["*", "DISTINCT", "TOP", "ALL"] +from .alter_table import get_alter_table_completions, get_alter_table_context +from .create_index import get_create_index_completions +from .create_table import get_create_table_completions, get_create_table_context +from .create_view import get_create_view_completions +from .delete import extract_delete_table_refs, get_delete_context +from .drop import get_drop_completions, get_drop_context +from .insert import get_insert_context +from .truncate import get_truncate_completions +from .update import get_update_context + + +def get_context(sql: str, cursor_pos: int) -> list[Suggestion]: + """Determine what type of suggestions to provide based on cursor position. + + Uses statement-specific handlers and sqlparse for accurate context detection. + + Args: + sql: The full SQL text + cursor_pos: Position of cursor in the text + + Returns: + List of Suggestion objects indicating what to suggest + """ + before_cursor = sql[:cursor_pos] + + # Don't suggest anything if we're inside a string literal + if is_inside_string(before_cursor): + return [] + + # Don't suggest anything after a statement terminator (semicolon) + if before_cursor.rstrip().endswith(";"): + return [] + + # Check for table.column pattern (alias or table prefix) + dot_match = re.search(r"(\w+)\.\w*$", before_cursor) + if dot_match: + prefix = dot_match.group(1) + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=prefix)] + + # Try statement-specific handlers + for handler in [get_insert_context, get_update_context, get_delete_context]: + result = handler(before_cursor) + if result is not None: + return result + + # Use sqlparse to detect operators - only when cursor is after whitespace + if before_cursor and before_cursor[-1] in " \t\n": + token_value, token_type = get_last_token_info(before_cursor.rstrip()) + + if token_type: + # After a comparison operator -> suggest columns/values + if "Comparison" in token_type: + return [Suggestion(type=SuggestionType.COLUMN)] + + # After a Name or ) in WHERE context -> suggest operators + if token_type == "Token.Name" or (token_type == "Token.Punctuation" and token_value == ")"): + clean_sql = remove_string_literals(before_cursor) + clean_sql = remove_comments(clean_sql) + clause = find_current_clause(clean_sql) + if clause in ("where", "having", "on"): + return [Suggestion(type=SuggestionType.OPERATOR)] + + # Fall back to keyword-based context detection + clean_sql = remove_string_literals(before_cursor) + clean_sql = remove_comments(clean_sql) + context_keyword = find_context_keyword(clean_sql) + + if context_keyword in ("from", "join", "inner", "left", "right", "outer", "cross", "full", "into", "update", "table"): + return [Suggestion(type=SuggestionType.TABLE)] + + # DISTINCT should suggest columns (same as SELECT) + if context_keyword == "distinct": + return [Suggestion(type=SuggestionType.COLUMN)] + + # CASE/WHEN/THEN/ELSE expressions should suggest columns + if context_keyword in ("when", "then", "else"): + return [Suggestion(type=SuggestionType.COLUMN)] + + if context_keyword in ("select", "where", "and", "or", "on", "having", "set"): + return [Suggestion(type=SuggestionType.COLUMN)] + + if context_keyword in ("order", "group"): + if re.search(r"\b(ORDER|GROUP)\s+BY\s+\w*$", clean_sql, re.IGNORECASE): + return [Suggestion(type=SuggestionType.COLUMN)] + return [Suggestion(type=SuggestionType.KEYWORD)] + + if context_keyword in ("exec", "execute", "call"): + return [Suggestion(type=SuggestionType.PROCEDURE)] + + if context_keyword == ",": + clause = find_current_clause(clean_sql) + if clause == "select": + return [Suggestion(type=SuggestionType.COLUMN)] + if clause in ("from", "join"): + return [Suggestion(type=SuggestionType.TABLE)] + if clause == "set": + # After comma in SET clause, suggest columns + return [Suggestion(type=SuggestionType.COLUMN)] + return [Suggestion(type=SuggestionType.COLUMN)] + + if context_keyword in ("by",): + return [Suggestion(type=SuggestionType.COLUMN)] + + # Default: suggest keywords + return [Suggestion(type=SuggestionType.KEYWORD)] + + +def get_completions( + sql: str, + cursor_pos: int, + tables: list[str], + columns: dict[str, list[str]], + procedures: list[str] | None = None, + include_keywords: bool = True, + include_functions: bool = True, +) -> list[str]: + """Get completion suggestions for the given SQL and cursor position. + + Args: + sql: The full SQL text + cursor_pos: Position of cursor in the text + tables: List of available table names + columns: Dict mapping table names to column lists + procedures: Optional list of stored procedure names + include_keywords: Whether to include SQL keywords + include_functions: Whether to include SQL functions + + Returns: + List of completion suggestions + """ + before_cursor = sql[:cursor_pos] + current_word = get_current_word(sql, cursor_pos) + + # Don't suggest if inside string literal + if is_inside_string(before_cursor): + return [] + + # Don't suggest if there's no SQL content yet (just whitespace) + if not before_cursor.strip(): + return [] + + # Try DDL-specific handlers first (they return completions directly) + for ddl_handler in [ + get_create_table_completions, + get_alter_table_completions, + get_create_index_completions, + get_create_view_completions, + ]: + result = ddl_handler(before_cursor, tables, columns) + if result is not None: + return fuzzy_match(current_word, result) + + # Try DROP handler (doesn't need columns) + drop_result = get_drop_completions(before_cursor, tables, procedures) + if drop_result is not None: + return fuzzy_match(current_word, drop_result) + + # Try TRUNCATE handler (only needs tables) + truncate_result = get_truncate_completions(before_cursor, tables) + if truncate_result is not None: + return fuzzy_match(current_word, truncate_result) + + # Handle specific patterns that need targeted suggestions + clean_before = remove_string_literals(before_cursor) + clean_before = remove_comments(clean_before) + + # After UNION/INTERSECT/EXCEPT [ALL] → suggest SELECT + if re.search(r"\b(UNION|INTERSECT|EXCEPT)(\s+ALL)?\s+\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["SELECT", "ALL"]) + + # After JOIN table_name [alias] → suggest ON keyword + # But NOT for CROSS JOIN or NATURAL JOIN (they don't use ON) + if re.search(r"\bJOIN\s+\w+(\s+(?:AS\s+)?\w+)?\s+\w*$", clean_before, re.IGNORECASE): + if not re.search(r"\bJOIN\s+\w*$", clean_before, re.IGNORECASE): + # Check if it's CROSS JOIN or NATURAL JOIN + if re.search(r"\b(CROSS|NATURAL)\s+JOIN\b", clean_before, re.IGNORECASE): + # CROSS/NATURAL JOIN don't use ON - suggest common follow-ups + return fuzzy_match(current_word, ["WHERE", "ORDER", "GROUP", "LIMIT", "UNION"]) + else: + return fuzzy_match(current_word, ["ON", "USING"]) + + # CAST(col AS → suggest data types + if re.search(r"\bCAST\s*\([^)]+\s+AS\s+\w*$", clean_before, re.IGNORECASE): + from .create_table import SQL_DATA_TYPES + return fuzzy_match(current_word, SQL_DATA_TYPES) + + # RETURNING clause → suggest columns from the target table + # Works for INSERT, UPDATE, DELETE with RETURNING + returning_match = re.search(r"\bRETURNING\s+(?:\w+\s*,\s*)*\w*$", clean_before, re.IGNORECASE) + if returning_match: + # Extract table from INSERT INTO, UPDATE, or DELETE FROM + table_match = re.search( + r"\b(?:INSERT\s+INTO|UPDATE|DELETE\s+FROM)\s+(\w+)", + clean_before, + re.IGNORECASE, + ) + if table_match: + table_name = table_match.group(1).lower() + if table_name in columns: + return fuzzy_match(current_word, columns[table_name]) + + # Inside function parens → suggest columns + # Includes aggregates, string functions, date functions, etc. + func_match = re.search( + r"\b(COUNT|SUM|AVG|MAX|MIN|COALESCE|NULLIF|ISNULL|IFNULL|NVL|NVL2|" + r"GROUP_CONCAT|STRING_AGG|ARRAY_AGG|CAST|" + r"TRIM|LTRIM|RTRIM|UPPER|LOWER|LENGTH|LEN|SUBSTR|SUBSTRING|REPLACE|" + r"CONCAT|LEFT|RIGHT|LPAD|RPAD|REVERSE|" + r"ABS|ROUND|CEIL|CEILING|FLOOR|SIGN|SQRT|POWER|MOD|" + r"DATE|YEAR|MONTH|DAY|HOUR|MINUTE|SECOND|" + r"TO_CHAR|TO_DATE|TO_NUMBER|FORMAT)\s*\(\s*\w*$", + clean_before, + re.IGNORECASE, + ) + if func_match: + # Build column list from tables in the full SQL + table_refs = extract_table_refs(sql) + table_refs.extend(extract_delete_table_refs(sql)) + result_cols = [] + for ref in table_refs: + table_key = ref.name.lower() + if table_key in columns: + result_cols.extend(columns[table_key]) + if result_cols: + return fuzzy_match(current_word, result_cols) + + # Schema.table prefix → suggest tables after schema name + # Pattern: FROM/JOIN schema. or schema.partial + if re.search(r"\b(FROM|JOIN)\s+\w+\.\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, tables) + + # ANY/ALL/SOME ( → suggest SELECT for subquery + if re.search(r"\b(ANY|ALL|SOME)\s*\(\s*\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["SELECT"]) + + # IN ( or NOT IN ( → suggest SELECT for subquery + if re.search(r"\bNOT\s+IN\s*\(\s*\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["SELECT"]) + if re.search(r"\bIN\s*\(\s*\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["SELECT"]) + + # EXISTS ( or NOT EXISTS ( → suggest SELECT for subquery + if re.search(r"\b(NOT\s+)?EXISTS\s*\(\s*\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["SELECT"]) + + # GROUPING SETS/CUBE/ROLLUP ( → suggest columns + if re.search(r"\b(GROUPING\s+SETS|CUBE|ROLLUP)\s*\(\s*\w*$", clean_before, re.IGNORECASE): + table_refs = extract_table_refs(sql) + result_cols = [] + for ref in table_refs: + table_key = ref.name.lower() + if table_key in columns: + result_cols.extend(columns[table_key]) + if result_cols: + return fuzzy_match(current_word, result_cols) + + # ORDER BY column → suggest ASC, DESC, NULLS + # Pattern: ORDER BY col or ORDER BY col1, col2 + if re.search(r"\bORDER\s+BY\s+(?:\w+\s*,\s*)*\w+\s+\w*$", clean_before, re.IGNORECASE): + # Make sure we're not right after ORDER BY (that should suggest columns) + if not re.search(r"\bORDER\s+BY\s+\w*$", clean_before, re.IGNORECASE): + # Check if we just typed ASC/DESC and need more options + if re.search(r"\b(ASC|DESC)\s+\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["NULLS", ",", "LIMIT", "OFFSET", "FETCH"]) + return fuzzy_match(current_word, ["ASC", "DESC", "NULLS", ",", "LIMIT"]) + + # NULLS → suggest FIRST, LAST + if re.search(r"\bNULLS\s+\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["FIRST", "LAST"]) + + # CASE [column] → suggest WHEN + # Matches "CASE " or "CASE col " but not "CASE WHEN" + if re.search(r"\bCASE\s+(?!WHEN\b)\w*\s*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["WHEN"]) + + # OVER ( → suggest PARTITION BY, ORDER BY + if re.search(r"\bOVER\s*\(\s*\w*$", clean_before, re.IGNORECASE): + return fuzzy_match(current_word, ["PARTITION", "ORDER", "ROWS", "RANGE"]) + + # Fall back to context-based completion + suggestions = get_context(sql, cursor_pos) + if not suggestions: + return [] + + # Build table alias map from all sources + table_refs = extract_table_refs(sql) + table_refs.extend(extract_delete_table_refs(sql)) + alias_map = build_alias_map(table_refs, tables) + + cte_names = extract_cte_names(sql) + + results: list[str] = [] + + for suggestion in suggestions: + if suggestion.type == SuggestionType.TABLE: + results.extend(tables) + results.extend(cte_names) + + elif suggestion.type == SuggestionType.COLUMN: + # Check if we're in SELECT clause (before FROM) to add special keywords + clause = find_current_clause(clean_before) + if clause == "select": + results.extend(SELECT_CLAUSE_KEYWORDS) + + for ref in table_refs: + table_key = ref.name.lower() + if table_key in columns: + results.extend(columns[table_key]) + + # Only add table names if NOT in SELECT clause (tables go after FROM, not SELECT) + if clause != "select": + results.extend(tables) + + if include_functions: + results.extend(get_all_functions()) + + elif suggestion.type == SuggestionType.ALIAS_COLUMN: + scope = suggestion.table_scope + if scope: + scope_lower = scope.lower() + + if scope_lower in alias_map: + table_name = alias_map[scope_lower] + if table_name.lower() in columns: + results.extend(columns[table_name.lower()]) + elif scope_lower in columns: + results.extend(columns[scope_lower]) + + elif suggestion.type == SuggestionType.PROCEDURE: + if procedures: + results.extend(procedures) + + elif suggestion.type == SuggestionType.KEYWORD: + if include_keywords: + results.extend(get_all_keywords()) + if include_functions: + results.extend(get_all_functions()) + + elif suggestion.type == SuggestionType.OPERATOR: + results.extend(SQL_OPERATORS) + + # Remove duplicates while preserving order + seen: set[str] = set() + unique_results: list[str] = [] + for r in results: + if r.lower() not in seen: + seen.add(r.lower()) + unique_results.append(r) + + return fuzzy_match(current_word, unique_results) diff --git a/sqlit/sql_completion/core.py b/sqlit/sql_completion/core.py new file mode 100644 index 00000000..85f47e96 --- /dev/null +++ b/sqlit/sql_completion/core.py @@ -0,0 +1,670 @@ +"""Core SQL completion utilities. + +Shared logic for fuzzy matching, table extraction, keywords, and helper functions. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum, auto +from typing import NamedTuple + + +class SuggestionType(Enum): + """Types of SQL completion suggestions.""" + + TABLE = auto() + COLUMN = auto() + KEYWORD = auto() + FUNCTION = auto() + SCHEMA = auto() + DATABASE = auto() + PROCEDURE = auto() + ALIAS_COLUMN = auto() # Column for a specific table/alias + OPERATOR = auto() # Comparison operators (=, <, >, etc.) + + +class Suggestion(NamedTuple): + """A completion suggestion with type and optional scope.""" + + type: SuggestionType + table_scope: str | None = None # For ALIAS_COLUMN, which table + + +@dataclass +class TableRef: + """A table reference with optional alias.""" + + name: str + alias: str | None = None + schema: str | None = None + + +# SQL Comparison operators and condition keywords +SQL_OPERATORS = [ + "=", + "!=", + "<>", + "<", + ">", + "<=", + ">=", + "IS NULL", + "IS NOT NULL", + "IN", + "NOT IN", + "LIKE", + "NOT LIKE", + "ILIKE", + "NOT ILIKE", + "BETWEEN", + "NOT BETWEEN", + "EXISTS", + "NOT EXISTS", +] + +# SQL Keywords grouped by category +SQL_KEYWORDS = { + "dml": [ + "SELECT", + "FROM", + "WHERE", + "JOIN", + "LEFT", + "RIGHT", + "INNER", + "OUTER", + "CROSS", + "FULL", + "ON", + "AND", + "OR", + "NOT", + "IN", + "EXISTS", + "BETWEEN", + "LIKE", + "IS", + "NULL", + "ORDER", + "BY", + "ASC", + "DESC", + "GROUP", + "HAVING", + "LIMIT", + "OFFSET", + "TOP", + "DISTINCT", + "AS", + "UNION", + "INTERSECT", + "EXCEPT", + "ALL", + "INSERT", + "INTO", + "VALUES", + "UPDATE", + "SET", + "DELETE", + "MERGE", + "USING", + "MATCHED", + ], + "ddl": [ + "CREATE", + "ALTER", + "DROP", + "TRUNCATE", + "INDEX", + "VIEW", + "TABLE", + "DATABASE", + "SCHEMA", + "CONSTRAINT", + "PRIMARY", + "KEY", + "FOREIGN", + "REFERENCES", + "UNIQUE", + "CHECK", + "DEFAULT", + ], + "control": [ + "CASE", + "WHEN", + "THEN", + "ELSE", + "END", + "IF", + "BEGIN", + "COMMIT", + "ROLLBACK", + "TRANSACTION", + ], + "types": [ + "INT", + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "REAL", + "DOUBLE", + "VARCHAR", + "CHAR", + "TEXT", + "NVARCHAR", + "NCHAR", + "DATE", + "TIME", + "DATETIME", + "TIMESTAMP", + "BOOLEAN", + "BIT", + "BLOB", + "CLOB", + "UUID", + "JSON", + "XML", + ], +} + +# Common SQL functions +SQL_FUNCTIONS = { + "aggregate": [ + "COUNT", + "SUM", + "AVG", + "MIN", + "MAX", + "GROUP_CONCAT", + "STRING_AGG", + "ARRAY_AGG", + "LISTAGG", + ], + "string": [ + "CONCAT", + "SUBSTRING", + "SUBSTR", + "LEFT", + "RIGHT", + "TRIM", + "LTRIM", + "RTRIM", + "UPPER", + "LOWER", + "LENGTH", + "LEN", + "CHARINDEX", + "POSITION", + "REPLACE", + "REVERSE", + "SPLIT_PART", + "STUFF", + ], + "numeric": [ + "ABS", + "ROUND", + "FLOOR", + "CEILING", + "CEIL", + "POWER", + "SQRT", + "MOD", + "SIGN", + "RAND", + "RANDOM", + ], + "datetime": [ + "NOW", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "GETDATE", + "GETUTCDATE", + "SYSDATETIME", + "DATEADD", + "DATEDIFF", + "DATEPART", + "YEAR", + "MONTH", + "DAY", + "HOUR", + "MINUTE", + "SECOND", + "EXTRACT", + "DATE_TRUNC", + "TO_DATE", + "TO_CHAR", + "FORMAT", + ], + "conversion": [ + "CAST", + "CONVERT", + "TRY_CAST", + "TRY_CONVERT", + "PARSE", + "TRY_PARSE", + ], + "null_handling": [ + "COALESCE", + "NULLIF", + "ISNULL", + "IFNULL", + "NVL", + "NVL2", + ], + "conditional": [ + "IIF", + "CHOOSE", + "DECODE", + ], + "window": [ + "ROW_NUMBER", + "RANK", + "DENSE_RANK", + "NTILE", + "LAG", + "LEAD", + "FIRST_VALUE", + "LAST_VALUE", + "OVER", + "PARTITION", + ], +} + +# Reserved words that cannot be aliases +RESERVED_WORDS = { + "select", + "from", + "where", + "join", + "inner", + "outer", + "left", + "right", + "cross", + "full", + "on", + "and", + "or", + "not", + "in", + "as", + "order", + "by", + "group", + "having", + "union", + "intersect", + "except", + "limit", + "offset", + "insert", + "into", + "values", + "update", + "set", + "delete", + "create", + "alter", + "drop", + "table", + "index", + "view", + "case", + "when", + "then", + "else", + "end", + "null", + "is", + "like", + "between", + "exists", + "distinct", + "all", + "top", + "with", + "asc", + "desc", + "natural", + "using", +} + + +def get_all_keywords() -> list[str]: + """Get all SQL keywords as a flat list.""" + keywords = [] + for category in SQL_KEYWORDS.values(): + keywords.extend(category) + return list(set(keywords)) + + +def get_all_functions() -> list[str]: + """Get all SQL functions as a flat list.""" + functions = [] + for category in SQL_FUNCTIONS.values(): + functions.extend(category) + return list(set(functions)) + + +def fuzzy_match(text: str, candidates: list[str], max_results: int = 50) -> list[str]: + """Fuzzy match text against candidates. + + Matches if all characters in text appear in candidate in order. + E.g., 'djmi' matches 'django_migrations' + + Args: + text: The text to match + candidates: List of candidate strings + max_results: Maximum number of results to return + + Returns: + List of matching candidates, sorted by match quality + """ + if not text: + return candidates[:max_results] + + text_lower = text.lower() + results: list[tuple[int, int, str]] = [] + + for candidate in candidates: + c_lower = candidate.lower() + + # First check prefix match (higher priority) + if c_lower.startswith(text_lower): + # Score: 0 for exact prefix, length for sorting + results.append((0, len(candidate), candidate)) + continue + + # Fuzzy match: all chars must appear in order + idx = 0 + matched = True + first_match_pos = -1 + + for char in text_lower: + idx = c_lower.find(char, idx) + if idx == -1: + matched = False + break + if first_match_pos == -1: + first_match_pos = idx + idx += 1 + + if matched: + # Score: 1 for fuzzy, then by first match position, then length + results.append((1, first_match_pos * 100 + len(candidate), candidate)) + + # Sort by score tuple and return candidates + results.sort(key=lambda x: (x[0], x[1])) + return [r[2] for r in results[:max_results]] + + +def extract_table_refs(sql: str) -> list[TableRef]: + """Extract table references and aliases from SQL. + + Handles patterns like: + - FROM users + - FROM users u + - FROM users AS u + - JOIN orders o ON ... + - FROM schema.users u + - FROM "quoted_table" (PostgreSQL) + - FROM [bracketed_table] (SQL Server) + - FROM `backtick_table` (MySQL) + - UPDATE users u SET ... + - DELETE FROM users u WHERE ... + + Args: + sql: The SQL text to parse + + Returns: + List of TableRef objects with name, alias, and optional schema + """ + refs: list[TableRef] = [] + + # Pattern to match quoted identifiers: "name", [name], `name`, or unquoted name + ident = r'(?:"([^"]+)"|`([^`]+)`|\[([^\]]+)\]|(\w+))' + + # Pattern for FROM/JOIN + from_join_pattern = ( + r"(?:FROM|JOIN)\s+" + + r"(?:" + ident + r"\.)?" # optional schema + + ident # table name (required) + + r"(?:\s+(?:AS\s+)?(\w+))?" # optional alias + ) + + for match in re.finditer(from_join_pattern, sql, re.IGNORECASE): + groups = match.groups() + schema = next((g for g in groups[0:4] if g is not None), None) + table = next((g for g in groups[4:8] if g is not None), None) + alias = groups[8] + + if alias and alias.lower() in RESERVED_WORDS: + alias = None + + if table: + refs.append(TableRef(name=table, alias=alias, schema=schema)) + + # Pattern for UPDATE table [alias] SET + update_pattern = ( + r"\bUPDATE\s+" + + r"(?:" + ident + r"\.)?" # optional schema + + ident # table name (required) + + r"(?:\s+(?:AS\s+)?(\w+))?" # optional alias + + r"(?=\s+SET\b)" # followed by SET (lookahead) + ) + + for match in re.finditer(update_pattern, sql, re.IGNORECASE): + groups = match.groups() + schema = next((g for g in groups[0:4] if g is not None), None) + table = next((g for g in groups[4:8] if g is not None), None) + alias = groups[8] + + if alias and alias.lower() in RESERVED_WORDS: + alias = None + + if table: + refs.append(TableRef(name=table, alias=alias, schema=schema)) + + return refs + + +def extract_cte_names(sql: str) -> list[str]: + """Extract CTE (Common Table Expression) names from WITH clause. + + Args: + sql: The SQL text to parse + + Returns: + List of CTE names + """ + ctes: list[str] = [] + + pattern = r"\bWITH\s+(.+?)(?=\s+SELECT\b)" + + match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL) + if match: + with_clause = match.group(1) + cte_pattern = r"(\w+)\s+AS\s*\(" + for cte_match in re.finditer(cte_pattern, with_clause, re.IGNORECASE): + ctes.append(cte_match.group(1)) + + return ctes + + +def is_inside_string(sql: str) -> bool: + """Check if the cursor position is inside an unclosed string literal. + + Args: + sql: The SQL text up to cursor position + + Returns: + True if inside a string literal, False otherwise + """ + in_single_quote = False + in_double_quote = False + i = 0 + + while i < len(sql): + char = sql[i] + + if char == "'" and not in_double_quote: + if i + 1 < len(sql) and sql[i + 1] == "'": + i += 2 + continue + in_single_quote = not in_single_quote + elif char == '"' and not in_single_quote: + if i + 1 < len(sql) and sql[i + 1] == '"': + i += 2 + continue + in_double_quote = not in_double_quote + + i += 1 + + return in_single_quote or in_double_quote + + +def get_last_token_info(sql: str) -> tuple[str | None, str | None]: + """Get the last meaningful token and its type using sqlparse. + + Args: + sql: The SQL text to analyze + + Returns: + Tuple of (token_value, token_type_string) + """ + try: + import sqlparse + + parsed = sqlparse.parse(sql) + if not parsed: + return None, None + + tokens = [t for t in parsed[0].flatten() if not t.is_whitespace] + if not tokens: + return None, None + + last = tokens[-1] + ttype = str(last.ttype) if last.ttype else None + return last.value, ttype + except Exception: + return None, None + + +def remove_string_literals(sql: str) -> str: + """Remove string literals from SQL to avoid false matches.""" + result = re.sub(r"'[^']*'", "''", sql) + result = re.sub(r'"[^"]*"', '""', result) + return result + + +def remove_comments(sql: str) -> str: + """Remove SQL comments.""" + result = re.sub(r"--[^\n]*", "", sql) + result = re.sub(r"/\*.*?\*/", "", result, flags=re.DOTALL) + return result + + +def find_context_keyword(sql: str) -> str: + """Find the SQL keyword that provides context for completion. + + This looks for the keyword BEFORE the current partial word being typed. + """ + original_sql = sql + sql = sql.rstrip() + + if sql.endswith(","): + return "," + + ends_with_space = len(original_sql) > len(sql) or (sql and sql[-1] in ",()") + + tokens = re.findall(r"\w+|,", sql) + + if not tokens: + return "" + + if ends_with_space or len(tokens) == 1: + return tokens[-1].lower() + else: + if len(tokens) >= 2: + return tokens[-2].lower() + return tokens[-1].lower() + + +def find_last_keyword(sql: str) -> str: + """Find the last significant SQL keyword or punctuation.""" + sql = sql.rstrip() + + if sql.endswith(","): + return "," + + match = re.search(r"(\w+|\,)\s*$", sql) + if match: + word = match.group(1).lower() + return word + + return "" + + +def find_current_clause(sql: str) -> str: + """Determine which clause the cursor is in. + + Looks for the most recent main SQL clause keyword. + """ + sql_upper = sql.upper() + + clauses = ["SELECT", "FROM", "WHERE", "GROUP BY", "HAVING", "ORDER BY", "ON", "SET"] + join_pattern = r"\b(INNER\s+JOIN|LEFT\s+JOIN|RIGHT\s+JOIN|FULL\s+JOIN|CROSS\s+JOIN|JOIN)\b" + + last_clause = "" + last_pos = -1 + + for clause in clauses: + pattern = r"\b" + clause + r"\b" + for match in re.finditer(pattern, sql_upper): + if match.start() > last_pos: + last_pos = match.start() + last_clause = clause.split()[0].lower() + + for match in re.finditer(join_pattern, sql_upper): + if match.start() > last_pos: + last_pos = match.start() + last_clause = "join" + + return last_clause + + +def get_current_word(sql: str, cursor_pos: int) -> str: + """Get the word currently being typed at cursor position.""" + before_cursor = sql[:cursor_pos] + + if "." in before_cursor: + dot_match = re.search(r"\.(\w*)$", before_cursor) + if dot_match: + return dot_match.group(1) + + match = re.search(r"(\w*)$", before_cursor) + if match: + return match.group(1) + return "" + + +def build_alias_map(refs: list[TableRef], known_tables: list[str]) -> dict[str, str]: + """Build a map of alias -> table name. + + Only includes aliases for tables that exist in known_tables. + """ + known_lower = {t.lower() for t in known_tables} + alias_map: dict[str, str] = {} + + for ref in refs: + if ref.alias and ref.name.lower() in known_lower: + alias_map[ref.alias.lower()] = ref.name + + return alias_map diff --git a/sqlit/sql_completion/create_index.py b/sqlit/sql_completion/create_index.py new file mode 100644 index 00000000..3ab81ea0 --- /dev/null +++ b/sqlit/sql_completion/create_index.py @@ -0,0 +1,53 @@ +"""CREATE INDEX statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + + +def get_create_index_completions( + before_cursor: str, tables: list[str], columns: dict[str, list[str]] +) -> list[str] | None: + """Get completions specific to CREATE INDEX context. + + Handles: + - CREATE [UNIQUE] INDEX name ON → tables + - CREATE INDEX name ON table ( → columns + - CREATE INDEX name ON table (col1, → more columns + + Args: + before_cursor: SQL text before cursor position + tables: List of available table names + columns: Dict mapping table names to column lists + + Returns: + List of completions, or None if not in CREATE INDEX context + """ + # Check for CREATE INDEX pattern (with optional UNIQUE) + if not re.search(r"\bCREATE\s+(?:UNIQUE\s+)?INDEX\b", before_cursor, re.IGNORECASE): + return None + + # Check for ON table ( → suggest columns + # Pattern: ON table_name ( with optional columns already listed + table_paren_match = re.search( + r"\bON\s+(\w+)\s*\(\s*(?:[\w\s,]*,\s*)?\w*$", + before_cursor, + re.IGNORECASE, + ) + if table_paren_match: + table_name = table_paren_match.group(1).lower() + if table_name in columns: + return columns[table_name] + return [] + + # Check for ON → suggest tables + if re.search(r"\bON\s+\w*$", before_cursor, re.IGNORECASE): + return tables + + # After CREATE INDEX name, suggest ON keyword + if re.search(r"\bCREATE\s+(?:UNIQUE\s+)?INDEX\s+\w+\s+\w*$", before_cursor, re.IGNORECASE): + return ["ON"] + + return None diff --git a/sqlit/sql_completion/create_table.py b/sqlit/sql_completion/create_table.py new file mode 100644 index 00000000..027fccfc --- /dev/null +++ b/sqlit/sql_completion/create_table.py @@ -0,0 +1,217 @@ +"""CREATE TABLE statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + +# SQL Data types for column definitions +SQL_DATA_TYPES = [ + # Numeric + "INT", + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "REAL", + "DOUBLE", + "DOUBLE PRECISION", + "MONEY", + "SMALLMONEY", + # String + "VARCHAR", + "CHAR", + "TEXT", + "NVARCHAR", + "NCHAR", + "NTEXT", + # Binary + "BINARY", + "VARBINARY", + "BLOB", + "BYTEA", + # Date/Time + "DATE", + "TIME", + "DATETIME", + "DATETIME2", + "DATETIMEOFFSET", + "SMALLDATETIME", + "TIMESTAMP", + "TIMESTAMPTZ", + "INTERVAL", + # Boolean + "BOOLEAN", + "BOOL", + "BIT", + # Other + "UUID", + "UNIQUEIDENTIFIER", + "JSON", + "JSONB", + "XML", + "CLOB", + "SERIAL", + "BIGSERIAL", + "SMALLSERIAL", + "IDENTITY", +] + +# Column constraints +SQL_CONSTRAINTS = [ + "PRIMARY KEY", + "NOT NULL", + "NULL", + "UNIQUE", + "DEFAULT", + "CHECK", + "REFERENCES", + "AUTO_INCREMENT", + "AUTOINCREMENT", + "GENERATED", +] + +# Table-level constraints +SQL_TABLE_CONSTRAINTS = [ + "PRIMARY KEY", + "FOREIGN KEY", + "UNIQUE", + "CHECK", + "CONSTRAINT", + "INDEX", +] + + +class SuggestionTypeExtended: + """Extended suggestion types for DDL.""" + DATA_TYPE = "DATA_TYPE" + CONSTRAINT = "CONSTRAINT" + TABLE_CONSTRAINT = "TABLE_CONSTRAINT" + + +def get_create_table_context(before_cursor: str) -> list[Suggestion] | None: + """Detect CREATE TABLE-specific context and return suggestions. + + Handles: + - CREATE TABLE name ( → nothing (user defines column name) + - CREATE TABLE name (col → data types + - CREATE TABLE name (col TYPE → constraints + - CREATE TABLE name (col TYPE, → nothing (new column name) + - FOREIGN KEY (col) REFERENCES → table names + - REFERENCES table ( → column names + + Args: + before_cursor: SQL text before cursor position + + Returns: + List of suggestions if in CREATE TABLE context, None otherwise + """ + # Check if we're in a CREATE TABLE statement + if not re.search(r"\bCREATE\s+TABLE\b", before_cursor, re.IGNORECASE): + return None + + # Check if we're inside the column definition parentheses + create_match = re.search( + r"\bCREATE\s+TABLE\s+\w+\s*\((.*)$", + before_cursor, + re.IGNORECASE | re.DOTALL, + ) + if not create_match: + return None + + inside_parens = create_match.group(1) + + # Check for REFERENCES table ( → suggest columns + ref_table_match = re.search( + r"\bREFERENCES\s+(\w+)\s*\(\s*\w*$", + inside_parens, + re.IGNORECASE, + ) + if ref_table_match: + table_name = ref_table_match.group(1) + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + # Check for FOREIGN KEY ... REFERENCES → suggest tables + if re.search(r"\bREFERENCES\s+\w*$", inside_parens, re.IGNORECASE): + return [Suggestion(type=SuggestionType.TABLE)] + + # Check if we just typed a data type and need constraints + # Pattern: column_name TYPE (possibly with size) and cursor after space + if re.search(r"\b(?:" + "|".join(SQL_DATA_TYPES) + r")(?:\s*\([^)]*\))?\s+\w*$", inside_parens, re.IGNORECASE): + # After a data type, suggest constraints + return [Suggestion(type=SuggestionType.KEYWORD)] # Will be handled specially + + # Check if we're right after a column name (no type yet) + # Pattern: comma or opening paren, then word, then space + if re.search(r"(?:,|\()\s*\w+\s+\w*$", inside_parens, re.IGNORECASE): + # After column name, suggest data types + return [Suggestion(type=SuggestionType.KEYWORD)] # Will be handled specially + + return None + + +def get_create_table_completions(before_cursor: str, tables: list[str], columns: dict[str, list[str]]) -> list[str] | None: + """Get completions specific to CREATE TABLE context. + + Args: + before_cursor: SQL text before cursor position + tables: List of available table names + columns: Dict mapping table names to column lists + + Returns: + List of completions, or None if not in CREATE TABLE context + """ + # Check if we're in a CREATE TABLE statement + if not re.search(r"\bCREATE\s+TABLE\b", before_cursor, re.IGNORECASE): + return None + + # Check if we're inside the column definition parentheses + create_match = re.search( + r"\bCREATE\s+TABLE\s+\w+\s*\((.*)$", + before_cursor, + re.IGNORECASE | re.DOTALL, + ) + if not create_match: + return None + + inside_parens = create_match.group(1) + + # Check for REFERENCES table ( → suggest columns from that table + ref_table_match = re.search( + r"\bREFERENCES\s+(\w+)\s*\(\s*\w*$", + inside_parens, + re.IGNORECASE, + ) + if ref_table_match: + table_name = ref_table_match.group(1).lower() + if table_name in columns: + return columns[table_name] + return [] + + # Check for FOREIGN KEY ... REFERENCES → suggest tables + if re.search(r"\bREFERENCES\s+\w*$", inside_parens, re.IGNORECASE): + return tables + + # Check if after a data type → suggest constraints + if re.search(r"\b(?:" + "|".join(SQL_DATA_TYPES) + r")(?:\s*\([^)]*\))?\s+\w*$", inside_parens, re.IGNORECASE): + return SQL_CONSTRAINTS + + # Check if right after column name → suggest data types + # Match start of content or after comma, then column name, then space + if re.search(r"(?:^|,)\s*\w+\s+\w*$", inside_parens, re.IGNORECASE): + return SQL_DATA_TYPES + + # At start of new column definition (empty or after comma with space) + if not inside_parens.strip() or re.search(r",\s*$", inside_parens): + # User needs to type column name, no suggestions + return [] + + # Check for table-level constraint context + if re.search(r",\s*(?:PRIMARY|FOREIGN|UNIQUE|CHECK|CONSTRAINT)\s+\w*$", inside_parens, re.IGNORECASE): + return SQL_TABLE_CONSTRAINTS + + return None diff --git a/sqlit/sql_completion/create_view.py b/sqlit/sql_completion/create_view.py new file mode 100644 index 00000000..607a9c25 --- /dev/null +++ b/sqlit/sql_completion/create_view.py @@ -0,0 +1,52 @@ +"""CREATE VIEW statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + + +def get_create_view_completions( + before_cursor: str, tables: list[str], columns: dict[str, list[str]] +) -> list[str] | None: + """Get completions specific to CREATE VIEW context. + + Handles: + - CREATE [OR REPLACE] VIEW name AS → SELECT keyword + - After AS SELECT, delegates to normal SELECT handling + + Args: + before_cursor: SQL text before cursor position + tables: List of available table names + columns: Dict mapping table names to column lists + + Returns: + List of completions, or None if not in CREATE VIEW context + Returns None after AS SELECT to let normal SELECT handling take over + """ + # Check for CREATE VIEW pattern (with optional OR REPLACE) + create_view_match = re.search( + r"\bCREATE\s+(?:OR\s+REPLACE\s+)?VIEW\b", + before_cursor, + re.IGNORECASE, + ) + if not create_view_match: + return None + + after_create_view = before_cursor[create_view_match.end():] + + # Check if we're after AS SELECT - let normal SELECT handling take over + if re.search(r"\bAS\s+SELECT\b", after_create_view, re.IGNORECASE): + return None + + # Check for AS → suggest SELECT + if re.search(r"\bAS\s+\w*$", after_create_view, re.IGNORECASE): + return ["SELECT"] + + # After view name, suggest AS + if re.search(r"^\s+\w+\s+\w*$", after_create_view): + return ["AS"] + + # Still typing view name or haven't started + return None diff --git a/sqlit/sql_completion/delete.py b/sqlit/sql_completion/delete.py new file mode 100644 index 00000000..28e4dc13 --- /dev/null +++ b/sqlit/sql_completion/delete.py @@ -0,0 +1,85 @@ +"""DELETE statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType, RESERVED_WORDS, TableRef + + +def get_delete_context(before_cursor: str) -> list[Suggestion] | None: + """Detect DELETE-specific context and return suggestions. + + Handles: + - DELETE FROM table WHERE → columns for that table (only right after WHERE or AND/OR) + + Args: + before_cursor: SQL text before cursor position + + Returns: + List of suggestions if in DELETE WHERE context, None otherwise + """ + # Don't handle if we're inside a subquery (has unclosed parenthesis after DELETE) + delete_pos = before_cursor.upper().find("DELETE") + if delete_pos != -1: + after_delete = before_cursor[delete_pos:] + open_parens = after_delete.count("(") - after_delete.count(")") + if open_parens > 0: + # Inside a subquery, let normal context detection handle it + return None + + # Pattern: DELETE FROM table [alias] WHERE ... + delete_match = re.search( + r"\bDELETE\s+FROM\s+(\w+)(?:\s+(\w+))?\s+WHERE\b", + before_cursor, + re.IGNORECASE, + ) + if delete_match: + table_name = delete_match.group(1) + # Only suggest columns right after WHERE, AND, or OR (not after a column name) + # This allows operator detection to work after column names + if re.search(r"\b(WHERE|AND|OR)\s+\w*$", before_cursor, re.IGNORECASE): + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + return None + + +def extract_delete_table_refs(sql: str) -> list[TableRef]: + """Extract table references from DELETE statements. + + Handles: + - DELETE FROM users + - DELETE FROM users u WHERE ... + - DELETE u FROM users u JOIN ... (SQL Server) + + Args: + sql: The SQL text to parse + + Returns: + List of TableRef objects + """ + refs: list[TableRef] = [] + + # Pattern: DELETE FROM table [alias] + ident = r'(?:"([^"]+)"|`([^`]+)`|\[([^\]]+)\]|(\w+))' + + delete_from_pattern = ( + r"\bDELETE\s+FROM\s+" + + r"(?:" + ident + r"\.)?" # optional schema + + ident # table name (required) + + r"(?:\s+(?:AS\s+)?(\w+))?" # optional alias + ) + + for match in re.finditer(delete_from_pattern, sql, re.IGNORECASE): + groups = match.groups() + schema = next((g for g in groups[0:4] if g is not None), None) + table = next((g for g in groups[4:8] if g is not None), None) + alias = groups[8] + + if alias and alias.lower() in RESERVED_WORDS: + alias = None + + if table: + refs.append(TableRef(name=table, alias=alias, schema=schema)) + + return refs diff --git a/sqlit/sql_completion/drop.py b/sqlit/sql_completion/drop.py new file mode 100644 index 00000000..a57fa8d2 --- /dev/null +++ b/sqlit/sql_completion/drop.py @@ -0,0 +1,108 @@ +"""DROP statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + +# Objects that can be dropped +DROP_OBJECTS = [ + "TABLE", + "VIEW", + "INDEX", + "DATABASE", + "SCHEMA", + "PROCEDURE", + "FUNCTION", + "TRIGGER", + "SEQUENCE", + "TYPE", + "CONSTRAINT", +] + + +def get_drop_context(before_cursor: str) -> list[Suggestion] | None: + """Detect DROP-specific context and return suggestions. + + Handles: + - DROP → object types (TABLE, VIEW, INDEX, etc.) + - DROP TABLE → table names + - DROP VIEW → view names (we'll suggest tables as we don't track views separately) + - DROP INDEX → index names + - DROP TABLE IF EXISTS → table names + + Args: + before_cursor: SQL text before cursor position + + Returns: + List of suggestions if in DROP context, None otherwise + """ + # Check for DROP pattern + drop_match = re.search(r"\bDROP\s+", before_cursor, re.IGNORECASE) + if not drop_match: + return None + + after_drop = before_cursor[drop_match.end():] + + # If nothing after DROP, suggest object types + if not after_drop.strip() or re.match(r"^\w*$", after_drop.strip()): + return [Suggestion(type=SuggestionType.KEYWORD)] # Will return DROP_OBJECTS + + # Check for DROP TABLE [IF EXISTS] → suggest tables + if re.search(r"\bTABLE\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return [Suggestion(type=SuggestionType.TABLE)] + + # Check for DROP VIEW [IF EXISTS] → suggest tables/views + if re.search(r"\bVIEW\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return [Suggestion(type=SuggestionType.TABLE)] # Views treated as tables + + # Check for DROP INDEX [IF EXISTS] → we don't track indexes, so suggest nothing special + if re.search(r"\bINDEX\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return [Suggestion(type=SuggestionType.KEYWORD)] # Could track indexes in future + + # Check for DROP DATABASE/SCHEMA → suggest nothing (not tracked) + if re.search(r"\b(?:DATABASE|SCHEMA)\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return [Suggestion(type=SuggestionType.KEYWORD)] + + # Check for DROP PROCEDURE/FUNCTION → suggest procedures + if re.search(r"\b(?:PROCEDURE|FUNCTION)\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return [Suggestion(type=SuggestionType.PROCEDURE)] + + return None + + +def get_drop_completions(before_cursor: str, tables: list[str], procedures: list[str] | None = None) -> list[str] | None: + """Get completions specific to DROP context. + + Args: + before_cursor: SQL text before cursor position + tables: List of available table names + procedures: List of stored procedure names + + Returns: + List of completions, or None if not in DROP context + """ + drop_match = re.search(r"\bDROP\s+", before_cursor, re.IGNORECASE) + if not drop_match: + return None + + after_drop = before_cursor[drop_match.end():] + + # Check for DROP TABLE [IF EXISTS] → suggest tables (must check before generic pattern) + if re.search(r"^TABLE\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return tables + + # Check for DROP VIEW [IF EXISTS] → suggest tables (views mixed with tables in our schema) + if re.search(r"^VIEW\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return tables + + # Check for DROP PROCEDURE/FUNCTION → suggest procedures + if re.search(r"^(?:PROCEDURE|FUNCTION)\s+(?:IF\s+EXISTS\s+)?\w*$", after_drop, re.IGNORECASE): + return procedures or [] + + # If nothing after DROP or just typing object type, suggest object types + if not after_drop.strip() or re.match(r"^\w*$", after_drop.strip()): + return DROP_OBJECTS + + return None diff --git a/sqlit/sql_completion/insert.py b/sqlit/sql_completion/insert.py new file mode 100644 index 00000000..2f2c431a --- /dev/null +++ b/sqlit/sql_completion/insert.py @@ -0,0 +1,35 @@ +"""INSERT statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + + +def get_insert_context(before_cursor: str) -> list[Suggestion] | None: + """Detect INSERT-specific context and return suggestions. + + Handles: + - INSERT INTO table ( → columns for that table + - INSERT INTO table (col1, → more columns + + Args: + before_cursor: SQL text before cursor position + + Returns: + List of suggestions if in INSERT context, None otherwise + """ + # Pattern: INSERT INTO table_name ( with optional columns and commas + insert_match = re.search( + r"\bINSERT\s+INTO\s+(\w+)\s*\([^)]*$", + before_cursor, + re.IGNORECASE, + ) + if insert_match: + # Check we're not inside VALUES clause + if not re.search(r"\bVALUES\s*\(", before_cursor, re.IGNORECASE): + table_name = insert_match.group(1) + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + return None diff --git a/sqlit/sql_completion/truncate.py b/sqlit/sql_completion/truncate.py new file mode 100644 index 00000000..bfcace91 --- /dev/null +++ b/sqlit/sql_completion/truncate.py @@ -0,0 +1,41 @@ +"""TRUNCATE TABLE statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + + +def get_truncate_completions( + before_cursor: str, tables: list[str] +) -> list[str] | None: + """Get completions specific to TRUNCATE context. + + Handles: + - TRUNCATE → TABLE keyword or tables directly + - TRUNCATE TABLE → table names + + Args: + before_cursor: SQL text before cursor position + tables: List of available table names + + Returns: + List of completions, or None if not in TRUNCATE context + """ + # Check for TRUNCATE pattern + truncate_match = re.search(r"\bTRUNCATE\s+", before_cursor, re.IGNORECASE) + if not truncate_match: + return None + + after_truncate = before_cursor[truncate_match.end():] + + # Check for TRUNCATE TABLE → suggest tables + if re.search(r"^TABLE\s+\w*$", after_truncate, re.IGNORECASE): + return tables + + # Just TRUNCATE or partial word → suggest TABLE keyword and tables + if not after_truncate.strip() or re.match(r"^\w*$", after_truncate.strip()): + return ["TABLE"] + tables + + return None diff --git a/sqlit/sql_completion/update.py b/sqlit/sql_completion/update.py new file mode 100644 index 00000000..4936cb54 --- /dev/null +++ b/sqlit/sql_completion/update.py @@ -0,0 +1,38 @@ +"""UPDATE statement context detection.""" + +from __future__ import annotations + +import re + +from .core import Suggestion, SuggestionType + + +def get_update_context(before_cursor: str) -> list[Suggestion] | None: + """Detect UPDATE-specific context and return suggestions. + + Handles: + - UPDATE table SET → columns for that table + - UPDATE table SET col = value, → more columns + + Args: + before_cursor: SQL text before cursor position + + Returns: + List of suggestions if in UPDATE SET context, None otherwise + """ + # Pattern: UPDATE table [alias] SET ... (not after WHERE or FROM) + update_set_match = re.search( + r"\bUPDATE\s+(\w+)(?:\s+\w+)?\s+SET\b", + before_cursor, + re.IGNORECASE, + ) + if update_set_match: + # Check we're not in WHERE or FROM clause (SQL Server UPDATE...FROM...JOIN syntax) + if not re.search(r"\b(WHERE|FROM)\b", before_cursor[update_set_match.end():], re.IGNORECASE): + table_name = update_set_match.group(1) + # Check if we're after SET (not just typing "SET") + set_pos = before_cursor.upper().rfind("SET") + if set_pos != -1 and len(before_cursor) > set_pos + 3: + return [Suggestion(type=SuggestionType.ALIAS_COLUMN, table_scope=table_name)] + + return None diff --git a/sqlit/state_machine.py b/sqlit/state_machine.py index c8ea501f..f558fd74 100644 --- a/sqlit/state_machine.py +++ b/sqlit/state_machine.py @@ -607,16 +607,14 @@ class TreeOnDatabaseState(State): help_category = "Explorer" def _setup_actions(self) -> None: - self.allows("use_database", key="u", label="Use as default", help="Set as default database") + pass # Expanding a database now sets it as active automatically def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] seen: set[str] = set() - left.append(DisplayBinding(key="enter", label="Expand", action="toggle_node")) + left.append(DisplayBinding(key="enter", label="Use database", action="toggle_node")) seen.add("toggle_node") - left.append(DisplayBinding(key="u", label="Use as default", action="use_database")) - seen.add("use_database") left.append(DisplayBinding(key="f", label="Refresh", action="refresh_tree")) seen.add("refresh_tree") diff --git a/sqlit/ui/mixins/autocomplete.py b/sqlit/ui/mixins/autocomplete.py index d751a8ce..e0107c32 100644 --- a/sqlit/ui/mixins/autocomplete.py +++ b/sqlit/ui/mixins/autocomplete.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from typing import Any from textual.timer import Timer @@ -9,6 +10,16 @@ from textual.worker import Worker from ..protocols import AppProtocol +from ...sql_completion import ( + SQL_OPERATORS, + SuggestionType, + extract_table_refs, + fuzzy_match, + get_all_functions, + get_all_keywords, + get_completions, + get_context, +) SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] @@ -20,6 +31,10 @@ class AutocompleteMixin: _schema_spinner_timer: Timer | None = None _schema_cache: dict[str, Any] = {} _table_metadata: dict[str, tuple[str, str, str | None]] = {} + _autocomplete_debounce_timer: Timer | None = None + # Shared cache for raw DB objects - used by both tree and autocomplete + # Structure: {db_name: {"tables": [(schema, name), ...], "views": [...], "procedures": [...]}} + _db_object_cache: dict[str, dict[str, list[Any]]] = {} def _run_db_call(self: AppProtocol, fn: Any, *args: Any, **kwargs: Any) -> Any: session = getattr(self, "_session", None) @@ -27,62 +42,83 @@ def _run_db_call(self: AppProtocol, fn: Any, *args: Any, **kwargs: Any) -> Any: return session.executor.submit(fn, *args, **kwargs).result() return fn(*args, **kwargs) - def _get_word_before_cursor(self, text: str, cursor_pos: int) -> tuple[str, str]: - """Get the current word being typed and the context keyword before it.""" - if cursor_pos <= 0 or cursor_pos > len(text): - return "", "" - + def _get_current_word(self, text: str, cursor_pos: int) -> str: + """Get the word currently being typed at cursor position.""" before_cursor = text[:cursor_pos] - word_start = cursor_pos - while word_start > 0 and before_cursor[word_start - 1] not in " \t\n,()[]": - word_start -= 1 - current_word = before_cursor[word_start:cursor_pos] - - if "." in current_word: - parts = current_word.rsplit(".", 1) - table_name = parts[0].strip("[]") - return parts[1] if len(parts) > 1 else "", f"column:{table_name}" - - context_text = before_cursor[:word_start].upper().strip() - - table_keywords = ["FROM", "JOIN", "INTO", "UPDATE", "TABLE"] - for kw in table_keywords: - if context_text.endswith(kw): - return current_word, "table" - - if context_text.endswith("EXEC") or context_text.endswith("EXECUTE"): - return current_word, "procedure" - - if context_text.endswith("SELECT") or context_text.endswith(","): - return current_word, "column_or_table" - - return current_word, "" - - def _get_autocomplete_suggestions(self: AppProtocol, word: str, context: str) -> list[str]: - """Get autocomplete suggestions based on context.""" - suggestions = [] - - if context == "table": - suggestions = self._schema_cache["tables"] + self._schema_cache["views"] - elif context == "procedure": - suggestions = self._schema_cache["procedures"] - elif context.startswith("column:"): - table_name = context.split(":", 1)[1].lower() - if table_name not in self._schema_cache["columns"]: - self._load_columns_for_table(table_name) - suggestions = self._schema_cache["columns"].get(table_name, []) - elif context == "column_or_table": - all_columns = [] - for cols in self._schema_cache["columns"].values(): - all_columns.extend(cols) - suggestions = list(set(all_columns)) + self._schema_cache["tables"] - - if word: - word_lower = word.lower() - suggestions = [s for s in suggestions if s.lower().startswith(word_lower)] + # Handle table.column case - get just the part after dot + if "." in before_cursor: + dot_match = re.search(r"\.(\w*)$", before_cursor) + if dot_match: + return dot_match.group(1) + + # Get word before cursor + match = re.search(r"(\w*)$", before_cursor) + if match: + return match.group(1) + return "" + + def _build_alias_map(self: AppProtocol, text: str) -> dict[str, str]: + """Build a map of alias -> table name from the SQL text.""" + table_refs = extract_table_refs(text) + known_tables = set(t.lower() for t in self._schema_cache.get("tables", [])) + known_tables.update(t.lower() for t in self._schema_cache.get("views", [])) + + alias_map: dict[str, str] = {} + for ref in table_refs: + if ref.alias and ref.name.lower() in known_tables: + alias_map[ref.alias.lower()] = ref.name + return alias_map + + def _get_autocomplete_suggestions(self: AppProtocol, text: str, cursor_pos: int) -> list[str]: + """Get autocomplete suggestions using the SQL completion engine.""" + # Build schema data for get_completions + tables = self._schema_cache.get("tables", []) + self._schema_cache.get("views", []) + columns = self._schema_cache.get("columns", {}) + procedures = self._schema_cache.get("procedures", []) + + # First check if we need to lazy-load columns before calling get_completions + suggestions = get_context(text, cursor_pos) + if suggestions: + alias_map = self._build_alias_map(text) + table_refs = extract_table_refs(text) + loading = getattr(self, "_columns_loading", set()) + + for suggestion in suggestions: + if suggestion.type == SuggestionType.COLUMN: + # Check if any tables need column loading + for ref in table_refs: + table_key = ref.name.lower() + if table_key not in columns and table_key not in loading: + self._load_columns_for_table(table_key) + return ["Loading..."] + elif table_key in loading: + return ["Loading..."] + + elif suggestion.type == SuggestionType.ALIAS_COLUMN: + scope = suggestion.table_scope + if scope: + scope_lower = scope.lower() + table_key = alias_map.get(scope_lower, scope_lower) + + if table_key not in columns and table_key not in loading: + self._load_columns_for_table(table_key) + return ["Loading..."] + elif table_key in loading: + return ["Loading..."] + + # Now call get_completions with all available data + results = get_completions( + text, + cursor_pos, + tables, + columns, + procedures, + include_keywords=True, + include_functions=True, + ) - return suggestions[:50] + return results def _load_columns_for_table(self: AppProtocol, table_name: str) -> None: """Lazy load columns for a specific table (async via worker).""" @@ -109,8 +145,11 @@ def work() -> None: column_names = [] else: try: + db_arg = database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(database) columns = self._run_db_call( - adapter.get_columns, connection, actual_table_name, database, schema_name + adapter.get_columns, connection, actual_table_name, db_arg, schema_name ) column_names = [c.name for c in columns] except Exception: @@ -133,6 +172,60 @@ def _on_autocomplete_columns_loaded( self._schema_cache["columns"][table_name] = column_names self._schema_cache["columns"][actual_table_name.lower()] = column_names + # Refresh autocomplete if visible (replaces "Loading..." with actual columns) + if self._autocomplete_visible: + text = self.query_input.text + cursor_loc = self.query_input.cursor_location + cursor_pos = self._location_to_offset(text, cursor_loc) + current_word = self._get_current_word(text, cursor_pos) + suggestions = self._get_autocomplete_suggestions(text, cursor_pos) + if suggestions: + self._show_autocomplete(suggestions, current_word) + else: + self._hide_autocomplete() + + def _has_tables_needing_columns(self: AppProtocol, text: str) -> bool: + """Check if there are tables in the query that need column loading.""" + if not text.strip(): + return False + + table_refs = extract_table_refs(text) + columns_cache = self._schema_cache.get("columns", {}) + loading = getattr(self, "_columns_loading", set()) + + for ref in table_refs: + table_key = ref.name.lower() + if table_key in columns_cache or table_key in loading: + continue + if table_key in self._table_metadata: + return True + return False + + def _preload_columns_for_query(self: AppProtocol) -> None: + """Preload columns for all tables found in the current query (runs during idle).""" + if not self.current_connection or not self.current_adapter: + return + + text = self.query_input.text + if not text.strip(): + return + + # Extract table references from the query + table_refs = extract_table_refs(text) + columns_cache = self._schema_cache.get("columns", {}) + loading = getattr(self, "_columns_loading", set()) + + for ref in table_refs: + table_key = ref.name.lower() + # Skip if already loaded or currently loading + if table_key in columns_cache or table_key in loading: + continue + # Skip if not a known table + if table_key not in self._table_metadata: + continue + # Queue column loading + self._load_columns_for_table(table_key) + def _show_autocomplete(self: AppProtocol, suggestions: list[str], filter_text: str) -> None: """Show the autocomplete dropdown with suggestions.""" @@ -209,15 +302,27 @@ def _offset_to_location(self, text: str, offset: int) -> tuple[int, int]: def on_text_area_changed(self: AppProtocol, event: TextArea.Changed) -> None: """Handle text changes in the query editor for autocomplete.""" from ...widgets import VimMode + from ...idle_scheduler import on_user_activity + + # Track user activity for idle scheduler + on_user_activity() if event.text_area.id != "query-input": return + # Mark that text just changed so selection_changed knows to ignore cursor movement + self._text_just_changed = True + if self._autocomplete_just_applied: self._autocomplete_just_applied = False self._hide_autocomplete() return + # Suppress autocomplete after Enter dismisses dropdown (newline shouldn't re-trigger) + if getattr(self, "_suppress_autocomplete_on_newline", False): + self._suppress_autocomplete_on_newline = False + return + if self.vim_mode != VimMode.INSERT: self._hide_autocomplete() return @@ -225,25 +330,55 @@ def on_text_area_changed(self: AppProtocol, event: TextArea.Changed) -> None: if not self.current_connection: return - text = event.text_area.text - cursor_loc = event.text_area.cursor_location + # Cancel any pending debounce timer + if self._autocomplete_debounce_timer is not None: + self._autocomplete_debounce_timer.stop() + self._autocomplete_debounce_timer = None + + # Debounce: wait 100ms before triggering autocomplete + self._autocomplete_debounce_timer = self.set_timer( + 0.1, lambda: self._trigger_autocomplete(event.text_area) + ) + + def _trigger_autocomplete(self: AppProtocol, text_area: TextArea) -> None: + """Actually trigger autocomplete after debounce delay.""" + from ...idle_scheduler import get_idle_scheduler, Priority + + self._autocomplete_debounce_timer = None + + text = text_area.text + cursor_loc = text_area.cursor_location cursor_pos = self._location_to_offset(text, cursor_loc) - word, context = self._get_word_before_cursor(text, cursor_pos) + # Get current word for display purposes + current_word = self._get_current_word(text, cursor_pos) - if context: - is_column_context = context.startswith("column:") - if is_column_context or len(word) >= 1: - suggestions = self._get_autocomplete_suggestions(word, context) - if suggestions: - self._show_autocomplete(suggestions, word) - else: - self._hide_autocomplete() - else: - self._hide_autocomplete() + # Get suggestions using the SQL completion engine + suggestions = self._get_autocomplete_suggestions(text, cursor_pos) + + if suggestions: + self._show_autocomplete(suggestions, current_word) else: self._hide_autocomplete() + # Queue column preloading for tables in the query (runs during idle) + # Only queue if there are actually tables that need column loading + scheduler = get_idle_scheduler() + if scheduler and self._has_tables_needing_columns(text): + # Cancel any previous preload job - we'll queue a fresh one + scheduler.cancel_all(name="preload-columns") + scheduler.request_idle_callback( + self._preload_columns_for_query, + priority=Priority.LOW, + name="preload-columns", + ) + + def on_descendant_blur(self: AppProtocol, event: Any) -> None: + """Handle blur events - don't hide autocomplete on window focus loss.""" + # Only hide if focus moves to another widget within the app (not window blur) + # We want autocomplete to stay visible when user moves mouse to another window + pass + def action_autocomplete_next(self: AppProtocol) -> None: """Move to next autocomplete suggestion.""" if self._autocomplete_visible: @@ -258,9 +393,29 @@ def action_autocomplete_close(self: AppProtocol) -> None: """Close autocomplete dropdown without exiting insert mode.""" self._hide_autocomplete() + def on_text_area_selection_changed(self: AppProtocol, event: Any) -> None: + """Hide autocomplete when cursor moves without text change.""" + if not self._autocomplete_visible: + return + + if getattr(event, "text_area", None) and getattr(event.text_area, "id", None) != "query-input": + return + + # If text just changed, this cursor movement is from typing - ignore it + if getattr(self, "_text_just_changed", False): + self._text_just_changed = False + return + + # Cursor moved without text change (arrow keys, click, etc.) - hide autocomplete + self._hide_autocomplete() + def on_key(self: AppProtocol, event: Any) -> None: """Handle key events for autocomplete navigation.""" from ...widgets import VimMode + from ...idle_scheduler import on_user_activity + + # Track user activity for idle scheduler + on_user_activity() # Handle autocomplete navigation if not self._autocomplete_visible: @@ -276,18 +431,19 @@ def on_key(self: AppProtocol, event: Any) -> None: dropdown.move_selection(-1) event.prevent_default() event.stop() - elif event.key == "tab": + elif event.key in ("tab", "enter"): if self.vim_mode == VimMode.INSERT and dropdown.filtered_items: self._apply_autocomplete() event.prevent_default() event.stop() elif event.key == "escape": - self._hide_autocomplete() + # Hide autocomplete AND exit insert mode (go to normal mode) + self.action_exit_insert_mode() event.prevent_default() event.stop() def _load_schema_cache(self: AppProtocol) -> None: - """Load database schema for autocomplete asynchronously.""" + """Load database schema for autocomplete using threaded workers.""" if not self.current_connection or not self.current_config or not self.current_adapter: return @@ -303,17 +459,352 @@ def _load_schema_cache(self: AppProtocol) -> None: "procedures": [], } self._table_metadata = {} + self._columns_loading = set() # Clear any in-progress column loads + self._db_object_cache = {} # Clear shared object cache # Start schema indexing spinner self._start_schema_spinner() - # Run schema loading in background thread - self._schema_worker = self.run_worker( - self._load_schema_cache_async(), - name="schema_cache_loading", - exclusive=True, + # Load schema directly using threaded workers (no idle scheduler needed) + self._load_schema_directly() + + def _load_schema_directly(self: AppProtocol) -> None: + """Load schema using threaded workers - runs immediately without idle scheduler.""" + adapter = self.current_adapter + connection = self.current_connection + config = self.current_config + + if not adapter or not connection or not config: + self._stop_schema_spinner() + return + + # Track pending database loads + self._schema_pending_dbs: list[str | None] = [] + self._schema_total_jobs = 0 + self._schema_completed_jobs = 0 + + if adapter.supports_multiple_databases: + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + if db: + # Single database specified - load immediately + self._on_databases_loaded([db]) + elif adapter.supports_cross_database_queries: + # Need to fetch database list - offload to thread + def work() -> None: + try: + all_dbs = adapter.get_databases(connection) + system_dbs = {s.lower() for s in adapter.system_databases} + databases = [d for d in all_dbs if d.lower() not in system_dbs] + self.call_from_thread(self._on_databases_loaded, databases) + except Exception as e: + self.call_from_thread(self._on_databases_error, e) + + self.run_worker(work, thread=True, name="get-databases") + else: + self._stop_schema_spinner() + else: + # No multiple databases - just proceed with None + self._on_databases_loaded([None]) + + def _load_schema_via_idle_scheduler(self: AppProtocol, scheduler: Any) -> None: + """Load schema using idle scheduler for smoother UI.""" + from ...idle_scheduler import Priority + + adapter = self.current_adapter + connection = self.current_connection + config = self.current_config + + if not adapter or not connection or not config: + self._stop_schema_spinner() + return + + # Track pending database loads + self._schema_pending_dbs: list[str | None] = [] + self._schema_total_jobs = 0 + self._schema_completed_jobs = 0 + # Store scheduler reference for use in callbacks + self._schema_scheduler = scheduler + + def get_databases_job() -> None: + """First job: dispatch thread to get list of databases.""" + if adapter.supports_multiple_databases: + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + if db: + # Single database specified - no need for DB call + self._on_databases_loaded([db]) + elif adapter.supports_cross_database_queries: + # Need to fetch database list - offload to thread + def work() -> None: + try: + all_dbs = adapter.get_databases(connection) + system_dbs = {s.lower() for s in adapter.system_databases} + databases = [d for d in all_dbs if d.lower() not in system_dbs] + self.call_from_thread(self._on_databases_loaded, databases) + except Exception as e: + self.call_from_thread(self._on_databases_error, e) + + self.run_worker(work, thread=True, name="get-databases") + else: + self._stop_schema_spinner() + else: + # No multiple databases - just proceed with None + self._on_databases_loaded([None]) + + # Queue the first job with high priority + scheduler.request_idle_callback( + get_databases_job, + priority=Priority.HIGH, + name="schema-load", ) + def _on_databases_loaded(self: AppProtocol, databases: list) -> None: + """Handle databases list loaded - spawn threaded workers for each database.""" + adapter = self.current_adapter + + if not adapter: + self._stop_schema_spinner() + return + + self._schema_pending_dbs = databases + self._schema_total_jobs = len(databases) * 3 # tables, views, procedures per db + + # Spawn workers directly - they're threaded so won't block + for database in databases: + self._load_tables_job(database) + self._load_views_job(database) + if adapter.supports_stored_procedures: + self._load_procedures_job(database) + else: + self._schema_completed_jobs += 1 # Skip procedures + + def _on_databases_error(self: AppProtocol, error: Exception) -> None: + """Handle error getting databases list.""" + self.log.error(f"Error getting databases: {error}") + self._stop_schema_spinner() + + def _load_tables_job(self: AppProtocol, database: str | None) -> None: + """Idle job: load tables for a single database (dispatches to thread).""" + adapter = self.current_adapter + connection = self.current_connection + + if not adapter or not connection: + self._schema_job_complete() + return + + cache_key = database or "__default__" + + # Check shared cache first (may have been populated by tree expansion) + if cache_key in self._db_object_cache and "tables" in self._db_object_cache[cache_key]: + self._process_tables_result(self._db_object_cache[cache_key]["tables"], database, cache_key) + return + + # Offload DB call to thread + def work() -> None: + try: + db_arg = database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(database) + tables = adapter.get_tables(connection, db_arg) + # Store in shared cache and process on main thread + self.call_from_thread(self._on_tables_loaded, tables, database, cache_key) + except Exception as e: + self.call_from_thread(self._on_tables_error, e, database) + + self.run_worker(work, thread=True, name=f"load-tables-{cache_key}") + + def _on_tables_loaded(self: AppProtocol, tables: list, database: str | None, cache_key: str) -> None: + """Handle tables loaded from thread.""" + if cache_key not in self._db_object_cache: + self._db_object_cache[cache_key] = {} + self._db_object_cache[cache_key]["tables"] = tables + self._process_tables_result(tables, database, cache_key) + + def _on_tables_error(self: AppProtocol, error: Exception, database: str | None) -> None: + """Handle tables load error from thread.""" + self.log.error(f"Error loading tables for {database}: {error}") + self._schema_job_complete() + + def _process_tables_result(self: AppProtocol, tables: list, database: str | None, cache_key: str) -> None: + """Process tables result on main thread.""" + adapter = self.current_adapter + if not adapter: + self._schema_job_complete() + return + + try: + single_db = len(getattr(self, "_schema_pending_dbs", [None])) == 1 + + for schema_name, table_name in tables: + if single_db: + self._schema_cache["tables"].append(table_name) + else: + quoted_db = adapter.quote_identifier(database) if database else "" + quoted_schema = adapter.quote_identifier(schema_name) + quoted_table = adapter.quote_identifier(table_name) + if database: + full_name = f"{quoted_db}.{quoted_schema}.{quoted_table}" + else: + full_name = f"{quoted_schema}.{quoted_table}" + self._schema_cache["tables"].append(full_name) + + # Store metadata for column loading + display_name = adapter.format_table_name(schema_name, table_name) + self._table_metadata[display_name.lower()] = (schema_name, table_name, database) + self._table_metadata[table_name.lower()] = (schema_name, table_name, database) + if database: + self._table_metadata[f"{database}.{table_name}".lower()] = (schema_name, table_name, database) + if not single_db: + self._table_metadata[full_name.lower()] = (schema_name, table_name, database) + + except Exception as e: + self.log.error(f"Error processing tables for {database}: {e}") + + self._schema_job_complete() + + def _load_views_job(self: AppProtocol, database: str | None) -> None: + """Idle job: load views for a single database (dispatches to thread).""" + adapter = self.current_adapter + connection = self.current_connection + + if not adapter or not connection: + self._schema_job_complete() + return + + cache_key = database or "__default__" + + # Check shared cache first (may have been populated by tree expansion) + if cache_key in self._db_object_cache and "views" in self._db_object_cache[cache_key]: + self._process_views_result(self._db_object_cache[cache_key]["views"], database, cache_key) + return + + # Offload DB call to thread + def work() -> None: + try: + db_arg = database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(database) + views = adapter.get_views(connection, db_arg) + self.call_from_thread(self._on_views_loaded, views, database, cache_key) + except Exception as e: + self.call_from_thread(self._on_views_error, e, database) + + self.run_worker(work, thread=True, name=f"load-views-{cache_key}") + + def _on_views_loaded(self: AppProtocol, views: list, database: str | None, cache_key: str) -> None: + """Handle views loaded from thread.""" + if cache_key not in self._db_object_cache: + self._db_object_cache[cache_key] = {} + self._db_object_cache[cache_key]["views"] = views + self._process_views_result(views, database, cache_key) + + def _on_views_error(self: AppProtocol, error: Exception, database: str | None) -> None: + """Handle views load error from thread.""" + self.log.error(f"Error loading views for {database}: {error}") + self._schema_job_complete() + + def _process_views_result(self: AppProtocol, views: list, database: str | None, cache_key: str) -> None: + """Process views result on main thread.""" + adapter = self.current_adapter + if not adapter: + self._schema_job_complete() + return + + try: + single_db = len(getattr(self, "_schema_pending_dbs", [None])) == 1 + + for schema_name, view_name in views: + if single_db: + self._schema_cache["views"].append(view_name) + else: + quoted_db = adapter.quote_identifier(database) if database else "" + quoted_schema = adapter.quote_identifier(schema_name) + quoted_view = adapter.quote_identifier(view_name) + if database: + full_name = f"{quoted_db}.{quoted_schema}.{quoted_view}" + else: + full_name = f"{quoted_schema}.{quoted_view}" + self._schema_cache["views"].append(full_name) + + # Store metadata for column loading + display_name = adapter.format_table_name(schema_name, view_name) + self._table_metadata[display_name.lower()] = (schema_name, view_name, database) + self._table_metadata[view_name.lower()] = (schema_name, view_name, database) + if database: + self._table_metadata[f"{database}.{view_name}".lower()] = (schema_name, view_name, database) + if not single_db: + self._table_metadata[full_name.lower()] = (schema_name, view_name, database) + + except Exception as e: + self.log.error(f"Error processing views for {database}: {e}") + + self._schema_job_complete() + + def _load_procedures_job(self: AppProtocol, database: str | None) -> None: + """Idle job: load procedures for a single database (dispatches to thread).""" + adapter = self.current_adapter + connection = self.current_connection + + if not adapter or not connection: + self._schema_job_complete() + return + + cache_key = database or "__default__" + + # Check shared cache first (may have been populated by tree expansion) + if cache_key in self._db_object_cache and "procedures" in self._db_object_cache[cache_key]: + self._process_procedures_result(self._db_object_cache[cache_key]["procedures"], cache_key) + return + + # Offload DB call to thread + def work() -> None: + try: + db_arg = database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(database) + procedures = adapter.get_procedures(connection, db_arg) + self.call_from_thread(self._on_procedures_loaded, procedures, database, cache_key) + except Exception as e: + self.call_from_thread(self._on_procedures_error, e, database) + + self.run_worker(work, thread=True, name=f"load-procedures-{cache_key}") + + def _on_procedures_loaded(self: AppProtocol, procedures: list, database: str | None, cache_key: str) -> None: + """Handle procedures loaded from thread.""" + if cache_key not in self._db_object_cache: + self._db_object_cache[cache_key] = {} + self._db_object_cache[cache_key]["procedures"] = procedures + self._process_procedures_result(procedures, cache_key) + + def _on_procedures_error(self: AppProtocol, error: Exception, database: str | None) -> None: + """Handle procedures load error from thread.""" + self.log.error(f"Error loading procedures for {database}: {error}") + self._schema_job_complete() + + def _process_procedures_result(self: AppProtocol, procedures: list, cache_key: str) -> None: + """Process procedures result on main thread.""" + try: + self._schema_cache["procedures"].extend(procedures) + except Exception as e: + self.log.error(f"Error processing procedures: {e}") + + self._schema_job_complete() + + def _schema_job_complete(self: AppProtocol) -> None: + """Called when a schema loading job completes.""" + self._schema_completed_jobs = getattr(self, "_schema_completed_jobs", 0) + 1 + total = getattr(self, "_schema_total_jobs", 1) + + if self._schema_completed_jobs >= total: + # All jobs done - deduplicate and finalize + self._schema_cache["tables"] = list(dict.fromkeys(self._schema_cache["tables"])) + self._schema_cache["views"] = list(dict.fromkeys(self._schema_cache["views"])) + self._schema_cache["procedures"] = list(dict.fromkeys(self._schema_cache["procedures"])) + self._stop_schema_spinner() + def _start_schema_spinner(self: AppProtocol) -> None: """Start the schema indexing spinner animation.""" self._schema_indexing = True @@ -348,9 +839,10 @@ def action_cancel_schema_indexing(self: AppProtocol) -> None: self.notify("Schema indexing cancelled") async def _load_schema_cache_async(self: AppProtocol) -> None: - """Load database schema asynchronously in a worker thread. + """Load database schema asynchronously. - Only loads tables, views, and procedures. Columns are loaded lazily. + Only loads tables, views, and procedures. + Columns are lazy-loaded on demand when user types `table.` """ import asyncio @@ -377,49 +869,86 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any: return await asyncio.to_thread(fn, *args, **kwargs) try: - # Get database list in thread databases: list[str | None] if adapter.supports_multiple_databases: - db = config.database - if db and db.lower() not in ("", "master"): + db = None + if hasattr(self, "_get_effective_database"): + db = self._get_effective_database() + if db: + # Active/default database is set - only load that one databases = [db] - else: + elif adapter.supports_cross_database_queries: + # No default database - load all non-system databases all_dbs = await run_db_call(adapter.get_databases, connection) - system_dbs = {"master", "tempdb", "model", "msdb"} + system_dbs = {s.lower() for s in adapter.system_databases} databases = [d for d in all_dbs if d.lower() not in system_dbs] + else: + databases = [] else: databases = [None] for database in databases: try: - # Get tables in thread (NO columns - lazy loaded) - tables = await run_db_call(adapter.get_tables, connection, database) + # Get tables + db_arg = database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(database) + tables = await run_db_call(adapter.get_tables, connection, db_arg) for schema_name, table_name in tables: + # Use simple name if we have a default database, full qualifier otherwise + if len(databases) == 1: + # Single database - use simple table name + schema_cache["tables"].append(table_name) + else: + # Multiple databases - use full qualifier [db].[schema].[table] + quoted_db = adapter.quote_identifier(database) if database else "" + quoted_schema = adapter.quote_identifier(schema_name) + quoted_table = adapter.quote_identifier(table_name) + if database: + full_name = f"{quoted_db}.{quoted_schema}.{quoted_table}" + else: + full_name = f"{quoted_schema}.{quoted_table}" + schema_cache["tables"].append(full_name) + # Keep metadata for column loading (multiple keys for flexible lookup) display_name = adapter.format_table_name(schema_name, table_name) - schema_cache["tables"].append(display_name) - # Store metadata for lazy column loading table_metadata[display_name.lower()] = (schema_name, table_name, database) table_metadata[table_name.lower()] = (schema_name, table_name, database) if database: - full_name = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier(display_name)}" - schema_cache["tables"].append(full_name) - table_metadata[full_name.lower()] = (schema_name, table_name, database) + table_metadata[f"{database}.{table_name}".lower()] = (schema_name, table_name, database) + # Also store with full quoted name for [db].[schema].[table] lookups + if len(databases) > 1: + table_metadata[full_name.lower()] = (schema_name, table_name, database) - # Get views in thread (NO columns - lazy loaded) - views = await run_db_call(adapter.get_views, connection, database) + # Get views + views = await run_db_call(adapter.get_views, connection, db_arg) for schema_name, view_name in views: + # Use simple name if we have a default database, full qualifier otherwise + if len(databases) == 1: + # Single database - use simple view name + schema_cache["views"].append(view_name) + else: + # Multiple databases - use full qualifier [db].[schema].[view] + quoted_db = adapter.quote_identifier(database) if database else "" + quoted_schema = adapter.quote_identifier(schema_name) + quoted_view = adapter.quote_identifier(view_name) + if database: + full_name = f"{quoted_db}.{quoted_schema}.{quoted_view}" + else: + full_name = f"{quoted_schema}.{quoted_view}" + schema_cache["views"].append(full_name) + # Keep metadata for column loading (multiple keys for flexible lookup) display_name = adapter.format_table_name(schema_name, view_name) - schema_cache["views"].append(display_name) - # Store metadata for lazy column loading table_metadata[display_name.lower()] = (schema_name, view_name, database) table_metadata[view_name.lower()] = (schema_name, view_name, database) if database: - full_name = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier(display_name)}" - schema_cache["views"].append(full_name) - table_metadata[full_name.lower()] = (schema_name, view_name, database) + table_metadata[f"{database}.{view_name}".lower()] = (schema_name, view_name, database) + # Also store with full quoted name for [db].[schema].[view] lookups + if len(databases) > 1: + table_metadata[full_name.lower()] = (schema_name, view_name, database) + # Get procedures if adapter.supports_stored_procedures: - procedures = await run_db_call(adapter.get_procedures, connection, database) + procedures = await run_db_call(adapter.get_procedures, connection, db_arg) schema_cache["procedures"].extend(procedures) except Exception: @@ -430,7 +959,7 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any: schema_cache["views"] = list(dict.fromkeys(schema_cache["views"])) schema_cache["procedures"] = list(dict.fromkeys(schema_cache["procedures"])) - # Update cache (we're back on main thread after await) + # Update cache - columns will be lazy-loaded when needed self._update_schema_cache(schema_cache, table_metadata) except Exception as e: diff --git a/sqlit/ui/mixins/connection.py b/sqlit/ui/mixins/connection.py index 2c25daa4..bc9fed3b 100644 --- a/sqlit/ui/mixins/connection.py +++ b/sqlit/ui/mixins/connection.py @@ -18,7 +18,7 @@ def _needs_db_password(config: ConnectionConfig) -> bool: """Check if the connection needs a database password prompt. Returns True if password is None (not set) and the database type uses passwords. - Note: Empty string "" means explicitly set to empty (no prompt needed). + An empty string password ("") is considered explicitly set and does not need a prompt. """ from ...db.providers import is_file_based @@ -26,7 +26,7 @@ def _needs_db_password(config: ConnectionConfig) -> bool: if is_file_based(config.db_type): return False - # Check if password is not set (None means prompt needed) + # Only prompt if password is None (not set), not for empty string (explicitly empty) return config.password is None @@ -34,7 +34,7 @@ def _needs_ssh_password(config: ConnectionConfig) -> bool: """Check if the connection needs an SSH password prompt. Returns True if SSH is enabled with password auth and password is None (not set). - Note: Empty string "" means explicitly set to empty (no prompt needed). + An empty string password ("") is considered explicitly set and does not need a prompt. """ if not config.ssh_enabled: return False @@ -42,6 +42,7 @@ def _needs_ssh_password(config: ConnectionConfig) -> bool: if config.ssh_auth_type != "password": return False + # Only prompt if password is None (not set), not for empty string (explicitly empty) return config.ssh_password is None @@ -214,12 +215,18 @@ def on_success(session: ConnectionSession) -> None: self.current_ssh_tunnel = session.tunnel is_saved = any(c.name == config.name for c in self.connections) self._direct_connection_config = None if is_saved else config + self._active_database = None + reconnected = False self.refresh_tree() self.call_after_refresh(self._select_connected_node) - self._load_schema_cache() + if not reconnected: + self._load_schema_cache() self._update_status_bar() self._update_section_labels() + # Update database labels to show star on active database + if hasattr(self, "_update_database_labels"): + self.call_after_refresh(self._update_database_labels) if self.current_adapter: warnings = self.current_adapter.get_post_connect_warnings(config) for message in warnings: @@ -267,6 +274,8 @@ def _disconnect_silent(self: AppProtocol) -> None: self.current_adapter = None self.current_ssh_tunnel = None self._direct_connection_config = None + self._active_database = None + self._clear_query_target_database() self.refresh_tree() self._update_section_labels() @@ -287,6 +296,29 @@ def action_disconnect(self: AppProtocol) -> None: self.status_bar.update("Disconnected") self.notify("Disconnected") + def _get_effective_database(self: AppProtocol) -> str | None: + """Return the active database for the current connection context.""" + if not self.current_adapter or not self.current_config: + return None + if self.current_adapter.supports_cross_database_queries: + db_name = getattr(self, "_active_database", None) or self.current_config.database + return db_name or None + db_name = self.current_config.database + return db_name or None + + def _get_metadata_db_arg(self: AppProtocol, database: str | None) -> str | None: + """Return database arg for metadata calls when cross-db queries are supported.""" + if not database or not self.current_adapter: + return None + if self.current_adapter.supports_cross_database_queries: + return database + return None + + def _clear_query_target_database(self: AppProtocol) -> None: + """Clear any pending per-query database override.""" + if hasattr(self, "_query_target_database"): + self._query_target_database = None + def action_new_connection(self: AppProtocol) -> None: from ..screens import ConnectionScreen diff --git a/sqlit/ui/mixins/query.py b/sqlit/ui/mixins/query.py index 7edccab1..241e6fa2 100644 --- a/sqlit/ui/mixins/query.py +++ b/sqlit/ui/mixins/query.py @@ -137,6 +137,21 @@ async def _run_query_async(self: AppProtocol, query: str, keep_insert_mode: bool self._stop_query_spinner() return + # If we have a target database from clicking a table in the tree, + # use that database for the query execution (needed for Azure SQL) + target_db = getattr(self, "_query_target_database", None) + if target_db and target_db != config.database: + config = adapter.apply_database_override(config, target_db) + # Clear target database after use - it's only for the auto-generated query + self._query_target_database = None + + # Apply active database to query execution (from USE statement or 'u' key) + active_db = None + if hasattr(self, "_get_effective_database"): + active_db = self._get_effective_database() + if active_db and active_db != config.database and not target_db: + config = adapter.apply_database_override(config, active_db) + # Handle USE database statements db_name = parse_use_statement(query) if db_name is not None: diff --git a/sqlit/ui/mixins/tree.py b/sqlit/ui/mixins/tree.py index d51599c1..5f807eeb 100644 --- a/sqlit/ui/mixins/tree.py +++ b/sqlit/ui/mixins/tree.py @@ -165,7 +165,10 @@ def get_conn_label(config: Any, connected: Any = False) -> str: try: if adapter.supports_multiple_databases: specific_db = self.current_config.database - if specific_db and specific_db.lower() not in ("", "master"): + # Show a single database view when a specific database was configured. + # Otherwise, show the Databases folder to browse all databases. + show_single_db = specific_db and specific_db.lower() not in ("", "master") + if show_single_db: self._add_database_object_nodes(active_node, specific_db) active_node.expand() else: @@ -173,10 +176,12 @@ def get_conn_label(config: Any, connected: Any = False) -> str: dbs_node.data = FolderNode(folder_type="databases") databases = self._run_db_call(adapter.get_databases, self.current_connection) - default_db = self.current_config.database if self.current_config else None + active_db = None + if hasattr(self, "_get_effective_database"): + active_db = self._get_effective_database() for db_name in databases: - # Show default database with star and green text - if default_db and db_name.lower() == default_db.lower(): + # Show active database with star and green text + if active_db and db_name.lower() == active_db.lower(): db_label = f"[#4ADE80]* {escape_markup(db_name)}[/]" else: db_label = escape_markup(db_name) @@ -295,6 +300,10 @@ def on_tree_node_expanded(self: AppProtocol, event: Tree.NodeExpanded) -> None: data = node.data + # When a database node is expanded, ensure we're connected to it + if self._get_node_kind(node) == "database": + self._ensure_database_connection(data.name) + # Skip if already has children (not just loading placeholder) children = list(node.children) if children: @@ -315,6 +324,10 @@ def on_tree_node_expanded(self: AppProtocol, event: Tree.NodeExpanded) -> None: # Handle table/view column expansion if self._get_node_kind(node) in ("table", "view"): + # Ensure we're connected to the right database before loading + target_db = data.database + if target_db and not self._ensure_database_connection(target_db): + return # Switch failed self._loading_nodes.add(node_path) loading_node = node.add_leaf("[dim italic]Loading...[/]") loading_node.data = LoadingNode() @@ -323,6 +336,10 @@ def on_tree_node_expanded(self: AppProtocol, event: Tree.NodeExpanded) -> None: # Handle folder expansion (database can be None for single-db adapters) if self._get_node_kind(node) == "folder": + # Ensure we're connected to the right database before loading + target_db = data.database + if target_db and not self._ensure_database_connection(target_db): + return # Switch failed self._loading_nodes.add(node_path) loading_node = node.add_leaf("[dim italic]Loading...[/]") loading_node.data = LoadingNode() @@ -343,7 +360,10 @@ def work() -> None: else: adapter = self._session.adapter conn = self._session.connection - columns = self._run_db_call(adapter.get_columns, conn, obj_name, db_name, schema_name) + db_arg = db_name + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(db_name) + columns = self._run_db_call(adapter.get_columns, conn, obj_name, db_arg, schema_name) # Update UI from worker thread self.call_from_thread(self._on_columns_loaded, node, db_name, schema_name, obj_name, columns) @@ -378,6 +398,7 @@ def _load_folder_async(self: AppProtocol, node: Any, data: FolderNode) -> None: """Spawn worker to load folder contents (tables/views/indexes/triggers/sequences/procedures).""" folder_type = data.folder_type db_name = data.database + cache_key = db_name or "__default__" def work() -> None: """Run in worker thread.""" @@ -387,16 +408,40 @@ def work() -> None: else: adapter = self._session.adapter conn = self._session.connection + db_arg = db_name + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(db_name) + + # Check shared cache first for tables/views/procedures + obj_cache = getattr(self, "_db_object_cache", {}) if folder_type == "tables": - items = [("table", s, t) for s, t in self._run_db_call(adapter.get_tables, conn, db_name)] + if cache_key in obj_cache and "tables" in obj_cache[cache_key]: + raw_data = obj_cache[cache_key]["tables"] + else: + raw_data = self._run_db_call(adapter.get_tables, conn, db_arg) + # Store in shared cache + if cache_key not in obj_cache: + obj_cache[cache_key] = {} + obj_cache[cache_key]["tables"] = raw_data + self._db_object_cache = obj_cache + items = [("table", s, t) for s, t in raw_data] elif folder_type == "views": - items = [("view", s, v) for s, v in self._run_db_call(adapter.get_views, conn, db_name)] + if cache_key in obj_cache and "views" in obj_cache[cache_key]: + raw_data = obj_cache[cache_key]["views"] + else: + raw_data = self._run_db_call(adapter.get_views, conn, db_arg) + # Store in shared cache + if cache_key not in obj_cache: + obj_cache[cache_key] = {} + obj_cache[cache_key]["views"] = raw_data + self._db_object_cache = obj_cache + items = [("view", s, v) for s, v in raw_data] elif folder_type == "indexes": if adapter.supports_indexes: items = [ ("index", i.name, i.table_name) - for i in self._run_db_call(adapter.get_indexes, conn, db_name) + for i in self._run_db_call(adapter.get_indexes, conn, db_arg) ] else: items = [] @@ -404,7 +449,7 @@ def work() -> None: if adapter.supports_triggers: items = [ ("trigger", t.name, t.table_name) - for t in self._run_db_call(adapter.get_triggers, conn, db_name) + for t in self._run_db_call(adapter.get_triggers, conn, db_arg) ] else: items = [] @@ -412,16 +457,22 @@ def work() -> None: if adapter.supports_sequences: items = [ ("sequence", s.name, "") - for s in self._run_db_call(adapter.get_sequences, conn, db_name) + for s in self._run_db_call(adapter.get_sequences, conn, db_arg) ] else: items = [] elif folder_type == "procedures": if adapter.supports_stored_procedures: - items = [ - ("procedure", "", p) - for p in self._run_db_call(adapter.get_procedures, conn, db_name) - ] + if cache_key in obj_cache and "procedures" in obj_cache[cache_key]: + raw_data = obj_cache[cache_key]["procedures"] + else: + raw_data = self._run_db_call(adapter.get_procedures, conn, db_arg) + # Store in shared cache + if cache_key not in obj_cache: + obj_cache[cache_key] = {} + obj_cache[cache_key]["procedures"] = raw_data + self._db_object_cache = obj_cache + items = [("procedure", "", p) for p in raw_data] else: items = [] else: @@ -430,7 +481,11 @@ def work() -> None: # Update UI from worker thread self.call_from_thread(self._on_folder_loaded, node, db_name, folder_type, items) except Exception as e: - self.call_from_thread(self._on_tree_load_error, node, f"Error loading: {e}") + # If we have a target database, try reconnecting as fallback (handles Azure SQL etc.) + if db_name: + self.call_from_thread(self._fallback_reconnect_and_retry, node, data, db_name, e) + else: + self.call_from_thread(self._on_tree_load_error, node, f"Error loading: {e}") self.run_worker(work, name=f"load-folder-{folder_type}", thread=True, exclusive=False) @@ -536,6 +591,32 @@ def _on_tree_load_error(self: AppProtocol, node: Any, error_message: str) -> Non self.notify(escape_markup(error_message), severity="error") + def _fallback_reconnect_and_retry( + self: AppProtocol, node: Any, data: FolderNode, db_name: str, original_error: Exception + ) -> None: + """Try reconnecting to database and retry loading. Show original error if this also fails.""" + node_path = self._get_node_path(node) + self._loading_nodes.discard(node_path) + + # Remove loading placeholder + for child in list(node.children): + if self._get_node_kind(child) == "loading": + child.remove() + + # Try to reconnect + try: + self._reconnect_to_database(db_name) + except Exception: + # Reconnect failed - show original error + self.notify(escape_markup(f"Error loading: {original_error}"), severity="error") + return + + # Reconnect succeeded - retry loading + self._loading_nodes.add(node_path) + loading_node = node.add_leaf("[dim italic]Loading...[/]") + loading_node.data = LoadingNode() + self._load_folder_async(node, data) + def on_tree_node_selected(self: AppProtocol, event: Tree.NodeSelected) -> None: """Handle tree node selection (double-click/enter).""" # Ignore selection events when tree filter is active - the filter captures @@ -563,6 +644,11 @@ def on_tree_node_highlighted(self: AppProtocol, event: Tree.NodeHighlighted) -> def action_refresh_tree(self: AppProtocol) -> None: """Refresh the explorer.""" + # Clear shared object cache so fresh data is fetched + self._db_object_cache = {} + # Clear column cache too so columns are re-fetched + if hasattr(self, "_schema_cache") and "columns" in self._schema_cache: + self._schema_cache["columns"] = {} self.refresh_tree() self.notify("Refreshed") @@ -604,8 +690,11 @@ def action_select_table(self: AppProtocol) -> None: if self._get_node_kind(node) in ("table", "view"): # Store table info for edit_cell action try: + db_arg = data.database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(data.database) columns = self._session.adapter.get_columns( - self._session.connection, data.name, data.database, data.schema + self._session.connection, data.name, db_arg, data.schema ) self._last_query_table = { "database": data.database, @@ -617,6 +706,8 @@ def action_select_table(self: AppProtocol) -> None: self._last_query_table = None self.query_input.text = self.current_adapter.build_select_query(data.name, 100, data.database, data.schema) + # Set target database for query execution (needed for Azure SQL) + self._query_target_database = data.database self.action_execute_query() return @@ -641,8 +732,11 @@ def _show_index_info(self: AppProtocol, data: IndexNode) -> None: return try: + db_arg = data.database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(data.database) info = self._session.adapter.get_index_definition( - self._session.connection, data.name, data.table_name, data.database + self._session.connection, data.name, data.table_name, db_arg ) self._display_object_info("Index", info) except Exception as e: @@ -654,8 +748,11 @@ def _show_trigger_info(self: AppProtocol, data: TriggerNode) -> None: return try: + db_arg = data.database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(data.database) info = self._session.adapter.get_trigger_definition( - self._session.connection, data.name, data.table_name, data.database + self._session.connection, data.name, data.table_name, db_arg ) self._display_object_info("Trigger", info) except Exception as e: @@ -667,8 +764,11 @@ def _show_sequence_info(self: AppProtocol, data: SequenceNode) -> None: return try: + db_arg = data.database + if hasattr(self, "_get_metadata_db_arg"): + db_arg = self._get_metadata_db_arg(data.database) info = self._session.adapter.get_sequence_definition( - self._session.connection, data.name, data.database + self._session.connection, data.name, db_arg ) self._display_object_info("Sequence", info) except Exception as e: @@ -707,32 +807,140 @@ def _display_object_info(self: AppProtocol, object_type: str, info: dict) -> Non if definition: self.query_input.text = f"/*\n{definition}\n*/" - def set_default_database(self: AppProtocol, db_name: str) -> None: - """Set the default database for the current connection. + def _ensure_database_connection(self: AppProtocol, target_db: str) -> bool: + """Ensure we're connected to the target database, switching if needed. + + For adapters that don't support cross-database queries (PostgreSQL, etc.), + this will switch the connection if we're not already connected to the + target database. + + Args: + target_db: The database name we need to be connected to. + + Returns: + True if we're connected to the target database (or adapter supports + cross-db queries), False if switch failed. + """ + if not self.current_adapter or not self.current_config: + return False + + # For cross-db adapters, try USE approach first (no reconnection needed). + # Note: While MSSQL generally supports cross-database queries, some variants + # like Azure SQL have restrictions. If USE fails, we fall back to reconnection. + if self.current_adapter.supports_cross_database_queries: + current_active = getattr(self, "_active_database", None) + if not current_active or current_active.lower() != target_db.lower(): + try: + self.set_default_database(target_db) + except Exception: + # USE approach failed - fall back to reconnection + self._reconnect_to_database(target_db) + return True + + # For non-cross-db adapters, check if already connected to target database + current_db = self.current_config.database + if current_db and current_db.lower() == target_db.lower(): + return True + + # Need to reconnect - set_default_database handles this + self.set_default_database(target_db) + + # Verify switch succeeded + return ( + self.current_config.database is not None + and self.current_config.database.lower() == target_db.lower() + ) + + def _reconnect_to_database(self: AppProtocol, db_name: str) -> None: + """Reconnect to a different database without re-rendering the tree. + + Used for PostgreSQL and other databases that don't support cross-database + queries. Creates a new connection to the specified database while keeping + the tree structure intact. + """ + if not self._session: + return + + if hasattr(self, "_clear_query_target_database"): + self._clear_query_target_database() + + try: + self._session.switch_database(db_name) + + # Update app state to match session + self.current_config = self._session.config + self.current_connection = self._session.connection + + # Update UI + self.notify(f"Switched to database: {db_name}") + self._update_status_bar() + self._update_database_labels() + + # Clear caches and reload schema for autocomplete + self._db_object_cache = {} + self._load_schema_cache() + + except Exception as e: + self.notify(f"Failed to connect to {db_name}: {e}", severity="error") + + def set_default_database(self: AppProtocol, db_name: str | None) -> None: + """Set or clear the active database for the current connection. This is the shared function used by both the USE query handler and the explorer 'Use as default' action. + For databases that support cross-database queries (SQL Server, MySQL, etc.), + this just sets _active_database so queries use the right context. + + For databases that don't support cross-database queries (PostgreSQL, etc.), + this will reconnect to the selected database since each connection is + database-specific. + Args: - db_name: The database name to set as default. + db_name: The database name to set as active, or None to clear. """ - from dataclasses import replace - - if not self.current_config: + if not self.current_config or not self.current_adapter: self.notify("Not connected", severity="error") return - self.current_config = replace(self.current_config, database=db_name) - self.notify(f"Switched to database: {db_name}") + if hasattr(self, "_clear_query_target_database"): + self._clear_query_target_database() + + # Check if adapter supports cross-database queries + if not self.current_adapter.supports_cross_database_queries and db_name: + # For PostgreSQL, CockroachDB, etc. - need to reconnect to the new database + # Check if we're already connected to this database + current_db = self.current_config.database + if current_db and current_db.lower() == db_name.lower(): + # Already connected to this database, just update UI + self._active_database = db_name + self._update_status_bar() + self._update_database_labels() + return + + # Reconnect to the new database without re-rendering the tree + self._reconnect_to_database(db_name) + return + + # For databases that support cross-database queries, just update the active database + self._active_database = db_name + if db_name: + self.notify(f"Switched to database: {db_name}") + else: + self.notify("Cleared default database") self._update_status_bar() self._update_database_labels() + # Reload schema cache for autocomplete with new database context + self._load_schema_cache() def _update_database_labels(self: AppProtocol) -> None: - """Update database node labels to show the default database with a star.""" - if not self.current_config: + """Update database node labels to show the active database with a star.""" + if not self.current_config or not self.current_adapter: return - default_db = self.current_config.database + active_db = None + if hasattr(self, "_get_effective_database"): + active_db = self._get_effective_database() # Find the Databases folder and update labels for conn_node in self.object_tree.root.children: @@ -750,7 +958,8 @@ def _update_database_labels(self: AppProtocol) -> None: for db_node in child.children: if self._get_node_kind(db_node) == "database": db_name = db_node.data.name - if default_db and db_name.lower() == default_db.lower(): + is_active = active_db and db_name.lower() == active_db.lower() + if is_active: db_node.set_label(f"[#4ADE80]* {escape_markup(db_name)}[/]") else: db_node.set_label(escape_markup(db_name)) @@ -758,15 +967,23 @@ def _update_database_labels(self: AppProtocol) -> None: break def action_use_database(self: AppProtocol) -> None: - """Set the selected database as the default for the current connection.""" + """Toggle the selected database as the default for the current connection.""" node = self.object_tree.cursor_node if not node or self._get_node_kind(node) != "database": return - if not self.current_connection: + if not self.current_connection or not self.current_config: self.notify("Not connected", severity="error") return db_name = node.data.name - self.set_default_database(db_name) + current_active = None + if hasattr(self, "_get_effective_database"): + current_active = self._get_effective_database() + + # Toggle: if already active, clear it; otherwise set it + if current_active and current_active.lower() == db_name.lower(): + self.set_default_database(None) + else: + self.set_default_database(db_name) diff --git a/sqlit/ui/mixins/ui_navigation.py b/sqlit/ui/mixins/ui_navigation.py index 11a00973..4bcafd40 100644 --- a/sqlit/ui/mixins/ui_navigation.py +++ b/sqlit/ui/mixins/ui_navigation.py @@ -276,6 +276,10 @@ def _update_status_bar(self: AppProtocol) -> None: launch_str = f"[dim]Launched in {launch_ms:.0f}ms[/]" if show_launch else "" launch_plain = f"Launched in {launch_ms:.0f}ms" if show_launch else "" + # Combine right-side content + right_str = launch_str + right_plain = launch_plain + if notification: # Normal/warning notifications on right side import re @@ -300,7 +304,7 @@ def _update_status_bar(self: AppProtocol) -> None: status.update(f"{left_content}{' ' * gap}{notif_str}") else: status.update(f"{left_content} {notif_str}") - elif launch_str: + elif right_str: import re left_plain = re.sub(r"\[.*?\]", "", left_content) @@ -309,14 +313,55 @@ def _update_status_bar(self: AppProtocol) -> None: except Exception: total_width = 80 - gap = total_width - len(left_plain) - len(launch_plain) + gap = total_width - len(left_plain) - len(right_plain) if gap > 2: - status.update(f"{left_content}{' ' * gap}{launch_str}") + status.update(f"{left_content}{' ' * gap}{right_str}") else: - status.update(f"{left_content} {launch_str}") + status.update(f"{left_content} {right_str}") else: status.update(left_content) + def _update_idle_scheduler_bar(self: AppProtocol) -> None: + """Update the idle scheduler debug bar.""" + if not getattr(self, "_debug_idle_scheduler", False): + return + + try: + bar = self.idle_scheduler_bar + except Exception: + return + + from ...idle_scheduler import get_idle_scheduler + + scheduler = get_idle_scheduler() + if not scheduler: + bar.update("[dim]Idle Scheduler: Not initialized[/]") + return + + pending = scheduler.pending_jobs + is_idle = scheduler.is_idle + completed = scheduler._jobs_completed + work_time = scheduler._total_work_time_ms + + if pending > 0 and is_idle: + status = "[bold cyan]⚡ WORKING[/]" + details = f"[bold]{pending}[/] jobs pending" + elif pending > 0 and not is_idle: + status = "[yellow]⏸ POSTPONED[/]" + details = f"[bold]{pending}[/] jobs waiting for you to stop" + elif is_idle: + status = "[dim]💤 IDLE[/]" + details = "waiting for work" + else: + status = "[dim]👆 USER ACTIVE[/]" + details = "no pending work" + + bar.update( + f"{status} │ {details} │ " + f"[dim]{completed} completed[/] │ " + f"[dim]{work_time:.0f}ms worked[/]" + ) + def notify( self: AppProtocol, message: str, @@ -360,19 +405,19 @@ def notify( def _show_error_in_results(self: AppProtocol, message: str, timestamp: str) -> None: """Display error message in the results table.""" - import textwrap + import re error_text = f"[{timestamp}] {message}" if timestamp else message - # Wrap to table width (minus some padding), minimum 40 chars - wrap_width = max(40, self.results_table.size.width - 4) - wrapped = textwrap.fill(error_text, width=wrap_width) + # Replace newlines and collapse multiple whitespace to single space + # DataTable cells only show one line, so we flatten the error + error_text = re.sub(r"\s+", " ", error_text).strip() self._last_result_columns = ["Error"] - self._last_result_rows = [(wrapped,)] + self._last_result_rows = [(error_text,)] self._last_result_row_count = 1 - self._replace_results_table(["Error"], [(wrapped,)]) # type: ignore[attr-defined] + self._replace_results_table(["Error"], [(error_text,)]) # type: ignore[attr-defined] self._update_footer_bindings() def action_toggle_explorer(self: AppProtocol) -> None: diff --git a/sqlit/ui/screens/confirm.py b/sqlit/ui/screens/confirm.py index 95a99368..0854e9fd 100644 --- a/sqlit/ui/screens/confirm.py +++ b/sqlit/ui/screens/confirm.py @@ -17,7 +17,7 @@ class ConfirmScreen(ModalScreen): BINDINGS = [ Binding("y", "yes", "Yes", show=False), Binding("n", "no", "No", show=False), - Binding("escape", "cancel", "Cancel", show=False), + Binding("escape", "cancel", "Cancel", show=False, priority=True), Binding("enter", "select_option", "Select", show=False), ] diff --git a/sqlit/ui/screens/connection.py b/sqlit/ui/screens/connection.py index 0a708ea8..03bf16b5 100644 --- a/sqlit/ui/screens/connection.py +++ b/sqlit/ui/screens/connection.py @@ -65,7 +65,7 @@ class ConnectionScreen(ModalScreen): _INSTALL_SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] BINDINGS = [ - Binding("escape", "cancel", "Cancel"), + Binding("escape", "cancel", "Cancel", priority=True), Binding("ctrl+s", "save", "Save", priority=True), Binding("ctrl+t", "test_connection", "Test", priority=True), Binding("ctrl+i", "install_driver", "Install driver", show=False, priority=True), diff --git a/sqlit/ui/screens/leader_menu.py b/sqlit/ui/screens/leader_menu.py index f341bed1..2b9cb8d7 100644 --- a/sqlit/ui/screens/leader_menu.py +++ b/sqlit/ui/screens/leader_menu.py @@ -84,6 +84,12 @@ def compose(self) -> ComposeResult: def action_dismiss(self) -> None: # type: ignore[override] self.dismiss(None) + def on_key(self, event: Any) -> None: + """Handle key events - explicit ESC handling.""" + if event.key == "escape": + self.dismiss(None) + event.stop() + def _run_and_dismiss(self, action_name: str) -> None: """Run an app action and dismiss the menu.""" self.dismiss(action_name) diff --git a/sqlit/ui/screens/loading.py b/sqlit/ui/screens/loading.py index 15e9d17a..5784f229 100644 --- a/sqlit/ui/screens/loading.py +++ b/sqlit/ui/screens/loading.py @@ -16,7 +16,7 @@ class LoadingScreen(ModalScreen[None]): """Screen to display a loading message with a spinner.""" BINDINGS = [ - Binding("escape", "cancel", "Cancel", show=False), + Binding("escape", "cancel", "Cancel", show=False, priority=True), ] def __init__(self, message: str, *, on_cancel: Callable[[], None] | None = None): diff --git a/sqlit/ui/screens/package_setup.py b/sqlit/ui/screens/package_setup.py index ee797d63..7e612d6f 100644 --- a/sqlit/ui/screens/package_setup.py +++ b/sqlit/ui/screens/package_setup.py @@ -21,7 +21,7 @@ class PackageSetupScreen(ModalScreen): BINDINGS = [ Binding("i", "install", "Install"), Binding("y", "yank", "Yank"), - Binding("escape", "cancel", "Cancel"), + Binding("escape", "cancel", "Cancel", priority=True), ] CSS = """ diff --git a/sqlit/ui/screens/password_input.py b/sqlit/ui/screens/password_input.py index ff743452..2bcfbd65 100644 --- a/sqlit/ui/screens/password_input.py +++ b/sqlit/ui/screens/password_input.py @@ -18,7 +18,7 @@ class PasswordInputScreen(ModalScreen): """ BINDINGS = [ - Binding("escape", "cancel", "Cancel"), + Binding("escape", "cancel", "Cancel", priority=True), Binding("enter", "submit", "Submit", show=False), ] diff --git a/sqlit/ui/screens/query_history.py b/sqlit/ui/screens/query_history.py index ef27959d..54f930fc 100644 --- a/sqlit/ui/screens/query_history.py +++ b/sqlit/ui/screens/query_history.py @@ -19,7 +19,7 @@ class QueryHistoryScreen(ModalScreen): """Modal screen for query history selection.""" BINDINGS = [ - Binding("escape", "cancel", "Cancel"), + Binding("escape", "cancel", "Cancel", priority=True), Binding("q", "cancel", "Cancel"), Binding("enter", "select", "Select"), Binding("d", "delete", "Delete"), diff --git a/sqlit/ui/screens/theme.py b/sqlit/ui/screens/theme.py index f8940a8b..197ace44 100644 --- a/sqlit/ui/screens/theme.py +++ b/sqlit/ui/screens/theme.py @@ -162,7 +162,7 @@ class ThemeScreen(ModalScreen[str | None]): """Modal screen for theme selection with live preview.""" BINDINGS = [ - Binding("escape", "cancel", "Cancel"), + Binding("escape", "cancel", "Cancel", priority=True), Binding("enter", "select_option", "Select"), Binding("n", "new_theme", "New"), Binding("e", "edit_theme", "Edit"), diff --git a/sqlit/widgets.py b/sqlit/widgets.py index e5a14448..cd504057 100644 --- a/sqlit/widgets.py +++ b/sqlit/widgets.py @@ -6,9 +6,9 @@ from typing import TYPE_CHECKING, Any from textual.app import ComposeResult -from textual.containers import Container, Horizontal +from textual.containers import Container, Horizontal, VerticalScroll from textual.strip import Strip -from textual.widgets import Static +from textual.widgets import Static, TextArea from textual_fastdatatable import DataTable as FastDataTable if TYPE_CHECKING: @@ -18,6 +18,23 @@ from textual.widget import Widget +class QueryTextArea(TextArea): + """TextArea that defers Enter key to app when autocomplete is visible.""" + + def _on_key(self, event: Key) -> None: + """Intercept Enter key when autocomplete is visible.""" + if event.key == "enter": + # Check if autocomplete is visible on the app + app = self.app + if getattr(app, "_autocomplete_visible", False): + # Hide autocomplete and suppress re-triggering from the newline + if hasattr(app, "_hide_autocomplete"): + app._hide_autocomplete() + app._suppress_autocomplete_on_newline = True + # For all other keys, use default TextArea behavior + super()._on_key(event) + + class SqlitDataTable(FastDataTable): """FastDataTable with correct header behavior when show_header is False.""" @@ -296,21 +313,23 @@ def is_visible(self) -> bool: ResultsFilterInput = FilterInput -class AutocompleteDropdown(Static): - """Dropdown widget for SQL autocomplete suggestions.""" +class AutocompleteDropdown(VerticalScroll): + """Dropdown widget for SQL autocomplete suggestions with scrollbar.""" DEFAULT_CSS = """ AutocompleteDropdown { layer: autocomplete; width: auto; - min-width: 20; - max-width: 50; + min-width: 25; + max-width: 80; height: auto; - max-height: 10; + max-height: 12; background: $surface; border: round $border; padding: 0; display: none; + scrollbar-size: 1 1; + constrain: inside inside; } AutocompleteDropdown.visible { @@ -318,6 +337,8 @@ class AutocompleteDropdown(Static): } AutocompleteDropdown .autocomplete-item { + width: 100%; + height: 1; padding: 0 1; } @@ -328,7 +349,7 @@ class AutocompleteDropdown(Static): """ def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__("", *args, **kwargs) + super().__init__(*args, **kwargs) self.items: list[str] = [] self.filtered_items: list[str] = [] self.selected_index: int = 0 @@ -342,17 +363,36 @@ def set_items(self, items: list[str], filter_text: str = "") -> None: if self.filter_text: self.filtered_items = [item for item in items if item.lower().startswith(self.filter_text)] else: - self.filtered_items = items[:20] + self.filtered_items = items[:50] # Show more items with scrolling self.selected_index = 0 self._rebuild() + # Reset scroll to top + self.scroll_to(y=0, animate=False) def move_selection(self, delta: int) -> None: """Move selection up or down.""" if not self.filtered_items: return + old_index = self.selected_index self.selected_index = (self.selected_index + delta) % len(self.filtered_items) - self._rebuild() + self._update_selection(old_index, self.selected_index) + self._scroll_to_selected() + + def _update_selection(self, old_index: int, new_index: int) -> None: + """Update selection by toggling CSS classes (fast).""" + children = list(self.children) + if old_index < len(children): + children[old_index].remove_class("selected") + if new_index < len(children): + children[new_index].add_class("selected") + + def _scroll_to_selected(self) -> None: + """Scroll to ensure selected item is visible.""" + if not self.filtered_items: + return + # Each item is 1 line high, scroll to show selected + self.scroll_to(y=max(0, self.selected_index - 5), animate=False) def get_selected(self) -> str | None: """Get the currently selected item.""" @@ -361,26 +401,29 @@ def get_selected(self) -> str | None: return None def _rebuild(self) -> None: - """Rebuild the dropdown content.""" + """Rebuild the dropdown content (only called when items change).""" + # Remove all existing children + self.remove_children() + if not self.filtered_items: - self.update("[dim]No matches[/]") + self.mount(Static("[dim]No matches[/]")) return - lines = [] - for i, item in enumerate(self.filtered_items[:10]): + # Create item widgets + for i, item in enumerate(self.filtered_items): + label = Static(f" {item} ", classes="autocomplete-item") if i == self.selected_index: - lines.append(f"[reverse] {item} [/]") - else: - lines.append(f" {item} ") - self.update("\n".join(lines)) + label.add_class("selected") + self.mount(label) def show(self) -> None: """Show the dropdown.""" self.add_class("visible") def hide(self) -> None: - """Hide the dropdown.""" + """Hide the dropdown and reset selection.""" self.remove_class("visible") + self.selected_index = 0 @property def is_visible(self) -> bool: diff --git a/tests/.env.example b/tests/.env.example index d12791ba..5ac37e1d 100644 --- a/tests/.env.example +++ b/tests/.env.example @@ -17,3 +17,10 @@ # Local Docker settings (defaults shown) # TURSO_HOST=localhost # TURSO_PORT=8081 + +# Azure SQL Database settings +# Used for testing Azure SQL compatibility (no cross-database query support) +# AZURE_SQL_SERVER=your-server.database.windows.net +# AZURE_SQL_DATABASE=your-database +# AZURE_SQL_USER=your-username +# AZURE_SQL_PASSWORD=your-password diff --git a/tests/test_adapter_system_databases.py b/tests/test_adapter_system_databases.py new file mode 100644 index 00000000..6f8195f6 --- /dev/null +++ b/tests/test_adapter_system_databases.py @@ -0,0 +1,229 @@ +"""Unit tests for adapter system_databases property.""" + +import pytest + + +class TestSystemDatabasesProperty: + """Test that each adapter correctly defines system_databases.""" + + def test_base_adapter_returns_empty(self): + """Base DatabaseAdapter should return empty frozenset.""" + from sqlit.db.adapters.base import DatabaseAdapter + + # Create a minimal concrete implementation to test the base class + class ConcreteAdapter(DatabaseAdapter): + @property + def name(self) -> str: + return "Test" + + @property + def supports_multiple_databases(self) -> bool: + return True + + @property + def supports_stored_procedures(self) -> bool: + return False + + def connect(self, config): + pass + + def get_databases(self, conn): + return [] + + def get_tables(self, conn, database=None): + return [] + + def get_views(self, conn, database=None): + return [] + + def get_columns(self, conn, table, database=None, schema=None): + return [] + + def get_procedures(self, conn, database=None): + return [] + + def get_indexes(self, conn, database=None): + return [] + + def get_triggers(self, conn, database=None): + return [] + + def get_sequences(self, conn, database=None): + return [] + + def quote_identifier(self, name): + return f'"{name}"' + + def build_select_query(self, table, limit, database=None, schema=None): + return f"SELECT * FROM {table} LIMIT {limit}" + + def execute_query(self, conn, query, max_rows=None): + return [], [], False + + def execute_non_query(self, conn, query): + return 0 + + adapter = ConcreteAdapter() + assert adapter.system_databases == frozenset() + + def test_mssql_system_databases(self): + """SQL Server adapter should exclude master, tempdb, model, msdb.""" + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + expected = frozenset({"master", "tempdb", "model", "msdb"}) + assert adapter.system_databases == expected + + def test_postgresql_system_databases(self): + """PostgreSQL adapter should exclude template0, template1.""" + from sqlit.db.adapters.postgresql import PostgreSQLAdapter + + adapter = PostgreSQLAdapter() + expected = frozenset({"template0", "template1"}) + assert adapter.system_databases == expected + + def test_cockroachdb_inherits_postgres_system_databases(self): + """CockroachDB inherits PostgreSQL's system_databases.""" + from sqlit.db.adapters.cockroachdb import CockroachDBAdapter + + adapter = CockroachDBAdapter() + expected = frozenset({"template0", "template1"}) + assert adapter.system_databases == expected + + def test_mysql_system_databases(self): + """MySQL adapter should exclude mysql, information_schema, performance_schema, sys.""" + from sqlit.db.adapters.mysql import MySQLAdapter + + adapter = MySQLAdapter() + expected = frozenset({"mysql", "information_schema", "performance_schema", "sys"}) + assert adapter.system_databases == expected + + def test_mariadb_inherits_mysql_system_databases(self): + """MariaDB inherits MySQL's system_databases.""" + from sqlit.db.adapters.mariadb import MariaDBAdapter + + adapter = MariaDBAdapter() + expected = frozenset({"mysql", "information_schema", "performance_schema", "sys"}) + assert adapter.system_databases == expected + + def test_clickhouse_system_databases(self): + """ClickHouse adapter should exclude system, information_schema.""" + from sqlit.db.adapters.clickhouse import ClickHouseAdapter + + adapter = ClickHouseAdapter() + assert "system" in adapter.system_databases + assert "information_schema" in adapter.system_databases or "INFORMATION_SCHEMA" in adapter.system_databases + + def test_snowflake_system_databases(self): + """Snowflake adapter should exclude SNOWFLAKE metadata database.""" + from sqlit.db.adapters.snowflake import SnowflakeAdapter + + adapter = SnowflakeAdapter() + # Case-insensitive: either SNOWFLAKE or snowflake should be present + lowercase_dbs = {s.lower() for s in adapter.system_databases} + assert "snowflake" in lowercase_dbs + + def test_sqlite_no_system_databases(self): + """SQLite (single-file) should return empty frozenset.""" + from sqlit.db.adapters.sqlite import SQLiteAdapter + + adapter = SQLiteAdapter() + assert adapter.system_databases == frozenset() + # Also verify it doesn't support multiple databases + assert adapter.supports_multiple_databases is False + + def test_duckdb_no_system_databases(self): + """DuckDB (single-file) should return empty frozenset.""" + from sqlit.db.adapters.duckdb import DuckDBAdapter + + adapter = DuckDBAdapter() + assert adapter.system_databases == frozenset() + assert adapter.supports_multiple_databases is False + + def test_turso_no_system_databases(self): + """Turso (SQLite-based) should return empty frozenset.""" + from sqlit.db.adapters.turso import TursoAdapter + + adapter = TursoAdapter() + assert adapter.system_databases == frozenset() + assert adapter.supports_multiple_databases is False + + def test_oracle_no_system_databases(self): + """Oracle (single-database with schemas) should return empty frozenset.""" + from sqlit.db.adapters.oracle import OracleAdapter + + adapter = OracleAdapter() + assert adapter.system_databases == frozenset() + assert adapter.supports_multiple_databases is False + + +class TestSystemDatabasesFiltering: + """Test that system_databases filtering works correctly.""" + + def test_lowercase_comparison(self): + """System databases should be compared case-insensitively.""" + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + system_dbs = {s.lower() for s in adapter.system_databases} + + # These should all be filtered out + all_dbs = ["master", "MASTER", "Master", "tempdb", "TEMPDB", "userdb"] + filtered = [d for d in all_dbs if d.lower() not in system_dbs] + + assert "userdb" in filtered + assert len(filtered) == 1 + + def test_filtering_preserves_user_databases(self): + """Filtering should keep all non-system databases.""" + from sqlit.db.adapters.postgresql import PostgreSQLAdapter + + adapter = PostgreSQLAdapter() + system_dbs = {s.lower() for s in adapter.system_databases} + + all_dbs = ["postgres", "myapp", "template0", "template1", "analytics"] + filtered = [d for d in all_dbs if d.lower() not in system_dbs] + + assert "postgres" in filtered + assert "myapp" in filtered + assert "analytics" in filtered + assert "template0" not in filtered + assert "template1" not in filtered + assert len(filtered) == 3 + + def test_empty_system_databases_filters_nothing(self): + """Empty system_databases should not filter any databases.""" + from sqlit.db.adapters.sqlite import SQLiteAdapter + + adapter = SQLiteAdapter() + system_dbs = {s.lower() for s in adapter.system_databases} + + all_dbs = ["db1", "db2", "system", "master"] + filtered = [d for d in all_dbs if d.lower() not in system_dbs] + + assert filtered == all_dbs + + +class TestSystemDatabasesInterface: + """Test that system_databases property has correct interface.""" + + def test_returns_frozenset(self): + """system_databases should return a frozenset (immutable).""" + from sqlit.db.adapters.mssql import SQLServerAdapter + from sqlit.db.adapters.postgresql import PostgreSQLAdapter + + for AdapterClass in [SQLServerAdapter, PostgreSQLAdapter]: + adapter = AdapterClass() + assert isinstance(adapter.system_databases, frozenset) + + def test_property_is_idempotent(self): + """Multiple calls to system_databases should return same object.""" + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + first_call = adapter.system_databases + second_call = adapter.system_databases + + assert first_call == second_call + # frozensets are immutable so this is safe + assert first_call is not None diff --git a/tests/test_autocomplete_database_modes.py b/tests/test_autocomplete_database_modes.py new file mode 100644 index 00000000..9f51b6bb --- /dev/null +++ b/tests/test_autocomplete_database_modes.py @@ -0,0 +1,282 @@ +"""Test autocomplete behavior with different database selection modes.""" + +import pytest +from unittest.mock import MagicMock, PropertyMock, patch + + +class TestAutocompleteDatabaseModes: + """Test how autocomplete handles various database configurations.""" + + def test_single_user_db_after_system_filter_shows_unqualified(self): + """ + Scenario: Server has 4 databases, 3 are system (filtered), 1 is user db. + No database is selected (_active_database=None, config.database=None). + + After filtering, only 1 database remains -> should use unqualified names. + """ + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + + # Simulate: no database selected + _active_database = None + config_database = None + + # Server returns all databases + all_dbs_from_server = ["master", "tempdb", "model", "msdb", "UserAppDb"] + + # Filter system databases (this is what autocomplete does) + system_dbs = {s.lower() for s in adapter.system_databases} + filtered_dbs = [d for d in all_dbs_from_server if d.lower() not in system_dbs] + + # Only UserAppDb should remain + assert filtered_dbs == ["UserAppDb"] + + # Determine single_db mode + db = _active_database or config_database + if db: + databases = [db] + else: + databases = filtered_dbs # Use filtered list + + single_db = len(databases) == 1 + + # With only 1 database after filtering, should be single_db mode + assert single_db is True, "Should be single_db mode when only 1 user database exists" + + # Build schema cache + schema_cache = {"tables": [], "views": [], "columns": {}, "procedures": []} + tables = [("dbo", "Users"), ("dbo", "Orders")] + + for schema_name, table_name in tables: + if single_db: + schema_cache["tables"].append(table_name) + else: + quoted_db = adapter.quote_identifier(databases[0]) + quoted_schema = adapter.quote_identifier(schema_name) + quoted_table = adapter.quote_identifier(table_name) + full_name = f"{quoted_db}.{quoted_schema}.{quoted_table}" + schema_cache["tables"].append(full_name) + + # Should have UNQUALIFIED names since only 1 db + assert schema_cache["tables"] == ["Users", "Orders"] + assert "[UserAppDb]" not in str(schema_cache["tables"]) + + def test_multiple_user_dbs_after_system_filter_shows_qualified(self): + """ + Scenario: Server has 5 databases, 3 are system (filtered), 2 are user dbs. + No database is selected. + + After filtering, 2 databases remain -> should use fully qualified names. + """ + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + + # Simulate: no database selected + _active_database = None + config_database = None + + # Server returns all databases + all_dbs_from_server = ["master", "tempdb", "model", "msdb", "AppDb1", "AppDb2"] + + # Filter system databases + system_dbs = {s.lower() for s in adapter.system_databases} + filtered_dbs = [d for d in all_dbs_from_server if d.lower() not in system_dbs] + + # Two user databases should remain + assert filtered_dbs == ["AppDb1", "AppDb2"] + + # Determine single_db mode + db = _active_database or config_database + if db: + databases = [db] + else: + databases = filtered_dbs + + single_db = len(databases) == 1 + + # With 2 databases, should NOT be single_db mode + assert single_db is False, "Should NOT be single_db mode with multiple user databases" + + # Build schema cache + schema_cache = {"tables": [], "views": [], "columns": {}, "procedures": []} + tables_by_db = { + "AppDb1": [("dbo", "Users")], + "AppDb2": [("dbo", "Products")], + } + + for database in databases: + for schema_name, table_name in tables_by_db[database]: + if single_db: + schema_cache["tables"].append(table_name) + else: + quoted_db = adapter.quote_identifier(database) + quoted_schema = adapter.quote_identifier(schema_name) + quoted_table = adapter.quote_identifier(table_name) + full_name = f"{quoted_db}.{quoted_schema}.{quoted_table}" + schema_cache["tables"].append(full_name) + + # Should have QUALIFIED names + assert "[AppDb1].[dbo].[Users]" in schema_cache["tables"] + assert "[AppDb2].[dbo].[Products]" in schema_cache["tables"] + + def test_no_user_dbs_after_filter_shows_empty(self): + """ + Scenario: Server only has system databases, all get filtered. + No database is selected. + + After filtering, 0 databases remain -> empty autocomplete. + """ + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + + # Server returns only system databases + all_dbs_from_server = ["master", "tempdb", "model", "msdb"] + + # Filter system databases + system_dbs = {s.lower() for s in adapter.system_databases} + filtered_dbs = [d for d in all_dbs_from_server if d.lower() not in system_dbs] + + # No databases should remain + assert filtered_dbs == [] + + # Schema cache should be empty + schema_cache = {"tables": [], "views": [], "columns": {}, "procedures": []} + + assert schema_cache["tables"] == [] + + def test_selected_db_overrides_filter_logic(self): + """ + Scenario: Multiple user databases exist, but one is explicitly selected. + + Should use single_db mode with unqualified names for the selected db. + """ + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + + # User selected a specific database + _active_database = "AppDb1" + config_database = None + + # Server has multiple user databases + all_dbs_from_server = ["master", "tempdb", "AppDb1", "AppDb2", "AppDb3"] + + # When a database is selected, we DON'T load all databases + db = _active_database or config_database + if db: + databases = [db] # Only the selected one + else: + system_dbs = {s.lower() for s in adapter.system_databases} + databases = [d for d in all_dbs_from_server if d.lower() not in system_dbs] + + single_db = len(databases) == 1 + + # Should be single_db mode because we selected one + assert single_db is True + assert databases == ["AppDb1"] + + # Build schema cache - should be unqualified + schema_cache = {"tables": []} + tables = [("dbo", "Users"), ("dbo", "Orders")] + + for schema_name, table_name in tables: + if single_db: + schema_cache["tables"].append(table_name) + + assert schema_cache["tables"] == ["Users", "Orders"] + + def test_select_then_unselect_database(self): + """ + Scenario: + 1. User selects a database -> unqualified names + 2. User unselects (clears) database -> qualified names (if multiple dbs) + """ + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + + all_user_dbs = ["AppDb1", "AppDb2"] + + # Phase 1: Database selected + _active_database = "AppDb1" + db = _active_database or None + databases = [db] if db else all_user_dbs + single_db = len(databases) == 1 + + assert single_db is True + + schema_cache = {"tables": []} + for table_name in ["Users", "Orders"]: + if single_db: + schema_cache["tables"].append(table_name) + + assert schema_cache["tables"] == ["Users", "Orders"] + + # Phase 2: Database unselected + _active_database = None + db = _active_database or None + databases = [db] if db else all_user_dbs + single_db = len(databases) == 1 + + assert single_db is False + + # Reload cache with qualified names + schema_cache = {"tables": []} + tables_by_db = {"AppDb1": ["Users"], "AppDb2": ["Products"]} + + for database in databases: + for table_name in tables_by_db[database]: + if single_db: + schema_cache["tables"].append(table_name) + else: + full_name = f"[{database}].[dbo].[{table_name}]" + schema_cache["tables"].append(full_name) + + assert "[AppDb1].[dbo].[Users]" in schema_cache["tables"] + assert "[AppDb2].[dbo].[Products]" in schema_cache["tables"] + + +class TestSystemDatabaseFiltering: + """Test system database filtering across different adapters.""" + + def test_mssql_filters_system_databases(self): + """MSSQL should filter master, tempdb, model, msdb.""" + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + all_dbs = ["master", "MASTER", "tempdb", "model", "msdb", "UserDb"] + + system_dbs = {s.lower() for s in adapter.system_databases} + filtered = [d for d in all_dbs if d.lower() not in system_dbs] + + assert filtered == ["UserDb"] + + def test_postgres_filters_template_databases(self): + """PostgreSQL should filter template0, template1.""" + from sqlit.db.adapters.postgresql import PostgreSQLAdapter + + adapter = PostgreSQLAdapter() + all_dbs = ["postgres", "template0", "template1", "myapp"] + + system_dbs = {s.lower() for s in adapter.system_databases} + filtered = [d for d in all_dbs if d.lower() not in system_dbs] + + assert "postgres" in filtered + assert "myapp" in filtered + assert "template0" not in filtered + assert "template1" not in filtered + + def test_mysql_filters_system_databases(self): + """MySQL should filter mysql, information_schema, performance_schema, sys.""" + from sqlit.db.adapters.mysql import MySQLAdapter + + adapter = MySQLAdapter() + all_dbs = ["mysql", "information_schema", "performance_schema", "sys", "myapp"] + + system_dbs = {s.lower() for s in adapter.system_databases} + filtered = [d for d in all_dbs if d.lower() not in system_dbs] + + assert filtered == ["myapp"] diff --git a/tests/test_cross_database_queries.py b/tests/test_cross_database_queries.py new file mode 100644 index 00000000..3285368e --- /dev/null +++ b/tests/test_cross_database_queries.py @@ -0,0 +1,197 @@ +"""Tests for supports_cross_database_queries property and validation.""" + +import pytest + + +class TestCrossDatabaseQueriesProperty: + """Test the supports_cross_database_queries adapter property.""" + + def test_base_adapter_defaults_to_true(self): + """Base DatabaseAdapter should default to True (supports cross-db queries).""" + from sqlit.db.adapters.base import DatabaseAdapter + + # Create a minimal concrete implementation for testing + class TestAdapter(DatabaseAdapter): + @property + def name(self): + return "Test" + + @property + def supports_multiple_databases(self): + return True + + @property + def supports_stored_procedures(self): + return False + + def connect(self, config): + pass + + def get_databases(self, conn): + return [] + + def get_tables(self, conn, database=None): + return [] + + def get_views(self, conn, database=None): + return [] + + def get_columns(self, conn, table, database=None, schema=None): + return [] + + def get_procedures(self, conn, database=None): + return [] + + def get_indexes(self, conn, database=None): + return [] + + def get_triggers(self, conn, database=None): + return [] + + def get_sequences(self, conn, database=None): + return [] + + def quote_identifier(self, name): + return f'"{name}"' + + def build_select_query(self, table, limit, database=None, schema=None): + return f"SELECT * FROM {table} LIMIT {limit}" + + def execute_query(self, conn, query, max_rows=None): + return [], [], False + + def execute_non_query(self, conn, query): + return 0 + + adapter = TestAdapter() + assert adapter.supports_cross_database_queries is True + + def test_mssql_supports_cross_database_queries(self): + """MSSQL adapter should support cross-database queries.""" + from sqlit.db.adapters.mssql import SQLServerAdapter + + adapter = SQLServerAdapter() + assert adapter.supports_cross_database_queries is True + + def test_mysql_supports_cross_database_queries(self): + """MySQL adapter should support cross-database queries.""" + from sqlit.db.adapters.mysql import MySQLAdapter + + adapter = MySQLAdapter() + assert adapter.supports_cross_database_queries is True + + def test_postgresql_does_not_support_cross_database_queries(self): + """PostgreSQL adapter should NOT support cross-database queries.""" + from sqlit.db.adapters.postgresql import PostgreSQLAdapter + + adapter = PostgreSQLAdapter() + assert adapter.supports_cross_database_queries is False + + def test_cockroachdb_does_not_support_cross_database_queries(self): + """CockroachDB adapter (extends PostgresBaseAdapter) should NOT support cross-db queries.""" + from sqlit.db.adapters.cockroachdb import CockroachDBAdapter + + adapter = CockroachDBAdapter() + assert adapter.supports_cross_database_queries is False + + def test_d1_does_not_support_cross_database_queries(self): + """D1 adapter should NOT support cross-database queries.""" + from sqlit.db.adapters.d1 import D1Adapter + + adapter = D1Adapter() + assert adapter.supports_cross_database_queries is False + + def test_clickhouse_supports_cross_database_queries(self): + """ClickHouse adapter should support cross-database queries.""" + from sqlit.db.adapters.clickhouse import ClickHouseAdapter + + adapter = ClickHouseAdapter() + assert adapter.supports_cross_database_queries is True + + def test_snowflake_supports_cross_database_queries(self): + """Snowflake adapter should support cross-database queries.""" + from sqlit.db.adapters.snowflake import SnowflakeAdapter + + adapter = SnowflakeAdapter() + assert adapter.supports_cross_database_queries is True + + def test_mariadb_supports_cross_database_queries(self): + """MariaDB adapter should support cross-database queries.""" + from sqlit.db.adapters.mariadb import MariaDBAdapter + + adapter = MariaDBAdapter() + assert adapter.supports_cross_database_queries is True + + +class TestRequiresDatabaseSelection: + """Test the requires_database_selection helper function.""" + + def test_postgresql_requires_database_selection(self): + """PostgreSQL should require database selection.""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("postgresql") is True + + def test_cockroachdb_requires_database_selection(self): + """CockroachDB should require database selection.""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("cockroachdb") is True + + def test_d1_requires_database_selection(self): + """D1 should require database selection.""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("d1") is True + + def test_mssql_does_not_require_database_selection(self): + """MSSQL should NOT require database selection.""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("mssql") is False + + def test_mysql_does_not_require_database_selection(self): + """MySQL should NOT require database selection.""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("mysql") is False + + def test_clickhouse_does_not_require_database_selection(self): + """ClickHouse should NOT require database selection.""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("clickhouse") is False + + def test_unknown_db_type_returns_false(self): + """Unknown database type should return False (fail open).""" + from sqlit.db.providers import requires_database_selection + + assert requires_database_selection("unknown_db_type") is False + + +class TestValidateDatabaseRequired: + """Test the validate_database_required helper function.""" + + def test_validate_database_required_raises_when_needed(self): + """validate_database_required raises for databases that need it.""" + from sqlit.db.providers import validate_database_required + + # PostgreSQL requires database but validation doesn't block - user selects in explorer + # This is now a no-op for UI flow, but still useful for programmatic validation + with pytest.raises(ValueError): + validate_database_required("postgresql", None) + + def test_validate_database_required_passes_with_database(self): + """validate_database_required passes when database is provided.""" + from sqlit.db.providers import validate_database_required + + # Should not raise + validate_database_required("postgresql", "mydb") + + def test_validate_database_required_passes_for_cross_db_adapters(self): + """validate_database_required passes for adapters supporting cross-db queries.""" + from sqlit.db.providers import validate_database_required + + # Should not raise - MSSQL supports cross-database queries + validate_database_required("mssql", None) + validate_database_required("mssql", "") diff --git a/tests/ui/explorer/test_tree_expansion.py b/tests/ui/explorer/test_tree_expansion.py index dc4b66b0..93c66557 100644 --- a/tests/ui/explorer/test_tree_expansion.py +++ b/tests/ui/explorer/test_tree_expansion.py @@ -51,6 +51,7 @@ class MockAdapter: def __init__(self, tables: list[tuple[str, str]], default_schema: str = "public"): self._tables = tables self._default_schema = default_schema + self.supports_cross_database_queries = True # Default to True for tests @property def default_schema(self) -> str: @@ -63,6 +64,14 @@ def get_views(self, conn, database=None) -> list[tuple[str, str]]: return [] +class MockConfig: + """Mock connection config.""" + + def __init__(self, database: str | None = None, name: str = "test_connection"): + self.database = database + self.name = name + + class MockSession: def __init__(self, adapter): self.adapter = adapter @@ -86,10 +95,17 @@ def _create_mixin_with_adapter(self, tables: list[tuple[str, str]], default_sche mixin._session = MockSession(adapter) mixin._loading_nodes = set() mixin._expanded_paths = set() + mixin._active_database = "mydb" mixin.current_connection = MagicMock() mixin.current_adapter = adapter + mixin.current_config = MockConfig(database="mydb") mixin.object_tree = MockTree() mixin.call_later = lambda fn: None + # Mock methods called by set_default_database + mixin.notify = MagicMock() + mixin._update_status_bar = MagicMock() + mixin._update_database_labels = MagicMock() + mixin._load_schema_cache = MagicMock() return mixin, adapter def test_expand_folder_triggers_async_load(self): diff --git a/tests/unit/sql_completion/__init__.py b/tests/unit/sql_completion/__init__.py new file mode 100644 index 00000000..d580783e --- /dev/null +++ b/tests/unit/sql_completion/__init__.py @@ -0,0 +1 @@ +"""SQL completion tests.""" diff --git a/tests/unit/sql_completion/test_alter_table.py b/tests/unit/sql_completion/test_alter_table.py new file mode 100644 index 00000000..e8b840b0 --- /dev/null +++ b/tests/unit/sql_completion/test_alter_table.py @@ -0,0 +1,196 @@ +"""Tests for ALTER TABLE statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + ALTER_OPERATIONS, + SQL_CONSTRAINTS, + SQL_DATA_TYPES, + get_completions, +) + + +class TestAlterTableStatements: + """Tests for ALTER TABLE autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "created_at"], + "orders": ["id", "user_id", "total", "status"], + "products": ["id", "name", "price", "category"], + }, + "procedures": [], + } + + def test_alter_table_suggests_tables(self, schema): + """ALTER TABLE should suggest table names.""" + sql = "ALTER TABLE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + assert "products" in completions + + def test_alter_table_partial_table(self, schema): + """Typing partial table name should filter.""" + sql = "ALTER TABLE us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" not in completions + + def test_alter_table_after_table_suggests_operations(self, schema): + """After table name, should suggest ALTER operations.""" + sql = "ALTER TABLE users " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ADD" in completions + assert "DROP" in completions + assert "ALTER" in completions + assert "MODIFY" in completions + assert "RENAME" in completions + + def test_alter_table_partial_operation(self, schema): + """Typing partial operation should filter.""" + sql = "ALTER TABLE users AD" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ADD" in completions + assert "ADD COLUMN" in completions + assert "DROP" not in completions + + def test_alter_table_drop_column_suggests_columns(self, schema): + """DROP COLUMN should suggest existing columns.""" + sql = "ALTER TABLE users DROP COLUMN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "email" in completions + + def test_alter_table_drop_without_column_keyword(self, schema): + """DROP (without COLUMN) should suggest existing columns.""" + sql = "ALTER TABLE users DROP " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_alter_table_alter_column_suggests_columns(self, schema): + """ALTER COLUMN should suggest existing columns.""" + sql = "ALTER TABLE orders ALTER COLUMN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "user_id" in completions + assert "total" in completions + + def test_alter_table_modify_column_suggests_columns(self, schema): + """MODIFY COLUMN should suggest existing columns.""" + sql = "ALTER TABLE products MODIFY COLUMN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "price" in completions + + def test_alter_table_rename_column_suggests_columns(self, schema): + """RENAME COLUMN should suggest existing columns.""" + sql = "ALTER TABLE users RENAME COLUMN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "email" in completions + + def test_alter_table_add_column_type(self, schema): + """ADD COLUMN name should suggest data types.""" + sql = "ALTER TABLE users ADD COLUMN age " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "INT" in completions + assert "VARCHAR" in completions + assert "TEXT" in completions + + def test_alter_table_add_without_column_keyword(self, schema): + """ADD name (without COLUMN) should suggest data types.""" + sql = "ALTER TABLE users ADD age " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "INT" in completions + assert "VARCHAR" in completions + + def test_alter_table_add_column_constraint(self, schema): + """After data type, should suggest constraints.""" + sql = "ALTER TABLE users ADD COLUMN age INT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "NOT NULL" in completions + assert "DEFAULT" in completions + assert "UNIQUE" in completions + + def test_alter_table_references_suggests_tables(self, schema): + """REFERENCES should suggest table names.""" + sql = "ALTER TABLE orders ADD COLUMN product_id INT REFERENCES " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "products" in completions + assert "users" in completions + + def test_alter_table_references_table_paren_suggests_columns(self, schema): + """REFERENCES table( should suggest columns from that table.""" + sql = "ALTER TABLE orders ADD COLUMN product_id INT REFERENCES products(" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "price" in completions + + def test_alter_table_partial_column_filter(self, schema): + """Typing partial column name should filter.""" + sql = "ALTER TABLE users DROP COLUMN em" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "email" in completions + assert "id" not in completions + + def test_alter_table_unknown_table(self, schema): + """ALTER TABLE with unknown table should return empty columns.""" + sql = "ALTER TABLE unknown_table DROP COLUMN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should return empty list since table doesn't exist + assert completions == [] + + +class TestAlterOperations: + """Tests for ALTER_OPERATIONS list.""" + + def test_alter_operations_not_empty(self): + """Should have ALTER operations defined.""" + assert len(ALTER_OPERATIONS) > 10 + assert "ADD" in ALTER_OPERATIONS + assert "DROP" in ALTER_OPERATIONS + assert "ALTER" in ALTER_OPERATIONS + assert "MODIFY" in ALTER_OPERATIONS + assert "RENAME" in ALTER_OPERATIONS diff --git a/tests/unit/sql_completion/test_core.py b/tests/unit/sql_completion/test_core.py new file mode 100644 index 00000000..bbf38c74 --- /dev/null +++ b/tests/unit/sql_completion/test_core.py @@ -0,0 +1,441 @@ +"""Tests for core SQL completion utilities.""" + +import pytest + +from sqlit.sql_completion import ( + RESERVED_WORDS, + SQL_FUNCTIONS, + SQL_KEYWORDS, + TableRef, + build_alias_map, + extract_cte_names, + extract_table_refs, + find_context_keyword, + find_current_clause, + find_last_keyword, + fuzzy_match, + get_all_functions, + get_all_keywords, + get_completions, + get_current_word, + is_inside_string, + remove_comments, + remove_string_literals, +) + + +class TestFuzzyMatch: + """Tests for fuzzy matching algorithm.""" + + def test_exact_prefix_match(self): + """Exact prefix matches should come first.""" + candidates = ["users", "user_logs", "orders", "user_settings"] + result = fuzzy_match("user", candidates) + assert result[0].startswith("user") + assert result[1].startswith("user") + + def test_fuzzy_match_subsequence(self): + """Fuzzy match should find subsequence matches.""" + candidates = ["django_migrations", "django_session", "orders"] + result = fuzzy_match("djmi", candidates) + assert "django_migrations" in result + + def test_fuzzy_match_case_insensitive(self): + """Fuzzy match should be case insensitive.""" + candidates = ["UserSettings", "user_logs", "USERS"] + result = fuzzy_match("user", candidates) + assert len(result) == 3 + + def test_empty_text_returns_all(self): + """Empty text should return all candidates up to max.""" + candidates = ["a", "b", "c", "d", "e"] + result = fuzzy_match("", candidates, max_results=3) + assert len(result) == 3 + + def test_no_match_returns_empty(self): + """No match should return empty list.""" + candidates = ["users", "orders", "products"] + result = fuzzy_match("xyz", candidates) + assert result == [] + + def test_max_results_limit(self): + """Should respect max_results limit.""" + candidates = [f"user_{i}" for i in range(100)] + result = fuzzy_match("user", candidates, max_results=10) + assert len(result) == 10 + + def test_prefix_match_before_fuzzy(self): + """Prefix matches should come before fuzzy matches.""" + candidates = ["ab_cd", "abcd", "a_b_c_d"] + result = fuzzy_match("ab", candidates) + assert result[0] in ["ab_cd", "abcd"] + assert result[1] in ["ab_cd", "abcd"] + + +class TestExtractTableRefs: + """Tests for table reference extraction.""" + + def test_simple_from(self): + """Extract simple FROM table.""" + sql = "SELECT * FROM users" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + assert refs[0].alias is None + + def test_from_with_alias(self): + """Extract FROM table with alias.""" + sql = "SELECT * FROM users u" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + assert refs[0].alias == "u" + + def test_from_with_as_alias(self): + """Extract FROM table with AS alias.""" + sql = "SELECT * FROM users AS u" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + assert refs[0].alias == "u" + + def test_multiple_joins(self): + """Extract multiple tables with JOINs.""" + sql = """ + SELECT * FROM users u + JOIN orders o ON u.id = o.user_id + LEFT JOIN products p ON o.product_id = p.id + """ + refs = extract_table_refs(sql) + assert len(refs) == 3 + assert refs[0].name == "users" + assert refs[0].alias == "u" + assert refs[1].name == "orders" + assert refs[1].alias == "o" + assert refs[2].name == "products" + assert refs[2].alias == "p" + + def test_schema_qualified_table(self): + """Extract schema.table reference.""" + sql = "SELECT * FROM dbo.users u" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + assert refs[0].alias == "u" + assert refs[0].schema == "dbo" + + def test_bracketed_identifiers(self): + """Extract tables with bracket quoting.""" + sql = "SELECT * FROM [dbo].[User Settings] u" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].schema == "dbo" + + def test_reserved_word_not_alias(self): + """Reserved words should not be detected as aliases.""" + sql = "SELECT * FROM users WHERE id = 1" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + assert refs[0].alias is None + + def test_inner_join(self): + """Extract INNER JOIN table.""" + sql = "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id" + refs = extract_table_refs(sql) + assert len(refs) == 2 + + def test_right_join(self): + """Extract RIGHT JOIN table.""" + sql = "SELECT * FROM users RIGHT JOIN orders o ON users.id = o.user_id" + refs = extract_table_refs(sql) + assert len(refs) == 2 + assert refs[1].alias == "o" + + def test_case_insensitive(self): + """FROM and JOIN should be case insensitive.""" + sql = "select * from Users u join Orders o on u.id = o.user_id" + refs = extract_table_refs(sql) + assert len(refs) == 2 + + def test_double_quoted_table(self): + """Extract double-quoted table (PostgreSQL style).""" + sql = 'SELECT * FROM "books" WHERE id = 1' + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "books" + + def test_double_quoted_table_with_alias(self): + """Extract double-quoted table with alias.""" + sql = 'SELECT * FROM "user_accounts" ua WHERE ua.id = 1' + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "user_accounts" + assert refs[0].alias == "ua" + + def test_backtick_quoted_table(self): + """Extract backtick-quoted table (MySQL style).""" + sql = "SELECT * FROM `orders` WHERE id = 1" + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "orders" + + def test_mixed_quoting_styles(self): + """Handle mixed quoting styles in same query.""" + sql = 'SELECT * FROM "users" u JOIN `orders` o ON u.id = o.user_id' + refs = extract_table_refs(sql) + assert len(refs) == 2 + assert refs[0].name == "users" + assert refs[1].name == "orders" + + def test_quoted_schema_and_table(self): + """Extract quoted schema.table reference.""" + sql = 'SELECT * FROM "public"."users"' + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].schema == "public" + assert refs[0].name == "users" + + def test_spaces_in_quoted_identifier(self): + """Extract table name with spaces (quoted).""" + sql = 'SELECT * FROM "user accounts" ua' + refs = extract_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "user accounts" + assert refs[0].alias == "ua" + + +class TestExtractCTENames: + """Tests for CTE name extraction.""" + + def test_simple_cte(self): + """Extract simple CTE name.""" + sql = """ + WITH active_users AS ( + SELECT * FROM users WHERE active = 1 + ) + SELECT * FROM active_users + """ + ctes = extract_cte_names(sql) + assert ctes == ["active_users"] + + def test_multiple_ctes(self): + """Extract multiple CTE names.""" + sql = """ + WITH + active_users AS (SELECT * FROM users WHERE active = 1), + recent_orders AS (SELECT * FROM orders WHERE date > '2024-01-01') + SELECT * FROM active_users au JOIN recent_orders ro ON au.id = ro.user_id + """ + ctes = extract_cte_names(sql) + assert "active_users" in ctes + assert "recent_orders" in ctes + + def test_no_cte(self): + """No CTE should return empty list.""" + sql = "SELECT * FROM users" + ctes = extract_cte_names(sql) + assert ctes == [] + + def test_recursive_cte(self): + """Extract recursive CTE name.""" + sql = """ + WITH RECURSIVE employee_tree AS ( + SELECT id, name, manager_id FROM employees WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id FROM employees e + JOIN employee_tree et ON e.manager_id = et.id + ) + SELECT * FROM employee_tree + """ + ctes = extract_cte_names(sql) + assert "employee_tree" in ctes + + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_remove_string_literals(self): + """String literals should be replaced.""" + sql = "SELECT * FROM users WHERE name = 'John' AND bio LIKE '%test%'" + result = remove_string_literals(sql) + assert "John" not in result + assert "test" not in result + + def test_remove_comments_single_line(self): + """Single line comments should be removed.""" + sql = "SELECT * FROM users -- this is a comment\nWHERE id = 1" + result = remove_comments(sql) + assert "this is a comment" not in result + + def test_remove_comments_multi_line(self): + """Multi-line comments should be removed.""" + sql = "SELECT * /* comment */ FROM users" + result = remove_comments(sql) + assert "comment" not in result + + def test_find_last_keyword(self): + """Should find the last keyword.""" + assert find_last_keyword("SELECT * FROM ") == "from" + assert find_last_keyword("SELECT * FROM users WHERE ") == "where" + assert find_last_keyword("SELECT id, ") == "," + + def test_find_current_clause(self): + """Should find the current clause.""" + assert find_current_clause("SELECT id, name") == "select" + assert find_current_clause("SELECT * FROM users WHERE") == "where" + assert find_current_clause("SELECT * FROM users u JOIN orders") == "join" + + def test_find_context_keyword(self): + """Should find the context keyword before the current word.""" + assert find_context_keyword("SELECT * FROM us") == "from" + assert find_context_keyword("SELECT * FROM users WHERE na") == "where" + assert find_context_keyword("SELECT * FROM ") == "from" + assert find_context_keyword("SELECT * FROM users WHERE ") == "where" + assert find_context_keyword("SELECT id, ") == "," + assert find_context_keyword("SELECT id,") == "," + + def test_get_current_word(self): + """Should extract word being typed.""" + assert get_current_word("SELECT us", 9) == "us" + assert get_current_word("SELECT * FROM ", 14) == "" + assert get_current_word("SELECT u.na", 11) == "na" + + def test_build_alias_map(self): + """Should build correct alias map.""" + refs = [ + TableRef(name="users", alias="u"), + TableRef(name="orders", alias="o"), + TableRef(name="unknown", alias="x"), + ] + known_tables = ["users", "orders", "products"] + alias_map = build_alias_map(refs, known_tables) + assert alias_map == {"u": "users", "o": "orders"} + + +class TestInsideString: + """Tests for string literal detection.""" + + def test_inside_single_quote_string(self): + """Should detect cursor inside single-quoted string.""" + assert is_inside_string("SELECT * FROM users WHERE name = '") is True + assert is_inside_string("SELECT * FROM users WHERE name = 'John") is True + + def test_inside_double_quote_string(self): + """Should detect cursor inside double-quoted string.""" + assert is_inside_string('SELECT * FROM users WHERE name = "') is True + assert is_inside_string('SELECT * FROM users WHERE name = "John') is True + + def test_outside_closed_string(self): + """Should detect cursor outside closed string.""" + assert is_inside_string("SELECT * FROM users WHERE name = 'John'") is False + assert is_inside_string("SELECT * FROM users WHERE name = 'John' AND ") is False + + def test_escaped_quotes(self): + """Should handle escaped quotes (SQL style '').""" + assert is_inside_string("SELECT * FROM users WHERE name = 'O''Brien") is True + assert is_inside_string("SELECT * FROM users WHERE name = 'O''Brien'") is False + + +class TestSQLKeywordsAndFunctions: + """Tests for SQL keywords and functions.""" + + def test_keywords_not_empty(self): + """Should have keywords defined.""" + keywords = get_all_keywords() + assert len(keywords) > 50 + assert "SELECT" in keywords + assert "FROM" in keywords + assert "WHERE" in keywords + + def test_functions_not_empty(self): + """Should have functions defined.""" + functions = get_all_functions() + assert len(functions) > 30 + assert "COUNT" in functions + assert "SUM" in functions + assert "COALESCE" in functions + + def test_reserved_words_lowercase(self): + """Reserved words should be lowercase.""" + for word in RESERVED_WORDS: + assert word == word.lower() + + def test_keywords_categories(self): + """Should have multiple keyword categories.""" + assert "dml" in SQL_KEYWORDS + assert "ddl" in SQL_KEYWORDS + assert "control" in SQL_KEYWORDS + + def test_functions_categories(self): + """Should have multiple function categories.""" + assert "aggregate" in SQL_FUNCTIONS + assert "string" in SQL_FUNCTIONS + assert "datetime" in SQL_FUNCTIONS + + +class TestEdgeCases: + """Tests for edge cases and potential issues.""" + + def test_empty_sql(self): + """Should return nothing for empty SQL (no context yet).""" + completions = get_completions("", 0, ["users"], {"users": ["id"]}) + assert len(completions) == 0 + + def test_cursor_at_start(self): + """Should return nothing when cursor is at start (no context yet).""" + sql = "SELECT * FROM users" + completions = get_completions(sql, 0, ["users"], {"users": ["id"]}) + assert len(completions) == 0 + + def test_cursor_in_middle(self): + """Should handle cursor in middle of query.""" + sql = "SELECT * FROM users WHERE id = 1" + cursor_pos = len("SELECT * FROM ") + completions = get_completions(sql, cursor_pos, ["users"], {"users": ["id"]}) + assert isinstance(completions, list) + + def test_unknown_alias(self): + """Should handle unknown alias gracefully.""" + sql = "SELECT * FROM users WHERE x." + completions = get_completions(sql, len(sql), ["users"], {"users": ["id"]}) + assert isinstance(completions, list) + + def test_special_characters_in_word(self): + """Should handle special characters.""" + sql = "SELECT * FROM [us" + completions = get_completions(sql, len(sql), ["users"], {"users": ["id"]}) + assert isinstance(completions, list) + + def test_multiline_query(self): + """Should handle multiline queries.""" + sql = """SELECT + u.id, + u.name + FROM + users u + WHERE + u.""" + completions = get_completions(sql, len(sql), ["users"], {"users": ["id", "name"]}) + assert "id" in completions + assert "name" in completions + + def test_case_insensitive_alias_lookup(self): + """Alias lookup should be case insensitive.""" + sql = "SELECT * FROM Users U WHERE U." + completions = get_completions(sql, len(sql), ["Users"], {"users": ["id", "name"]}) + assert "id" in completions + + def test_string_literal_not_confused(self): + """String literals should not confuse context detection.""" + sql = "SELECT * FROM users WHERE name = 'FROM orders' AND " + completions = get_completions(sql, len(sql), ["users", "orders"], {"users": ["id"], "orders": ["id"]}) + assert isinstance(completions, list) + + def test_comment_not_confused(self): + """Comments should not confuse context detection.""" + sql = """SELECT * FROM users + -- FROM orders + WHERE """ + completions = get_completions(sql, len(sql), ["users", "orders"], {"users": ["id"], "orders": ["id"]}) + assert isinstance(completions, list) diff --git a/tests/unit/sql_completion/test_create_index.py b/tests/unit/sql_completion/test_create_index.py new file mode 100644 index 00000000..32408279 --- /dev/null +++ b/tests/unit/sql_completion/test_create_index.py @@ -0,0 +1,116 @@ +"""Tests for CREATE INDEX statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import get_completions + + +class TestCreateIndexStatements: + """Tests for CREATE INDEX autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "created_at"], + "orders": ["id", "user_id", "total", "status"], + "products": ["id", "name", "price", "category"], + }, + "procedures": [], + } + + def test_create_index_on_suggests_tables(self, schema): + """CREATE INDEX name ON should suggest table names.""" + sql = "CREATE INDEX idx_user_email ON " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + assert "products" in completions + + def test_create_index_on_partial_table(self, schema): + """Typing partial table name should filter.""" + sql = "CREATE INDEX idx_user_email ON us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" not in completions + + def test_create_unique_index_on_suggests_tables(self, schema): + """CREATE UNIQUE INDEX name ON should suggest table names.""" + sql = "CREATE UNIQUE INDEX idx_user_email ON " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_create_index_table_paren_suggests_columns(self, schema): + """CREATE INDEX name ON table ( should suggest columns.""" + sql = "CREATE INDEX idx_user_email ON users (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "email" in completions + assert "created_at" in completions + + def test_create_index_partial_column(self, schema): + """Typing partial column name should filter.""" + sql = "CREATE INDEX idx_user_email ON users (em" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "email" in completions + assert "id" not in completions + + def test_create_index_second_column(self, schema): + """After first column and comma, should suggest more columns.""" + sql = "CREATE INDEX idx_composite ON users (name, " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "email" in completions + + def test_create_index_after_name_suggests_on(self, schema): + """After index name, should suggest ON keyword.""" + sql = "CREATE INDEX idx_user_email " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + def test_create_unique_index_after_name_suggests_on(self, schema): + """After UNIQUE INDEX name, should suggest ON keyword.""" + sql = "CREATE UNIQUE INDEX idx_user_email " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + def test_create_index_unknown_table(self, schema): + """CREATE INDEX on unknown table should return empty columns.""" + sql = "CREATE INDEX idx_test ON unknown_table (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [] + + def test_create_index_different_table(self, schema): + """CREATE INDEX on orders should suggest orders columns.""" + sql = "CREATE INDEX idx_order_status ON orders (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "user_id" in completions + assert "total" in completions + assert "status" in completions + # Should not have users columns + assert "email" not in completions diff --git a/tests/unit/sql_completion/test_create_table.py b/tests/unit/sql_completion/test_create_table.py new file mode 100644 index 00000000..f80cf583 --- /dev/null +++ b/tests/unit/sql_completion/test_create_table.py @@ -0,0 +1,164 @@ +"""Tests for CREATE TABLE statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + SQL_CONSTRAINTS, + SQL_DATA_TYPES, + get_completions, +) + + +class TestCreateTableStatements: + """Tests for CREATE TABLE autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_create_table_column_name_no_suggestions(self, schema): + """After CREATE TABLE name (, should not suggest (user types column name).""" + sql = "CREATE TABLE new_table (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # No suggestions - user needs to type column name + assert completions == [] + + def test_create_table_after_column_name_suggests_types(self, schema): + """After column name, should suggest data types.""" + sql = "CREATE TABLE new_table (id " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "INT" in completions + assert "VARCHAR" in completions + assert "TEXT" in completions + assert "BOOLEAN" in completions + + def test_create_table_partial_type(self, schema): + """Typing partial data type should filter.""" + sql = "CREATE TABLE new_table (id INT" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "INT" in completions + assert "INTEGER" in completions + + def test_create_table_after_type_suggests_constraints(self, schema): + """After data type, should suggest constraints.""" + sql = "CREATE TABLE new_table (id INT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "PRIMARY KEY" in completions + assert "NOT NULL" in completions + assert "UNIQUE" in completions + assert "DEFAULT" in completions + + def test_create_table_after_type_with_size_suggests_constraints(self, schema): + """After data type with size, should suggest constraints.""" + sql = "CREATE TABLE new_table (name VARCHAR(255) " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "NOT NULL" in completions + assert "DEFAULT" in completions + + def test_create_table_partial_constraint(self, schema): + """Typing partial constraint should filter.""" + sql = "CREATE TABLE new_table (id INT NOT" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "NOT NULL" in completions + + def test_create_table_second_column_no_suggestions(self, schema): + """After comma, should not suggest (user types column name).""" + sql = "CREATE TABLE new_table (id INT PRIMARY KEY, " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [] + + def test_create_table_second_column_type(self, schema): + """Second column should also get type suggestions.""" + sql = "CREATE TABLE new_table (id INT, name " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "VARCHAR" in completions + assert "TEXT" in completions + + def test_create_table_references_suggests_tables(self, schema): + """REFERENCES should suggest table names.""" + sql = "CREATE TABLE new_table (user_id INT REFERENCES " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_create_table_references_table_paren_suggests_columns(self, schema): + """REFERENCES table( should suggest columns from that table.""" + sql = "CREATE TABLE new_table (user_id INT REFERENCES users(" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "email" in completions + + def test_create_table_references_partial_column(self, schema): + """Typing partial column in REFERENCES should filter.""" + sql = "CREATE TABLE new_table (user_id INT REFERENCES users(i" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + + def test_create_table_foreign_key_references_suggests_tables(self, schema): + """FOREIGN KEY ... REFERENCES should suggest tables.""" + sql = "CREATE TABLE new_table (id INT, FOREIGN KEY (user_id) REFERENCES " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_create_table_multiline(self, schema): + """Should work with multiline CREATE TABLE.""" + sql = """CREATE TABLE new_table ( + id INT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + user_id INT REFERENCES """ + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + +class TestDataTypesAndConstraints: + """Tests for data type and constraint lists.""" + + def test_data_types_not_empty(self): + """Should have data types defined.""" + assert len(SQL_DATA_TYPES) > 20 + assert "INT" in SQL_DATA_TYPES + assert "VARCHAR" in SQL_DATA_TYPES + assert "BOOLEAN" in SQL_DATA_TYPES + + def test_constraints_not_empty(self): + """Should have constraints defined.""" + assert len(SQL_CONSTRAINTS) > 5 + assert "PRIMARY KEY" in SQL_CONSTRAINTS + assert "NOT NULL" in SQL_CONSTRAINTS + assert "UNIQUE" in SQL_CONSTRAINTS diff --git a/tests/unit/sql_completion/test_create_view.py b/tests/unit/sql_completion/test_create_view.py new file mode 100644 index 00000000..296a190b --- /dev/null +++ b/tests/unit/sql_completion/test_create_view.py @@ -0,0 +1,93 @@ +"""Tests for CREATE VIEW statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import get_completions + + +class TestCreateViewStatements: + """Tests for CREATE VIEW autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_create_view_as_suggests_select(self, schema): + """CREATE VIEW name AS should suggest SELECT.""" + sql = "CREATE VIEW user_emails AS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_create_view_as_partial_select(self, schema): + """Typing partial SELECT should filter.""" + sql = "CREATE VIEW user_emails AS SEL" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_create_or_replace_view_as_suggests_select(self, schema): + """CREATE OR REPLACE VIEW name AS should suggest SELECT.""" + sql = "CREATE OR REPLACE VIEW user_emails AS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_create_view_after_name_suggests_as(self, schema): + """After view name, should suggest AS.""" + sql = "CREATE VIEW user_emails " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "AS" in completions + + def test_create_view_select_from_suggests_tables(self, schema): + """CREATE VIEW ... AS SELECT ... FROM should suggest tables.""" + sql = "CREATE VIEW user_emails AS SELECT * FROM " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_create_view_select_suggests_columns(self, schema): + """CREATE VIEW ... AS SELECT should suggest columns.""" + sql = "CREATE VIEW user_emails AS SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should fall through to normal SELECT handling + # Which suggests special SELECT keywords and functions (not tables) + assert "*" in completions # SELECT clause special keyword + assert "DISTINCT" in completions # SELECT clause special keyword + + def test_create_view_select_from_table_where(self, schema): + """CREATE VIEW with full SELECT should work normally.""" + sql = "CREATE VIEW active_users AS SELECT * FROM users WHERE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should suggest columns from users table + assert "id" in completions + assert "name" in completions + assert "email" in completions + + def test_create_or_replace_view_after_name_suggests_as(self, schema): + """After OR REPLACE VIEW name, should suggest AS.""" + sql = "CREATE OR REPLACE VIEW user_emails " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "AS" in completions diff --git a/tests/unit/sql_completion/test_delete.py b/tests/unit/sql_completion/test_delete.py new file mode 100644 index 00000000..19e7b0d1 --- /dev/null +++ b/tests/unit/sql_completion/test_delete.py @@ -0,0 +1,177 @@ +"""Tests for DELETE statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + SuggestionType, + extract_delete_table_refs, + get_completions, + get_context, +) + + +class TestDeleteStatements: + """Tests for DELETE statement autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "created_at", "status"], + "orders": ["id", "user_id", "total", "status"], + "products": ["id", "name", "price", "category"], + }, + "procedures": [], + } + + def test_delete_from_suggests_tables(self, schema): + """After DELETE FROM, should suggest table names.""" + sql = "DELETE FROM " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_delete_from_partial_table(self, schema): + """Typing partial table name after DELETE FROM.""" + sql = "DELETE FROM us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_delete_where_suggests_columns(self, schema): + """After DELETE FROM table WHERE, should suggest columns.""" + sql = "DELETE FROM users WHERE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "status" in completions + + def test_delete_where_partial_column(self, schema): + """Typing partial column in WHERE clause.""" + sql = "DELETE FROM users WHERE st" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "status" in completions + + def test_delete_where_with_alias_dot(self, schema): + """DELETE with alias should suggest columns via alias.""" + sql = "DELETE FROM users u WHERE u." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "status" in completions + + def test_delete_where_operator_after_column(self, schema): + """After column name in DELETE WHERE, should suggest operators.""" + sql = "DELETE FROM users WHERE id " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "=" in completions + assert "IN" in completions + assert "IS NULL" in completions + + def test_delete_where_and_suggests_columns(self, schema): + """After AND in DELETE WHERE, should suggest columns.""" + sql = "DELETE FROM users WHERE id = 1 AND " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + assert "status" in completions + + def test_delete_no_suggestions_in_string(self, schema): + """Inside string literal, should NOT suggest anything.""" + sql = "DELETE FROM users WHERE name = '" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [] + + def test_delete_subquery_from_suggests_tables(self, schema): + """DELETE with subquery FROM should suggest tables.""" + sql = "DELETE FROM users WHERE id IN (SELECT user_id FROM " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + + def test_delete_join_suggests_tables(self, schema): + """DELETE with JOIN should suggest tables (MySQL/SQL Server style).""" + sql = "DELETE u FROM users u JOIN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + + def test_delete_using_suggests_tables(self, schema): + """DELETE USING should suggest tables (PostgreSQL style).""" + sql = "DELETE FROM users USING " + suggestions = get_context(sql, len(sql)) + # USING should trigger table suggestions (falls back to keyword detection) + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should have some completions + assert len(completions) > 0 + + +class TestExtractDeleteTableRefs: + """Tests for DELETE table reference extraction.""" + + def test_simple_delete(self): + """Extract table from simple DELETE.""" + sql = "DELETE FROM users WHERE id = 1" + refs = extract_delete_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + + def test_delete_with_alias(self): + """Extract table with alias from DELETE.""" + sql = "DELETE FROM users u WHERE u.id = 1" + refs = extract_delete_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "users" + assert refs[0].alias == "u" + + def test_delete_with_schema(self): + """Extract schema-qualified table from DELETE.""" + sql = "DELETE FROM dbo.users WHERE id = 1" + refs = extract_delete_table_refs(sql) + assert len(refs) == 1 + assert refs[0].schema == "dbo" + assert refs[0].name == "users" + + def test_delete_quoted_table(self): + """Extract quoted table from DELETE.""" + sql = 'DELETE FROM "user_accounts" WHERE id = 1' + refs = extract_delete_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "user_accounts" + + def test_delete_no_where(self): + """Extract table from DELETE without WHERE.""" + sql = "DELETE FROM temp_data" + refs = extract_delete_table_refs(sql) + assert len(refs) == 1 + assert refs[0].name == "temp_data" + + def test_delete_reserved_word_not_alias(self): + """Reserved words should not be captured as aliases.""" + sql = "DELETE FROM users WHERE id = 1" + refs = extract_delete_table_refs(sql) + assert len(refs) == 1 + assert refs[0].alias is None diff --git a/tests/unit/sql_completion/test_drop.py b/tests/unit/sql_completion/test_drop.py new file mode 100644 index 00000000..ea53d8a8 --- /dev/null +++ b/tests/unit/sql_completion/test_drop.py @@ -0,0 +1,167 @@ +"""Tests for DROP statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + DROP_OBJECTS, + get_completions, +) + + +class TestDropStatements: + """Tests for DROP autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": ["get_user", "update_order", "calculate_total"], + } + + def test_drop_suggests_object_types(self, schema): + """DROP should suggest object types.""" + sql = "DROP " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "TABLE" in completions + assert "VIEW" in completions + assert "INDEX" in completions + assert "DATABASE" in completions + assert "PROCEDURE" in completions + + def test_drop_partial_object_type(self, schema): + """Typing partial object type should filter.""" + sql = "DROP TAB" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "TABLE" in completions + assert "VIEW" not in completions + + def test_drop_table_suggests_tables(self, schema): + """DROP TABLE should suggest table names.""" + sql = "DROP TABLE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + assert "products" in completions + + def test_drop_table_partial_name(self, schema): + """Typing partial table name should filter.""" + sql = "DROP TABLE us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" not in completions + + def test_drop_table_if_exists_suggests_tables(self, schema): + """DROP TABLE IF EXISTS should suggest table names.""" + sql = "DROP TABLE IF EXISTS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_drop_table_if_exists_partial(self, schema): + """DROP TABLE IF EXISTS with partial name should filter.""" + sql = "DROP TABLE IF EXISTS ord" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + assert "users" not in completions + + def test_drop_view_suggests_tables(self, schema): + """DROP VIEW should suggest tables (views mixed with tables).""" + sql = "DROP VIEW " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_drop_view_if_exists(self, schema): + """DROP VIEW IF EXISTS should suggest tables.""" + sql = "DROP VIEW IF EXISTS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_drop_procedure_suggests_procedures(self, schema): + """DROP PROCEDURE should suggest procedure names.""" + sql = "DROP PROCEDURE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "get_user" in completions + assert "update_order" in completions + assert "calculate_total" in completions + + def test_drop_procedure_partial_name(self, schema): + """Typing partial procedure name should filter.""" + sql = "DROP PROCEDURE get" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "get_user" in completions + assert "update_order" not in completions + + def test_drop_procedure_if_exists(self, schema): + """DROP PROCEDURE IF EXISTS should suggest procedures.""" + sql = "DROP PROCEDURE IF EXISTS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "get_user" in completions + assert "update_order" in completions + + def test_drop_function_suggests_procedures(self, schema): + """DROP FUNCTION should suggest procedures (treated same as procedures).""" + sql = "DROP FUNCTION " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "get_user" in completions + assert "calculate_total" in completions + + def test_drop_function_if_exists(self, schema): + """DROP FUNCTION IF EXISTS should suggest procedures.""" + sql = "DROP FUNCTION IF EXISTS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "get_user" in completions + + def test_drop_no_procedures(self, schema): + """DROP PROCEDURE with no procedures should return empty.""" + sql = "DROP PROCEDURE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], None + ) + assert completions == [] + + +class TestDropObjects: + """Tests for DROP_OBJECTS list.""" + + def test_drop_objects_not_empty(self): + """Should have DROP object types defined.""" + assert len(DROP_OBJECTS) > 5 + assert "TABLE" in DROP_OBJECTS + assert "VIEW" in DROP_OBJECTS + assert "INDEX" in DROP_OBJECTS + assert "DATABASE" in DROP_OBJECTS + assert "PROCEDURE" in DROP_OBJECTS + assert "FUNCTION" in DROP_OBJECTS diff --git a/tests/unit/sql_completion/test_edge_cases.py b/tests/unit/sql_completion/test_edge_cases.py new file mode 100644 index 00000000..ac228243 --- /dev/null +++ b/tests/unit/sql_completion/test_edge_cases.py @@ -0,0 +1,1103 @@ +"""Tests for SQL edge cases and advanced patterns.""" + +import pytest + +from sqlit.sql_completion import get_completions + + +class TestSelectDistinct: + """Tests for SELECT DISTINCT autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_select_distinct_suggests_columns(self, schema): + """SELECT DISTINCT should suggest special keywords and functions.""" + sql = "SELECT DISTINCT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should suggest * and functions (not tables - those go after FROM) + assert "*" in completions + # Check for any aggregate function (they're all present but order varies) + has_function = any(f in completions for f in ["COUNT", "SUM", "AVG", "MIN", "MAX"]) + assert has_function + + def test_select_distinct_from_table(self, schema): + """SELECT DISTINCT with FROM should suggest columns.""" + sql = "SELECT DISTINCT FROM users" + # Cursor after DISTINCT and space + cursor_pos = len("SELECT DISTINCT ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_select_distinct_partial(self, schema): + """SELECT DISTINCT with partial column should filter.""" + sql = "SELECT DISTINCT na FROM users" + cursor_pos = len("SELECT DISTINCT na") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + +class TestCaseWhen: + """Tests for CASE WHEN expression autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "status", "active"], + "orders": ["id", "user_id", "total", "status"], + }, + "procedures": [], + } + + def test_case_when_suggests_columns(self, schema): + """CASE WHEN should suggest columns.""" + sql = "SELECT CASE WHEN FROM users" + cursor_pos = len("SELECT CASE WHEN ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "status" in completions + + def test_case_when_then_suggests_values(self, schema): + """CASE WHEN condition THEN should suggest columns/values.""" + sql = "SELECT CASE WHEN status = 1 THEN FROM users" + cursor_pos = len("SELECT CASE WHEN status = 1 THEN ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + # THEN can be followed by a value or column + assert len(completions) > 0 + + def test_case_when_else_suggests_values(self, schema): + """CASE WHEN ... ELSE should suggest columns/values.""" + sql = "SELECT CASE WHEN status = 1 THEN 'Active' ELSE FROM users" + cursor_pos = len("SELECT CASE WHEN status = 1 THEN 'Active' ELSE ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert len(completions) > 0 + + def test_case_when_in_where(self, schema): + """CASE WHEN in WHERE clause should suggest columns.""" + sql = "SELECT * FROM users WHERE CASE WHEN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "status" in completions + + +class TestWindowFunctions: + """Tests for window function (OVER clause) autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["employees", "departments"], + "columns": { + "employees": ["id", "name", "dept_id", "salary", "hire_date"], + "departments": ["id", "name", "budget"], + }, + "procedures": [], + } + + def test_over_partition_by_suggests_columns(self, schema): + """OVER (PARTITION BY should suggest columns.""" + sql = "SELECT ROW_NUMBER() OVER (PARTITION BY FROM employees" + cursor_pos = len("SELECT ROW_NUMBER() OVER (PARTITION BY ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "dept_id" in completions + assert "name" in completions + + def test_over_order_by_suggests_columns(self, schema): + """OVER (ORDER BY should suggest columns.""" + sql = "SELECT ROW_NUMBER() OVER (ORDER BY FROM employees" + cursor_pos = len("SELECT ROW_NUMBER() OVER (ORDER BY ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "salary" in completions + assert "hire_date" in completions + + def test_over_partition_by_order_by(self, schema): + """OVER (PARTITION BY x ORDER BY should suggest columns.""" + sql = "SELECT ROW_NUMBER() OVER (PARTITION BY dept_id ORDER BY FROM employees" + cursor_pos = len("SELECT ROW_NUMBER() OVER (PARTITION BY dept_id ORDER BY ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "salary" in completions + + def test_over_with_join(self, schema): + """Window function with JOIN should suggest columns from both tables.""" + sql = "SELECT ROW_NUMBER() OVER (PARTITION BY FROM employees e JOIN departments d ON e.dept_id = d.id" + cursor_pos = len("SELECT ROW_NUMBER() OVER (PARTITION BY ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + # Should have columns from employees + assert "dept_id" in completions or "salary" in completions + + +class TestDerivedTableAliases: + """Tests for derived table (subquery) alias autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + }, + "procedures": [], + } + + def test_derived_table_alias_dot(self, schema): + """Alias for derived table should suggest columns.""" + sql = "SELECT u. FROM (SELECT id, name FROM users) AS u" + cursor_pos = len("SELECT u.") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + # This is tricky - would need to parse the subquery + # For now, at minimum shouldn't error + assert isinstance(completions, list) + + def test_derived_table_where_alias(self, schema): + """WHERE clause with derived table alias.""" + sql = "SELECT * FROM (SELECT id, name FROM users) AS u WHERE u." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert isinstance(completions, list) + + +class TestJoinOnKeyword: + """Tests for JOIN ... ON keyword suggestion.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_join_suggests_on(self, schema): + """After JOIN table, should suggest ON keyword.""" + sql = "SELECT * FROM users JOIN orders " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + def test_left_join_suggests_on(self, schema): + """After LEFT JOIN table, should suggest ON.""" + sql = "SELECT * FROM users LEFT JOIN orders " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + def test_inner_join_suggests_on(self, schema): + """After INNER JOIN table, should suggest ON.""" + sql = "SELECT * FROM users INNER JOIN orders " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + def test_join_with_alias_suggests_on(self, schema): + """After JOIN table alias, should suggest ON.""" + sql = "SELECT * FROM users u JOIN orders o " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + def test_join_on_partial(self, schema): + """Typing partial ON should filter.""" + sql = "SELECT * FROM users JOIN orders O" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ON" in completions + + +class TestUnionContext: + """Tests for UNION/INTERSECT/EXCEPT autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "admins", "guests"], + "columns": { + "users": ["id", "name", "email"], + "admins": ["id", "name", "role"], + "guests": ["id", "name", "expires"], + }, + "procedures": [], + } + + def test_union_suggests_select(self, schema): + """After UNION, should suggest SELECT.""" + sql = "SELECT * FROM users UNION " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_union_all_suggests_select(self, schema): + """After UNION ALL, should suggest SELECT.""" + sql = "SELECT * FROM users UNION ALL " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_intersect_suggests_select(self, schema): + """After INTERSECT, should suggest SELECT.""" + sql = "SELECT * FROM users INTERSECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_except_suggests_select(self, schema): + """After EXCEPT, should suggest SELECT.""" + sql = "SELECT * FROM users EXCEPT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_union_partial_select(self, schema): + """Typing partial SELECT after UNION should filter.""" + sql = "SELECT * FROM users UNION SEL" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + +class TestBetweenContext: + """Tests for BETWEEN clause autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "products"], + "columns": { + "users": ["id", "name", "age", "created_at"], + "products": ["id", "name", "price", "stock"], + }, + "procedures": [], + } + + def test_between_suggests_columns(self, schema): + """After BETWEEN, should suggest columns/values.""" + sql = "SELECT * FROM users WHERE age BETWEEN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should suggest columns (for comparing with another column) or allow typing value + assert "id" in completions or len(completions) > 0 + + def test_between_and_suggests_columns(self, schema): + """After BETWEEN x AND, should suggest columns/values.""" + sql = "SELECT * FROM users WHERE age BETWEEN 18 AND " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # AND in BETWEEN context should suggest columns + assert "id" in completions or "age" in completions + + def test_between_with_columns(self, schema): + """BETWEEN with column references.""" + sql = "SELECT * FROM products WHERE price BETWEEN min_price AND " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert len(completions) > 0 + + +class TestComplexSubqueries: + """Tests for complex subquery scenarios.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "product_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_nested_subquery_from(self, schema): + """Nested subquery FROM should suggest tables.""" + sql = "SELECT * FROM (SELECT * FROM (SELECT * FROM " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_correlated_subquery_where(self, schema): + """Correlated subquery in WHERE.""" + sql = "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # Should suggest users columns for u. + assert "id" in completions + + def test_subquery_in_select_list(self, schema): + """Subquery in SELECT list.""" + sql = "SELECT id, (SELECT COUNT(*) FROM orders WHERE user_id = users." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + + +class TestAggregateFunctions: + """Tests for aggregate function argument autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "age"], + "orders": ["id", "user_id", "total", "quantity"], + "products": ["id", "name", "price", "stock"], + }, + "procedures": [], + } + + def test_count_suggests_columns(self, schema): + """COUNT( should suggest columns from tables in query.""" + sql = "SELECT COUNT( FROM users" + cursor_pos = len("SELECT COUNT(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_sum_suggests_columns(self, schema): + """SUM( should suggest columns.""" + sql = "SELECT SUM( FROM orders" + cursor_pos = len("SELECT SUM(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "total" in completions + assert "quantity" in completions + + def test_avg_suggests_columns(self, schema): + """AVG( should suggest columns.""" + sql = "SELECT AVG( FROM products" + cursor_pos = len("SELECT AVG(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "price" in completions + assert "stock" in completions + + def test_max_suggests_columns(self, schema): + """MAX( should suggest columns.""" + sql = "SELECT MAX( FROM users" + cursor_pos = len("SELECT MAX(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "age" in completions + + def test_min_suggests_columns(self, schema): + """MIN( should suggest columns.""" + sql = "SELECT MIN( FROM orders" + cursor_pos = len("SELECT MIN(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "total" in completions + + def test_count_with_alias(self, schema): + """COUNT( with table alias should suggest columns.""" + sql = "SELECT COUNT(u. FROM users u" + cursor_pos = len("SELECT COUNT(u.") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_aggregate_in_having(self, schema): + """Aggregate in HAVING should suggest columns.""" + sql = "SELECT dept FROM users GROUP BY dept HAVING COUNT( " + cursor_pos = len("SELECT dept FROM users GROUP BY dept HAVING COUNT(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + + +class TestCastExpression: + """Tests for CAST expression autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "age"], + "orders": ["id", "total", "created_at"], + }, + "procedures": [], + } + + def test_cast_as_suggests_types(self, schema): + """CAST(col AS should suggest data types.""" + sql = "SELECT CAST(id AS FROM users" + cursor_pos = len("SELECT CAST(id AS ") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "INT" in completions or "INTEGER" in completions + assert "VARCHAR" in completions + + def test_cast_column_suggests_columns(self, schema): + """CAST( should suggest columns.""" + sql = "SELECT CAST( FROM users" + cursor_pos = len("SELECT CAST(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_convert_type_suggests_types(self, schema): + """CONVERT with type argument should suggest types (SQL Server style).""" + sql = "SELECT CONVERT( " + cursor_pos = len("SELECT CONVERT(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + # CONVERT first arg is usually type in SQL Server + assert "INT" in completions or "VARCHAR" in completions or len(completions) > 0 + + +class TestCrossJoin: + """Tests for CROSS JOIN (should not suggest ON).""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name"], + "orders": ["id", "user_id"], + "products": ["id", "name"], + }, + "procedures": [], + } + + def test_cross_join_no_on(self, schema): + """After CROSS JOIN table, should NOT suggest ON.""" + sql = "SELECT * FROM users CROSS JOIN orders " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # CROSS JOIN doesn't use ON - should suggest WHERE, ORDER BY, etc. + assert "ON" not in completions + + def test_cross_join_suggests_where(self, schema): + """After CROSS JOIN table, should suggest WHERE.""" + sql = "SELECT * FROM users CROSS JOIN orders " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "WHERE" in completions or "ORDER" in completions or len(completions) > 0 + + def test_natural_join_no_on(self, schema): + """After NATURAL JOIN table, should NOT suggest ON.""" + sql = "SELECT * FROM users NATURAL JOIN orders " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # NATURAL JOIN doesn't use ON + assert "ON" not in completions + + +class TestSchemaPrefix: + """Tests for schema.table prefix autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name"], + "orders": ["id", "user_id"], + "products": ["id", "name"], + }, + "procedures": [], + } + + def test_schema_dot_suggests_tables(self, schema): + """After schema., should suggest tables.""" + sql = "SELECT * FROM public." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_schema_dot_partial_table(self, schema): + """After schema. with partial table, should filter.""" + sql = "SELECT * FROM public.us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_schema_dot_in_join(self, schema): + """Schema prefix in JOIN should suggest tables.""" + sql = "SELECT * FROM users JOIN dbo." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + + +class TestInClause: + """Tests for IN clause autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "status"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_in_suggests_select(self, schema): + """WHERE col IN ( should suggest SELECT for subquery.""" + sql = "SELECT * FROM users WHERE id IN (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_in_with_partial_select(self, schema): + """WHERE col IN (SEL should filter to SELECT.""" + sql = "SELECT * FROM users WHERE id IN (SEL" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_not_in_suggests_select(self, schema): + """WHERE col NOT IN ( should suggest SELECT.""" + sql = "SELECT * FROM users WHERE id NOT IN (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_in_subquery_select_columns(self, schema): + """IN (SELECT should suggest columns.""" + sql = "SELECT * FROM users WHERE id IN (SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # After SELECT in subquery, should suggest columns/tables + assert len(completions) > 0 + + +class TestExistsClause: + """Tests for EXISTS clause autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_exists_suggests_select(self, schema): + """WHERE EXISTS ( should suggest SELECT.""" + sql = "SELECT * FROM users WHERE EXISTS (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_not_exists_suggests_select(self, schema): + """WHERE NOT EXISTS ( should suggest SELECT.""" + sql = "SELECT * FROM users WHERE NOT EXISTS (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_exists_partial_select(self, schema): + """EXISTS (SEL should filter to SELECT.""" + sql = "SELECT * FROM users WHERE EXISTS (SEL" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_exists_subquery_from(self, schema): + """EXISTS (SELECT 1 FROM should suggest tables.""" + sql = "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + + +class TestReturningClause: + """Tests for RETURNING clause autocomplete (PostgreSQL, SQLite).""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "created_at"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_insert_returning_suggests_columns(self, schema): + """INSERT ... RETURNING should suggest columns.""" + sql = "INSERT INTO users (name) VALUES ('test') RETURNING " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_update_returning_suggests_columns(self, schema): + """UPDATE ... RETURNING should suggest columns.""" + sql = "UPDATE users SET name = 'test' RETURNING " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_delete_returning_suggests_columns(self, schema): + """DELETE ... RETURNING should suggest columns.""" + sql = "DELETE FROM users WHERE id = 1 RETURNING " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_returning_partial_column(self, schema): + """RETURNING with partial column should filter.""" + sql = "INSERT INTO users (name) VALUES ('test') RETURNING na" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + def test_returning_multiple_columns(self, schema): + """RETURNING with comma should suggest more columns.""" + sql = "INSERT INTO users (name) VALUES ('test') RETURNING id, " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + assert "email" in completions + + +class TestNestedFunctions: + """Tests for nested function call autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "email", "phone"], + "orders": ["id", "user_id", "total", "discount"], + }, + "procedures": [], + } + + def test_nested_coalesce_nullif(self, schema): + """COALESCE(NULLIF( should suggest columns.""" + sql = "SELECT COALESCE(NULLIF( FROM users" + cursor_pos = len("SELECT COALESCE(NULLIF(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + assert "email" in completions + + def test_nested_ifnull(self, schema): + """IFNULL(TRIM( should suggest columns.""" + sql = "SELECT IFNULL(TRIM( FROM users" + cursor_pos = len("SELECT IFNULL(TRIM(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + def test_deeply_nested_functions(self, schema): + """Deeply nested functions should suggest columns.""" + sql = "SELECT COALESCE(NULLIF(TRIM( FROM users" + cursor_pos = len("SELECT COALESCE(NULLIF(TRIM(") + completions = get_completions( + sql, cursor_pos, schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + def test_nested_in_where(self, schema): + """Nested functions in WHERE should suggest columns.""" + sql = "SELECT * FROM users WHERE COALESCE(NULLIF(" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions or "id" in completions + + +class TestAnyAllSome: + """Tests for ANY/ALL/SOME subquery autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_any_suggests_select(self, schema): + """= ANY ( should suggest SELECT for subquery.""" + sql = "SELECT * FROM users WHERE id = ANY (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_all_suggests_select(self, schema): + """= ALL ( should suggest SELECT for subquery.""" + sql = "SELECT * FROM users WHERE id = ALL (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_some_suggests_select(self, schema): + """> SOME ( should suggest SELECT for subquery.""" + sql = "SELECT * FROM users WHERE id > SOME (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_not_in_any_context(self, schema): + """ANY in non-subquery context should not interfere.""" + sql = "SELECT * FROM users WHERE id = ANY (SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # After SELECT in subquery, should suggest columns/tables + assert len(completions) > 0 + + +class TestGroupingSets: + """Tests for GROUPING SETS/CUBE/ROLLUP autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["sales", "products"], + "columns": { + "sales": ["id", "product_id", "region", "year", "amount"], + "products": ["id", "name", "category"], + }, + "procedures": [], + } + + def test_grouping_sets_suggests_columns(self, schema): + """GROUP BY GROUPING SETS ( should suggest columns.""" + sql = "SELECT region, year, SUM(amount) FROM sales GROUP BY GROUPING SETS (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "region" in completions + assert "year" in completions + + def test_cube_suggests_columns(self, schema): + """GROUP BY CUBE ( should suggest columns.""" + sql = "SELECT region, year, SUM(amount) FROM sales GROUP BY CUBE (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "region" in completions + + def test_rollup_suggests_columns(self, schema): + """GROUP BY ROLLUP ( should suggest columns.""" + sql = "SELECT region, year, SUM(amount) FROM sales GROUP BY ROLLUP (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "region" in completions + + def test_grouping_sets_partial(self, schema): + """Partial column in GROUPING SETS should filter.""" + sql = "SELECT region, year FROM sales GROUP BY GROUPING SETS (reg" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "region" in completions + + +class TestOverClause: + """Tests for OVER () window clause autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["employees"], + "columns": { + "employees": ["id", "name", "dept_id", "salary", "hire_date"], + }, + "procedures": [], + } + + def test_over_paren_suggests_partition_order(self, schema): + """OVER ( should suggest PARTITION BY and ORDER BY.""" + sql = "SELECT ROW_NUMBER() OVER (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "PARTITION" in completions or "ORDER" in completions + + def test_over_partial_partition(self, schema): + """OVER (PART should filter to PARTITION.""" + sql = "SELECT ROW_NUMBER() OVER (PART" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "PARTITION" in completions + + def test_over_partial_order(self, schema): + """OVER (ORD should filter to ORDER.""" + sql = "SELECT ROW_NUMBER() OVER (ORD" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ORDER" in completions + + +class TestOrderByModifiers: + """Tests for ORDER BY modifiers (ASC/DESC/NULLS).""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "email", "created_at"], + "orders": ["id", "user_id", "total"], + }, + "procedures": [], + } + + def test_order_by_column_suggests_asc_desc(self, schema): + """After ORDER BY column, should suggest ASC/DESC.""" + sql = "SELECT * FROM users ORDER BY name " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ASC" in completions + assert "DESC" in completions + + def test_order_by_suggests_nulls(self, schema): + """After ORDER BY column, should suggest NULLS.""" + sql = "SELECT * FROM users ORDER BY name " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "NULLS" in completions + + def test_nulls_suggests_first_last(self, schema): + """After NULLS, should suggest FIRST/LAST.""" + sql = "SELECT * FROM users ORDER BY name NULLS " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "FIRST" in completions + assert "LAST" in completions + + def test_order_by_asc_then_comma(self, schema): + """After ORDER BY col ASC, comma should suggest columns.""" + sql = "SELECT * FROM users ORDER BY name ASC, " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "email" in completions + + def test_order_by_partial_asc(self, schema): + """Typing partial ASC should filter.""" + sql = "SELECT * FROM users ORDER BY name A" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "ASC" in completions + + +class TestCaseExpression: + """Tests for CASE expression autocomplete.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "status", "type"], + "orders": ["id", "user_id", "total", "status"], + }, + "procedures": [], + } + + def test_case_suggests_when(self, schema): + """CASE should suggest WHEN.""" + sql = "SELECT CASE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "WHEN" in completions + + def test_case_column_suggests_when(self, schema): + """CASE column should suggest WHEN.""" + sql = "SELECT CASE status " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "WHEN" in completions + + def test_case_end_suggests_as(self, schema): + """CASE ... END should suggest AS for alias.""" + sql = "SELECT CASE WHEN status = 1 THEN 'Active' END " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + # After END, common to add alias or comma + assert "AS" in completions or len(completions) > 0 + + +class TestSemicolonBehavior: + """Tests for statement terminator (semicolon) behavior.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "tradition_foods"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "tradition_foods": ["id", "name", "origin"], + }, + "procedures": [], + } + + def test_after_semicolon_no_suggestions(self, schema): + """After a semicolon, autocomplete should hide (no suggestions).""" + sql = "SELECT * FROM tradition_foods;" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [], f"Expected no suggestions after semicolon, got {completions}" + + def test_after_semicolon_with_space_no_suggestions(self, schema): + """After semicolon and space, autocomplete should hide (no suggestions).""" + sql = "SELECT * FROM users; " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [], f"Expected no suggestions after semicolon, got {completions}" + + def test_after_semicolon_typing_new_statement(self, schema): + """After semicolon, typing a new keyword should show keyword completions.""" + sql = "SELECT * FROM users; SEL" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions, "Should suggest SELECT when typing new statement" diff --git a/tests/unit/sql_completion/test_insert.py b/tests/unit/sql_completion/test_insert.py new file mode 100644 index 00000000..c5909e35 --- /dev/null +++ b/tests/unit/sql_completion/test_insert.py @@ -0,0 +1,106 @@ +"""Tests for INSERT statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + SuggestionType, + get_completions, + get_context, +) + + +class TestInsertStatements: + """Tests for INSERT statement autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "created_at"], + "orders": ["id", "user_id", "total", "status"], + "products": ["id", "name", "price", "category"], + }, + "procedures": [], + } + + def test_insert_into_suggests_tables(self, schema): + """After INSERT INTO, should suggest table names.""" + sql = "INSERT INTO " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_insert_into_partial_table(self, schema): + """Typing partial table name after INSERT INTO.""" + sql = "INSERT INTO us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_insert_into_table_opening_paren_suggests_columns(self, schema): + """After INSERT INTO table (, should suggest columns for that table.""" + sql = "INSERT INTO users (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "email" in completions + + def test_insert_into_table_comma_suggests_more_columns(self, schema): + """After INSERT INTO table (col1, should suggest more columns.""" + sql = "INSERT INTO users (id, " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + assert "email" in completions + + def test_insert_into_partial_column(self, schema): + """Typing partial column name in INSERT column list.""" + sql = "INSERT INTO users (na" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + def test_insert_values_no_column_suggestions(self, schema): + """Inside VALUES clause, should NOT suggest columns.""" + sql = "INSERT INTO users (id, name) VALUES (" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" not in completions or len(completions) == 0 + + def test_insert_values_string_literal_no_suggestions(self, schema): + """Inside string literal in VALUES, should NOT suggest anything.""" + sql = "INSERT INTO users (id, name) VALUES (1, '" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [] + + def test_insert_select_suggests_columns(self, schema): + """INSERT ... SELECT should suggest columns in SELECT clause.""" + sql = "INSERT INTO users (id, name) SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert len(completions) > 0 + + def test_insert_select_from_suggests_tables(self, schema): + """INSERT ... SELECT ... FROM should suggest tables.""" + sql = "INSERT INTO users (id, name) SELECT id, name FROM " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + assert "products" in completions diff --git a/tests/unit/sql_completion/test_select.py b/tests/unit/sql_completion/test_select.py new file mode 100644 index 00000000..b49a0f15 --- /dev/null +++ b/tests/unit/sql_completion/test_select.py @@ -0,0 +1,347 @@ +"""Tests for SELECT statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + SuggestionType, + get_completions, + get_context, +) + + +class TestSelectContext: + """Tests for SELECT context detection.""" + + def test_after_from(self): + """After FROM should suggest tables.""" + sql = "SELECT * FROM " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + def test_after_join(self): + """After JOIN should suggest tables.""" + sql = "SELECT * FROM users u JOIN " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + def test_after_left_join(self): + """After LEFT JOIN should suggest tables.""" + sql = "SELECT * FROM users u LEFT JOIN " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + def test_after_select(self): + """After SELECT should suggest columns.""" + sql = "SELECT " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.COLUMN for s in suggestions) + + def test_after_where(self): + """After WHERE should suggest columns.""" + sql = "SELECT * FROM users WHERE " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.COLUMN for s in suggestions) + + def test_after_and(self): + """After AND should suggest columns.""" + sql = "SELECT * FROM users WHERE id = 1 AND " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.COLUMN for s in suggestions) + + def test_table_dot_pattern(self): + """table. pattern should suggest columns for that table.""" + sql = "SELECT * FROM users u WHERE u." + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.ALIAS_COLUMN for s in suggestions) + alias_suggestion = next(s for s in suggestions if s.type == SuggestionType.ALIAS_COLUMN) + assert alias_suggestion.table_scope == "u" + + def test_after_order_by(self): + """After ORDER BY should suggest columns.""" + sql = "SELECT * FROM users ORDER BY " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.COLUMN for s in suggestions) + + def test_after_group_by(self): + """After GROUP BY should suggest columns.""" + sql = "SELECT * FROM users GROUP BY " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.COLUMN for s in suggestions) + + def test_after_exec(self): + """After EXEC should suggest procedures.""" + sql = "EXEC " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.PROCEDURE for s in suggestions) + + def test_comma_in_select(self): + """Comma in SELECT should suggest columns.""" + sql = "SELECT id, " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.COLUMN for s in suggestions) + + def test_comma_in_from(self): + """Comma in FROM should suggest tables.""" + sql = "SELECT * FROM users, " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + def test_start_of_query(self): + """Start of query should suggest keywords.""" + sql = "" + suggestions = get_context(sql, 0) + assert any(s.type == SuggestionType.KEYWORD for s in suggestions) + + +class TestSelectCompletions: + """Integration tests for SELECT completion flow.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products", "order_items"], + "columns": { + "users": ["id", "name", "email", "created_at"], + "orders": ["id", "user_id", "total", "status", "created_at"], + "products": ["id", "name", "price", "category"], + "order_items": ["id", "order_id", "product_id", "quantity"], + }, + "procedures": ["sp_get_user", "sp_create_order", "sp_update_inventory"], + } + + def test_complete_table_after_from(self, schema): + """Should complete table names after FROM.""" + sql = "SELECT * FROM us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_complete_table_fuzzy(self, schema): + """Should fuzzy match table names.""" + sql = "SELECT * FROM ord" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + assert "order_items" in completions + + def test_complete_column_with_alias(self, schema): + """Should complete columns for aliased table.""" + sql = "SELECT * FROM users u WHERE u." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + assert "email" in completions + + def test_complete_column_with_alias_partial(self, schema): + """Should complete partial column for aliased table.""" + sql = "SELECT * FROM users u WHERE u.na" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + def test_complete_column_direct_table(self, schema): + """Should complete columns for direct table reference.""" + sql = "SELECT * FROM users WHERE users." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_complete_columns_in_where(self, schema): + """Should complete columns in WHERE clause.""" + sql = "SELECT * FROM users WHERE i" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + + def test_complete_procedure_after_exec(self, schema): + """Should complete procedures after EXEC.""" + sql = "EXEC sp_" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "sp_get_user" in completions + assert "sp_create_order" in completions + + def test_complete_includes_keywords(self, schema): + """Should include keywords when appropriate.""" + sql = "SEL" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "SELECT" in completions + + def test_complete_includes_functions(self, schema): + """Should include functions in SELECT context.""" + sql = "SELECT COU" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "COUNT" in completions + + def test_complete_join_table(self, schema): + """Should complete table after JOIN.""" + sql = "SELECT * FROM users u JOIN ord" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + + def test_complete_multiple_aliases(self, schema): + """Should handle multiple aliases correctly.""" + sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id WHERE o." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "user_id" in completions + assert "total" in completions + assert "status" in completions + + def test_no_duplicate_completions(self, schema): + """Should not return duplicate completions.""" + sql = "SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + lower_completions = [c.lower() for c in completions] + assert len(lower_completions) == len(set(lower_completions)) + + def test_complete_with_cte(self, schema): + """Should suggest CTE names as tables.""" + sql = "WITH active_users AS (SELECT * FROM users WHERE status = 'active') SELECT * FROM act" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "active_users" in completions + + +class TestSelectClauseSuggestions: + """Tests for SELECT clause special suggestions (*, DISTINCT, TOP).""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + }, + } + + def test_star_suggested_after_select(self, schema): + """Should suggest * after SELECT.""" + sql = "SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"] + ) + assert "*" in completions + + def test_distinct_suggested_after_select(self, schema): + """Should suggest DISTINCT after SELECT.""" + sql = "SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"] + ) + assert "DISTINCT" in completions + + def test_top_suggested_after_select(self, schema): + """Should suggest TOP after SELECT (SQL Server syntax).""" + sql = "SELECT " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"] + ) + assert "TOP" in completions + + def test_star_filtered_by_prefix(self, schema): + """Should filter * when typing.""" + sql = "SELECT *" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"] + ) + # When typing *, it should match + assert "*" in completions or len(completions) == 0 # Either matches or nothing + + def test_distinct_filtered_by_prefix(self, schema): + """Should filter DISTINCT when typing D.""" + sql = "SELECT D" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"] + ) + assert "DISTINCT" in completions + + def test_top_filtered_by_prefix(self, schema): + """Should filter TOP when typing T.""" + sql = "SELECT T" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"] + ) + assert "TOP" in completions + + +class TestOperatorSuggestions: + """Tests for operator suggestions using sqlparse.""" + + def test_operator_after_column_in_where(self): + """Should suggest operators after column name in WHERE clause.""" + sql = "SELECT * FROM users WHERE id " + completions = get_completions(sql, len(sql), ["users"], {"users": ["id", "name"]}) + assert "=" in completions + assert "!=" in completions + assert "IS NULL" in completions + assert "LIKE" in completions + + def test_operator_after_column_in_having(self): + """Should suggest operators after column name in HAVING clause.""" + sql = "SELECT COUNT(*) FROM users GROUP BY status HAVING COUNT(*) " + completions = get_completions(sql, len(sql), ["users"], {"users": ["id", "status"]}) + assert ">" in completions + assert ">=" in completions + assert "<" in completions + + def test_operator_after_aliased_column(self): + """Should suggest operators after aliased column in WHERE.""" + sql = "SELECT * FROM users u WHERE u.id " + completions = get_completions(sql, len(sql), ["users"], {"users": ["id", "name"]}) + assert "=" in completions + assert "IN" in completions + assert "BETWEEN" in completions + + def test_column_after_operator(self): + """Should suggest columns after comparison operator.""" + sql = "SELECT * FROM users WHERE id = " + completions = get_completions(sql, len(sql), ["users"], {"users": ["id", "name"]}) + assert "id" in completions + + def test_no_operators_after_from(self): + """Should NOT suggest operators after FROM keyword.""" + sql = "SELECT * FROM " + completions = get_completions(sql, len(sql), ["users"], {"users": ["id"]}) + assert "=" not in completions + assert "users" in completions + + def test_no_operators_in_select(self): + """Should NOT suggest operators in SELECT clause.""" + sql = "SELECT " + completions = get_completions(sql, len(sql), ["users"], {"users": ["id", "name"]}) + assert "=" not in completions + + def test_operators_filtered_by_prefix(self): + """Should filter operators by typed prefix.""" + sql = "SELECT * FROM users WHERE id I" + completions = get_completions(sql, len(sql), ["users"], {"users": ["id"]}) + assert "IN" in completions or "IS NULL" in completions or "ILIKE" in completions + + def test_operator_on_join_condition(self): + """Should suggest operators after column in ON clause.""" + sql = "SELECT * FROM users u JOIN orders o ON u.id " + completions = get_completions(sql, len(sql), ["users", "orders"], {"users": ["id"], "orders": ["user_id"]}) + assert "=" in completions diff --git a/tests/unit/sql_completion/test_truncate.py b/tests/unit/sql_completion/test_truncate.py new file mode 100644 index 00000000..778aebb0 --- /dev/null +++ b/tests/unit/sql_completion/test_truncate.py @@ -0,0 +1,78 @@ +"""Tests for TRUNCATE TABLE statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import get_completions + + +class TestTruncateStatements: + """Tests for TRUNCATE TABLE autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email"], + "orders": ["id", "user_id", "total"], + "products": ["id", "name", "price"], + }, + "procedures": [], + } + + def test_truncate_suggests_table_keyword_and_tables(self, schema): + """TRUNCATE should suggest TABLE keyword and table names.""" + sql = "TRUNCATE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "TABLE" in completions + assert "users" in completions + assert "orders" in completions + + def test_truncate_partial_table_keyword(self, schema): + """Typing partial TABLE should filter.""" + sql = "TRUNCATE TAB" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "TABLE" in completions + assert "users" not in completions # Filtered out by fuzzy match + + def test_truncate_table_suggests_tables(self, schema): + """TRUNCATE TABLE should suggest table names.""" + sql = "TRUNCATE TABLE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + assert "products" in completions + + def test_truncate_table_partial_name(self, schema): + """Typing partial table name should filter.""" + sql = "TRUNCATE TABLE us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" not in completions + + def test_truncate_partial_table_name_directly(self, schema): + """TRUNCATE with partial table name directly should filter.""" + sql = "TRUNCATE ord" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions + assert "users" not in completions + + def test_truncate_lowercase(self, schema): + """truncate (lowercase) should work the same.""" + sql = "truncate table " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions diff --git a/tests/unit/sql_completion/test_update.py b/tests/unit/sql_completion/test_update.py new file mode 100644 index 00000000..e7d570c1 --- /dev/null +++ b/tests/unit/sql_completion/test_update.py @@ -0,0 +1,132 @@ +"""Tests for UPDATE statement autocomplete suggestions.""" + +import pytest + +from sqlit.sql_completion import ( + SuggestionType, + get_completions, + get_context, +) + + +class TestUpdateStatements: + """Tests for UPDATE statement autocomplete suggestions.""" + + @pytest.fixture + def schema(self): + """Sample database schema.""" + return { + "tables": ["users", "orders", "products"], + "columns": { + "users": ["id", "name", "email", "created_at", "status"], + "orders": ["id", "user_id", "total", "status"], + "products": ["id", "name", "price", "category"], + }, + "procedures": [], + } + + def test_update_suggests_tables(self, schema): + """After UPDATE, should suggest table names.""" + sql = "UPDATE " + suggestions = get_context(sql, len(sql)) + assert any(s.type == SuggestionType.TABLE for s in suggestions) + + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + assert "orders" in completions + + def test_update_partial_table(self, schema): + """Typing partial table name after UPDATE.""" + sql = "UPDATE us" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "users" in completions + + def test_update_set_suggests_columns(self, schema): + """After UPDATE table SET, should suggest columns for that table.""" + sql = "UPDATE users SET " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + assert "email" in completions + assert "status" in completions + + def test_update_set_partial_column(self, schema): + """Typing partial column name after SET.""" + sql = "UPDATE users SET na" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + + def test_update_set_equals_suggests_columns(self, schema): + """After SET column =, should suggest columns/values.""" + sql = "UPDATE users SET name = " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "email" in completions or len(completions) > 0 + + def test_update_set_comma_suggests_more_columns(self, schema): + """After SET col = value,, should suggest more columns.""" + sql = "UPDATE users SET name = 'John', " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "email" in completions + assert "status" in completions + + def test_update_where_suggests_columns(self, schema): + """After UPDATE ... WHERE, should suggest columns.""" + sql = "UPDATE users SET name = 'John' WHERE " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "name" in completions + + def test_update_where_partial_column(self, schema): + """Typing partial column in WHERE clause.""" + sql = "UPDATE users SET name = 'John' WHERE st" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "status" in completions + + def test_update_with_alias_set(self, schema): + """UPDATE with alias should suggest columns via alias.""" + sql = "UPDATE users u SET u." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "name" in completions + assert "email" in completions + + def test_update_with_alias_where(self, schema): + """UPDATE with alias in WHERE clause.""" + sql = "UPDATE users u SET name = 'John' WHERE u." + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "id" in completions + assert "status" in completions + + def test_update_no_suggestions_in_string(self, schema): + """Inside string literal, should NOT suggest anything.""" + sql = "UPDATE users SET name = '" + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert completions == [] + + def test_update_from_join_suggests_tables(self, schema): + """UPDATE ... FROM ... JOIN (SQL Server style) should suggest tables.""" + sql = "UPDATE u SET u.name = o.status FROM users u JOIN " + completions = get_completions( + sql, len(sql), schema["tables"], schema["columns"], schema["procedures"] + ) + assert "orders" in completions diff --git a/tests/unit/test_docker_credential_parsing.py b/tests/unit/test_docker_credential_parsing.py index 1a311e1a..ac90da9d 100644 --- a/tests/unit/test_docker_credential_parsing.py +++ b/tests/unit/test_docker_credential_parsing.py @@ -52,9 +52,6 @@ class TestImagePatternMatching: ("mcr.microsoft.com/mssql/server:2019-latest", "mssql"), ("mcr.microsoft.com/mssql/server:2022-latest", "mssql"), ("mcr.microsoft.com/mssql/server:2022-CU10-ubuntu-22.04", "mssql"), - # Azure SQL Edge (ARM64 compatible) - ("mcr.microsoft.com/azure-sql-edge", "mssql"), - ("mcr.microsoft.com/azure-sql-edge:latest", "mssql"), # ClickHouse variations ("clickhouse/clickhouse-server", "clickhouse"), ("clickhouse/clickhouse-server:latest", "clickhouse"), diff --git a/tests/unit/test_mssql_adapter.py b/tests/unit/test_mssql_adapter.py new file mode 100644 index 00000000..ea8fd19e --- /dev/null +++ b/tests/unit/test_mssql_adapter.py @@ -0,0 +1,237 @@ +"""Unit tests for MSSQL adapter - specifically testing Azure SQL compatibility. + +These tests verify that the MSSQL adapter uses USE [database] instead of +cross-database references like [Database].INFORMATION_SCHEMA.TABLES, +which are not supported in Azure SQL Database. + +Azure SQL Database has two restrictions: +1. Cross-database references like [Database].INFORMATION_SCHEMA.TABLES don't work +2. USE statement to switch databases doesn't work either + +The adapter handles both by: +1. Using USE [database] for regular SQL Server +2. Gracefully handling the USE failure for Azure SQL Database +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, call, patch + +import pytest + + +class TestMSSQLAdapterNoCrossDatabaseReferences: + """Test that MSSQL adapter avoids cross-database query syntax. + + Azure SQL Database does not support cross-database references like: + - [Database].INFORMATION_SCHEMA.TABLES + - [Database].sys.tables + + Instead, the adapter attempts USE [database] to switch context, + and gracefully handles failure (Azure SQL Database). + """ + + @pytest.fixture + def mock_mssql(self): + """Create a mock mssql_python module.""" + mock = MagicMock() + with patch.dict("sys.modules", {"mssql_python": mock}): + yield mock + + @pytest.fixture + def adapter(self, mock_mssql): + """Create an MSSQL adapter instance.""" + from sqlit.db.adapters.mssql import SQLServerAdapter + return SQLServerAdapter() + + @pytest.fixture + def mock_conn(self): + """Create a mock connection with cursor.""" + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + cursor.fetchall.return_value = [] + return conn + + def _get_executed_sql(self, mock_conn) -> list[str]: + """Extract all SQL statements executed on the cursor.""" + cursor = mock_conn.cursor.return_value + return [call[0][0] for call in cursor.execute.call_args_list] + + def _assert_no_cross_db_refs(self, sql_statements: list[str], database: str): + """Assert no SQL contains cross-database references.""" + patterns = [ + f"[{database}].", + f"[{database.lower()}].", + f"[{database.upper()}].", + ] + for sql in sql_statements: + for pattern in patterns: + assert pattern not in sql, ( + f"Found cross-database reference '{pattern}' in SQL: {sql}\n" + "This syntax is not supported in Azure SQL Database. " + "Use 'USE [database]' instead." + ) + + def _assert_uses_database_context(self, sql_statements: list[str], database: str): + """Assert USE [database] is called before other queries.""" + assert len(sql_statements) >= 1, "Expected at least one SQL statement" + use_stmt = sql_statements[0] + assert use_stmt == f"USE [{database}]", ( + f"Expected first statement to be 'USE [{database}]', got: {use_stmt}" + ) + + def test_get_tables_uses_context_switch(self, adapter, mock_conn): + """Test get_tables uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_tables(mock_conn, database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_views_uses_context_switch(self, adapter, mock_conn): + """Test get_views uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_views(mock_conn, database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_columns_uses_context_switch(self, adapter, mock_conn): + """Test get_columns uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_columns(mock_conn, table="Users", database=database, schema="dbo") + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_procedures_uses_context_switch(self, adapter, mock_conn): + """Test get_procedures uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_procedures(mock_conn, database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_indexes_uses_context_switch(self, adapter, mock_conn): + """Test get_indexes uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_indexes(mock_conn, database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_triggers_uses_context_switch(self, adapter, mock_conn): + """Test get_triggers uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_triggers(mock_conn, database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_sequences_uses_context_switch(self, adapter, mock_conn): + """Test get_sequences uses USE instead of cross-database reference.""" + database = "TestDB" + adapter.get_sequences(mock_conn, database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_index_definition_uses_context_switch(self, adapter, mock_conn): + """Test get_index_definition uses USE instead of cross-database reference.""" + database = "TestDB" + mock_conn.cursor.return_value.fetchall.return_value = [ + (False, "NONCLUSTERED", "col1") + ] + adapter.get_index_definition(mock_conn, "IX_Test", "Users", database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_trigger_definition_uses_context_switch(self, adapter, mock_conn): + """Test get_trigger_definition uses USE instead of cross-database reference.""" + database = "TestDB" + mock_conn.cursor.return_value.fetchone.return_value = ("CREATE TRIGGER...", "AFTER") + adapter.get_trigger_definition(mock_conn, "TR_Test", "Users", database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_get_sequence_definition_uses_context_switch(self, adapter, mock_conn): + """Test get_sequence_definition uses USE instead of cross-database reference.""" + database = "TestDB" + mock_conn.cursor.return_value.fetchone.return_value = (1, 1, 1, 9999, False) + adapter.get_sequence_definition(mock_conn, "SEQ_Test", database=database) + + sql_statements = self._get_executed_sql(mock_conn) + self._assert_uses_database_context(sql_statements, database) + self._assert_no_cross_db_refs(sql_statements, database) + + def test_no_use_statement_when_no_database(self, adapter, mock_conn): + """Test that USE is not called when database is None.""" + adapter.get_tables(mock_conn, database=None) + + sql_statements = self._get_executed_sql(mock_conn) + assert len(sql_statements) == 1, "Expected only one SQL statement" + assert not sql_statements[0].startswith("USE"), ( + "Should not call USE when database is None" + ) + + +class TestMSSQLAdapterQueries: + """Test MSSQL adapter query correctness.""" + + @pytest.fixture + def mock_mssql(self): + mock = MagicMock() + with patch.dict("sys.modules", {"mssql_python": mock}): + yield mock + + @pytest.fixture + def adapter(self, mock_mssql): + from sqlit.db.adapters.mssql import SQLServerAdapter + return SQLServerAdapter() + + def test_get_tables_query_structure(self, adapter): + """Test get_tables executes correct query.""" + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + cursor.fetchall.return_value = [("dbo", "Users"), ("dbo", "Orders")] + + result = adapter.get_tables(mock_conn, database="TestDB") + + assert result == [("dbo", "Users"), ("dbo", "Orders")] + # Verify query uses INFORMATION_SCHEMA without database prefix + query_call = cursor.execute.call_args_list[-1] + assert "INFORMATION_SCHEMA.TABLES" in query_call[0][0] + assert "BASE TABLE" in query_call[0][0] + + def test_get_columns_returns_primary_keys(self, adapter): + """Test get_columns correctly identifies primary keys.""" + mock_conn = MagicMock() + cursor = MagicMock() + mock_conn.cursor.return_value = cursor + + # First call returns PK columns, second returns all columns + cursor.fetchall.side_effect = [ + [("id",)], # Primary key columns + [("id", "int"), ("name", "varchar"), ("email", "varchar")], # All columns + ] + + result = adapter.get_columns(mock_conn, "Users", database="TestDB", schema="dbo") + + assert len(result) == 3 + assert result[0].name == "id" + assert result[0].is_primary_key is True + assert result[1].name == "name" + assert result[1].is_primary_key is False diff --git a/uv.lock b/uv.lock index af9cab2c..be08e94e 100644 --- a/uv.lock +++ b/uv.lock @@ -2185,6 +2185,7 @@ dependencies = [ { name = "docker" }, { name = "keyring" }, { name = "pyperclip" }, + { name = "sqlparse" }, { name = "textual", extra = ["syntax"] }, { name = "textual-fastdatatable" }, ] @@ -2310,6 +2311,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" }, { name = "snowflake-connector-python", marker = "extra == 'all'", specifier = ">=3.7.0" }, { name = "snowflake-connector-python", marker = "extra == 'snowflake'", specifier = ">=3.7.0" }, + { name = "sqlparse", specifier = ">=0.5.0" }, { name = "sshtunnel", marker = "extra == 'all'", specifier = ">=0.4.0" }, { name = "sshtunnel", marker = "extra == 'ssh'", specifier = ">=0.4.0" }, { name = "textual", extras = ["syntax"], specifier = ">=6.10.0" }, @@ -2317,6 +2319,15 @@ requires-dist = [ ] provides-extras = ["all", "clickhouse", "cockroachdb", "d1", "dev", "duckdb", "firebird", "mariadb", "mssql", "mysql", "oracle", "postgres", "snowflake", "ssh", "test", "turso"] +[[package]] +name = "sqlparse" +version = "0.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/76/437d71068094df0726366574cf3432a4ed754217b436eb7429415cf2d480/sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e", size = 120815, upload-time = "2025-12-19T07:17:45.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/4b/359f28a903c13438ef59ebeee215fb25da53066db67b305c125f1c6d2a25/sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba", size = 46138, upload-time = "2025-12-19T07:17:46.573Z" }, +] + [[package]] name = "sshtunnel" version = "0.4.0"