diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4f47334b..0fa79399 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,17 @@ jobs: - name: Run unit tests run: | - pytest tests/test_validation.py tests/ui/ -v --timeout=60 + pytest tests/ -v --timeout=60 \ + --ignore=tests/test_sqlite.py \ + --ignore=tests/test_mssql.py \ + --ignore=tests/test_postgresql.py \ + --ignore=tests/test_mysql.py \ + --ignore=tests/test_oracle.py \ + --ignore=tests/test_mariadb.py \ + --ignore=tests/test_duckdb.py \ + --ignore=tests/test_cockroachdb.py \ + --ignore=tests/test_turso.py \ + --ignore=tests/test_ssh.py test-sqlite: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 5cab76fb..a39071e8 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,11 @@ venv/ *.swp *.swo +# Local caches +.cache/ +.ruff_cache/ +.sqlit/ + # OS .DS_Store @@ -34,3 +39,6 @@ aur/src/ aur/pkg/ aur/*.pkg.tar.* sqlit-notifications/ + +# Integration test artifacts +tests/integration/python_packages/artifacts/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..ad9e2020 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-added-large-files + - id: check-merge-conflict diff --git a/.sqlit-config/settings.json b/.sqlit-config/settings.json index e33f3e0d..840dbf15 100644 --- a/.sqlit-config/settings.json +++ b/.sqlit-config/settings.json @@ -1,4 +1,4 @@ { "theme": "tokyo-night", "expanded_nodes": [] -} \ No newline at end of file +} diff --git a/README.md b/README.md index 2c307244..90917bc8 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ A lightweight TUI for people who just want to run some queries fast. ## Motivation -I usually do my work in the terminal, but I found myself either having to boot up massively bloated GUI's like SSMS or Vscode for the simple task of merely browsing my databases and doing some queries toward them. For the vast majority of my use cases, I never used any of the advanced features for inspection and debugging that SSMS and other feature-rich clients provide. +I usually do my work in the terminal, but I found myself either having to boot up massively bloated GUI's like SSMS or Vscode for the simple task of merely browsing my databases and doing some queries toward them. For the vast majority of my use cases, I never used any of the advanced features for inspection and debugging that SSMS and other feature-rich clients provide. I had the unfortunate situation where doing queries became a pain-point due to the massive operation it is to open SSMS and it's lack of intuitive keyboard only navigation. @@ -45,11 +45,56 @@ sqlit is a lightweight database TUI that is easy to use and beautiful to look at ## Installation +### Method 1: `pipx` (Recommended) + +This is the recommended method. It installs `sqlit-tui` in an isolated environment, so optional drivers are easy to add later. + +1. **Install pipx:** If you don't have pipx, you can install it with: + ```bash + python3 -m pip install --user pipx + python3 -m pipx ensurepath + ``` + *(You may need to restart your terminal after this step)* + +2. **Install sqlit-tui:** + ```bash + pipx install sqlit-tui + ``` + +3. **Optional drivers (only if you need them):** `sqlit` will tell you what to install when a driver is missing, but you can also pre-install them. For example: + ```bash + # PostgreSQL / Supabase / CockroachDB + pipx inject sqlit-tui psycopg2-binary + + # MySQL + pipx inject sqlit-tui mysql-connector-python + ``` + +### Method 2: `uv` (Alternative) + +`uv` is a fast, modern installer. This also keeps things isolated and makes optional drivers easy. + ```bash -pip install sqlit-tui +uv tool install sqlit-tui ``` -If you are missing Python packages for your database provider, sqlit will help you install them when you attempt to connect. If you want to pre-install requirements, see [Adapter Requirements](#adapter-requirements). +### Method 3: `pip` (Alternative) + +*(Note: To avoid dependency conflicts, installing in a virtual environment is recommended.)* + +You can install `sqlit-tui` and drivers directly with `pip` using "extras". The application will guide you if a driver is missing. +If you installed Python via a system package manager (Homebrew, apt, pacman, etc.), `pip install` may be restricted; in that case, use `pipx`, `uv`, or a virtual environment. + +```bash +# To install with PostgreSQL and MySQL support +pip install "sqlit-tui[postgres,mysql]" + +# To add a driver to an existing installation +pip install "sqlit-tui[mariadb]" + +# To install all drivers +pip install "sqlit-tui[all]" +``` ## Usage @@ -78,20 +123,29 @@ sqlit query -c "MyConnection" -q "SELECT * FROM Users" --format csv sqlit query -c "MyConnection" -f "script.sql" --format json # Create connections for different databases -sqlit connection create --name "MySqlServer" --db-type mssql --server "localhost" --auth-type sql -sqlit connection create --name "MyPostgres" --db-type postgresql --server "localhost" --username "user" --password "pass" -sqlit connection create --name "MyMySQL" --db-type mysql --server "localhost" --username "user" --password "pass" -sqlit connection create --name "MyCockroach" --db-type cockroachdb --server "localhost" --port "26257" --database "defaultdb" --username "root" -sqlit connection create --name "MyLocalDB" --db-type sqlite --file-path "/path/to/database.db" -sqlit connection create --name "MyTurso" --db-type turso --server "libsql://your-db.turso.io" --password "your-auth-token" +sqlit connections add mssql --name "MySqlServer" --server "localhost" --auth-type sql +sqlit connections add postgresql --name "MyPostgres" --server "localhost" --username "user" --password "pass" +sqlit connections add mysql --name "MyMySQL" --server "localhost" --username "user" --password "pass" +sqlit connections add cockroachdb --name "MyCockroach" --server "localhost" --port "26257" --database "defaultdb" --username "root" +sqlit connections add sqlite --name "MyLocalDB" --file-path "/path/to/database.db" +sqlit connections add turso --name "MyTurso" --server "libsql://your-db.turso.io" --password "your-auth-token" # Connect via SSH tunnel -sqlit connection create --name "RemoteDB" --db-type postgresql --server "db-host" --username "dbuser" --password "dbpass" \ +sqlit connections add postgresql --name "RemoteDB" --server "db-host" --username "dbuser" --password "dbpass" \ --ssh-enabled --ssh-host "ssh.example.com" --ssh-username "sshuser" --ssh-auth-type password --ssh-password "sshpass" +# Temporary (not saved) connection +sqlit connect sqlite --file-path "/path/to/database.db" + +# Provider-specific CLI help +sqlit connect -h +sqlit connect supabase -h +sqlit connections add -h +sqlit connections add supabase -h + # Manage connections -sqlit connection list -sqlit connection delete "MyConnection" +sqlit connections list +sqlit connections delete "MyConnection" ``` ## Keybindings @@ -105,6 +159,7 @@ sqlit connection delete "MyConnection" | `h` | Query history | | `d` | Clear query | | `n` | New query (clear all) | +| `y` | Copy query (when query editor is focused) | | `v` / `y` / `Y` / `a` | View cell / Copy cell / Copy row / Copy all | | `Ctrl+Q` | Quit | | `?` | Help | @@ -134,14 +189,16 @@ Connections and settings are stored in `~/.sqlit/`. ### How are sensitive credentials stored? -Credentials are stored in plain text in a protected directory (`~/.sqlit/`) with restricted file permissions (700/600). +Connection details are stored in `~/.sqlit/connections.json`, but passwords are stored in your OS keyring when available (macOS Keychain, Windows Credential Locker, Linux Secret Service). + +If a keyring backend isn't available, `sqlit` will ask whether to store passwords as plaintext in `~/.sqlit/` (protected permissions). If you decline, you’ll be prompted when needed. ### How does sqlit compare to Harlequin, Lazysql, etc.? sqlit is inspired by [lazygit](https://github.com/jesseduffield/lazygit) - you can just jump in and there's no need for external documentation. The keybindings are shown at the bottom of the screen and the UI is designed to be intuitive without memorizing shortcuts. Key differences: -- **No need for external documentation** - Sqlit embrace the "lazy" approach in that a user should be able to jump in and use it right away intuitively. There should be no setup instructions. If python packages are required for certain adapters, sqlit will help you install them as you need them. +- **No need for external documentation** - Sqlit embrace the "lazy" approach in that a user should be able to jump in and use it right away intuitively. There should be no setup instructions. If python packages are required for certain adapters, sqlit will help you install them as you need them. - **No CLI config required** - Just run `sqlit` and pick a connection from the UI - **Lightweight** - While Lazysql or Harlequin offer more features, I experienced that for the vast majority of cases, all I needed was a simple and fast way to connect and run queries. Sqlit is focused on doing a limited amount of things really well. @@ -155,24 +212,23 @@ sqlit is built with [Textual](https://github.com/Textualize/textual) and inspire See `CONTRIBUTING.md` for development setup, testing, CI, and CockroachDB quickstart steps. -## Adapter Requirements +### Driver Reference -Each database provider requires specific Python packages. sqlit will prompt you to install these when needed, but you can also pre-install them: +Most of the time you can just run `sqlit` and connect. If a Python driver is missing, `sqlit` will show (and often run) the right install command for your environment. -| Database | Package | Install Command | -|----------|---------|-----------------| -| SQLite | *(built-in)* | No installation needed | -| SQL Server | `pyodbc` | `pip install pyodbc` | -| PostgreSQL | `psycopg2-binary` | `pip install psycopg2-binary` | -| MySQL | `mysql-connector-python` | `pip install mysql-connector-python` | -| MariaDB | `mariadb` | `pip install mariadb` | -| Oracle | `oracledb` | `pip install oracledb` | -| DuckDB | `duckdb` | `pip install duckdb` | -| CockroachDB | `psycopg2-binary` | `pip install psycopg2-binary` | -| Supabase | `psycopg2-binary` | `pip install psycopg2-binary` | -| Turso | `libsql-client` | `pip install libsql-client` | +| Database | Driver package | `pipx` | `pip` / venv | +| :--- | :--- | :--- | :--- | +| SQLite | *(built-in)* | *(built-in)* | *(built-in)* | +| PostgreSQL / CockroachDB / Supabase | `psycopg2-binary` | `pipx inject sqlit-tui psycopg2-binary` | `python -m pip install psycopg2-binary` | +| SQL Server | `pyodbc` | `pipx inject sqlit-tui pyodbc` | `python -m pip install pyodbc` | +| MySQL | `mysql-connector-python` | `pipx inject sqlit-tui mysql-connector-python` | `python -m pip install mysql-connector-python` | +| MariaDB | `mariadb` | `pipx inject sqlit-tui mariadb` | `python -m pip install mariadb` | +| Oracle | `oracledb` | `pipx inject sqlit-tui oracledb` | `python -m pip install oracledb` | +| DuckDB | `duckdb` | `pipx inject sqlit-tui duckdb` | `python -m pip install duckdb` | +| Turso | `libsql-client` | `pipx inject sqlit-tui libsql-client` | `python -m pip install libsql-client` | +| Cloudflare D1 | `requests` | `pipx inject sqlit-tui requests` | `python -m pip install requests` | -**Note:** SQL Server also requires the ODBC driver. On first connection attempt, sqlit will detect if it's missing and help you install it. +**Note:** SQL Server also requires the platform-specific ODBC driver. On your first connection attempt, `sqlit` can help you install it if it's missing. ## License diff --git a/aur/PKGBUILD b/aur/PKGBUILD index f6d05f31..2e2497bc 100644 --- a/aur/PKGBUILD +++ b/aur/PKGBUILD @@ -13,6 +13,7 @@ depends=( 'python-pyperclip' 'python-sshtunnel' 'python-paramiko' + 'python-keyring' ) optdepends=( 'python-psycopg2: PostgreSQL and CockroachDB support' diff --git a/demo-connect.gif b/demo-connect.gif deleted file mode 100644 index fc172229..00000000 Binary files a/demo-connect.gif and /dev/null differ diff --git a/demo-history.gif b/demo-history.gif deleted file mode 100644 index 403e3a4d..00000000 Binary files a/demo-history.gif and /dev/null differ diff --git a/demo-providers.gif b/demo-providers.gif deleted file mode 100644 index 7425a2a4..00000000 Binary files a/demo-providers.gif and /dev/null differ diff --git a/demo-query.gif b/demo-query.gif deleted file mode 100644 index 54a35741..00000000 Binary files a/demo-query.gif and /dev/null differ diff --git a/demo-sqlite.gif b/demo-sqlite.gif deleted file mode 100644 index a25ede43..00000000 Binary files a/demo-sqlite.gif and /dev/null differ diff --git a/demos/demo-connect.gif b/demos/demo-connect.gif new file mode 100644 index 00000000..fc890116 Binary files /dev/null and b/demos/demo-connect.gif differ diff --git a/demos/demo-history.gif b/demos/demo-history.gif new file mode 100644 index 00000000..67ca0982 Binary files /dev/null and b/demos/demo-history.gif differ diff --git a/demos/demo-providers.gif b/demos/demo-providers.gif new file mode 100644 index 00000000..970f09e9 Binary files /dev/null and b/demos/demo-providers.gif differ diff --git a/demos/demo-query.gif b/demos/demo-query.gif new file mode 100644 index 00000000..74203c50 Binary files /dev/null and b/demos/demo-query.gif differ diff --git a/demos/demo-sqlite.gif b/demos/demo-sqlite.gif new file mode 100644 index 00000000..7334be98 Binary files /dev/null and b/demos/demo-sqlite.gif differ diff --git a/demo.db b/demos/demo.db similarity index 100% rename from demo.db rename to demos/demo.db diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 90ffc22b..301d0e09 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -156,3 +156,21 @@ services: start_period: 10s tmpfs: - /var/lib/postgresql/data + + miniflare: + build: + context: ./tests/fixtures/d1 + container_name: sqlit-test-miniflare + command: ["wrangler", "dev", "--local", "--ip", "0.0.0.0", "--port", "8787"] + ports: + - "8787:8787" + volumes: + - ./tests/fixtures/d1/wrangler.toml:/app/wrangler.toml + - ./tests/fixtures/d1/index.js:/app/index.js + working_dir: /app + healthcheck: + test: ["CMD", "node", "-e", "fetch('http://localhost:8787/').then(r=>{if(!r.ok)process.exit(1)}).catch(()=>process.exit(1))"] + interval: 5s + timeout: 5s + retries: 10 + start_period: 10s diff --git a/mypy-errors.txt b/mypy-errors.txt new file mode 100644 index 00000000..c46be3f9 --- /dev/null +++ b/mypy-errors.txt @@ -0,0 +1,208 @@ +sqlit/db/adapters/base.py:211: error: Returning Any from function declared to return "int" [no-any-return] +sqlit/db/adapters/turso.py:106: error: Returning Any from function declared to return "int" [no-any-return] +sqlit/db/adapters/sqlite.py:106: error: Returning Any from function declared to return "int" [no-any-return] +sqlit/stores/connections.py:48: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/db/adapters/mssql.py:187: error: Returning Any from function declared to return "int" [no-any-return] +sqlit/db/adapters/oracle.py:118: error: Returning Any from function declared to return "int" [no-any-return] +sqlit/config.py:31: error: Second argument of Enum() must be string, tuple, list or dict literal for mypy to determine Enum members [misc] +sqlit/config.py:109: error: Returning Any from function declared to return "DatabaseType" [no-any-return] +sqlit/config.py:109: error: "type[DatabaseType]" has no attribute "MSSQL" [attr-defined] +sqlit/mocks.py:262: error: Function "builtins.callable" is not valid as a type [valid-type] +sqlit/mocks.py:262: note: Perhaps you meant "typing.Callable" instead of "callable"? +sqlit/mocks.py:364: error: Function "builtins.callable" is not valid as a type [valid-type] +sqlit/mocks.py:364: note: Perhaps you meant "typing.Callable" instead of "callable"? +sqlit/services/query.py:104: error: Incompatible types in assignment (expression has type "NonQueryResult", variable has type "QueryResult") [assignment] +sqlit/db/__init__.py:55: error: Function is missing a return type annotation [no-untyped-def] +sqlit/db/__init__.py:61: error: Function is missing a return type annotation [no-untyped-def] +sqlit/db/__init__.py:93: error: Returning Any from function declared to return "None" [no-any-return] +sqlit/commands.py:27: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/commands.py:54: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/commands.py:144: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/commands.py:197: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/commands.py:217: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/commands.py:233: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/commands.py:291: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/widgets.py:112: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/widgets.py:164: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/value_view.py:87: error: Signature of "action_dismiss" incompatible with supertype "textual.screen.Screen" [override] +sqlit/ui/screens/value_view.py:87: note: Superclass: +sqlit/ui/screens/value_view.py:87: note: def action_dismiss(self, result: Any | None = ...) -> Coroutine[Any, Any, None] +sqlit/ui/screens/value_view.py:87: note: Subclass: +sqlit/ui/screens/value_view.py:87: note: def action_dismiss(self) -> None +sqlit/ui/screens/theme.py:76: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/query_history.py:122: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/query_history.py:148: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/help.py:51: error: Signature of "action_dismiss" incompatible with supertype "textual.screen.Screen" [override] +sqlit/ui/screens/help.py:51: note: Superclass: +sqlit/ui/screens/help.py:51: note: def action_dismiss(self, result: Any | None = ...) -> Coroutine[Any, Any, None] +sqlit/ui/screens/help.py:51: note: Subclass: +sqlit/ui/screens/help.py:51: note: def action_dismiss(self) -> None +sqlit/ui/screens/driver_setup.py:119: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/connection_picker.py:193: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/connection.py:247: error: Returning Any from function declared to return "DatabaseType" [no-any-return] +sqlit/ui/screens/connection.py:247: error: "type[DatabaseType]" has no attribute "MSSQL" [attr-defined] +sqlit/ui/screens/connection.py:251: error: No return value expected [return-value] +sqlit/ui/screens/connection.py:316: error: Incompatible types in assignment (expression has type "Select[str]", target has type "Input | OptionList") [assignment] +sqlit/ui/screens/connection.py:505: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/connection.py:639: error: Incompatible types in assignment (expression has type "Select[str]", target has type "Input | OptionList") [assignment] +sqlit/ui/screens/connection.py:683: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/screens/connection.py:867: error: Incompatible return value type (got "set[Any | None]", expected "set[str]") [return-value] +sqlit/ui/screens/connection.py:876: error: Need type annotation for "existing" (hint: "existing: list[] = ...") [var-annotated] +sqlit/ui/screens/connection.py:992: error: "type[DatabaseType]" has no attribute "MSSQL" [attr-defined] +sqlit/ui/screens/confirm.py:61: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:69: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:77: error: No return value expected [return-value] +sqlit/ui/mixins/tree.py:83: error: "get_conn_label" does not return a value (it only ever returns None) [func-returns-value] +sqlit/ui/mixins/tree.py:88: error: "get_conn_label" does not return a value (it only ever returns None) [func-returns-value] +sqlit/ui/mixins/tree.py:121: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:136: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:156: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:171: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:235: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:256: error: Unexpected keyword argument "thread" for "run_worker" of "AppProtocol" [call-arg] +sqlit/ui/mixins/tree.py:350: error: Function "builtins.any" is not valid as a type [valid-type] +sqlit/ui/mixins/tree.py:350: note: Perhaps you meant "typing.Any" instead of "any"? +sqlit/ui/mixins/protocols.py:61: note: "run_worker" of "AppProtocol" defined here +sqlit/ui/mixins/tree.py:258: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:273: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:293: error: List comprehension has incompatible type List[tuple[str, str]]; expected List[tuple[str, str, str]] [misc] +sqlit/ui/mixins/tree.py:304: error: Unexpected keyword argument "thread" for "run_worker" of "AppProtocol" [call-arg] +sqlit/ui/mixins/tree.py:211: error: Cannot determine type of "_loading_nodes" [has-type] +sqlit/ui/mixins/tree.py:306: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:328: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:379: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/tree.py:418: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/autocomplete.py:135: error: Unexpected keyword argument "thread" for "run_worker" of "AppProtocol" [call-arg] +sqlit/ui/mixins/protocols.py:61: note: "run_worker" of "AppProtocol" defined here +sqlit/ui/mixins/autocomplete.py:72: error: Unsupported left operand type for + ("Collection[Any]") [operator] +sqlit/ui/mixins/autocomplete.py:74: error: Incompatible types in assignment (expression has type "Collection[Any]", variable has type "list[Any]") [assignment] +sqlit/ui/mixins/autocomplete.py:80: error: "Collection[Any]" has no attribute "get" [attr-defined] +sqlit/ui/mixins/autocomplete.py:83: error: "Collection[Any]" has no attribute "values" [attr-defined] +sqlit/ui/mixins/autocomplete.py:85: error: No overload variant of "__add__" of "list" matches argument type "Collection[Any]" [operator] +sqlit/ui/mixins/autocomplete.py:85: note: Possible overload variants: +sqlit/ui/mixins/autocomplete.py:85: note: def __add__(self, list[Any], /) -> list[Any] +sqlit/ui/mixins/autocomplete.py:85: note: def [_S] __add__(self, list[_S], /) -> list[_S | Any] +sqlit/ui/mixins/autocomplete.py:140: error: Unsupported target for indexed assignment ("Collection[Any]") [index] +sqlit/ui/mixins/autocomplete.py:142: error: Unsupported target for indexed assignment ("Collection[Any]") [index] +sqlit/ui/mixins/autocomplete.py:202: error: Returning Any from function declared to return "int" [no-any-return] +sqlit/ui/mixins/autocomplete.py:252: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/autocomplete.py:288: error: Need type annotation for "_schema_cache" [var-annotated] +sqlit/ui/mixins/autocomplete.py:373: error: List item 0 has incompatible type "None"; expected "str" [list-item] +sqlit/state_machine.py:101: error: No return value expected [return-value] +sqlit/state_machine.py:463: error: Incompatible return value type (got "ConnectionConfig | bool | None", expected "bool") [return-value] +sqlit/state_machine.py:471: error: Incompatible return value type (got "ConnectionConfig | bool | None", expected "bool") [return-value] +sqlit/ui/screens/leader_menu.py:72: error: Signature of "action_dismiss" incompatible with supertype "textual.screen.Screen" [override] +sqlit/ui/screens/leader_menu.py:72: note: Superclass: +sqlit/ui/screens/leader_menu.py:72: note: def action_dismiss(self, result: Any | None = ...) -> Coroutine[Any, Any, None] +sqlit/ui/screens/leader_menu.py:72: note: Subclass: +sqlit/ui/screens/leader_menu.py:72: note: def action_dismiss(self) -> None +sqlit/ui/screens/leader_menu.py:87: error: Argument 1 to "is_allowed" of "LeaderCommand" has incompatible type "App[Any]"; expected "SSMSTUI" [arg-type] +sqlit/ui/screens/leader_menu.py:90: error: No return value expected [return-value] +sqlit/ui/mixins/results.py:31: error: Unused "type: ignore" comment [unused-ignore] +sqlit/ui/mixins/results.py:52: error: Incompatible types in assignment (expression has type "str", variable has type "Literal['cell', 'row', 'column', 'none']") [assignment] +sqlit/ui/mixins/results.py:62: error: Incompatible types in assignment (expression has type "Any | str", variable has type "Literal['cell', 'row', 'column', 'none']") [assignment] +sqlit/ui/mixins/query.py:277: error: Cannot determine type of "_schema_worker" [has-type] +sqlit/ui/mixins/query.py:278: error: Cannot determine type of "_schema_worker" [has-type] +sqlit/ui/mixins/query.py:280: error: "QueryMixin" has no attribute "_stop_schema_spinner"; maybe "_stop_query_spinner"? [attr-defined] +sqlit/ui/mixins/query.py:307: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/query.py:312: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/query.py:328: error: Item "None" of "ConnectionConfig | None" has no attribute "name" [union-attr] +sqlit/ui/mixins/connection.py:38: error: Cannot determine type of "_session" [has-type] +sqlit/ui/mixins/connection.py:39: error: Cannot determine type of "_session" [has-type] +sqlit/ui/mixins/connection.py:45: error: "ConnectionMixin" has no attribute "refresh_tree" [attr-defined] +sqlit/ui/mixins/connection.py:66: error: "ConnectionMixin" has no attribute "refresh_tree" [attr-defined] +sqlit/ui/mixins/connection.py:67: error: "ConnectionMixin" has no attribute "_load_schema_cache" [attr-defined] +sqlit/ui/mixins/connection.py:86: error: Unexpected keyword argument "thread" for "run_worker" of "AppProtocol" [call-arg] +sqlit/ui/mixins/connection.py:116: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/connection.py:132: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/connection.py:139: error: "ConnectionMixin" has no attribute "query_one" [attr-defined] +sqlit/ui/mixins/connection.py:254: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/protocols.py:61: note: "run_worker" of "AppProtocol" defined here +sqlit/ui/mixins/connection.py:106: error: "ConnectionMixin" has no attribute "status_bar" [attr-defined] +sqlit/ui/mixins/connection.py:108: error: "ConnectionMixin" has no attribute "refresh_tree" [attr-defined] +sqlit/ui/mixins/connection.py:159: error: Cannot determine type of "connections" [has-type] +sqlit/ui/mixins/connection.py:165: error: "ConnectionMixin" has no attribute "refresh_tree" [attr-defined] +sqlit/ui/mixins/connection.py:196: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/connection.py:217: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/connection.py:231: error: "ConnectionMixin" has no attribute "refresh_tree" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:21: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:22: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:23: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:26: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:28: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:30: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:35: error: "UINavigationMixin" has no attribute "query_one" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:36: error: "UINavigationMixin" has no attribute "query_one" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:37: error: "UINavigationMixin" has no attribute "query_one" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:43: error: "UINavigationMixin" has no attribute "focused" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:76: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:77: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:118: error: "UINavigationMixin" has no attribute "_hide_autocomplete" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:127: error: "UINavigationMixin" has no attribute "status_bar" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:198: error: "UINavigationMixin" has no attribute "size" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:229: error: Cannot determine type of "_notification_timer" [has-type] +sqlit/ui/mixins/ui_navigation.py:230: error: Cannot determine type of "_notification_timer" [has-type] +sqlit/ui/mixins/ui_navigation.py:234: error: "UINavigationMixin" has no attribute "_notification_history"; maybe "_notification_timer"? [attr-defined] +sqlit/ui/mixins/ui_navigation.py:276: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:277: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:283: error: "UINavigationMixin" has no attribute "screen" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:293: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/ui_navigation.py:326: error: "UINavigationMixin" has no attribute "query_one" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:330: error: "UINavigationMixin" has no attribute "_state_machine" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:341: error: "UINavigationMixin" has no attribute "_state_machine" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:353: error: Cannot determine type of "_leader_timer" [has-type] +sqlit/ui/mixins/ui_navigation.py:354: error: Cannot determine type of "_leader_timer" [has-type] +sqlit/ui/mixins/ui_navigation.py:364: error: "UINavigationMixin" has no attribute "set_timer" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:381: error: "UINavigationMixin" has no attribute "exit" [attr-defined] +sqlit/ui/mixins/ui_navigation.py:396: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/ui_navigation.py:428: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/ui/mixins/ui_navigation.py:439: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/app.py:179: error: "get_leader_bindings" does not return a value (it only ever returns None) [func-returns-value] +sqlit/app.py:270: error: No return value expected [return-value] +sqlit/app.py:272: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/app.py:274: error: No return value expected [return-value] +sqlit/app.py:276: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/app.py:277: error: No return value expected [return-value] +sqlit/app.py:279: error: Argument "adapter_factory" to "create" of "ConnectionSession" has incompatible type "Callable[[str], None]"; expected "Callable[[str], DatabaseAdapter] | None" [arg-type] +sqlit/app.py:280: error: Argument "tunnel_factory" to "create" of "ConnectionSession" has incompatible type "Callable[[Any], None]"; expected "Callable[[ConnectionConfig], tuple[Any, str, int]] | None" [arg-type] +sqlit/app.py:283: error: No return value expected [return-value] +sqlit/app.py:286: error: Cannot override writeable attribute with read-only property [override] +sqlit/app.py:290: error: Cannot override writeable attribute with read-only property [override] +sqlit/app.py:294: error: Cannot override writeable attribute with read-only property [override] +sqlit/app.py:299: error: No return value expected [return-value] +sqlit/app.py:303: error: No return value expected [return-value] +sqlit/app.py:307: error: No return value expected [return-value] +sqlit/app.py:311: error: No return value expected [return-value] +sqlit/app.py:321: error: No return value expected [return-value] +sqlit/app.py:323: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/app.py:325: error: Call to abstract method "push_screen" of "AppProtocol" with trivial body via super() is unsafe [safe-super] +sqlit/app.py:325: error: Too many arguments for "push_screen" of "AppProtocol" [call-arg] +sqlit/app.py:325: error: Unexpected keyword argument "wait_for_dismiss" for "push_screen" of "AppProtocol" [call-arg] +sqlit/ui/mixins/protocols.py:72: note: "push_screen" of "AppProtocol" defined here +sqlit/app.py:36: error: Cannot determine type of "_schema_worker" in base class "QueryMixin" [misc] +sqlit/app.py:36: error: Definition of "action_quit" in base class "AppProtocol" is incompatible with definition in base class "App" [misc] +sqlit/app.py:36: error: Definition of "notify" in base class "UINavigationMixin" is incompatible with definition in base class "App" [misc] +sqlit/app.py:36: error: Definition of "run_worker" in base class "AppProtocol" is incompatible with definition in base class "DOMNode" [misc] +sqlit/app.py:36: error: Definition of "set_interval" in base class "AppProtocol" is incompatible with definition in base class "MessagePump" [misc] +sqlit/app.py:262: error: "_create_mock_session_factory" of "SSMSTUI" does not return a value (it only ever returns None) [func-returns-value] +sqlit/app.py:323: error: Signature of "push_screen" incompatible with supertype "textual.app.App" [override] +sqlit/app.py:323: note: Superclass: +sqlit/app.py:323: note: @overload +sqlit/app.py:323: note: def [ScreenResultType] push_screen(self, screen: Screen[ScreenResultType] | str, callback: Callable[[ScreenResultType | None], None] | Callable[[ScreenResultType | None], Awaitable[None]] | None = ..., wait_for_dismiss: Literal[False] = ...) -> AwaitMount +sqlit/app.py:323: note: @overload +sqlit/app.py:323: note: def [ScreenResultType] push_screen(self, screen: Screen[ScreenResultType] | str, callback: Callable[[ScreenResultType | None], None] | Callable[[ScreenResultType | None], Awaitable[None]] | None = ..., wait_for_dismiss: Literal[True] = ...) -> Future[ScreenResultType] +sqlit/app.py:323: note: Subclass: +sqlit/app.py:323: note: def push_screen(self, screen: Any, callback: Any = ..., wait_for_dismiss: bool = ...) -> None +sqlit/app.py:327: error: Returning Any from function declared to return "None" [no-any-return] +sqlit/app.py:329: error: Return type "None" of "pop_screen" incompatible with return type "AwaitComplete" in supertype "textual.app.App" [override] +sqlit/app.py:331: error: Call to abstract method "pop_screen" of "AppProtocol" with trivial body via super() is unsafe [safe-super] +sqlit/app.py:333: error: Returning Any from function declared to return "None" [no-any-return] +sqlit/app.py:348: error: Need type annotation for "tree" [var-annotated] +sqlit/app.py:433: error: Function is missing a type annotation for one or more arguments [no-untyped-def] +sqlit/cli.py:131: error: Cannot instantiate abstract class "SSMSTUI" with abstract attributes "action_quit", "call_from_thread", ... and "set_interval" (4 methods suppressed) [abstract] +sqlit/cli.py:131: note: The following methods were marked implicitly abstract because they have empty function bodies: "action_quit", "call_from_thread", ... and "set_interval" (3 methods suppressed). If they are not meant to be abstract, explicitly `return` or `return None`. +sqlit/__init__.py:26: error: No return value expected [return-value] +sqlit/__init__.py:30: error: No return value expected [return-value] +sqlit/__init__.py:34: error: No return value expected [return-value] +sqlit/__init__.py:38: error: No return value expected [return-value] +Found 177 errors in 31 files (checked 64 source files) diff --git a/mypy-summary.md b/mypy-summary.md new file mode 100644 index 00000000..1301177a --- /dev/null +++ b/mypy-summary.md @@ -0,0 +1,68 @@ +# mypy Error Summary + +## Files with Most Errors +``` +12 errors - sqlit/ui/mixins/tree.py +10 errors - sqlit/ui/screens/connection.py + 7 errors - sqlit/commands.py + 6 errors - sqlit/db/adapters/ (various) + 4 errors - sqlit/app.py +``` + +## Error Types + +### 1. no-untyped-def (40 errors) +Functions missing parameter type annotations. + +**Example:** +```python +# Before +def on_button_pressed(event): + ... + +# After +def on_button_pressed(event: Button.Pressed) -> None: + ... +``` + +### 2. attr-defined (38 errors) +Missing attributes in Protocol or enum issues. + +**Common ones:** +- `DatabaseType.MSSQL` - enum member access +- Missing Protocol attributes + +### 3. return-value (20 errors) +Wrong return type annotations. + +**Example:** +```python +# Wrong +def cursor(self) -> None: + return MockCursor() # Error! + +# Fixed +def cursor(self) -> MockCursor: + return MockCursor() +``` + +### 4. no-any-return (11 errors) +Functions returning `Any` but declared to return specific type. + +**Example:** +```python +# In adapters +def get_row_count(self, cursor: Any) -> int: + return cursor.rowcount # rowcount is Any type +``` + +## Quick Fixes + +Run in Neovim: +1. `:Telescope diagnostics` - Browse all errors +2. `:lua vim.diagnostic.open_float()` - Show error under cursor +3. `gd` - Go to definition +4. `K` - Show hover info + +## Full Error List +See: mypy-errors.txt (177 errors across 31 files) diff --git a/pyproject.toml b/pyproject.toml index 712b18d0..65f3923e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,26 +27,57 @@ classifiers = [ "Topic :: Database", ] dependencies = [ - "textual[syntax]>=0.50.0", - "pyodbc>=5.0.0", + "textual[syntax]>=6.10.0", "pyperclip>=1.8.2", "sshtunnel>=0.4.0", "paramiko>=2.0.0,<4.0.0", # sshtunnel requires paramiko<4.0.0 (DSSKey removed in 4.0) + "keyring>=24.0.0", ] [project.optional-dependencies] +all = [ + "psycopg2-binary>=2.9.0", + "pyodbc>=5.0.0", + "mysql-connector-python>=8.0.0", + "mariadb>=1.1.0", + "oracledb>=2.0.0", + "duckdb>=0.9.0", + "requests>=2.0.0", + "libsql-client>=0.1.0", +] +postgres = ["psycopg2-binary>=2.9.0"] +cockroachdb = ["psycopg2-binary>=2.9.0"] +mssql = ["pyodbc>=5.0.0"] +mysql = ["mysql-connector-python>=8.0.0"] +mariadb = ["mariadb>=1.1.0"] +oracle = ["oracledb>=2.0.0"] +duckdb = ["duckdb>=0.9.0"] +d1 = ["requests>=2.0.0"] +turso = ["libsql-client>=0.1.0"] test = [ "pytest>=7.0", "pytest-timeout>=2.0", "pytest-asyncio>=0.23.0", + "pytest-cov>=4.0", ] dev = [ "pytest>=7.0", "pytest-timeout>=2.0", "pytest-asyncio>=0.23.0", + "pytest-cov>=4.0", "build", + "ruff>=0.8.0", + "mypy>=1.0", + "pre-commit>=3.0", ] +[tool.ruff] +target-version = "py310" +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I", "UP"] # pycodestyle, pyflakes, isort, pyupgrade + [project.scripts] sqlit = "sqlit.cli:main" @@ -73,3 +104,40 @@ markers = [ "oracle: Oracle database tests", "asyncio: async tests", ] + +[tool.coverage.run] +source = ["sqlit"] +branch = true + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", +] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true +check_untyped_defs = true +warn_redundant_casts = true +strict_optional = true +exclude = ["tests/"] +namespace_packages = true + +[[tool.mypy.overrides]] +module = [ + "mysql.connector", + "mariadb", + "oracledb", + "duckdb", + "libsql_client", + "sshtunnel" +] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "pyodbc" +ignore_missing_imports = true diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..20d7f22a --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,10 @@ +{ + "include": ["sqlit"], + "exclude": ["tests", "**/__pycache__", ".venv", "venv"], + "pythonVersion": "3.10", + "typeCheckingMode": "basic", + "useLibraryCodeForTypes": true, + "reportMissingImports": false, + "reportMissingTypeStubs": false, + "stubPath": "" +} diff --git a/settings.template.json b/settings.template.json new file mode 100644 index 00000000..a58d9e52 --- /dev/null +++ b/settings.template.json @@ -0,0 +1,314 @@ +{ + "_note": "Copy to .sqlit/settings.json (gitignored) and run: sqlit --settings .sqlit/settings.json", + "theme": "tokyo-night", + "expanded_nodes": [], + "allow_plaintext_credentials": false, + "mock": { + "enabled": true, + "profile": "empty", + "use_default_adapters": true, + "connections": [], + "drivers": { + "missing": [], + "install_result": "real", + "pipx": "auto" + }, + "adapters": { + "supabase": { + "name": "Supabase", + "default_schema": "public", + "connect": { + "_note": "Use required_fields/allowed/auth_error_message to control auth outcomes.", + "result": "success", + "error_message": "Supabase connection failed", + "auth_error_message": "Invalid Supabase credentials", + "required_fields": [ + "supabase_region", + "supabase_project_id", + "password" + ], + "allowed": [] + }, + "schemas": { + "public": { + "tables": { + "profiles": { + "columns": [ + { "name": "id", "type": "uuid" }, + { "name": "username", "type": "text" }, + { "name": "full_name", "type": "text" }, + { "name": "avatar_url", "type": "text" }, + { "name": "created_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "9a1b2c3d-1111-2222-3333-444455556666", + "alice", + "Alice Johnson", + "https://cdn.example.com/avatars/alice.png", + "2024-01-10T12:00:00Z" + ], + [ + "7f8e9d0c-aaaa-bbbb-cccc-ddddeeeeffff", + "bob", + "Bob Smith", + "https://cdn.example.com/avatars/bob.png", + "2024-01-12T08:30:00Z" + ] + ] + }, + "projects": { + "columns": [ + { "name": "id", "type": "uuid" }, + { "name": "name", "type": "text" }, + { "name": "owner_id", "type": "uuid" }, + { "name": "created_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "0d9f1c2b-3333-4444-5555-666677778888", + "Launch Checklist", + "9a1b2c3d-1111-2222-3333-444455556666", + "2024-02-01T09:00:00Z" + ], + [ + "3b2a1c0d-9999-8888-7777-666655554444", + "Customer Portal", + "7f8e9d0c-aaaa-bbbb-cccc-ddddeeeeffff", + "2024-02-05T16:45:00Z" + ] + ] + }, + "todos": { + "columns": [ + { "name": "id", "type": "bigint" }, + { "name": "title", "type": "text" }, + { "name": "is_complete", "type": "boolean" }, + { "name": "user_id", "type": "uuid" }, + { "name": "inserted_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + 1, + "Set up Row Level Security", + true, + "9a1b2c3d-1111-2222-3333-444455556666", + "2024-02-02T10:15:00Z" + ], + [ + 2, + "Draft onboarding email", + false, + "7f8e9d0c-aaaa-bbbb-cccc-ddddeeeeffff", + "2024-02-03T13:20:00Z" + ] + ] + } + }, + "views": { + "active_todos": { + "columns": [ + { "name": "id", "type": "bigint" }, + { "name": "title", "type": "text" }, + { "name": "user_id", "type": "uuid" } + ], + "rows": [ + [ + 2, + "Draft onboarding email", + "7f8e9d0c-aaaa-bbbb-cccc-ddddeeeeffff" + ] + ] + } + } + }, + "auth": { + "tables": { + "users": { + "columns": [ + { "name": "id", "type": "uuid" }, + { "name": "email", "type": "text" }, + { "name": "phone", "type": "text" }, + { "name": "created_at", "type": "timestamp with time zone" }, + { "name": "last_sign_in_at", "type": "timestamp with time zone" }, + { "name": "role", "type": "text" }, + { "name": "is_anonymous", "type": "boolean" } + ], + "rows": [ + [ + "9a1b2c3d-1111-2222-3333-444455556666", + "alice@example.com", + "+155555501", + "2024-01-10T12:00:00Z", + "2024-02-10T12:45:00Z", + "authenticated", + false + ], + [ + "7f8e9d0c-aaaa-bbbb-cccc-ddddeeeeffff", + "bob@example.com", + "+155555502", + "2024-01-12T08:30:00Z", + "2024-02-11T09:10:00Z", + "authenticated", + false + ] + ] + }, + "identities": { + "columns": [ + { "name": "id", "type": "uuid" }, + { "name": "user_id", "type": "uuid" }, + { "name": "provider", "type": "text" }, + { "name": "identity_data", "type": "jsonb" }, + { "name": "created_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "2c3d4e5f-0000-1111-2222-333344445555", + "9a1b2c3d-1111-2222-3333-444455556666", + "email", + { "email_verified": true }, + "2024-01-10T12:00:00Z" + ] + ] + }, + "sessions": { + "columns": [ + { "name": "id", "type": "uuid" }, + { "name": "user_id", "type": "uuid" }, + { "name": "created_at", "type": "timestamp with time zone" }, + { "name": "expires_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "11112222-3333-4444-5555-666677778888", + "9a1b2c3d-1111-2222-3333-444455556666", + "2024-02-10T12:45:00Z", + "2024-03-10T12:45:00Z" + ] + ] + }, + "refresh_tokens": { + "columns": [ + { "name": "id", "type": "bigint" }, + { "name": "user_id", "type": "uuid" }, + { "name": "token", "type": "text" }, + { "name": "created_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + 1001, + "9a1b2c3d-1111-2222-3333-444455556666", + "rt_abc123", + "2024-02-10T12:45:00Z" + ] + ] + }, + "schema_migrations": { + "columns": [ + { "name": "version", "type": "text" }, + { "name": "inserted_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "20240208094500", + "2024-02-08T09:45:00Z" + ] + ] + } + } + }, + "storage": { + "tables": { + "buckets": { + "columns": [ + { "name": "id", "type": "text" }, + { "name": "name", "type": "text" }, + { "name": "owner", "type": "uuid" }, + { "name": "public", "type": "boolean" }, + { "name": "created_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "avatars", + "avatars", + "9a1b2c3d-1111-2222-3333-444455556666", + true, + "2024-01-11T15:20:00Z" + ] + ] + }, + "objects": { + "columns": [ + { "name": "id", "type": "uuid" }, + { "name": "bucket_id", "type": "text" }, + { "name": "name", "type": "text" }, + { "name": "owner", "type": "uuid" }, + { "name": "metadata", "type": "jsonb" }, + { "name": "created_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + "f1e2d3c4-5555-6666-7777-888899990000", + "avatars", + "public/alice.png", + "9a1b2c3d-1111-2222-3333-444455556666", + { "content_type": "image/png", "size": 24512 }, + "2024-01-11T15:20:10Z" + ] + ] + } + } + }, + "realtime": { + "tables": { + "subscription": { + "columns": [ + { "name": "id", "type": "bigint" }, + { "name": "created_at", "type": "timestamp with time zone" }, + { "name": "topic", "type": "text" }, + { "name": "claims", "type": "jsonb" } + ], + "rows": [ + [ + 5001, + "2024-02-12T10:00:00Z", + "realtime:public:todos", + { "role": "authenticated" } + ] + ] + }, + "schema_migrations": { + "columns": [ + { "name": "version", "type": "integer" }, + { "name": "inserted_at", "type": "timestamp with time zone" } + ], + "rows": [ + [ + 1, + "2024-02-01T00:00:00Z" + ] + ] + } + } + } + }, + "query_results": { + "select now()": { + "columns": [ "now" ], + "rows": [ + [ "2024-02-15T12:00:00Z" ] + ] + } + }, + "default_query_result": { + "columns": [ "result" ], + "rows": [ + [ "Query executed successfully (mock)" ] + ] + } + } + } + } +} diff --git a/sqlit/__init__.py b/sqlit/__init__.py index fc5d8983..b1db33a6 100644 --- a/sqlit/__init__.py +++ b/sqlit/__init__.py @@ -1,35 +1,59 @@ """sqlit - A terminal UI for SQL databases.""" -from importlib.metadata import version, PackageNotFoundError - -try: - __version__ = version("sqlit-tui") -except PackageNotFoundError: - # Package not installed (development mode without editable install) - __version__ = "0.0.0.dev" +from typing import TYPE_CHECKING, Any __author__ = "Peter" __all__ = [ + "__version__", "main", "SSMSTUI", "AuthType", "ConnectionConfig", ] +if TYPE_CHECKING: + from .app import SSMSTUI + from .cli import main + from .config import AuthType, ConnectionConfig + from importlib.metadata import PackageNotFoundError # noqa: F401 + + +_VERSION_CACHE: str | None = None + + +def _get_version() -> str: + global _VERSION_CACHE + if _VERSION_CACHE is not None: + return _VERSION_CACHE + try: + from importlib.metadata import PackageNotFoundError, version + + _VERSION_CACHE = version("sqlit-tui") + except PackageNotFoundError: + # Package not installed (development mode without editable install) + _VERSION_CACHE = "0.0.0.dev" + return _VERSION_CACHE -def __getattr__(name: str): + +def __getattr__(name: str) -> Any: """Lazy import for heavy modules to keep package import side-effect free.""" + if name == "__version__": + return _get_version() if name == "main": from .cli import main + return main if name == "SSMSTUI": from .app import SSMSTUI + return SSMSTUI if name == "AuthType": from .config import AuthType + return AuthType if name == "ConnectionConfig": from .config import ConnectionConfig + return ConnectionConfig raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/sqlit/adapters.py b/sqlit/adapters.py deleted file mode 100644 index 95ada3e6..00000000 --- a/sqlit/adapters.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Database adapters for sqlit - abstraction layer for different database types. - -This module re-exports from sqlit.db for backward compatibility. -New code should import directly from sqlit.db or sqlit.db.adapters. -""" - -# Re-export everything from the new location for backward compatibility -from .db import ( - ColumnInfo, - CockroachDBAdapter, - DatabaseAdapter, - DuckDBAdapter, - MariaDBAdapter, - MySQLAdapter, - OracleAdapter, - PostgreSQLAdapter, - SQLiteAdapter, - SQLServerAdapter, - create_ssh_tunnel, - get_adapter, -) - -__all__ = [ - # Base - "ColumnInfo", - "DatabaseAdapter", - # Adapters - "CockroachDBAdapter", - "DuckDBAdapter", - "MariaDBAdapter", - "MySQLAdapter", - "OracleAdapter", - "PostgreSQLAdapter", - "SQLiteAdapter", - "SQLServerAdapter", - # Factory - "get_adapter", - # Tunnel - "create_ssh_tunnel", -] diff --git a/sqlit/app.py b/sqlit/app.py index 72b5536b..fbe97680 100644 --- a/sqlit/app.py +++ b/sqlit/app.py @@ -2,19 +2,24 @@ from __future__ import annotations +import json +import os +import sys +import tempfile +import time +from collections.abc import Awaitable, Callable +from pathlib import Path from typing import Any -try: - import pyodbc - - PYODBC_AVAILABLE = True -except ImportError: - PYODBC_AVAILABLE = False - from textual.app import App, ComposeResult from textual.binding import Binding from textual.containers import Container, Horizontal, Vertical +from textual.lazy import Lazy +from textual.screen import ModalScreen +from textual.theme import Theme +from textual.timer import Timer from textual.widgets import DataTable, Static, TextArea, Tree +from textual.worker import Worker from .config import ( ConnectionConfig, @@ -24,23 +29,26 @@ ) from .db import DatabaseAdapter from .mocks import MockProfile +from .mock_settings import apply_mock_environment, build_mock_profile_from_settings from .state_machine import ( - get_leader_bindings, UIStateMachine, + get_leader_bindings, ) from .ui.mixins import ( AutocompleteMixin, ConnectionMixin, QueryMixin, ResultsMixin, + TreeFilterMixin, TreeMixin, UINavigationMixin, ) -from .widgets import AutocompleteDropdown, ContextFooter, VimMode +from .widgets import AutocompleteDropdown, ContextFooter, TreeFilterInput, VimMode class SSMSTUI( TreeMixin, + TreeFilterMixin, ConnectionMixin, QueryMixin, AutocompleteMixin, @@ -52,21 +60,57 @@ class SSMSTUI( TITLE = "sqlit" + _SQLIT_THEMES = [ + Theme( + name="sqlit", + primary="#97CB93", + secondary="#6D8DC4", + accent="#6D8DC4", + warning="#f59e0b", + error="#BE728C", + success="#4ADE80", + foreground="#a9b1d6", + background="#1A1B26", + surface="#24283B", + panel="#414868", + dark=True, + variables={ + "border": "#7a7f99", + "border-blurred": "#7a7f99", + "footer-background": "#24283B", + "footer-key-foreground": "#7FA1DE", + "button-color-foreground": "#1A1B26", + "input-selection-background": "#2a3144 40%", + }, + ), + ] + CSS = """ Screen { background: $surface; } + TextArea { + & > .text-area--cursor-line { + background: transparent; + } + &:focus > .text-area--cursor-line { + background: $surface-lighten-1; + } + } + DataTable.flash-cell:focus > .datatable--cursor, DataTable.flash-row:focus > .datatable--cursor, DataTable.flash-all:focus > .datatable--cursor { - background: $success; - color: $background; - text-style: bold; + background: $success 30%; } DataTable.flash-all { - border: solid $success; + border: solid $success 30%; + } + + .flash { + background: $success 30%; } Screen.results-fullscreen #sidebar { @@ -100,7 +144,6 @@ class SSMSTUI( Screen.explorer-fullscreen #sidebar { width: 1fr; - border-right: none; } Screen.explorer-hidden #sidebar { @@ -118,8 +161,9 @@ class SSMSTUI( #sidebar { width: 35; - border-right: solid $primary; + border: round $border; padding: 1; + margin: 0; } #object-tree { @@ -132,19 +176,31 @@ class SSMSTUI( #query-area { height: 50%; - border-bottom: solid $primary; + border: round $border; padding: 1; + margin: 0; } #query-input { height: 1fr; + border: none; } #results-area { height: 50%; padding: 1; + border: round $border; + margin: 0; } + #sidebar.active-pane, + #query-area.active-pane, + #results-area.active-pane { + border: round $primary; + border-title-color: $primary; + } + + #results-table { height: 1fr; } @@ -155,16 +211,13 @@ class SSMSTUI( padding: 0 1; } - .section-label { - height: 1; - color: $text-muted; - padding: 0 1; - margin-bottom: 1; - } - - .section-label.active { - color: $primary; - text-style: bold; + #sidebar, + #query-area, + #results-area { + border-title-align: left; + border-title-color: $border; + border-title-background: $surface; + border-title-style: bold; } #autocomplete-dropdown { @@ -203,20 +256,49 @@ class SSMSTUI( Binding("escape", "exit_insert_mode", "Normal", show=False), Binding("enter", "execute_query", "Execute", show=False), Binding("f5", "execute_query_insert", "Execute", show=False), + Binding("ctrl+enter", "execute_query_insert", "Execute", show=False), Binding("d", "clear_query", "Clear", show=False), Binding("n", "new_query", "New", show=False), Binding("h", "show_history", "History", show=False), Binding("z", "collapse_tree", "Collapse", show=False), + Binding("j", "tree_cursor_down", "Down", show=False), + Binding("k", "tree_cursor_up", "Up", show=False), Binding("v", "view_cell", "View cell", show=False), - Binding("y", "copy_cell", "Copy cell", show=False), + Binding("u", "edit_cell", "Update cell", show=False), + Binding("h", "results_cursor_left", "Left", show=False), + Binding("j", "results_cursor_down", "Down", show=False), + Binding("k", "results_cursor_up", "Up", show=False), + Binding("l", "results_cursor_right", "Right", show=False), + Binding("y", "copy_context", "Copy", show=False), Binding("Y", "copy_row", "Copy row", show=False), Binding("a", "copy_results", "Copy results", show=False), - Binding("ctrl+c", "cancel_operation", "Cancel", show=False), + Binding("x", "clear_results", "Clear results", show=False), + Binding("ctrl+z", "cancel_operation", "Cancel", show=False), + Binding("ctrl+j", "autocomplete_next", "Next suggestion", show=False), + Binding("ctrl+k", "autocomplete_prev", "Prev suggestion", show=False), + Binding("slash", "tree_filter", "Filter", show=False), + Binding("escape", "tree_filter_close", "Close filter", show=False), + Binding("enter", "tree_filter_accept", "Select", show=False), + Binding("n", "tree_filter_next", "Next match", show=False), + Binding("N", "tree_filter_prev", "Prev match", show=False), ] - def __init__(self, mock_profile: MockProfile | None = None): + def __init__( + self, + mock_profile: MockProfile | None = None, + startup_connection: ConnectionConfig | None = None, + ): super().__init__() self._mock_profile = mock_profile + self._startup_connection = startup_connection + self._startup_connect_config: ConnectionConfig | None = None + self._debug_mode = os.environ.get("SQLIT_DEBUG") == "1" + self._startup_profile = os.environ.get("SQLIT_PROFILE_STARTUP") == "1" + self._startup_mark = self._parse_startup_mark(os.environ.get("SQLIT_STARTUP_MARK")) + self._startup_init_time = time.perf_counter() + self._startup_events: list[tuple[str, float]] = [] + self._launch_ms: float | None = None + self._startup_stamp("init_start") self.connections: list[ConnectionConfig] = [] self.current_connection: Any | None = None self.current_config: ConnectionConfig | None = None @@ -245,41 +327,46 @@ def __init__(self, mock_profile: MockProfile | None = None): self._last_notification: str = "" self._last_notification_severity: str = "information" self._last_notification_time: str = "" - self._notification_timer = None + self._notification_timer: Timer | None = None self._notification_history: list = [] self._connection_failed: bool = False - self._leader_timer = None + self._leader_timer: Timer | None = None self._leader_pending: bool = False - self._query_worker = None + self._dialog_open: bool = False + self._last_active_pane: str | None = None + self._query_worker: Worker[Any] | None = None self._query_executing: bool = False self._cancellable_query: Any | None = None self._spinner_index: int = 0 - self._spinner_timer = None + self._spinner_timer: Timer | None = None # Schema indexing state self._schema_indexing: bool = False - self._schema_worker = None + self._schema_worker: Worker[Any] | None = None self._schema_spinner_index: int = 0 - self._schema_spinner_timer = None + self._schema_spinner_timer: Timer | None = None self._table_metadata: dict = {} self._columns_loading: set[str] = set() self._state_machine = UIStateMachine() + self._session_factory: Any | None = None + self._last_query_table: dict | None = None if mock_profile: self._session_factory = self._create_mock_session_factory(mock_profile) + self._startup_stamp("init_end") - def _create_mock_session_factory(self, profile: MockProfile): + def _create_mock_session_factory(self, profile: MockProfile) -> Any: """Create a session factory that uses mock adapters.""" from .services import ConnectionSession - def mock_adapter_factory(db_type: str): + def mock_adapter_factory(db_type: str) -> Any: """Return mock adapter for the given db type.""" return profile.get_adapter(db_type) - def mock_tunnel_factory(config): + def mock_tunnel_factory(config: Any) -> Any: """Return no tunnel for mock connections.""" return None, config.server, int(config.port or "0") - def factory(config): + def factory(config: Any) -> Any: return ConnectionSession.create( config, adapter_factory=mock_adapter_factory, @@ -301,19 +388,19 @@ def results_table(self) -> DataTable: return self.query_one("#results-table", DataTable) @property - def sidebar(self): + def sidebar(self) -> Any: return self.query_one("#sidebar") @property - def main_panel(self): + def main_panel(self) -> Any: return self.query_one("#main-panel") @property - def query_area(self): + def query_area(self) -> Any: return self.query_one("#query-area") @property - def results_area(self): + def results_area(self) -> Any: return self.query_one("#results-area") @property @@ -321,22 +408,44 @@ def status_bar(self) -> Static: return self.query_one("#status-bar", Static) @property - def autocomplete_dropdown(self): + def autocomplete_dropdown(self) -> Any: from .widgets import AutocompleteDropdown + return self.query_one("#autocomplete-dropdown", AutocompleteDropdown) - def push_screen(self, screen, callback=None, wait_for_dismiss: bool = False): + @property + def tree_filter_input(self) -> TreeFilterInput: + return self.query_one("#tree-filter", TreeFilterInput) + + def push_screen( + self, + screen: Any, + callback: Callable[[Any], None] | Callable[[Any], Awaitable[None]] | None = None, + wait_for_dismiss: bool = False, + ) -> Any: """Override push_screen to update footer when screen changes.""" - result = super().push_screen(screen, callback, wait_for_dismiss=wait_for_dismiss) + if wait_for_dismiss: + future = super().push_screen(screen, callback, wait_for_dismiss=True) + self._update_footer_bindings() + self._update_dialog_state() + return future + mount = super().push_screen(screen, callback, wait_for_dismiss=False) self._update_footer_bindings() - return result + self._update_dialog_state() + return mount - def pop_screen(self): + def pop_screen(self) -> Any: """Override pop_screen to update footer when screen changes.""" result = super().pop_screen() self._update_footer_bindings() + self._update_dialog_state() return result + def _update_dialog_state(self) -> None: + """Track whether a modal dialog is open and update pane title styling.""" + self._dialog_open = any(isinstance(screen, ModalScreen) for screen in self.screen_stack) + self._update_section_labels() + def check_action(self, action: str, parameters: tuple) -> bool | None: """Check if an action is allowed in the current state. @@ -345,165 +454,263 @@ def check_action(self, action: str, parameters: tuple) -> bool | None: """ return self._state_machine.check_action(self, action) + def _compute_restart_argv(self) -> list[str]: + """Compute a best-effort argv to restart the app.""" + # Linux provides the most reliable answer via /proc. + try: + cmdline_path = "/proc/self/cmdline" + if os.path.exists(cmdline_path): + raw = open(cmdline_path, "rb").read() + parts = [p.decode(errors="surrogateescape") for p in raw.split(b"\0") if p] + if parts: + return parts + except Exception: + pass + + # Fallback: sys.argv (good enough for most invocations). + argv = [sys.argv[0], *sys.argv[1:]] if sys.argv else [] + if argv: + return argv + return [sys.executable] + + def restart(self) -> None: + """Restart the current process in-place.""" + argv = getattr(self, "_restart_argv", None) or self._compute_restart_argv() + exe = argv[0] + # execv doesn't search PATH; use execvp for bare commands (e.g. "sqlit"). + if os.sep in exe: + os.execv(exe, argv) + else: + os.execvp(exe, argv) + def compose(self) -> ComposeResult: + self._startup_stamp("compose_start") with Vertical(id="main-container"): with Horizontal(id="content"): with Vertical(id="sidebar"): - yield Static( - r"\[E] Explorer", classes="section-label", id="label-explorer" - ) - tree = Tree("Servers", id="object-tree") + yield TreeFilterInput(id="tree-filter") + tree: Tree[Any] = Tree("Servers", id="object-tree") tree.show_root = False tree.guide_depth = 2 yield tree with Vertical(id="main-panel"): with Container(id="query-area"): - yield Static( - r"\[q] Query", classes="section-label", id="label-query" - ) yield TextArea( "", language="sql", id="query-input", read_only=True, ) - yield AutocompleteDropdown(id="autocomplete-dropdown") + yield Lazy(AutocompleteDropdown(id="autocomplete-dropdown")) with Container(id="results-area"): - yield Static( - r"\[r] Results", classes="section-label", id="label-results" - ) - yield DataTable(id="results-table", zebra_stripes=True) + yield Lazy(DataTable(id="results-table", zebra_stripes=True, show_header=False)) yield Static("Not connected", id="status-bar") yield ContextFooter() + self._startup_stamp("compose_end") def on_mount(self) -> None: """Initialize the app.""" - if not PYODBC_AVAILABLE and not self._mock_profile: - self.notify( - "pyodbc not installed. Run: pip install pyodbc", - severity="warning", - timeout=10, - ) + self._startup_stamp("on_mount_start") + self._restart_argv = self._compute_restart_argv() + + for theme in self._SQLIT_THEMES: + self.register_theme(theme) settings = load_settings() + self._startup_stamp("settings_loaded") if "theme" in settings: try: self.theme = settings["theme"] except Exception: - self.theme = "tokyo-night" + self.theme = "sqlit" else: - self.theme = "tokyo-night" + self.theme = "sqlit" - settings = load_settings() self._expanded_paths = set(settings.get("expanded_nodes", [])) + self._startup_stamp("settings_applied") - if self._mock_profile: + self._apply_mock_settings(settings) + + if self._startup_connection: + # Only show the explicit startup connection, not saved ones + self._setup_startup_connection(self._startup_connection) + elif self._mock_profile: self.connections = self._mock_profile.connections.copy() else: - self.connections = load_connections() + self.connections = load_connections(load_credentials=False) + self._startup_stamp("connections_loaded") self.refresh_tree() - self._update_footer_bindings() + self._startup_stamp("tree_refreshed") self.object_tree.focus() + self._startup_stamp("tree_focused") # Move cursor to first node if available if self.object_tree.root.children: self.object_tree.cursor_line = 0 self._update_section_labels() + self._maybe_restore_connection_screen() + self._startup_stamp("restore_checked") + if self._debug_mode: + self.call_after_refresh(self._record_launch_ms) + self.call_after_refresh(self._update_status_bar) + self._update_footer_bindings() + self._startup_stamp("footer_updated") + if self._startup_connect_config: + self.call_after_refresh(lambda: self.connect_to_server(self._startup_connect_config)) # type: ignore[arg-type] + self._log_startup_timing() + + def _apply_mock_settings(self, settings: dict) -> None: + apply_mock_environment(settings) + if self._mock_profile: + return + mock_profile = build_mock_profile_from_settings(settings) + if mock_profile: + self._mock_profile = mock_profile + self._session_factory = self._create_mock_session_factory(mock_profile) - if not self._mock_profile: - self._check_drivers() + def _setup_startup_connection(self, config: ConnectionConfig) -> None: + """Set up a startup connection as the only visible connection.""" + if not config.name: + config.name = "Temp Connection" + self.connections = [config] + self._startup_connect_config = config - def _check_drivers(self) -> None: - """Check if ODBC drivers are installed and show setup if needed.""" - has_mssql = any(c.db_type == "mssql" for c in self.connections) - if not has_mssql: + def _startup_stamp(self, name: str) -> None: + if not self._startup_profile: return + self._startup_events.append((name, time.perf_counter())) - if not PYODBC_AVAILABLE: + def _log_startup_timing(self) -> None: + if not self._startup_profile: return + now = time.perf_counter() + if self._startup_mark is not None: + since_start = (now - self._startup_mark) * 1000 + else: + since_start = None + init_to_mount = (now - self._startup_init_time) * 1000 + + parts = [] + if since_start is not None: + parts.append(f"start_to_mount_ms={since_start:.2f}") + parts.append(f"init_to_mount_ms={init_to_mount:.2f}") + print(f"[sqlit] startup {' '.join(parts)}", file=sys.stderr) + self._log_startup_steps() + + def after_refresh() -> None: + now_refresh = time.perf_counter() + if self._startup_mark is not None: + start_to_refresh = (now_refresh - self._startup_mark) * 1000 + else: + start_to_refresh = None + init_to_refresh = (now_refresh - self._startup_init_time) * 1000 - from .drivers import get_installed_drivers + self._log_startup_step("first_refresh", now_refresh) + refresh_parts = [] + if start_to_refresh is not None: + refresh_parts.append(f"start_to_first_refresh_ms={start_to_refresh:.2f}") + refresh_parts.append(f"init_to_first_refresh_ms={init_to_refresh:.2f}") + print(f"[sqlit] startup {' '.join(refresh_parts)}", file=sys.stderr) - installed = get_installed_drivers() - if not installed: - self.call_later(self._show_driver_setup) + self.call_after_refresh(after_refresh) - def _show_driver_setup(self) -> None: - """Show the driver setup screen.""" - from .drivers import get_installed_drivers - from .ui.screens import DriverSetupScreen + def _log_startup_steps(self) -> None: + for name, ts in self._startup_events: + self._log_startup_step(name, ts) - installed = get_installed_drivers() - self.push_screen(DriverSetupScreen(installed), self._handle_driver_result) + def _log_startup_step(self, name: str, timestamp: float) -> None: + if not self._startup_profile: + return + parts = [f"step={name}"] + if self._startup_mark is not None: + parts.append(f"start_ms={(timestamp - self._startup_mark) * 1000:.2f}") + parts.append(f"init_ms={(timestamp - self._startup_init_time) * 1000:.2f}") + print(f"[sqlit] startup {' '.join(parts)}", file=sys.stderr) + + def _get_restart_cache_path(self) -> Path: + return Path(tempfile.gettempdir()) / "sqlit-driver-install-restore.json" + + @staticmethod + def _parse_startup_mark(value: str | None) -> float | None: + if not value: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + def _record_launch_ms(self) -> None: + base = self._startup_mark if self._startup_mark is not None else self._startup_init_time + self._launch_ms = (time.perf_counter() - base) * 1000 + self._update_status_bar() + + def _maybe_restore_connection_screen(self) -> None: + """Restore an in-progress connection form after a driver-install restart.""" + cache_path = self._get_restart_cache_path() + if not cache_path.exists(): + return - def _handle_driver_result(self, result) -> None: - """Handle result from driver setup screen.""" - if not result: + try: + payload = json.loads(cache_path.read_text(encoding="utf-8")) + except Exception: + try: + cache_path.unlink(missing_ok=True) + except Exception: + pass return - action = result[0] - if action == "select": - driver = result[1] - self.notify(f"Selected driver: {driver}") - elif action == "install": - commands = result[1] - self._run_driver_install(commands) + try: + cache_path.unlink(missing_ok=True) + except Exception: + pass - def _run_driver_install(self, commands: list[str]) -> None: - """Run driver installation commands.""" - import subprocess + if not isinstance(payload, dict) or payload.get("version") != 1: + return - self.notify("Running installation commands...", timeout=3) + values = payload.get("values") + if not isinstance(values, dict): + return - full_command = " && ".join(commands) + editing = bool(payload.get("editing")) + original_name = payload.get("original_name") + post_install_message = payload.get("post_install_message") + active_tab = payload.get("active_tab") - try: - import shutil - - if shutil.which("gnome-terminal"): - subprocess.Popen([ - "gnome-terminal", "--", "bash", "-c", - f'{full_command}; echo ""; echo "Press Enter to close..."; read' - ]) - elif shutil.which("konsole"): - subprocess.Popen([ - "konsole", "-e", "bash", "-c", - f'{full_command}; echo ""; echo "Press Enter to close..."; read' - ]) - elif shutil.which("xterm"): - subprocess.Popen([ - "xterm", "-e", "bash", "-c", - f'{full_command}; echo ""; echo "Press Enter to close..."; read' - ]) - elif shutil.which("open"): # macOS - import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write("#!/bin/bash\n") - f.write(full_command + "\n") - f.write('echo ""\necho "Press Enter to close..."\nread\n') - script_path = f.name - import os - os.chmod(script_path, 0o755) - subprocess.Popen(["open", "-a", "Terminal", script_path]) - else: - self.notify( - "No terminal found. Run these commands manually:\n" + full_command, - severity="warning", - timeout=15, - ) - return - - self.notify( - "Installation started in new terminal. Restart sqlit when done.", - timeout=10, + config = None + if editing and isinstance(original_name, str) and original_name: + config = next((c for c in self.connections if getattr(c, "name", None) == original_name), None) + + if config is None: + from .config import ConnectionConfig + + config = ConnectionConfig( + name=str(values.get("name", "")), + db_type=str(values.get("db_type", "mssql") or "mssql"), ) - except Exception as e: - self.notify(f"Failed to start installation: {e}", severity="error") + editing = False + + prefill_values = { + "values": values, + "active_tab": active_tab, + } + + from .ui.screens import ConnectionScreen + + self._set_connection_screen_footer() + self.push_screen( + ConnectionScreen( + config, + editing=editing, + prefill_values=prefill_values, + post_install_message=post_install_message if isinstance(post_install_message, str) else None, + ), + self._wrap_connection_result, + ) def watch_theme(self, old_theme: str, new_theme: str) -> None: """Save theme whenever it changes.""" diff --git a/sqlit/cli.py b/sqlit/cli.py index 80e319ac..c2b0dfd5 100644 --- a/sqlit/cli.py +++ b/sqlit/cli.py @@ -4,105 +4,147 @@ from __future__ import annotations import argparse +import os import sys +import time -from .config import AuthType, DatabaseType +from .cli_helpers import add_schema_arguments, build_connection_config_from_args +from .config import AuthType, ConnectionConfig, DatabaseType +from .db.providers import get_connection_schema, get_supported_db_types def main() -> int: """Entry point for the CLI.""" parser = argparse.ArgumentParser( prog="sqlit", - description="A terminal UI for SQL Server, PostgreSQL, MySQL, and SQLite databases", + description="A terminal UI for SQL databases", ) - # Global options for TUI mode parser.add_argument( "--mock", metavar="PROFILE", help="Run with mock data (profiles: sqlite-demo, empty, multi-db)", ) + parser.add_argument( + "--db-type", + choices=[t.value for t in DatabaseType], + help="Temporary connection database type (auto-connects in UI)", + ) + parser.add_argument("--name", help="Temporary connection name (default: Temp )") + parser.add_argument("--server", help="Temporary connection server/host") + parser.add_argument("--host", help="Alias for --server") + parser.add_argument("--port", help="Temporary connection port") + parser.add_argument("--database", help="Temporary connection database name") + parser.add_argument("--username", help="Temporary connection username") + parser.add_argument("--password", help="Temporary connection password") + parser.add_argument("--file-path", help="Temporary connection file path (SQLite/DuckDB)") + parser.add_argument( + "--auth-type", + choices=[t.value for t in AuthType], + help="Temporary connection auth type (SQL Server only)", + ) + parser.add_argument("--supabase-region", help="Supabase region (temporary connection)") + parser.add_argument("--supabase-project-id", help="Supabase project id (temporary connection)") + parser.add_argument( + "--settings", + metavar="PATH", + help="Path to settings JSON file (overrides ~/.sqlit/settings.json)", + ) + parser.add_argument( + "--mock-missing-drivers", + metavar="DB_TYPES", + help="Force missing Python drivers for the given db types (comma-separated), e.g. postgresql,mysql", + ) + parser.add_argument( + "--mock-install", + choices=["real", "success", "fail"], + default="real", + help="Mock the driver install result in the UI (default: real).", + ) + parser.add_argument( + "--mock-pipx", + choices=["auto", "pipx", "pip", "unknown"], + default="auto", + help="Mock installation method for install hints: pipx, pip, or unknown (can't auto-install).", + ) + parser.add_argument( + "--mock-query-delay", + type=float, + default=0.0, + metavar="SECONDS", + help="Add artificial delay to mock query execution (e.g. 3.0 for 3 seconds).", + ) + parser.add_argument( + "--profile-startup", + action="store_true", + help="Log startup timing diagnostics to stderr.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Show startup timing in the status bar.", + ) subparsers = parser.add_subparsers(dest="command", help="Available commands") - # Connection commands - conn_parser = subparsers.add_parser("connection", help="Manage connections") - conn_subparsers = conn_parser.add_subparsers( - dest="conn_command", help="Connection commands" + conn_parser = subparsers.add_parser( + "connections", + help="Manage saved connections", + aliases=["connection"], ) + conn_subparsers = conn_parser.add_subparsers(dest="conn_command", help="Connection commands") - # connection list conn_subparsers.add_parser("list", help="List all saved connections") - # connection create - create_parser = conn_subparsers.add_parser("create", help="Create a new connection") - create_parser.add_argument("--name", "-n", required=True, help="Connection name") - create_parser.add_argument( - "--db-type", - "-t", - default="mssql", - choices=[t.value for t in DatabaseType], - help="Database type (default: mssql)", - ) - # Server-based database options (SQL Server, PostgreSQL, MySQL) - create_parser.add_argument("--server", "-s", help="Server address") - create_parser.add_argument("--port", "-P", help="Port (default: 1433/5432/3306)") - create_parser.add_argument( - "--database", "-d", default="", help="Database name (empty = browse all)" - ) - create_parser.add_argument("--username", "-u", help="Username") - create_parser.add_argument("--password", "-p", help="Password") - # SQL Server specific options - create_parser.add_argument( - "--auth-type", - "-a", - default="sql", - choices=[t.value for t in AuthType], - help="Authentication type (SQL Server only, default: sql)", - ) - # SQLite options - create_parser.add_argument("--file-path", help="Database file path (SQLite only)") - # SSH tunnel options - create_parser.add_argument("--ssh-enabled", action="store_true", help="Enable SSH tunnel") - create_parser.add_argument("--ssh-host", help="SSH server hostname") - create_parser.add_argument("--ssh-port", default="22", help="SSH server port (default: 22)") - create_parser.add_argument("--ssh-username", help="SSH username") - create_parser.add_argument("--ssh-auth-type", default="key", choices=["key", "password"], help="SSH auth type") - create_parser.add_argument("--ssh-key-path", help="SSH private key path") - create_parser.add_argument("--ssh-password", help="SSH password") - - # connection edit + add_parser = conn_subparsers.add_parser( + "add", + help="Add a new connection", + aliases=["create"], + ) + add_provider_parsers = add_parser.add_subparsers(dest="provider", metavar="PROVIDER") + for db_type in get_supported_db_types(): + schema = get_connection_schema(db_type) + provider_parser = add_provider_parsers.add_parser( + db_type, + help=f"{schema.display_name} options", + description=f"{schema.display_name} connection options", + ) + add_schema_arguments(provider_parser, schema, include_name=True, name_required=True) + edit_parser = conn_subparsers.add_parser("edit", help="Edit an existing connection") edit_parser.add_argument("connection_name", help="Name of connection to edit") edit_parser.add_argument("--name", "-n", help="New connection name") - # Server-based database options (SQL Server, PostgreSQL, MySQL) edit_parser.add_argument("--server", "-s", help="Server address") + edit_parser.add_argument("--host", help="Alias for --server (e.g. Cloudflare D1 Account ID)") edit_parser.add_argument("--port", "-P", help="Port") edit_parser.add_argument("--database", "-d", help="Database name") edit_parser.add_argument("--username", "-u", help="Username") edit_parser.add_argument("--password", "-p", help="Password") - # SQL Server specific options edit_parser.add_argument( "--auth-type", "-a", choices=[t.value for t in AuthType], help="Authentication type (SQL Server only)", ) - # SQLite options edit_parser.add_argument("--file-path", help="Database file path (SQLite only)") - # connection delete delete_parser = conn_subparsers.add_parser("delete", help="Delete a connection") delete_parser.add_argument("connection_name", help="Name of connection to delete") - # query command + connect_parser = subparsers.add_parser("connect", help="Temporary connection (not saved)") + connect_provider_parsers = connect_parser.add_subparsers(dest="provider", metavar="PROVIDER") + for db_type in get_supported_db_types(): + schema = get_connection_schema(db_type) + provider_parser = connect_provider_parsers.add_parser( + db_type, + help=f"{schema.display_name} options", + description=f"{schema.display_name} connection options", + ) + add_schema_arguments(provider_parser, schema, include_name=True, name_required=False) + query_parser = subparsers.add_parser("query", help="Execute a SQL query") - query_parser.add_argument( - "--connection", "-c", required=True, help="Connection name to use" - ) - query_parser.add_argument( - "--database", "-d", help="Database to query (overrides connection default)" - ) + query_parser.add_argument("--connection", "-c", required=True, help="Connection name to use") + query_parser.add_argument("--database", "-d", help="Database to query (overrides connection default)") query_parser.add_argument("--query", "-q", help="SQL query to execute") query_parser.add_argument("--file", "-f", help="SQL file to execute") query_parser.add_argument( @@ -120,9 +162,36 @@ def main() -> int: help="Maximum rows to fetch (default: 1000, use 0 for unlimited)", ) + startup_mark = time.perf_counter() args = parser.parse_args() - - # No command = launch TUI + if args.settings: + os.environ["SQLIT_SETTINGS_PATH"] = str(args.settings) + if args.mock_missing_drivers: + os.environ["SQLIT_MOCK_MISSING_DRIVERS"] = str(args.mock_missing_drivers) + if args.mock_install and args.mock_install != "real": + os.environ["SQLIT_MOCK_INSTALL_RESULT"] = str(args.mock_install) + else: + os.environ.pop("SQLIT_MOCK_INSTALL_RESULT", None) + if args.mock_pipx and args.mock_pipx != "auto": + os.environ["SQLIT_MOCK_PIPX"] = str(args.mock_pipx) + else: + os.environ.pop("SQLIT_MOCK_PIPX", None) + if args.mock_query_delay and args.mock_query_delay > 0: + os.environ["SQLIT_MOCK_QUERY_DELAY"] = str(args.mock_query_delay) + else: + os.environ.pop("SQLIT_MOCK_QUERY_DELAY", None) + if args.profile_startup: + os.environ["SQLIT_PROFILE_STARTUP"] = "1" + else: + os.environ.pop("SQLIT_PROFILE_STARTUP", None) + if args.debug: + os.environ["SQLIT_DEBUG"] = "1" + else: + os.environ.pop("SQLIT_DEBUG", None) + if args.profile_startup or args.debug: + os.environ["SQLIT_STARTUP_MARK"] = str(startup_mark) + else: + os.environ.pop("SQLIT_STARTUP_MARK", None) if args.command is None: from .app import SSMSTUI @@ -136,11 +205,17 @@ def main() -> int: print(f"Available profiles: {', '.join(list_mock_profiles())}") return 1 - app = SSMSTUI(mock_profile=mock_profile) + temp_config = None + try: + temp_config = _build_temp_connection(args) + except ValueError as exc: + print(f"Error: {exc}") + return 1 + + app = SSMSTUI(mock_profile=mock_profile, startup_connection=temp_config) app.run() return 0 - # Import commands lazily to speed up --help from .commands import ( cmd_connection_create, cmd_connection_delete, @@ -149,11 +224,45 @@ def main() -> int: cmd_query, ) - # Handle connection commands - if args.command == "connection": + if args.command == "connect": + from .app import SSMSTUI + + db_type = getattr(args, "provider", None) + if not db_type: + connect_parser.print_help() + return 1 + + mock_profile = None + if args.mock: + from .mocks import get_mock_profile, list_mock_profiles + + mock_profile = get_mock_profile(args.mock) + if mock_profile is None: + print(f"Unknown mock profile: {args.mock}") + print(f"Available profiles: {', '.join(list_mock_profiles())}") + return 1 + + schema = get_connection_schema(db_type) + try: + temp_config = build_connection_config_from_args( + schema, + args, + name=getattr(args, "name", None), + default_name=f"Temp {schema.display_name}", + strict=True, + ) + except ValueError as exc: + print(f"Error: {exc}") + return 1 + + app = SSMSTUI(mock_profile=mock_profile, startup_connection=temp_config) + app.run() + return 0 + + if args.command in {"connections", "connection"}: if args.conn_command == "list": return cmd_connection_list(args) - elif args.conn_command == "create": + elif args.conn_command in {"add", "create"}: return cmd_connection_create(args) elif args.conn_command == "edit": return cmd_connection_edit(args) @@ -163,7 +272,6 @@ def main() -> int: conn_parser.print_help() return 1 - # Handle query command if args.command == "query": return cmd_query(args) @@ -171,5 +279,32 @@ def main() -> int: return 1 +def _build_temp_connection(args: argparse.Namespace) -> ConnectionConfig | None: + """Build a temporary connection config from CLI args, if provided.""" + db_type = getattr(args, "db_type", None) + file_path = getattr(args, "file_path", None) + if not db_type and file_path: + db_type = "sqlite" + setattr(args, "db_type", db_type) + if not db_type: + if any(getattr(args, name, None) for name in ("file_path", "server", "host", "database")): + raise ValueError("--db-type is required for temporary connections") + return None + + try: + DatabaseType(db_type) + except ValueError: + raise ValueError(f"Invalid database type '{db_type}'") + + schema = get_connection_schema(db_type) + return build_connection_config_from_args( + schema, + args, + name=getattr(args, "name", None), + default_name=f"Temp {schema.display_name}", + strict=True, + ) + + if __name__ == "__main__": sys.exit(main()) diff --git a/sqlit/cli_helpers.py b/sqlit/cli_helpers.py new file mode 100644 index 00000000..1ed70ab9 --- /dev/null +++ b/sqlit/cli_helpers.py @@ -0,0 +1,171 @@ +"""CLI helpers for building provider-specific parsers and configs.""" + +from __future__ import annotations + +import argparse +from typing import Any, Iterable + +from .config import ConnectionConfig +from .db.schema import ConnectionSchema, FieldType + +CONNECTION_ARG_NAMES = { + "name", + "server", + "host", + "port", + "database", + "username", + "password", + "file_path", + "auth_type", + "supabase_region", + "supabase_project_id", + "ssh_enabled", + "ssh_host", + "ssh_port", + "ssh_username", + "ssh_auth_type", + "ssh_key_path", + "ssh_password", + "driver", +} + + +def add_schema_arguments( + parser: argparse.ArgumentParser, + schema: ConnectionSchema, + *, + include_name: bool, + name_required: bool, +) -> None: + """Add schema-driven arguments to a parser.""" + if include_name: + parser.add_argument( + "--name", + "-n", + required=name_required, + help="Connection name", + ) + + for field in schema.fields: + arg = f"--{field.name.replace('_', '-')}" + help_text = field.description or field.placeholder or field.label + kwargs: dict[str, Any] = { + "help": help_text, + "dest": field.name, + } + + if field.name == "server": + parser.add_argument(arg, "--host", **kwargs) + continue + + if field.name == "ssh_enabled": + parser.add_argument(arg, action="store_true", help=help_text, dest=field.name) + continue + + if field.field_type in (FieldType.SELECT, FieldType.DROPDOWN) and field.options: + kwargs["choices"] = [opt.value for opt in field.options] + + if field.default: + kwargs["default"] = field.default + + if field.required and field.visible_when is None: + kwargs["required"] = True + + parser.add_argument(arg, **kwargs) + + +def build_connection_config_from_args( + schema: ConnectionSchema, + args: Any, + *, + name: str | None, + default_name: str | None = None, + strict: bool = True, +) -> ConnectionConfig: + """Build a ConnectionConfig from CLI args based on a provider schema.""" + raw_values = _extract_raw_values(schema, args) + + missing = _find_missing_required_fields(schema, raw_values) + if missing: + missing_args = ", ".join(f"--{field.replace('_', '-')}" for field in missing) + raise ValueError(f"Missing required fields: {missing_args}") + + if strict: + extras = _find_unexpected_fields(schema, args) + if extras: + extras_args = ", ".join(f"--{field.replace('_', '-')}" for field in extras) + raise ValueError(f"Unexpected fields for {schema.display_name}: {extras_args}") + + config_name = name or default_name or f"Temp {schema.display_name}" + config_values = { + "name": config_name, + "db_type": schema.db_type, + } + + # Fields where None means "not set" vs "" means "explicitly empty" + nullable_fields = {"password", "ssh_password"} + + for field in schema.fields: + value = raw_values.get(field.name, "") + if value is None and field.name not in nullable_fields: + value = "" + if field.name == "ssh_enabled": + if isinstance(value, bool): + config_values[field.name] = value + else: + config_values[field.name] = str(value).lower() == "enabled" + else: + config_values[field.name] = value + + if "port" in config_values and not config_values["port"]: + config_values["port"] = schema.default_port or "" + + if schema.has_advanced_auth: + auth_type = config_values.get("auth_type") or "sql" + config_values["auth_type"] = auth_type + config_values["trusted_connection"] = auth_type == "windows" + + return ConnectionConfig(**config_values) + + +def _extract_raw_values(schema: ConnectionSchema, args: Any) -> dict[str, Any]: + raw_values: dict[str, Any] = {} + for field in schema.fields: + value = getattr(args, field.name, None) + if field.name == "ssh_enabled" and isinstance(value, bool): + value = "enabled" if value else "disabled" + if (value is None or value == "") and field.default: + value = field.default + raw_values[field.name] = value + return raw_values + + +def _find_missing_required_fields(schema: ConnectionSchema, raw_values: dict[str, Any]) -> list[str]: + missing: list[str] = [] + for field in schema.fields: + if not field.required: + continue + if field.visible_when and not field.visible_when(raw_values): + continue + value = raw_values.get(field.name) + if value is None or value == "": + missing.append(field.name) + return missing + + +def _find_unexpected_fields(schema: ConnectionSchema, args: Any) -> list[str]: + allowed = {field.name for field in schema.fields} + extras: list[str] = [] + for field in CONNECTION_ARG_NAMES: + if field in allowed or field == "name": + continue + value = getattr(args, field, None) + if value is None or value == "" or value is False: + continue + extras.append(field) + return extras + + +def iter_schema_arg_names(schema: ConnectionSchema) -> Iterable[str]: + return (field.name for field in schema.fields) diff --git a/sqlit/commands.py b/sqlit/commands.py index e4acd71c..4417a3f8 100644 --- a/sqlit/commands.py +++ b/sqlit/commands.py @@ -3,27 +3,92 @@ from __future__ import annotations import csv +import getpass import json import sys -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from dataclasses import replace +from typing import TYPE_CHECKING, Any +from .cli_helpers import build_connection_config_from_args from .config import ( AUTH_TYPE_LABELS, AuthType, ConnectionConfig, - DATABASE_TYPE_LABELS, DatabaseType, + get_database_type_labels, load_connections, save_connections, ) -from .db.schema import get_default_port, has_advanced_auth, is_file_based +from .db.providers import get_connection_schema, has_advanced_auth, is_file_based from .services import ConnectionSession, QueryResult, QueryService +from .services.credentials import ( + ALLOW_PLAINTEXT_CREDENTIALS_SETTING, + is_keyring_usable, + reset_credentials_service, +) if TYPE_CHECKING: - from .services import HistoryStoreProtocol + pass + + +def _maybe_prompt_plaintext_credentials() -> bool: + """Ensure plaintext credential storage preference is set when keyring isn't usable. + + Returns True if plaintext storage is allowed; False otherwise. + """ + from .config import load_settings, save_settings + + if is_keyring_usable(): + return False + + settings = load_settings() + existing = settings.get(ALLOW_PLAINTEXT_CREDENTIALS_SETTING) + if isinstance(existing, bool): + if existing: + reset_credentials_service() + return bool(existing) + + if not sys.stdin.isatty(): + return False + + answer = input("Keyring isn't available. Save passwords as plaintext in ~/.sqlit/? [y/N]: ").strip().lower() + allow = answer in {"y", "yes"} + settings[ALLOW_PLAINTEXT_CREDENTIALS_SETTING] = allow + save_settings(settings) + if allow: + reset_credentials_service() + return allow + + +def _clear_passwords_if_not_persisted(config: ConnectionConfig) -> None: + config.password = "" + config.ssh_password = "" -def cmd_connection_list(args) -> int: +def _prompt_for_password(config: ConnectionConfig) -> ConnectionConfig: + """Prompt for passwords if they are not set (None). + + Uses getpass for secure input that doesn't appear in bash history. + Returns a new config with passwords filled in (original is not modified). + + Note: Empty string "" means explicitly set to empty (no prompt). + None means not set (prompt for input). + """ + new_config = config + + if config.ssh_enabled and config.ssh_auth_type == "password" and config.ssh_password is None: + ssh_password = getpass.getpass(f"SSH password for '{config.name}': ") + new_config = replace(new_config, ssh_password=ssh_password) + + if not is_file_based(config.db_type) and config.password is None: + db_password = getpass.getpass(f"Password for '{config.name}': ") + new_config = replace(new_config, password=db_password) + + return new_config + + +def cmd_connection_list(args: Any) -> int: """List all saved connections.""" connections = load_connections() if not connections: @@ -32,8 +97,9 @@ def cmd_connection_list(args) -> int: print(f"{'Name':<20} {'Type':<10} {'Connection Info':<40} {'Auth Type':<25}") print("-" * 95) + labels = get_database_type_labels() for conn in connections: - db_type_label = DATABASE_TYPE_LABELS.get(conn.get_db_type(), conn.db_type) + db_type_label = labels.get(conn.get_db_type(), conn.db_type) if is_file_based(conn.db_type): conn_info = conn.file_path[:38] + ".." if len(conn.file_path) > 40 else conn.file_path auth_label = "N/A" @@ -46,22 +112,23 @@ def cmd_connection_list(args) -> int: conn_info = f"{conn.server}@{conn.database}" if conn.database else conn.server conn_info = conn_info[:38] + ".." if len(conn_info) > 40 else conn_info auth_label = f"User: {conn.username}" if conn.username else "N/A" - print( - f"{conn.name:<20} {db_type_label:<10} {conn_info:<40} {auth_label:<25}" - ) + print(f"{conn.name:<20} {db_type_label:<10} {conn_info:<40} {auth_label:<25}") return 0 -def cmd_connection_create(args) -> int: +def cmd_connection_create(args: Any) -> int: """Create a new connection.""" connections = load_connections() + if not getattr(args, "provider", None): + print("Error: provider is required (e.g. 'sqlit connection create supabase').") + return 1 + if any(c.name == args.name for c in connections): print(f"Error: Connection '{args.name}' already exists. Use 'edit' to modify it.") return 1 - # Determine database type - db_type = getattr(args, "db_type", "mssql") or "mssql" + db_type = getattr(args, "provider", None) try: DatabaseType(db_type) except ValueError: @@ -69,80 +136,29 @@ def cmd_connection_create(args) -> int: print(f"Error: Invalid database type '{db_type}'. Valid types: {valid_types}") return 1 - if is_file_based(db_type): - file_path = getattr(args, "file_path", None) - if not file_path: - print(f"Error: --file-path is required for {db_type.upper()} connections.") - return 1 - - config = ConnectionConfig( - name=args.name, - db_type=db_type, - file_path=file_path, - ) - elif has_advanced_auth(db_type): - # SQL Server connection (has Windows/Azure AD auth) - if not args.server: - print("Error: --server is required for SQL Server connections.") - return 1 - - auth_type_str = getattr(args, "auth_type", "sql") or "sql" - try: - auth_type = AuthType(auth_type_str) - except ValueError: - valid_types = ", ".join(t.value for t in AuthType) - print(f"Error: Invalid auth type '{auth_type_str}'. Valid types: {valid_types}") - return 1 - - config = ConnectionConfig( - name=args.name, - db_type=db_type, - server=args.server, - port=args.port or get_default_port(db_type), - database=args.database or "", - username=args.username or "", - password=args.password or "", - auth_type=auth_type.value, - trusted_connection=(auth_type == AuthType.WINDOWS), - ssh_enabled=getattr(args, "ssh_enabled", False) or False, - ssh_host=getattr(args, "ssh_host", "") or "", - ssh_port=getattr(args, "ssh_port", "22") or "22", - ssh_username=getattr(args, "ssh_username", "") or "", - ssh_auth_type=getattr(args, "ssh_auth_type", "key") or "key", - ssh_key_path=getattr(args, "ssh_key_path", "") or "", - ssh_password=getattr(args, "ssh_password", "") or "", - ) - else: - # Server-based databases with simple auth - if not args.server: - db_label = DATABASE_TYPE_LABELS.get(DatabaseType(db_type), db_type.upper()) - print(f"Error: --server is required for {db_label} connections.") - return 1 - - config = ConnectionConfig( + schema = get_connection_schema(db_type) + try: + config = build_connection_config_from_args( + schema, + args, name=args.name, - db_type=db_type, - server=args.server, - port=args.port or get_default_port(db_type), - database=args.database or "", - username=args.username or "", - password=args.password or "", - ssh_enabled=getattr(args, "ssh_enabled", False) or False, - ssh_host=getattr(args, "ssh_host", "") or "", - ssh_port=getattr(args, "ssh_port", "22") or "22", - ssh_username=getattr(args, "ssh_username", "") or "", - ssh_auth_type=getattr(args, "ssh_auth_type", "key") or "key", - ssh_key_path=getattr(args, "ssh_key_path", "") or "", - ssh_password=getattr(args, "ssh_password", "") or "", + default_name=None, + strict=True, ) + except ValueError as exc: + print(f"Error: {exc}") + return 1 connections.append(config) + if (config.password or config.ssh_password) and not is_keyring_usable(): + if not _maybe_prompt_plaintext_credentials(): + _clear_passwords_if_not_persisted(config) save_connections(connections) print(f"Connection '{args.name}' created successfully.") return 0 -def cmd_connection_edit(args) -> int: +def cmd_connection_edit(args: Any) -> int: """Edit an existing connection.""" connections = load_connections() @@ -164,9 +180,9 @@ def cmd_connection_edit(args) -> int: return 1 conn.name = args.name - # SQL Server fields - if args.server: - conn.server = args.server + server = getattr(args, "server", None) or getattr(args, "host", None) + if server: + conn.server = server if args.port: conn.port = args.port if args.database: @@ -185,17 +201,20 @@ def cmd_connection_edit(args) -> int: if args.password is not None: conn.password = args.password - # SQLite fields file_path = getattr(args, "file_path", None) if file_path is not None: conn.file_path = file_path + if (conn.password or conn.ssh_password) and not is_keyring_usable(): + if not _maybe_prompt_plaintext_credentials(): + _clear_passwords_if_not_persisted(conn) + save_connections(connections) print(f"Connection '{conn.name}' updated successfully.") return 0 -def cmd_connection_delete(args) -> int: +def cmd_connection_delete(args: Any) -> int: """Delete a connection.""" connections = load_connections() @@ -215,7 +234,7 @@ def cmd_connection_delete(args) -> int: return 0 -def _stream_csv_output(cursor, columns: list[str]) -> int: +def _stream_csv_output(cursor: Any, columns: list[str]) -> int: """Stream CSV output from cursor using fetchmany.""" writer = csv.writer(sys.stdout) writer.writerow(columns) @@ -231,7 +250,7 @@ def _stream_csv_output(cursor, columns: list[str]) -> int: return row_count -def _stream_json_output(cursor, columns: list[str]) -> int: +def _stream_json_output(cursor: Any, columns: list[str]) -> int: """Stream JSON output from cursor using fetchmany (JSON array format).""" print("[") first = True @@ -254,25 +273,23 @@ def _stream_json_output(cursor, columns: list[str]) -> int: def _output_table(columns: list[str], rows: list[tuple], truncated: bool) -> None: """Output query results in table format with optimized width calculation.""" - MAX_COL_WIDTH = 50 # Cap column width to avoid excessive line length + MAX_COL_WIDTH = 50 - # Calculate column widths (only scan first 100 rows for performance) + # Only scan first 100 rows for performance col_widths = [min(len(col), MAX_COL_WIDTH) for col in columns] for row in rows[:100]: for i, val in enumerate(row): val_str = str(val) if val is not None else "NULL" col_widths[i] = min(MAX_COL_WIDTH, max(col_widths[i], len(val_str))) - # Print header header_parts = [] for i, col in enumerate(columns): - col_display = col[:col_widths[i]] if len(col) > col_widths[i] else col + col_display = col[: col_widths[i]] if len(col) > col_widths[i] else col header_parts.append(col_display.ljust(col_widths[i])) header = " | ".join(header_parts) print(header) print("-" * len(header)) - # Print rows for row in rows: row_parts = [] for i, val in enumerate(row): @@ -282,7 +299,6 @@ def _output_table(columns: list[str], rows: list[tuple], truncated: bool) -> Non row_parts.append(val_str.ljust(col_widths[i])) print(" | ".join(row_parts)) - # Print count with truncation notice if truncated: print(f"\n({len(rows)} rows shown, results truncated)") else: @@ -290,7 +306,7 @@ def _output_table(columns: list[str], rows: list[tuple], truncated: bool) -> Non def cmd_query( - args, + args: Any, *, session_factory: Callable[[ConnectionConfig], ConnectionSession] | None = None, query_service: QueryService | None = None, @@ -319,46 +335,39 @@ def cmd_query( print(f"Error: Connection '{args.connection}' not found.") return 1 - # Override database if specified (only for SQL Server) if args.database and config.db_type == "mssql": - config.database = args.database + config = replace(config, database=args.database) + + config = _prompt_for_password(config) if args.query: query = args.query elif args.file: try: - with open(args.file, "r", encoding="utf-8") as f: + with open(args.file, encoding="utf-8") as f: query = f.read() except FileNotFoundError: print(f"Error: File '{args.file}' not found.") return 1 - except IOError as e: + except OSError as e: print(f"Error reading file: {e}") return 1 else: print("Error: Either --query or --file must be provided.") return 1 - # Determine row limit (0 means unlimited) max_rows = args.limit if args.limit > 0 else None - # Use injected or default factories create_session = session_factory or ConnectionSession.create service = query_service or QueryService() try: - # Use ConnectionSession for automatic resource cleanup with create_session(config) as session: - # For unlimited streaming output (CSV/JSON only), use direct cursor access from .services.query import is_select_query - # Check if connection supports cursors (some adapters like Turso don't) - has_cursor = hasattr(session.connection, "cursor") and callable( - getattr(session.connection, "cursor", None) - ) + has_cursor = hasattr(session.connection, "cursor") and callable(getattr(session.connection, "cursor", None)) if max_rows is None and args.format in ("csv", "json") and is_select_query(query) and has_cursor: - # Stream directly from cursor for unlimited CSV/JSON cursor = session.connection.cursor() cursor.execute(query) @@ -373,12 +382,10 @@ def cmd_query( else: row_count = _stream_json_output(cursor, columns) - # Save to history service._save_to_history(config.name, query) print(f"\n({row_count} row(s) returned)", file=sys.stderr) return 0 - # Standard execution with QueryService (with row limit) result = service.execute( connection=session.connection, adapter=session.adapter, @@ -403,8 +410,7 @@ def cmd_query( print(f"\n({len(rows)} row(s) returned)", file=sys.stderr) elif args.format == "json": json_result = [ - dict(zip(columns, [val if val is not None else None for val in row])) - for row in rows + dict(zip(columns, [val if val is not None else None for val in row])) for row in rows ] print(json.dumps(json_result, indent=2, default=str)) if result.truncated: @@ -414,7 +420,6 @@ def cmd_query( else: _output_table(columns, rows, result.truncated) else: - # NonQueryResult print(f"Query executed successfully. Rows affected: {result.rows_affected}") return 0 diff --git a/sqlit/config.py b/sqlit/config.py index a9c97102..f6638249 100644 --- a/sqlit/config.py +++ b/sqlit/config.py @@ -2,13 +2,20 @@ This module contains domain types (DatabaseType, AuthType, ConnectionConfig) and re-exports persistence functions from stores for backward compatibility. + +NOTE: This module uses lazy imports for db.providers to avoid loading all +adapter classes at import time. Only _get_supported_db_types is loaded +eagerly (needed to create DatabaseType enum). """ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from pathlib import Path +from typing import TYPE_CHECKING + +# Only import what's needed to create the DatabaseType enum +from .db.providers import get_supported_db_types as _get_supported_db_types # Re-export store paths for backward compatibility from .stores.base import CONFIG_DIR @@ -17,51 +24,81 @@ SETTINGS_PATH = CONFIG_DIR / "settings.json" HISTORY_PATH = CONFIG_DIR / "query_history.json" -# Re-export persistence functions from stores -from .stores.connections import load_connections, save_connections -from .stores.history import ( - QueryHistoryEntry, - delete_query_from_history, - load_query_history, - save_query_to_history, -) -from .stores.settings import load_settings, save_settings - - -# Import schema capabilities - use function to avoid circular imports -def _is_file_based(db_type: str) -> bool: - from .db.schema import is_file_based - - return is_file_based(db_type) - - -class DatabaseType(Enum): - """Supported database types.""" - - MSSQL = "mssql" - SQLITE = "sqlite" - POSTGRESQL = "postgresql" - MYSQL = "mysql" - ORACLE = "oracle" - MARIADB = "mariadb" - DUCKDB = "duckdb" - COCKROACHDB = "cockroachdb" - TURSO = "turso" - SUPABASE = "supabase" - - -DATABASE_TYPE_LABELS = { - DatabaseType.MSSQL: "SQL Server", - DatabaseType.SQLITE: "SQLite", - DatabaseType.POSTGRESQL: "PostgreSQL", - DatabaseType.MYSQL: "MySQL", - DatabaseType.ORACLE: "Oracle", - DatabaseType.MARIADB: "MariaDB", - DatabaseType.DUCKDB: "DuckDB", - DatabaseType.COCKROACHDB: "CockroachDB", - DatabaseType.TURSO: "Turso", - DatabaseType.SUPABASE: "Supabase", -} + +# Module-level convenience functions for backward compatibility. +# These are wrappers to avoid import cycles with the store modules. +def load_connections(load_credentials: bool = True) -> list[ConnectionConfig]: + """Load saved connections from config file.""" + from .stores.connections import load_connections as _load_connections + + return _load_connections(load_credentials=load_credentials) + + +def save_connections(connections: list[ConnectionConfig]) -> None: + """Save connections to config file.""" + from .stores.connections import save_connections as _save_connections + + _save_connections(connections) + + +def load_settings() -> dict: + """Load app settings from config file.""" + from .stores.settings import load_settings as _load_settings + + return _load_settings() + + +def save_settings(settings: dict) -> None: + """Save app settings to config file.""" + from .stores.settings import save_settings as _save_settings + + _save_settings(settings) + + +def load_query_history(connection_name: str) -> list: + """Load query history for a specific connection, sorted by most recent first.""" + from .stores.history import load_query_history as _load_query_history + + return _load_query_history(connection_name) + + +def save_query_to_history(connection_name: str, query: str) -> None: + """Save a query to history for a connection.""" + from .stores.history import save_query_to_history as _save_query_to_history + + _save_query_to_history(connection_name, query) + + +def delete_query_from_history(connection_name: str, timestamp: str) -> bool: + """Delete a specific query from history by connection name and timestamp.""" + from .stores.history import delete_query_from_history as _delete_query_from_history + + return _delete_query_from_history(connection_name, timestamp) + + +if TYPE_CHECKING: + + class DatabaseType(str, Enum): + MSSQL = "mssql" + POSTGRESQL = "postgresql" + COCKROACHDB = "cockroachdb" + MYSQL = "mysql" + MARIADB = "mariadb" + ORACLE = "oracle" + SQLITE = "sqlite" + DUCKDB = "duckdb" + SUPABASE = "supabase" + TURSO = "turso" + D1 = "d1" + +else: + DatabaseType = Enum("DatabaseType", {t.upper(): t for t in _get_supported_db_types()}) # type: ignore[misc] + + +def get_database_type_labels() -> dict[DatabaseType, str]: + """Get database type display labels (lazy-loaded).""" + from .db.providers import get_display_name + return {db_type: get_display_name(db_type.value) for db_type in DatabaseType} class AuthType(Enum): @@ -83,6 +120,12 @@ class AuthType(Enum): } +def _get_default_driver() -> str: + """Get default ODBC driver (lazy import).""" + from .drivers import SUPPORTED_DRIVERS + return SUPPORTED_DRIVERS[0] + + @dataclass class ConnectionConfig: """Database connection configuration.""" @@ -91,14 +134,14 @@ class ConnectionConfig: db_type: str = "mssql" # Database type: mssql, sqlite, postgresql, mysql # Server-based database fields (SQL Server, PostgreSQL, MySQL) server: str = "" - port: str = "1433" # Default varies: 1433 (MSSQL), 5432 (PostgreSQL), 3306 (MySQL) + port: str = "" # Default derived from schema for server-based databases database: str = "" username: str = "" - password: str = "" + password: str | None = None # SQL Server specific fields auth_type: str = "sql" - driver: str = "ODBC Driver 18 for SQL Server" - trusted_connection: bool = False # Legacy field for backwards compatibility + driver: str = field(default_factory=_get_default_driver) + trusted_connection: bool = False # SQLite specific fields file_path: str = "" # SSH tunnel fields @@ -107,17 +150,25 @@ class ConnectionConfig: ssh_port: str = "22" ssh_username: str = "" ssh_auth_type: str = "key" # "key" or "password" - ssh_password: str = "" + ssh_password: str | None = None ssh_key_path: str = "" # Supabase specific fields supabase_region: str = "" supabase_project_id: str = "" - def __post_init__(self): + def __post_init__(self) -> None: """Handle backwards compatibility with old configs.""" # Old configs without db_type are SQL Server if not hasattr(self, "db_type") or not self.db_type: self.db_type = "mssql" + + # Apply default port for server-based DBs if missing (lazy import) + if not getattr(self, "port", None): + from .db.providers import get_default_port + default_port = get_default_port(self.db_type) + if default_port: + self.port = default_port + # Handle old SQL Server auth compatibility if self.db_type == "mssql": if self.auth_type == "windows" and not self.trusted_connection and self.username: @@ -128,7 +179,7 @@ def get_db_type(self) -> DatabaseType: try: return DatabaseType(self.db_type) except ValueError: - return DatabaseType.MSSQL + return DatabaseType.MSSQL # type: ignore[attr-defined, no-any-return] def get_auth_type(self) -> AuthType: """Get the AuthType enum value.""" @@ -175,15 +226,9 @@ def get_connection_string(self) -> str: elif auth == AuthType.SQL_SERVER: return base + f"UID={self.username};PWD={self.password};" elif auth == AuthType.AD_PASSWORD: - return ( - base - + f"Authentication=ActiveDirectoryPassword;" - f"UID={self.username};PWD={self.password};" - ) + return base + f"Authentication=ActiveDirectoryPassword;" f"UID={self.username};PWD={self.password};" elif auth == AuthType.AD_INTERACTIVE: - return ( - base + f"Authentication=ActiveDirectoryInteractive;" f"UID={self.username};" - ) + return base + f"Authentication=ActiveDirectoryInteractive;" f"UID={self.username};" elif auth == AuthType.AD_INTEGRATED: return base + "Authentication=ActiveDirectoryIntegrated;" @@ -191,7 +236,8 @@ def get_connection_string(self) -> str: def get_display_info(self) -> str: """Get a display string for the connection.""" - if _is_file_based(self.db_type): + from .db.providers import is_file_based + if is_file_based(self.db_type): return self.file_path or self.name if self.db_type == "supabase": diff --git a/sqlit/db/__init__.py b/sqlit/db/__init__.py index 8c814509..84765c81 100644 --- a/sqlit/db/__init__.py +++ b/sqlit/db/__init__.py @@ -1,25 +1,13 @@ """Database abstraction layer for sqlit.""" -from .adapters import ( - ColumnInfo, - CockroachDBAdapter, - DatabaseAdapter, - DuckDBAdapter, - MariaDBAdapter, - MySQLAdapter, - OracleAdapter, - PostgreSQLAdapter, - SQLiteAdapter, - SQLServerAdapter, +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +from .adapters.base import ColumnInfo, DatabaseAdapter, TableInfo +from .providers import ( get_adapter, -) -from .schema import ( - ConnectionSchema, - FieldType, - SchemaField, - SelectOption, - get_all_schemas, - get_connection_schema, get_default_port, get_display_name, get_supported_db_types, @@ -33,7 +21,23 @@ # Base "ColumnInfo", "DatabaseAdapter", - # Adapters + "TableInfo", + # Factory / providers + "get_adapter", + "get_default_port", + "get_display_name", + "get_supported_db_types", + "has_advanced_auth", + "is_file_based", + "supports_ssh", + # UI schema (lazy wrappers) + "ConnectionSchema", + "FieldType", + "SchemaField", + "SelectOption", + "get_all_schemas", + "get_connection_schema", + # Adapters (lazy via __getattr__) "CockroachDBAdapter", "DuckDBAdapter", "MariaDBAdapter", @@ -42,21 +46,62 @@ "PostgreSQLAdapter", "SQLiteAdapter", "SQLServerAdapter", - # Factory - "get_adapter", - # Schema - "ConnectionSchema", - "FieldType", - "SchemaField", - "SelectOption", - "get_all_schemas", - "get_connection_schema", - "get_default_port", - "get_display_name", - "get_supported_db_types", - "has_advanced_auth", - "is_file_based", - "supports_ssh", + "SupabaseAdapter", + "TursoAdapter", # Tunnel "create_ssh_tunnel", ] + +if TYPE_CHECKING: + from .adapters.cockroachdb import CockroachDBAdapter + from .adapters.duckdb import DuckDBAdapter + from .adapters.mariadb import MariaDBAdapter + from .adapters.mssql import SQLServerAdapter + from .adapters.mysql import MySQLAdapter + from .adapters.oracle import OracleAdapter + from .adapters.postgresql import PostgreSQLAdapter + from .adapters.sqlite import SQLiteAdapter + from .adapters.supabase import SupabaseAdapter + from .adapters.turso import TursoAdapter + from .schema import ConnectionSchema, FieldType, SchemaField, SelectOption + + +def get_connection_schema(db_type: str) -> Any: + from .schema import get_connection_schema as _get_connection_schema + + return _get_connection_schema(db_type) + + +def get_all_schemas() -> Any: + from .schema import get_all_schemas as _get_all_schemas + + return _get_all_schemas() + + +_LAZY_ATTRS: dict[str, tuple[str, str]] = { + # Schema types + "ConnectionSchema": ("sqlit.db.schema", "ConnectionSchema"), + "FieldType": ("sqlit.db.schema", "FieldType"), + "SchemaField": ("sqlit.db.schema", "SchemaField"), + "SelectOption": ("sqlit.db.schema", "SelectOption"), + # Adapters (through sqlit.db.adapters, which itself lazy-loads) + "CockroachDBAdapter": ("sqlit.db.adapters", "CockroachDBAdapter"), + "DuckDBAdapter": ("sqlit.db.adapters", "DuckDBAdapter"), + "MariaDBAdapter": ("sqlit.db.adapters", "MariaDBAdapter"), + "MySQLAdapter": ("sqlit.db.adapters", "MySQLAdapter"), + "OracleAdapter": ("sqlit.db.adapters", "OracleAdapter"), + "PostgreSQLAdapter": ("sqlit.db.adapters", "PostgreSQLAdapter"), + "SQLiteAdapter": ("sqlit.db.adapters", "SQLiteAdapter"), + "SQLServerAdapter": ("sqlit.db.adapters", "SQLServerAdapter"), + "SupabaseAdapter": ("sqlit.db.adapters", "SupabaseAdapter"), + "TursoAdapter": ("sqlit.db.adapters", "TursoAdapter"), +} + + +def __getattr__(name: str) -> Any: + target = _LAZY_ATTRS.get(name) + if target is None: + raise AttributeError(name) + module_name, attr_name = target + module = import_module(module_name) + return getattr(module, attr_name) diff --git a/sqlit/db/adapters/__init__.py b/sqlit/db/adapters/__init__.py index 714af278..fe5a7ea3 100644 --- a/sqlit/db/adapters/__init__.py +++ b/sqlit/db/adapters/__init__.py @@ -1,19 +1,24 @@ +"""Adapter factory and lightweight exports. + +Avoid importing every adapter module at import time; adapters are loaded lazily +via the provider registry when requested. +""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING + +from ..providers import PROVIDERS +from ..providers import get_adapter as _get_adapter +from ..providers import get_supported_db_types as _get_supported_adapter_db_types from .base import ColumnInfo, DatabaseAdapter, TableInfo -from .cockroachdb import CockroachDBAdapter -from .duckdb import DuckDBAdapter -from .mariadb import MariaDBAdapter -from .mssql import SQLServerAdapter -from .mysql import MySQLAdapter -from .oracle import OracleAdapter -from .postgresql import PostgreSQLAdapter -from .sqlite import SQLiteAdapter -from .supabase import SupabaseAdapter -from .turso import TursoAdapter __all__ = [ "ColumnInfo", "DatabaseAdapter", "TableInfo", + # Adapter classes (lazy via __getattr__) "CockroachDBAdapter", "DuckDBAdapter", "MariaDBAdapter", @@ -24,24 +29,44 @@ "SQLServerAdapter", "SupabaseAdapter", "TursoAdapter", + # Factory helpers "get_adapter", + "get_supported_adapter_db_types", ] +if TYPE_CHECKING: + from .cockroachdb import CockroachDBAdapter + from .duckdb import DuckDBAdapter + from .mariadb import MariaDBAdapter + from .mssql import SQLServerAdapter + from .mysql import MySQLAdapter + from .oracle import OracleAdapter + from .postgresql import PostgreSQLAdapter + from .sqlite import SQLiteAdapter + from .supabase import SupabaseAdapter + from .turso import TursoAdapter + def get_adapter(db_type: str) -> DatabaseAdapter: - adapters = { - "mssql": SQLServerAdapter(), - "sqlite": SQLiteAdapter(), - "postgresql": PostgreSQLAdapter(), - "mysql": MySQLAdapter(), - "oracle": OracleAdapter(), - "mariadb": MariaDBAdapter(), - "duckdb": DuckDBAdapter(), - "cockroachdb": CockroachDBAdapter(), - "turso": TursoAdapter(), - "supabase": SupabaseAdapter(), - } - adapter = adapters.get(db_type) - if not adapter: - raise ValueError(f"Unknown database type: {db_type}") - return adapter + return _get_adapter(db_type) + + +def get_supported_adapter_db_types() -> list[str]: + """Return the database types supported by the adapter factory.""" + return _get_supported_adapter_db_types() + + +_ADAPTER_PATH_BY_NAME: dict[str, tuple[str, str]] | None = None + + +def __getattr__(name: str) -> type[DatabaseAdapter]: + global _ADAPTER_PATH_BY_NAME + if _ADAPTER_PATH_BY_NAME is None: + _ADAPTER_PATH_BY_NAME = {spec.adapter_path[1]: spec.adapter_path for spec in PROVIDERS.values()} + + adapter_path = _ADAPTER_PATH_BY_NAME.get(name) + if adapter_path is None: + raise AttributeError(name) + module_name, class_name = adapter_path + module = import_module(module_name) + return getattr(module, class_name) diff --git a/sqlit/db/adapters/base.py b/sqlit/db/adapters/base.py index ee349201..ae3f0119 100644 --- a/sqlit/db/adapters/base.py +++ b/sqlit/db/adapters/base.py @@ -2,11 +2,15 @@ from __future__ import annotations +import importlib +import os from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any +from rich.markup import escape + if TYPE_CHECKING: from ...config import ConnectionConfig @@ -40,6 +44,7 @@ class ColumnInfo: name: str data_type: str + is_primary_key: bool = False # Type alias for table/view info: (schema, name) @@ -53,6 +58,53 @@ class DatabaseAdapter(ABC): Connection schema/metadata is defined separately in db.schema. """ + @property + def install_hint(self) -> str | None: + """Installation hint for the adapter's dependencies.""" + if not self.install_extra or not self.install_package: + return None + return _create_driver_import_error_hint(self.name, self.install_extra, self.install_package).strip() + + @property + def driver_import_names(self) -> tuple[str, ...]: + """Import names used to verify required driver dependencies are installed.""" + return () + + def ensure_driver_available(self) -> None: + """Verify required dependencies can be imported, raising MissingDriverError if not.""" + forced_missing = os.environ.get("SQLIT_MOCK_MISSING_DRIVERS", "").strip() + if forced_missing: + forced = {s.strip() for s in forced_missing.split(",") if s.strip()} + db_type = getattr(self, "_db_type", None) + if db_type in forced: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise ImportError(f"Missing driver for {self.name}") + raise MissingDriverError(self.name, self.install_extra, self.install_package) + + if not self.driver_import_names: + return + try: + for module_name in self.driver_import_names: + importlib.import_module(module_name) + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e + + @property + def install_extra(self) -> str | None: + """Name of the [extra] for pip install.""" + return None + + @property + def install_package(self) -> str | None: + """Name of the package for pipx inject.""" + return None + @property @abstractmethod def name(self) -> str: @@ -94,7 +146,7 @@ def format_table_name(self, schema: str, name: str) -> str: return f"{schema}.{name}" @abstractmethod - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Create a connection to the database.""" pass @@ -146,9 +198,7 @@ def quote_identifier(self, name: str) -> str: pass @abstractmethod - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build a SELECT query with limit. Args: @@ -160,9 +210,7 @@ def build_select_query( pass @abstractmethod - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query and return (columns, rows, truncated). Args: @@ -188,9 +236,7 @@ class CursorBasedAdapter(DatabaseAdapter): Provides common implementations for execute_query and execute_non_query. """ - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query using cursor-based approach with optional row limit.""" cursor = conn.cursor() cursor.execute(query) @@ -212,7 +258,7 @@ def execute_non_query(self, conn: Any, query: str) -> int: """Execute a non-query using cursor-based approach.""" cursor = conn.cursor() cursor.execute(query) - rowcount = cursor.rowcount + rowcount = int(cursor.rowcount) conn.commit() return rowcount @@ -259,14 +305,12 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: cursor = conn.cursor() if database: cursor.execute( - "SELECT table_name FROM information_schema.views " - "WHERE table_schema = %s ORDER BY table_name", + "SELECT table_name FROM information_schema.views " "WHERE table_schema = %s ORDER BY table_name", (database,), ) else: cursor.execute( - "SELECT table_name FROM information_schema.views " - "WHERE table_schema = DATABASE() ORDER BY table_name" + "SELECT table_name FROM information_schema.views " "WHERE table_schema = DATABASE() ORDER BY table_name" ) return [("", row[0]) for row in cursor.fetchall()] @@ -275,6 +319,23 @@ def get_columns( ) -> list[ColumnInfo]: """Get columns for a table. Schema parameter is ignored (MySQL has no schemas).""" cursor = conn.cursor() + + # Get primary key columns + if database: + cursor.execute( + "SELECT column_name FROM information_schema.key_column_usage " + "WHERE table_schema = %s AND table_name = %s AND constraint_name = 'PRIMARY'", + (database, table), + ) + else: + cursor.execute( + "SELECT column_name FROM information_schema.key_column_usage " + "WHERE table_schema = DATABASE() AND table_name = %s AND constraint_name = 'PRIMARY'", + (table,), + ) + pk_columns = {row[0] for row in cursor.fetchall()} + + # Get all columns if database: cursor.execute( "SELECT column_name, data_type FROM information_schema.columns " @@ -289,7 +350,7 @@ def get_columns( "ORDER BY ordinal_position", (table,), ) - return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Get stored procedures.""" @@ -317,9 +378,7 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace("`", "``") return f"`{escaped}`" - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT LIMIT query. Schema parameter is ignored (MySQL has no schemas).""" if database: return f"SELECT * FROM `{database}`.`{table}` LIMIT {limit}" @@ -371,13 +430,31 @@ def get_columns( """Get columns for a table.""" cursor = conn.cursor() schema = schema or "public" + + # Get primary key columns + cursor.execute( + "SELECT kcu.column_name " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "WHERE tc.constraint_type = 'PRIMARY KEY' " + "AND tc.table_schema = %s AND tc.table_name = %s", + (schema, table), + ) + pk_columns = {row[0] for row in cursor.fetchall()} + + # Get all columns cursor.execute( "SELECT column_name, data_type FROM information_schema.columns " "WHERE table_schema = %s AND table_name = %s " "ORDER BY ordinal_position", (schema, table), ) - return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + return [ + ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) + for row in cursor.fetchall() + ] def quote_identifier(self, name: str) -> str: """Quote identifier using double quotes for PostgreSQL. @@ -387,9 +464,20 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') return f'"{escaped}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT LIMIT query for PostgreSQL.""" schema = schema or "public" return f'SELECT * FROM "{schema}"."{table}" LIMIT {limit}' + + +def _create_driver_import_error_hint(driver_name: str, extra_name: str, package_name: str) -> str: + """Generate a context-aware hint for missing driver installation.""" + from ...install_strategy import detect_strategy + + strategy = detect_strategy(extra_name=extra_name, package_name=package_name) + instructions = escape(strategy.manual_instructions) + return ( + f"{driver_name} driver not found.\n\n" + f"To connect to {driver_name}, run:\n\n" + f"[bold]{instructions}[/bold]\n" + ) diff --git a/sqlit/db/adapters/cockroachdb.py b/sqlit/db/adapters/cockroachdb.py index 47122995..1f0ade76 100644 --- a/sqlit/db/adapters/cockroachdb.py +++ b/sqlit/db/adapters/cockroachdb.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from ..schema import get_default_port from .base import PostgresBaseAdapter if TYPE_CHECKING: @@ -17,15 +18,34 @@ class CockroachDBAdapter(PostgresBaseAdapter): def name(self) -> str: return "CockroachDB" + @property + def install_extra(self) -> str: + return "cockroachdb" + + @property + def install_package(self) -> str: + return "psycopg2-binary" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("psycopg2",) + @property def supports_stored_procedures(self) -> bool: return False # CockroachDB has limited stored procedure support - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to CockroachDB database.""" - import psycopg2 + try: + import psycopg2 + except ImportError as e: + from ...db.exceptions import MissingDriverError - port = int(config.port) if config.port else 26257 + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e + + port = int(config.port or get_default_port("cockroachdb")) conn = psycopg2.connect( host=config.server, port=port, @@ -42,9 +62,7 @@ def connect(self, config: "ConnectionConfig") -> Any: def get_databases(self, conn: Any) -> list[str]: """Get list of databases from CockroachDB.""" cursor = conn.cursor() - cursor.execute( - "SELECT database_name FROM [SHOW DATABASES] ORDER BY database_name" - ) + cursor.execute("SELECT database_name FROM [SHOW DATABASES] ORDER BY database_name") return [row[0] for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: diff --git a/sqlit/db/adapters/d1.py b/sqlit/db/adapters/d1.py new file mode 100644 index 00000000..3cf60de1 --- /dev/null +++ b/sqlit/db/adapters/d1.py @@ -0,0 +1,213 @@ +"""Cloudflare D1 database adapter.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from .base import ColumnInfo, DatabaseAdapter, TableInfo + +if TYPE_CHECKING: + import requests + + from ...config import ConnectionConfig + + +@dataclass +class D1Connection: + """Holds connection details for a D1 database.""" + + session: requests.Session + account_id: str + database_id: str + + +class D1Adapter(DatabaseAdapter): + """Adapter for Cloudflare D1.""" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("requests",) + + def _api_base_url(self) -> str: + return os.environ.get("D1_API_BASE_URL", "https://api.cloudflare.com").rstrip("/") + + @property + def name(self) -> str: + """Human-readable name for this database type.""" + return "Cloudflare D1" + + @property + def install_extra(self) -> str: + return "d1" + + @property + def install_package(self) -> str: + return "requests" + + @property + def supports_multiple_databases(self) -> bool: + """D1 supports multiple databases under a single account.""" + return True + + @property + def supports_stored_procedures(self) -> bool: + """D1 is SQLite-based and does not support stored procedures.""" + return False + + def connect(self, config: ConnectionConfig) -> D1Connection: + """Establishes a 'connection' to D1 by preparing authenticated session.""" + try: + import requests + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e + + session = requests.Session() + session.headers.update({"Authorization": f"Bearer {config.password}"}) + account_id = config.server + + if not config.database: + raise ValueError("Database name is required for Cloudflare D1 connection.") + + database_id = self._find_database_id_by_name(session, account_id, config.database) + if not database_id: + raise ConnectionError(f"Cloudflare D1 database '{config.database}' not found.") + + return D1Connection(session=session, account_id=account_id, database_id=database_id) + + def _get_all_databases(self, session: requests.Session, account_id: str) -> list[dict[str, Any]]: + """Fetches all D1 databases for an account.""" + api_url = f"{self._api_base_url()}/client/v4/accounts/{account_id}/d1/database" + response = session.get(api_url) + response.raise_for_status() + data = cast(dict[str, Any], response.json()) + result = data.get("result", []) + if not isinstance(result, list): + return [] + return [cast(dict[str, Any], item) for item in result if isinstance(item, dict)] + + def _find_database_id_by_name(self, session: requests.Session, account_id: str, name: str) -> str | None: + """Finds a D1 database's UUID by its name.""" + databases = self._get_all_databases(session, account_id) + for db in databases: + if db.get("name") == name: + return db.get("uuid") + return None + + def get_databases(self, conn: D1Connection) -> list[str]: + """Gets a list of all database names for the account.""" + databases = self._get_all_databases(conn.session, conn.account_id) + return [db["name"] for db in databases if "name" in db] + + def _execute(self, conn: D1Connection, query: str) -> dict[str, Any]: + """Internal method to run a command on the D1 execute endpoint.""" + api_url = f"{self._api_base_url()}/client/v4/accounts/{conn.account_id}/d1/database/{conn.database_id}/execute" + response = conn.session.post(api_url, json={"sql": query}) + response.raise_for_status() + # The result is a list containing a single result object + data = cast(dict[str, Any], response.json()) + result_list = data.get("result", []) + if not isinstance(result_list, list) or not result_list or not isinstance(result_list[0], dict): + raise RuntimeError("Unexpected D1 API response format") + return cast(dict[str, Any], result_list[0]) + + def get_tables(self, conn: D1Connection, database: str | None = None) -> list[TableInfo]: + """Gets tables using PRAGMA.""" + result = self._execute(conn, "PRAGMA table_list;") + tables = [] + rows = result.get("results", []) + if not isinstance(rows, list): + return [] + for row in rows: + if not isinstance(row, dict): + continue + if row.get("type") == "table" and not row.get("name", "").startswith("sqlite_"): + tables.append((row.get("schema", ""), row.get("name", ""))) + return tables + + def get_views(self, conn: D1Connection, database: str | None = None) -> list[TableInfo]: + """Gets views using PRAGMA.""" + result = self._execute(conn, "PRAGMA table_list;") + views = [] + rows = result.get("results", []) + if not isinstance(rows, list): + return [] + for row in rows: + if not isinstance(row, dict): + continue + if row.get("type") == "view": + views.append((row.get("schema", ""), row.get("name", ""))) + return views + + def get_columns( + self, conn: D1Connection, table: str, database: str | None = None, schema: str | None = None + ) -> list[ColumnInfo]: + """Gets table columns using PRAGMA.""" + result = self._execute(conn, f"PRAGMA table_info({self.quote_identifier(table)});") + rows = result.get("results", []) + if not isinstance(rows, list): + return [] + cols: list[ColumnInfo] = [] + for col in rows: + if not isinstance(col, dict): + continue + name = col.get("name") + data_type = col.get("type") + # pk > 0 indicates column is part of primary key + pk_value = col.get("pk", 0) + is_pk = isinstance(pk_value, int) and pk_value > 0 + if isinstance(name, str) and isinstance(data_type, str): + cols.append(ColumnInfo(name=name, data_type=data_type, is_primary_key=is_pk)) + return cols + + def get_procedures(self, conn: D1Connection, database: str | None = None) -> list[str]: + """Returns an empty list as D1 does not support stored procedures.""" + return [] + + def quote_identifier(self, name: str) -> str: + """Quotes an identifier with double quotes.""" + return f'"{name}"' + + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: + """Builds a standard SELECT ... LIMIT query.""" + return f"SELECT * FROM {self.quote_identifier(table)} LIMIT {limit}" + + def execute_query( + self, conn: D1Connection, query: str, max_rows: int | None = None + ) -> tuple[list[str], list[tuple], bool]: + """Executes a query and returns results in the expected format.""" + result = self._execute(conn, query) + rows_dicts = result.get("results", []) + if not isinstance(rows_dicts, list): + return [], [], False + rows_dicts = [row for row in rows_dicts if isinstance(row, dict)] + + if not rows_dicts: + return [], [], False + + columns = [str(k) for k in rows_dicts[0].keys()] + rows = [tuple(row.values()) for row in rows_dicts] + + # D1 doesn't have a concept of server-side cursor, so we can't easily tell if truncated + # unless we add `LIMIT max_rows + 1` to the query, which is complex here. + # For now, we assume not truncated. + truncated = False + if max_rows is not None and len(rows) > max_rows: + rows = rows[:max_rows] + truncated = True + + return columns, rows, truncated + + def execute_non_query(self, conn: D1Connection, query: str) -> int: + """Executes a non-query statement and returns rows affected.""" + result = self._execute(conn, query) + meta = result.get("meta", {}) + # D1 provides `rows_written` for mutations. + if not isinstance(meta, dict): + return 0 + return int(meta.get("rows_written", 0) or 0) diff --git a/sqlit/db/adapters/duckdb.py b/sqlit/db/adapters/duckdb.py index 503e29c9..f9b6cd1b 100644 --- a/sqlit/db/adapters/duckdb.py +++ b/sqlit/db/adapters/duckdb.py @@ -17,6 +17,18 @@ class DuckDBAdapter(DatabaseAdapter): def name(self) -> str: return "DuckDB" + @property + def install_extra(self) -> str: + return "duckdb" + + @property + def install_package(self) -> str: + return "duckdb" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("duckdb",) + @property def supports_multiple_databases(self) -> bool: return False @@ -29,17 +41,25 @@ def supports_stored_procedures(self) -> bool: def default_schema(self) -> str: return "main" - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to DuckDB database file. Note: DuckDB connections have limited thread safety. Operations are serialized via exclusive workers to ensure only one thread accesses the connection at a time. """ - import duckdb + try: + import duckdb + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e file_path = resolve_file_path(config.file_path) - return duckdb.connect(str(file_path)) + duckdb_any: Any = duckdb + return duckdb_any.connect(str(file_path)) def get_databases(self, conn: Any) -> list[str]: """DuckDB doesn't support multiple databases - return empty list.""" @@ -70,13 +90,28 @@ def get_columns( ) -> list[ColumnInfo]: """Get columns for a table from DuckDB.""" schema = schema or "main" + + # Get primary key columns + result = conn.execute( + "SELECT kcu.column_name " + "FROM information_schema.table_constraints tc " + "JOIN information_schema.key_column_usage kcu " + " ON tc.constraint_name = kcu.constraint_name " + " AND tc.table_schema = kcu.table_schema " + "WHERE tc.constraint_type = 'PRIMARY KEY' " + "AND tc.table_schema = ? AND tc.table_name = ?", + (schema, table), + ) + pk_columns = {row[0] for row in result.fetchall()} + + # Get all columns result = conn.execute( "SELECT column_name, data_type FROM information_schema.columns " "WHERE table_schema = ? AND table_name = ? " "ORDER BY ordinal_position", (schema, table), ) - return [ColumnInfo(name=row[0], data_type=row[1]) for row in result.fetchall()] + return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in result.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """DuckDB doesn't support stored procedures - return empty list.""" @@ -90,16 +125,12 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') return f'"{escaped}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT LIMIT query for DuckDB.""" schema = schema or "main" return f'SELECT * FROM "{schema}"."{table}" LIMIT {limit}' - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query on DuckDB with optional row limit.""" result = conn.execute(query) if result.description: @@ -120,6 +151,6 @@ def execute_non_query(self, conn: Any, query: str) -> int: result = conn.execute(query) # DuckDB doesn't provide rowcount for all operations try: - return result.rowcount if hasattr(result, 'rowcount') else -1 + return result.rowcount if hasattr(result, "rowcount") else -1 except Exception: return -1 diff --git a/sqlit/db/adapters/mariadb.py b/sqlit/db/adapters/mariadb.py index 54c433da..eafe7b30 100644 --- a/sqlit/db/adapters/mariadb.py +++ b/sqlit/db/adapters/mariadb.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from ..schema import get_default_port from .base import ColumnInfo, MySQLBaseAdapter, TableInfo if TYPE_CHECKING: @@ -21,12 +22,32 @@ class MariaDBAdapter(MySQLBaseAdapter): def name(self) -> str: return "MariaDB" - def connect(self, config: "ConnectionConfig") -> Any: + @property + def install_extra(self) -> str: + return "mariadb" + + @property + def install_package(self) -> str: + return "mariadb" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("mariadb",) + + def connect(self, config: ConnectionConfig) -> Any: """Connect to MariaDB database.""" - import mariadb + try: + import mariadb + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e - port = int(config.port) if config.port else 3306 - return mariadb.connect( + port = int(config.port or get_default_port("mariadb")) + mariadb_any: Any = mariadb + return mariadb_any.connect( host=config.server, port=port, database=config.database or None, @@ -56,14 +77,12 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: cursor = conn.cursor() if database: cursor.execute( - "SELECT table_name FROM information_schema.views " - "WHERE table_schema = ? ORDER BY table_name", + "SELECT table_name FROM information_schema.views " "WHERE table_schema = ? ORDER BY table_name", (database,), ) else: cursor.execute( - "SELECT table_name FROM information_schema.views " - "WHERE table_schema = DATABASE() ORDER BY table_name" + "SELECT table_name FROM information_schema.views " "WHERE table_schema = DATABASE() ORDER BY table_name" ) return [("", row[0]) for row in cursor.fetchall()] @@ -72,6 +91,23 @@ def get_columns( ) -> list[ColumnInfo]: """Get columns for a table from MariaDB. Schema parameter is ignored.""" cursor = conn.cursor() + + # Get primary key columns + if database: + cursor.execute( + "SELECT column_name FROM information_schema.key_column_usage " + "WHERE table_schema = ? AND table_name = ? AND constraint_name = 'PRIMARY'", + (database, table), + ) + else: + cursor.execute( + "SELECT column_name FROM information_schema.key_column_usage " + "WHERE table_schema = DATABASE() AND table_name = ? AND constraint_name = 'PRIMARY'", + (table,), + ) + pk_columns = {row[0] for row in cursor.fetchall()} + + # Get all columns if database: cursor.execute( "SELECT column_name, data_type FROM information_schema.columns " @@ -86,7 +122,7 @@ def get_columns( "ORDER BY ordinal_position", (table,), ) - return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Get stored procedures from MariaDB.""" diff --git a/sqlit/db/adapters/mssql.py b/sqlit/db/adapters/mssql.py index ad494900..df0f688d 100644 --- a/sqlit/db/adapters/mssql.py +++ b/sqlit/db/adapters/mssql.py @@ -17,6 +17,18 @@ class SQLServerAdapter(DatabaseAdapter): def name(self) -> str: return "SQL Server" + @property + def install_extra(self) -> str: + return "mssql" + + @property + def install_package(self) -> str: + return "pyodbc" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("pyodbc",) + @property def supports_multiple_databases(self) -> bool: return True @@ -29,12 +41,9 @@ def supports_stored_procedures(self) -> bool: def default_schema(self) -> str: return "dbo" - def _build_connection_string(self, config: "ConnectionConfig") -> str: + def _build_connection_string(self, config: ConnectionConfig) -> str: """Build ODBC connection string from config. - This method encapsulates the SQL Server-specific connection string - building logic that was previously in ConnectionConfig.get_connection_string(). - Args: config: Connection configuration. @@ -61,23 +70,30 @@ def _build_connection_string(self, config: "ConnectionConfig") -> str: elif auth == AuthType.SQL_SERVER: return base + f"UID={config.username};PWD={config.password};" elif auth == AuthType.AD_PASSWORD: - return ( - base - + f"Authentication=ActiveDirectoryPassword;" - f"UID={config.username};PWD={config.password};" - ) + return base + f"Authentication=ActiveDirectoryPassword;" f"UID={config.username};PWD={config.password};" elif auth == AuthType.AD_INTERACTIVE: - return ( - base + f"Authentication=ActiveDirectoryInteractive;" f"UID={config.username};" - ) + return base + f"Authentication=ActiveDirectoryInteractive;" f"UID={config.username};" elif auth == AuthType.AD_INTEGRATED: return base + "Authentication=ActiveDirectoryIntegrated;" return base + "Trusted_Connection=yes;" - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to SQL Server using pyodbc.""" - import pyodbc + try: + import pyodbc + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e + + installed = list(pyodbc.drivers()) + if config.driver not in installed: + from ...db.exceptions import MissingODBCDriverError + + raise MissingODBCDriverError(config.driver, installed) conn_str = self._build_connection_string(config) return pyodbc.connect(conn_str, timeout=10) @@ -113,8 +129,7 @@ def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: ) else: cursor.execute( - "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.VIEWS " - "ORDER BY TABLE_SCHEMA, TABLE_NAME" + "SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.VIEWS " "ORDER BY TABLE_SCHEMA, TABLE_NAME" ) return [(row[0], row[1]) for row in cursor.fetchall()] @@ -124,6 +139,33 @@ def get_columns( """Get columns for a table from SQL Server.""" cursor = conn.cursor() schema = schema or "dbo" + + # Get primary key columns + if database: + cursor.execute( + f"SELECT kcu.COLUMN_NAME " + f"FROM [{database}].INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc " + f"JOIN [{database}].INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu " + f" ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME " + f" AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA " + f"WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' " + f"AND tc.TABLE_SCHEMA = ? AND tc.TABLE_NAME = ?", + (schema, table), + ) + else: + cursor.execute( + "SELECT kcu.COLUMN_NAME " + "FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc " + "JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu " + " ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME " + " AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA " + "WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' " + "AND tc.TABLE_SCHEMA = ? AND tc.TABLE_NAME = ?", + (schema, table), + ) + pk_columns = {row[0] for row in cursor.fetchall()} + + # Get all columns if database: cursor.execute( f"SELECT COLUMN_NAME, DATA_TYPE FROM [{database}].INFORMATION_SCHEMA.COLUMNS " @@ -136,7 +178,7 @@ def get_columns( "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION", (schema, table), ) - return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Get stored procedures from SQL Server.""" @@ -161,18 +203,14 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace("]", "]]") return f"[{escaped}]" - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT TOP query for SQL Server.""" schema = schema or "dbo" if database: return f"SELECT TOP {limit} * FROM [{database}].[{schema}].[{table}]" return f"SELECT TOP {limit} * FROM [{schema}].[{table}]" - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query on SQL Server with optional row limit.""" cursor = conn.cursor() cursor.execute(query) @@ -193,6 +231,6 @@ def execute_non_query(self, conn: Any, query: str) -> int: """Execute a non-query on SQL Server.""" cursor = conn.cursor() cursor.execute(query) - rowcount = cursor.rowcount + rowcount = int(cursor.rowcount) conn.commit() return rowcount diff --git a/sqlit/db/adapters/mysql.py b/sqlit/db/adapters/mysql.py index 9f0a594e..5859eced 100644 --- a/sqlit/db/adapters/mysql.py +++ b/sqlit/db/adapters/mysql.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from ..schema import get_default_port from .base import MySQLBaseAdapter if TYPE_CHECKING: @@ -17,11 +18,30 @@ class MySQLAdapter(MySQLBaseAdapter): def name(self) -> str: return "MySQL" - def connect(self, config: "ConnectionConfig") -> Any: + @property + def install_extra(self) -> str: + return "mysql" + + @property + def install_package(self) -> str: + return "mysql-connector-python" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("mysql.connector",) + + def connect(self, config: ConnectionConfig) -> Any: """Connect to MySQL database.""" - import mysql.connector + try: + import mysql.connector + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e - port = int(config.port) if config.port else 3306 + port = int(config.port or get_default_port("mysql")) return mysql.connector.connect( host=config.server, port=port, diff --git a/sqlit/db/adapters/oracle.py b/sqlit/db/adapters/oracle.py index d5ab5820..fc75a320 100644 --- a/sqlit/db/adapters/oracle.py +++ b/sqlit/db/adapters/oracle.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from ..schema import get_default_port from .base import ColumnInfo, DatabaseAdapter, TableInfo if TYPE_CHECKING: @@ -21,6 +22,18 @@ class OracleAdapter(DatabaseAdapter): def name(self) -> str: return "Oracle" + @property + def install_extra(self) -> str: + return "oracle" + + @property + def install_package(self) -> str: + return "oracledb" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("oracledb",) + @property def supports_multiple_databases(self) -> bool: # Oracle uses schemas within a single database, not multiple databases @@ -30,11 +43,18 @@ def supports_multiple_databases(self) -> bool: def supports_stored_procedures(self) -> bool: return True - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to Oracle database.""" - import oracledb + try: + import oracledb + except ImportError as e: + from ...db.exceptions import MissingDriverError - port = int(config.port) if config.port else 1521 + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e + + port = int(config.port or get_default_port("oracle")) # Use Easy Connect string format: host:port/service_name dsn = f"{config.server}:{port}/{config.database}" return oracledb.connect( @@ -50,18 +70,14 @@ def get_databases(self, conn: Any) -> list[str]: def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: """Get list of tables from Oracle. Returns (schema, name) with empty schema.""" cursor = conn.cursor() - cursor.execute( - "SELECT table_name FROM user_tables ORDER BY table_name" - ) + cursor.execute("SELECT table_name FROM user_tables ORDER BY table_name") # user_tables returns only current user's tables, so no schema prefix needed return [("", row[0]) for row in cursor.fetchall()] def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: """Get list of views from Oracle. Returns (schema, name) with empty schema.""" cursor = conn.cursor() - cursor.execute( - "SELECT view_name FROM user_views ORDER BY view_name" - ) + cursor.execute("SELECT view_name FROM user_views ORDER BY view_name") return [("", row[0]) for row in cursor.fetchall()] def get_columns( @@ -69,19 +85,29 @@ def get_columns( ) -> list[ColumnInfo]: """Get columns for a table from Oracle. Schema parameter is ignored.""" cursor = conn.cursor() + + # Get primary key columns + cursor.execute( + "SELECT cols.column_name " + "FROM user_constraints cons " + "JOIN user_cons_columns cols ON cons.constraint_name = cols.constraint_name " + "WHERE cons.constraint_type = 'P' AND cons.table_name = :1", + (table.upper(),), + ) + pk_columns = {row[0] for row in cursor.fetchall()} + + # Get all columns cursor.execute( - "SELECT column_name, data_type FROM user_tab_columns " - "WHERE table_name = :1 ORDER BY column_id", + "SELECT column_name, data_type FROM user_tab_columns " "WHERE table_name = :1 ORDER BY column_id", (table.upper(),), ) - return [ColumnInfo(name=row[0], data_type=row[1]) for row in cursor.fetchall()] + return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Get stored procedures from Oracle.""" cursor = conn.cursor() cursor.execute( - "SELECT object_name FROM user_procedures " - "WHERE object_type = 'PROCEDURE' ORDER BY object_name" + "SELECT object_name FROM user_procedures " "WHERE object_type = 'PROCEDURE' ORDER BY object_name" ) return [row[0] for row in cursor.fetchall()] @@ -93,15 +119,11 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') return f'"{escaped}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT query with FETCH FIRST for Oracle 12c+. Schema parameter is ignored.""" return f'SELECT * FROM "{table}" FETCH FIRST {limit} ROWS ONLY' - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query on Oracle with optional row limit.""" cursor = conn.cursor() cursor.execute(query) @@ -122,6 +144,6 @@ def execute_non_query(self, conn: Any, query: str) -> int: """Execute a non-query on Oracle.""" cursor = conn.cursor() cursor.execute(query) - rowcount = cursor.rowcount + rowcount = int(cursor.rowcount) conn.commit() return rowcount diff --git a/sqlit/db/adapters/postgresql.py b/sqlit/db/adapters/postgresql.py index 746b48c4..48f4ac84 100644 --- a/sqlit/db/adapters/postgresql.py +++ b/sqlit/db/adapters/postgresql.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from ..schema import get_default_port from .base import PostgresBaseAdapter if TYPE_CHECKING: @@ -17,11 +18,30 @@ class PostgreSQLAdapter(PostgresBaseAdapter): def name(self) -> str: return "PostgreSQL" - def connect(self, config: "ConnectionConfig") -> Any: + @property + def install_extra(self) -> str: + return "postgres" + + @property + def install_package(self) -> str: + return "psycopg2-binary" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("psycopg2",) + + def connect(self, config: ConnectionConfig) -> Any: """Connect to PostgreSQL database.""" - import psycopg2 + try: + import psycopg2 + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e - port = int(config.port) if config.port else 5432 + port = int(config.port or get_default_port("postgresql")) conn = psycopg2.connect( host=config.server, port=port, @@ -37,10 +57,7 @@ def connect(self, config: "ConnectionConfig") -> Any: def get_databases(self, conn: Any) -> list[str]: """Get list of databases from PostgreSQL.""" cursor = conn.cursor() - cursor.execute( - "SELECT datname FROM pg_database " - "WHERE datistemplate = false ORDER BY datname" - ) + cursor.execute("SELECT datname FROM pg_database " "WHERE datistemplate = false ORDER BY datname") return [row[0] for row in cursor.fetchall()] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: diff --git a/sqlit/db/adapters/sqlite.py b/sqlit/db/adapters/sqlite.py index f8574882..cd58bd4c 100644 --- a/sqlit/db/adapters/sqlite.py +++ b/sqlit/db/adapters/sqlite.py @@ -1,8 +1,6 @@ """SQLite adapter using built-in sqlite3.""" from __future__ import annotations - -import sqlite3 from typing import TYPE_CHECKING, Any from .base import ColumnInfo, DatabaseAdapter, TableInfo, resolve_file_path @@ -26,8 +24,10 @@ def supports_multiple_databases(self) -> bool: def supports_stored_procedures(self) -> bool: return False - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to SQLite database file.""" + import sqlite3 + file_path = resolve_file_path(config.file_path) # check_same_thread=False allows connection to be used from background threads # (for async query execution). SQLite serializes access internally. @@ -43,17 +43,14 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: """Get list of tables from SQLite. Returns (schema, name) with empty schema.""" cursor = conn.cursor() cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' " - "AND name NOT LIKE 'sqlite_%' ORDER BY name" + "SELECT name FROM sqlite_master WHERE type='table' " "AND name NOT LIKE 'sqlite_%' ORDER BY name" ) return [("", row[0]) for row in cursor.fetchall()] def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: """Get list of views from SQLite. Returns (schema, name) with empty schema.""" cursor = conn.cursor() - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='view' ORDER BY name" - ) + cursor.execute("SELECT name FROM sqlite_master WHERE type='view' ORDER BY name") return [("", row[0]) for row in cursor.fetchall()] def get_columns( @@ -65,7 +62,11 @@ def get_columns( quoted_table = self.quote_identifier(table) cursor.execute(f"PRAGMA table_info({quoted_table})") # PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk - return [ColumnInfo(name=row[1], data_type=row[2] or "TEXT") for row in cursor.fetchall()] + # pk > 0 indicates column is part of primary key + return [ + ColumnInfo(name=row[1], data_type=row[2] or "TEXT", is_primary_key=row[5] > 0) + for row in cursor.fetchall() + ] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """SQLite doesn't support stored procedures - return empty list.""" @@ -79,15 +80,11 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') return f'"{escaped}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT LIMIT query for SQLite. Schema parameter is ignored.""" return f'SELECT * FROM "{table}" LIMIT {limit}' - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query on SQLite with optional row limit.""" cursor = conn.cursor() cursor.execute(query) @@ -108,6 +105,6 @@ def execute_non_query(self, conn: Any, query: str) -> int: """Execute a non-query on SQLite.""" cursor = conn.cursor() cursor.execute(query) - rowcount = cursor.rowcount + rowcount = int(cursor.rowcount) conn.commit() return rowcount diff --git a/sqlit/db/adapters/supabase.py b/sqlit/db/adapters/supabase.py index 5c34b051..cfcf7b04 100644 --- a/sqlit/db/adapters/supabase.py +++ b/sqlit/db/adapters/supabase.py @@ -17,7 +17,7 @@ def name(self) -> str: def supports_multiple_databases(self) -> bool: return False - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: from dataclasses import replace transformed = replace( diff --git a/sqlit/db/adapters/turso.py b/sqlit/db/adapters/turso.py index 1e945de8..8ac8ec97 100644 --- a/sqlit/db/adapters/turso.py +++ b/sqlit/db/adapters/turso.py @@ -21,6 +21,18 @@ class TursoAdapter(DatabaseAdapter): def name(self) -> str: return "Turso" + @property + def install_extra(self) -> str: + return "turso" + + @property + def install_package(self) -> str: + return "libsql-client" + + @property + def driver_import_names(self) -> tuple[str, ...]: + return ("libsql_client",) + @property def supports_multiple_databases(self) -> bool: return False @@ -29,13 +41,20 @@ def supports_multiple_databases(self) -> bool: def supports_stored_procedures(self) -> bool: return False - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to Turso database. Uses config.server for the database URL and config.password for the auth token. Supports libsql://, https://, and http:// URLs. """ - from libsql_client import create_client_sync + try: + from libsql_client import create_client_sync + except ImportError as e: + from ...db.exceptions import MissingDriverError + + if not self.install_extra or not self.install_package: + raise e + raise MissingDriverError(self.name, self.install_extra, self.install_package) from e url = config.server # Ensure URL has proper scheme @@ -61,9 +80,7 @@ def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: """Get list of views from Turso. Returns (schema, name) with empty schema.""" - result = conn.execute( - "SELECT name FROM sqlite_master WHERE type='view' ORDER BY name" - ) + result = conn.execute("SELECT name FROM sqlite_master WHERE type='view' ORDER BY name") return [("", row[0]) for row in result.rows] def get_columns( @@ -73,7 +90,8 @@ def get_columns( quoted_table = self.quote_identifier(table) result = conn.execute(f"PRAGMA table_info({quoted_table})") # PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk - return [ColumnInfo(name=row[1], data_type=row[2] or "TEXT") for row in result.rows] + # pk > 0 indicates column is part of primary key + return [ColumnInfo(name=row[1], data_type=row[2] or "TEXT", is_primary_key=row[5] > 0) for row in result.rows] def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: """Turso doesn't support stored procedures - return empty list.""" @@ -87,15 +105,11 @@ def quote_identifier(self, name: str) -> str: escaped = name.replace('"', '""') return f'"{escaped}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: """Build SELECT LIMIT query for Turso. Schema parameter is ignored.""" return f'SELECT * FROM "{table}" LIMIT {limit}' - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query on Turso with optional row limit.""" result = conn.execute(query) if result.columns: @@ -109,4 +123,4 @@ def execute_query( def execute_non_query(self, conn: Any, query: str) -> int: """Execute a non-query on Turso.""" result = conn.execute(query) - return result.rows_affected + return int(result.rows_affected or 0) diff --git a/sqlit/db/exceptions.py b/sqlit/db/exceptions.py new file mode 100644 index 00000000..9f52abac --- /dev/null +++ b/sqlit/db/exceptions.py @@ -0,0 +1,20 @@ +"""Custom exceptions for the database layer.""" + + +class MissingDriverError(ConnectionError): + """Exception raised when a required database driver package is not installed.""" + + def __init__(self, driver_name: str, extra_name: str, package_name: str): + self.driver_name = driver_name + self.extra_name = extra_name + self.package_name = package_name + super().__init__(f"Missing driver for {driver_name}") + + +class MissingODBCDriverError(ConnectionError): + """Exception raised when a required ODBC driver is not installed (SQL Server).""" + + def __init__(self, selected_driver: str, installed_drivers: list[str]): + self.selected_driver = selected_driver + self.installed_drivers = installed_drivers + super().__init__(f"Missing ODBC driver: {selected_driver}") diff --git a/sqlit/db/providers.py b/sqlit/db/providers.py new file mode 100644 index 00000000..a4fe4c79 --- /dev/null +++ b/sqlit/db/providers.py @@ -0,0 +1,143 @@ +"""Canonical provider registry (Plan B). + +This module is the single source of truth for: +- supported provider ids (db_type) +- display names and capabilities (via ConnectionSchema) +- adapter classes +""" + +from __future__ import annotations + +from dataclasses import dataclass +from importlib import import_module +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + from .adapters.base import DatabaseAdapter + from .schema import ConnectionSchema + + +@dataclass(frozen=True) +class ProviderSpec: + schema_path: tuple[str, str] + adapter_path: tuple[str, str] + + +PROVIDERS: dict[str, ProviderSpec] = { + "mssql": ProviderSpec( + schema_path=("sqlit.db.schema", "MSSQL_SCHEMA"), + adapter_path=("sqlit.db.adapters.mssql", "SQLServerAdapter"), + ), + "sqlite": ProviderSpec( + schema_path=("sqlit.db.schema", "SQLITE_SCHEMA"), + adapter_path=("sqlit.db.adapters.sqlite", "SQLiteAdapter"), + ), + "postgresql": ProviderSpec( + schema_path=("sqlit.db.schema", "POSTGRESQL_SCHEMA"), + adapter_path=("sqlit.db.adapters.postgresql", "PostgreSQLAdapter"), + ), + "mysql": ProviderSpec( + schema_path=("sqlit.db.schema", "MYSQL_SCHEMA"), + adapter_path=("sqlit.db.adapters.mysql", "MySQLAdapter"), + ), + "oracle": ProviderSpec( + schema_path=("sqlit.db.schema", "ORACLE_SCHEMA"), + adapter_path=("sqlit.db.adapters.oracle", "OracleAdapter"), + ), + "mariadb": ProviderSpec( + schema_path=("sqlit.db.schema", "MARIADB_SCHEMA"), + adapter_path=("sqlit.db.adapters.mariadb", "MariaDBAdapter"), + ), + "duckdb": ProviderSpec( + schema_path=("sqlit.db.schema", "DUCKDB_SCHEMA"), + adapter_path=("sqlit.db.adapters.duckdb", "DuckDBAdapter"), + ), + "cockroachdb": ProviderSpec( + schema_path=("sqlit.db.schema", "COCKROACHDB_SCHEMA"), + adapter_path=("sqlit.db.adapters.cockroachdb", "CockroachDBAdapter"), + ), + "turso": ProviderSpec( + schema_path=("sqlit.db.schema", "TURSO_SCHEMA"), + adapter_path=("sqlit.db.adapters.turso", "TursoAdapter"), + ), + "supabase": ProviderSpec( + schema_path=("sqlit.db.schema", "SUPABASE_SCHEMA"), + adapter_path=("sqlit.db.adapters.supabase", "SupabaseAdapter"), + ), + "d1": ProviderSpec( + schema_path=("sqlit.db.schema", "D1_SCHEMA"), + adapter_path=("sqlit.db.adapters.d1", "D1Adapter"), + ), +} + + +def get_supported_db_types() -> list[str]: + return list(PROVIDERS.keys()) + + +def iter_provider_schemas() -> Iterable[ConnectionSchema]: + return (_get_schema(spec) for spec in PROVIDERS.values()) + + +def get_provider_spec(db_type: str) -> ProviderSpec: + spec = PROVIDERS.get(db_type) + if spec is None: + raise ValueError(f"Unknown database type: {db_type}") + return spec + + +def _get_schema(spec: ProviderSpec) -> ConnectionSchema: + module_name, attr_name = spec.schema_path + module = import_module(module_name) + return getattr(module, attr_name) + + +def get_connection_schema(db_type: str) -> ConnectionSchema: + return _get_schema(get_provider_spec(db_type)) + + +def get_all_schemas() -> dict[str, ConnectionSchema]: + return {k: _get_schema(v) for k, v in PROVIDERS.items()} + + +def get_adapter(db_type: str) -> "DatabaseAdapter": + adapter = get_adapter_class(db_type)() + # Internal: allow adapters to know their provider id for test/mocking hooks. + setattr(adapter, "_db_type", db_type) + return adapter + + +def get_adapter_class(db_type: str) -> type["DatabaseAdapter"]: + spec = get_provider_spec(db_type) + module_name, class_name = spec.adapter_path + module = import_module(module_name) + adapter_cls = getattr(module, class_name) + return adapter_cls + + +def get_default_port(db_type: str) -> str: + spec = PROVIDERS.get(db_type) + if spec is None: + return "1433" + return _get_schema(spec).default_port + + +def get_display_name(db_type: str) -> str: + spec = PROVIDERS.get(db_type) + return _get_schema(spec).display_name if spec else db_type + + +def supports_ssh(db_type: str) -> bool: + spec = PROVIDERS.get(db_type) + return _get_schema(spec).supports_ssh if spec else False + + +def is_file_based(db_type: str) -> bool: + spec = PROVIDERS.get(db_type) + return _get_schema(spec).is_file_based if spec else False + + +def has_advanced_auth(db_type: str) -> bool: + spec = PROVIDERS.get(db_type) + return _get_schema(spec).has_advanced_auth if spec else False diff --git a/sqlit/db/schema.py b/sqlit/db/schema.py index ada7ebb7..f9cf45d4 100644 --- a/sqlit/db/schema.py +++ b/sqlit/db/schema.py @@ -1,15 +1,16 @@ """Connection schema definitions for database types. -This module provides pure metadata about connection parameters for each -database type, decoupled from UI concerns. The UI layer transforms these -schemas into form widgets. +This module defines UI-facing connection metadata (fields + labels + defaults). +The canonical provider registry is `sqlit.db.providers.PROVIDERS`. """ from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import dataclass from enum import Enum -from typing import Callable + +from ..drivers import SUPPORTED_DRIVERS class FieldType(Enum): @@ -100,6 +101,7 @@ def _password_field() -> SchemaField: name="password", label="Password", field_type=FieldType.PASSWORD, + placeholder="(empty = ask every connect)", group="credentials", ) @@ -188,6 +190,7 @@ def _get_ssh_fields() -> tuple[SchemaField, ...]: name="ssh_password", label="Password", field_type=FieldType.PASSWORD, + placeholder="(empty = ask every connect)", visible_when=_ssh_auth_is_password, tab="ssh", ), @@ -197,20 +200,11 @@ def _get_ssh_fields() -> tuple[SchemaField, ...]: SSH_FIELDS = _get_ssh_fields() -# Schema definitions for each database type - def _get_mssql_driver_options() -> tuple[SelectOption, ...]: - """Get available ODBC driver options for SQL Server.""" - # These are checked at runtime in the UI layer - return ( - SelectOption("ODBC Driver 18 for SQL Server", "ODBC Driver 18 for SQL Server"), - SelectOption("ODBC Driver 17 for SQL Server", "ODBC Driver 17 for SQL Server"), - SelectOption("ODBC Driver 13 for SQL Server", "ODBC Driver 13 for SQL Server"), - ) + return tuple(SelectOption(d, d) for d in SUPPORTED_DRIVERS) def _get_mssql_auth_options() -> tuple[SelectOption, ...]: - """Get authentication type options for SQL Server.""" return ( SelectOption("sql", "SQL Server Authentication"), SelectOption("windows", "Windows Authentication"), @@ -244,7 +238,7 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: label="Driver", field_type=FieldType.SELECT, options=_get_mssql_driver_options(), - default="ODBC Driver 18 for SQL Server", + default=SUPPORTED_DRIVERS[0], advanced=True, ), SchemaField( @@ -265,10 +259,12 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: name="password", label="Password", field_type=FieldType.PASSWORD, + placeholder="(empty = ask every connect)", group="credentials", visible_when=lambda v: v.get("auth_type") in _MSSQL_AUTH_NEEDS_PASSWORD, ), - ) + SSH_FIELDS, + ) + + SSH_FIELDS, has_advanced_auth=True, default_port="1433", ) @@ -282,7 +278,8 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: _database_field(), _username_field(), _password_field(), - ) + SSH_FIELDS, + ) + + SSH_FIELDS, default_port="5432", ) @@ -295,7 +292,8 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: _database_field(), _username_field(), _password_field(), - ) + SSH_FIELDS, + ) + + SSH_FIELDS, default_port="3306", ) @@ -308,7 +306,8 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: _database_field(), _username_field(), _password_field(), - ) + SSH_FIELDS, + ) + + SSH_FIELDS, default_port="3306", ) @@ -332,7 +331,8 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: ), _username_field(), _password_field(), - ) + SSH_FIELDS, + ) + + SSH_FIELDS, default_port="1521", ) @@ -345,16 +345,15 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: _database_field(), _username_field(), _password_field(), - ) + SSH_FIELDS, + ) + + SSH_FIELDS, default_port="26257", ) SQLITE_SCHEMA = ConnectionSchema( db_type="sqlite", display_name="SQLite", - fields=( - _file_path_field("/path/to/database.db"), - ), + fields=(_file_path_field("/path/to/database.db"),), supports_ssh=False, is_file_based=True, ) @@ -362,9 +361,7 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: DUCKDB_SCHEMA = ConnectionSchema( db_type="duckdb", display_name="DuckDB", - fields=( - _file_path_field("/path/to/database.duckdb"), - ), + fields=(_file_path_field("/path/to/database.duckdb"),), supports_ssh=False, is_file_based=True, ) @@ -385,10 +382,40 @@ def _get_mssql_auth_options() -> tuple[SelectOption, ...]: label="Auth Token", field_type=FieldType.PASSWORD, required=False, + placeholder="auth token (optional)", description="Database authentication token, optional for local servers", ), ), - supports_ssh=False, # Turso uses HTTPS, SSH tunneling not applicable + supports_ssh=False, +) + + +D1_SCHEMA = ConnectionSchema( + db_type="d1", + display_name="Cloudflare D1", + fields=( + SchemaField( + name="server", + label="Account ID", + placeholder="Your Cloudflare Account ID", + required=True, + ), + SchemaField( + name="password", + label="API Token", + field_type=FieldType.PASSWORD, + required=True, + placeholder="cloudflare api token", + description="Cloudflare API Token with D1 permissions", + ), + SchemaField( + name="database", + label="Database Name", + placeholder="Your D1 database name", + required=True, + ), + ), + supports_ssh=False, ) @@ -438,53 +465,31 @@ def _get_supabase_region_options() -> tuple[SelectOption, ...]: label="Password", field_type=FieldType.PASSWORD, required=True, + placeholder="database password", ), ), supports_ssh=False, ) -# Schema registry -_SCHEMAS: dict[str, ConnectionSchema] = { - "mssql": MSSQL_SCHEMA, - "postgresql": POSTGRESQL_SCHEMA, - "mysql": MYSQL_SCHEMA, - "mariadb": MARIADB_SCHEMA, - "oracle": ORACLE_SCHEMA, - "cockroachdb": COCKROACHDB_SCHEMA, - "sqlite": SQLITE_SCHEMA, - "duckdb": DUCKDB_SCHEMA, - "turso": TURSO_SCHEMA, - "supabase": SUPABASE_SCHEMA, -} - - def get_connection_schema(db_type: str) -> ConnectionSchema: - """Get the connection schema for a database type. + from .providers import get_connection_schema as _get_connection_schema - Args: - db_type: Database type identifier (e.g., "postgresql", "mysql") - - Returns: - ConnectionSchema for the database type - - Raises: - ValueError: If db_type is not recognized - """ - schema = _SCHEMAS.get(db_type) - if schema is None: - raise ValueError(f"Unknown database type: {db_type}") - return schema + return _get_connection_schema(db_type) def get_all_schemas() -> dict[str, ConnectionSchema]: """Get all registered connection schemas.""" - return dict(_SCHEMAS) + from .providers import get_all_schemas as _get_all_schemas + + return _get_all_schemas() def get_supported_db_types() -> list[str]: """Get list of supported database type identifiers.""" - return list(_SCHEMAS.keys()) + from .providers import get_supported_db_types as _get_supported_db_types + + return _get_supported_db_types() def is_file_based(db_type: str) -> bool: @@ -496,8 +501,9 @@ def is_file_based(db_type: str) -> bool: Returns: True if the database is file-based, False otherwise """ - schema = _SCHEMAS.get(db_type) - return schema.is_file_based if schema else False + from .providers import is_file_based as _is_file_based + + return _is_file_based(db_type) def has_advanced_auth(db_type: str) -> bool: @@ -509,8 +515,9 @@ def has_advanced_auth(db_type: str) -> bool: Returns: True if the database has advanced auth options, False otherwise """ - schema = _SCHEMAS.get(db_type) - return schema.has_advanced_auth if schema else False + from .providers import has_advanced_auth as _has_advanced_auth + + return _has_advanced_auth(db_type) def supports_ssh(db_type: str) -> bool: @@ -522,8 +529,9 @@ def supports_ssh(db_type: str) -> bool: Returns: True if the database supports SSH tunneling, False otherwise """ - schema = _SCHEMAS.get(db_type) - return schema.supports_ssh if schema else False + from .providers import supports_ssh as _supports_ssh + + return _supports_ssh(db_type) def get_default_port(db_type: str) -> str: @@ -535,8 +543,9 @@ def get_default_port(db_type: str) -> str: Returns: Default port string, or "1433" as fallback for unknown types """ - schema = _SCHEMAS.get(db_type) - return schema.default_port if schema and schema.default_port else "1433" + from .providers import get_default_port as _get_default_port + + return _get_default_port(db_type) def get_display_name(db_type: str) -> str: @@ -548,5 +557,6 @@ def get_display_name(db_type: str) -> str: Returns: Display name string, or the db_type itself as fallback """ - schema = _SCHEMAS.get(db_type) - return schema.display_name if schema else db_type + from .providers import get_display_name as _get_display_name + + return _get_display_name(db_type) diff --git a/sqlit/db/tunnel.py b/sqlit/db/tunnel.py index 1006b49c..9cefc6b7 100644 --- a/sqlit/db/tunnel.py +++ b/sqlit/db/tunnel.py @@ -10,7 +10,7 @@ from ..config import ConnectionConfig -def create_ssh_tunnel(config: "ConnectionConfig") -> tuple[Any, str, int]: +def create_ssh_tunnel(config: ConnectionConfig) -> tuple[Any, str, int]: """Create an SSH tunnel for the connection if SSH is enabled. Returns: diff --git a/sqlit/drivers.py b/sqlit/drivers.py index f6f0d7ce..3a95bd92 100644 --- a/sqlit/drivers.py +++ b/sqlit/drivers.py @@ -3,10 +3,9 @@ from __future__ import annotations import platform -from dataclasses import dataclass +from dataclasses import dataclass, field - -# Supported SQL Server ODBC drivers in order of preference +# In order of preference SUPPORTED_DRIVERS = [ "ODBC Driver 18 for SQL Server", "ODBC Driver 17 for SQL Server", @@ -16,6 +15,16 @@ "SQL Server", ] +# Supported OS versions per Microsoft documentation (2024-2025) +SUPPORTED_VERSIONS = { + "ubuntu": ["18.04", "20.04", "22.04", "24.04", "24.10"], + "debian": ["9", "10", "11", "12"], + "rhel": ["7", "8", "9"], + "oracle": ["7", "8", "9"], + "sles": ["12", "15"], + "alpine": ["3.17", "3.18", "3.19", "3.20"], +} + @dataclass class InstallCommand: @@ -24,6 +33,7 @@ class InstallCommand: description: str commands: list[str] requires_sudo: bool = True + warnings: list[str] = field(default_factory=list) def get_installed_drivers() -> list[str]: @@ -33,7 +43,7 @@ def get_installed_drivers() -> list[str]: try: import pyodbc - available = [d for d in pyodbc.drivers()] + available = list(pyodbc.drivers()) for driver in SUPPORTED_DRIVERS: if driver in available: installed.append(driver) @@ -54,7 +64,6 @@ def get_os_info() -> tuple[str, str]: system = platform.system().lower() if system == "linux": - # Try to get distro info try: with open("/etc/os-release") as f: info = {} @@ -75,96 +84,190 @@ def get_os_info() -> tuple[str, str]: return system, "" +def _check_version_support(os_type: str, version: str) -> list[str]: + """Check if the OS version is officially supported and return warnings if not.""" + warnings = [] + supported = SUPPORTED_VERSIONS.get(os_type) + + if supported and version: + # For distros that use major version only + major_version = version.split(".")[0] + if version not in supported and major_version not in supported: + warnings.append( + f"Warning: {os_type} {version} may not be officially supported. " + f"Supported versions: {', '.join(supported)}" + ) + return warnings + + def get_install_commands(driver: str = "ODBC Driver 18 for SQL Server") -> InstallCommand | None: - """Get installation commands for the current OS.""" + """Get installation commands for the current OS. + + Commands are based on Microsoft's official documentation: + https://learn.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server + """ os_type, os_version = get_os_info() + driver_pkg = "msodbcsql18" if "18" in driver else "msodbcsql17" if os_type == "macos": return InstallCommand( description="Install via Homebrew", commands=[ + "brew install unixodbc", "brew tap microsoft/mssql-release https://github.com/Microsoft/homebrew-mssql-release", "brew update", - f"HOMEBREW_ACCEPT_EULA=Y brew install {'msodbcsql18' if '18' in driver else 'msodbcsql17'}", + f"HOMEBREW_ACCEPT_EULA=Y brew install {driver_pkg}", ], requires_sudo=False, ) elif os_type == "ubuntu": - driver_pkg = "msodbcsql18" if "18" in driver else "msodbcsql17" version = os_version or "22.04" + warnings = _check_version_support("ubuntu", version) return InstallCommand( - description="Install on Ubuntu", + description=f"Install on Ubuntu {version}", commands=[ - "curl https://packages.microsoft.com/keys/microsoft.asc | sudo tee /etc/apt/trusted.gpg.d/microsoft.asc", - f"curl https://packages.microsoft.com/config/ubuntu/{version}/prod.list | sudo tee /etc/apt/sources.list.d/mssql-release.list", + f"curl -sSL -O https://packages.microsoft.com/config/ubuntu/{version}/packages-microsoft-prod.deb", + "sudo dpkg -i packages-microsoft-prod.deb", + "rm packages-microsoft-prod.deb", "sudo apt-get update", f"sudo ACCEPT_EULA=Y apt-get install -y {driver_pkg}", ], + warnings=warnings, ) elif os_type == "debian": - driver_pkg = "msodbcsql18" if "18" in driver else "msodbcsql17" - # Debian version mapping: 11=bullseye, 12=bookworm version = os_version.split(".")[0] if os_version else "12" + warnings = _check_version_support("debian", version) return InstallCommand( - description="Install on Debian", + description=f"Install on Debian {version}", commands=[ - "curl https://packages.microsoft.com/keys/microsoft.asc | sudo tee /etc/apt/trusted.gpg.d/microsoft.asc", - f"curl https://packages.microsoft.com/config/debian/{version}/prod.list | sudo tee /etc/apt/sources.list.d/mssql-release.list", + f"curl -sSL -O https://packages.microsoft.com/config/debian/{version}/packages-microsoft-prod.deb", + "sudo dpkg -i packages-microsoft-prod.deb", + "rm packages-microsoft-prod.deb", "sudo apt-get update", f"sudo ACCEPT_EULA=Y apt-get install -y {driver_pkg}", ], + warnings=warnings, ) elif os_type == "fedora": - driver_pkg = "msodbcsql18" if "18" in driver else "msodbcsql17" + # Fedora uses RHEL 9 packages return InstallCommand( - description="Install on Fedora", + description="Install on Fedora (using RHEL 9 packages)", commands=[ - "sudo curl https://packages.microsoft.com/config/rhel/9/prod.repo -o /etc/yum.repos.d/mssql-release.repo", - "sudo dnf remove unixODBC-utf16 unixODBC-utf16-devel", + "curl -sSL -O https://packages.microsoft.com/config/rhel/9/packages-microsoft-prod.rpm", + "sudo rpm -i packages-microsoft-prod.rpm", + "rm packages-microsoft-prod.rpm", + "sudo dnf remove -y unixODBC-utf16 unixODBC-utf16-devel 2>/dev/null || true", f"sudo ACCEPT_EULA=Y dnf install -y {driver_pkg}", ], ) elif os_type in ("rhel", "centos", "rocky", "almalinux"): - driver_pkg = "msodbcsql18" if "18" in driver else "msodbcsql17" version = os_version.split(".")[0] if os_version else "9" + warnings = _check_version_support("rhel", version) + return InstallCommand( + description=f"Install on {os_type.upper()} {version}", + commands=[ + f"curl -sSL -O https://packages.microsoft.com/config/rhel/{version}/packages-microsoft-prod.rpm", + "sudo rpm -i packages-microsoft-prod.rpm", + "rm packages-microsoft-prod.rpm", + "sudo yum remove -y unixODBC-utf16 unixODBC-utf16-devel 2>/dev/null || true", + f"sudo ACCEPT_EULA=Y yum install -y {driver_pkg}", + ], + warnings=warnings, + ) + + elif os_type == "ol": # Oracle Linux + version = os_version.split(".")[0] if os_version else "9" + warnings = _check_version_support("oracle", version) return InstallCommand( - description=f"Install on {os_type.upper()}", + description=f"Install on Oracle Linux {version}", commands=[ - f"sudo curl https://packages.microsoft.com/config/rhel/{version}/prod.repo -o /etc/yum.repos.d/mssql-release.repo", - "sudo yum remove unixODBC-utf16 unixODBC-utf16-devel", + f"curl -sSL -O https://packages.microsoft.com/config/rhel/{version}/packages-microsoft-prod.rpm", + "sudo rpm -i packages-microsoft-prod.rpm", + "rm packages-microsoft-prod.rpm", + "sudo yum remove -y unixODBC-utf16 unixODBC-utf16-devel 2>/dev/null || true", f"sudo ACCEPT_EULA=Y yum install -y {driver_pkg}", ], + warnings=warnings, + ) + + elif os_type in ("sles", "opensuse-leap"): + version = os_version.split(".")[0] if os_version else "15" + warnings = _check_version_support("sles", version) + return InstallCommand( + description=f"Install on SLES/openSUSE {version}", + commands=[ + "sudo rpm --import https://packages.microsoft.com/keys/microsoft.asc", + f"curl -sSL -O https://packages.microsoft.com/config/sles/{version}/packages-microsoft-prod.rpm", + "sudo zypper install -y packages-microsoft-prod.rpm", + "rm packages-microsoft-prod.rpm", + "sudo zypper refresh", + f"sudo ACCEPT_EULA=Y zypper install -y {driver_pkg}", + ], + warnings=warnings, + ) + + elif os_type == "alpine": + version = os_version or "3.20" + warnings = _check_version_support("alpine", version) + # Alpine requires direct package download + arch = "amd64" if platform.machine() == "x86_64" else "arm64" + return InstallCommand( + description=f"Install on Alpine Linux {version}", + commands=[ + f"curl -O https://download.microsoft.com/download/fae28b9a-d880-42fd-9b98-d779f0fdd77f/{driver_pkg}_18.5.1.1-1_{arch}.apk", + f"sudo apk add --allow-untrusted {driver_pkg}_18.5.1.1-1_{arch}.apk", + ], + warnings=warnings + ["Note: Alpine package URLs may change with new driver versions"], ) elif os_type == "arch": return InstallCommand( - description="Install on Arch Linux (AUR) - alternatively use: paru -S msodbcsql", + description="Install on Arch Linux (AUR)", commands=[ "yay -S msodbcsql", ], requires_sudo=False, + warnings=["Alternative AUR helpers: paru -S msodbcsql"], ) elif os_type == "windows": + winget_pkg = "Microsoft.msodbcsql.17" if "17" in driver else "Microsoft.msodbcsql.18" return InstallCommand( - description="Install via winget (or download from https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server)", + description="Install via winget", commands=[ - "winget install Microsoft.msodbcsql.18", + f"winget install {winget_pkg}", ], requires_sudo=False, + warnings=[ + "Alternative: https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server" + ], ) return None -def check_pyodbc_installed() -> bool: - """Check if pyodbc is installed.""" - try: - import pyodbc - return True - except ImportError: - return False +def run_install_in_terminal(driver: str = "ODBC Driver 18 for SQL Server") -> tuple[bool, str]: + """Run driver installation commands in a new terminal window. + + Returns (success, message) tuple. + """ + from .terminal import TerminalType, run_in_terminal + + install_cmd = get_install_commands(driver) + if not install_cmd: + os_type, _ = get_os_info() + return False, f"No installation commands available for {os_type}" + + result = run_in_terminal(install_cmd.commands) + + if not result.success: + if result.terminal == TerminalType.NONE: + cmd_str = " && ".join(install_cmd.commands) + return False, f"No terminal found. Run manually:\n{cmd_str}" + return False, f"Failed to launch terminal: {result.error}" + + return True, "Installation started in new terminal. Restart sqlit when done." diff --git a/sqlit/fields.py b/sqlit/fields.py index bc9a8926..42013304 100644 --- a/sqlit/fields.py +++ b/sqlit/fields.py @@ -7,10 +7,10 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING -# Re-export types from schema for backward compatibility from .db.schema import FieldType, SelectOption if TYPE_CHECKING: @@ -101,7 +101,7 @@ def get_credential_fields() -> list[FieldDefinition]: # Transform functions: convert pure schema metadata to UI field definitions -def schema_field_to_definition(schema_field: "SchemaField") -> FieldDefinition: +def schema_field_to_definition(schema_field: SchemaField) -> FieldDefinition: """Convert a SchemaField (pure metadata) to a FieldDefinition (UI-specific). Args: @@ -136,7 +136,7 @@ def schema_field_to_definition(schema_field: "SchemaField") -> FieldDefinition: ) -def schema_to_field_definitions(schema: "ConnectionSchema") -> list[FieldDefinition]: +def schema_to_field_definitions(schema: ConnectionSchema) -> list[FieldDefinition]: """Convert a ConnectionSchema to a list of FieldDefinitions. Args: diff --git a/sqlit/install_strategy.py b/sqlit/install_strategy.py new file mode 100644 index 00000000..738310fe --- /dev/null +++ b/sqlit/install_strategy.py @@ -0,0 +1,188 @@ +"""Detection for how sqlit-tui should suggest/install optional Python drivers. + +This module intentionally avoids depending on Textual or other app layers so it +can be used from adapters, services, and UI screens. +""" + +from __future__ import annotations + +import importlib.util +import os +import site +import sys +import sysconfig +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class InstallStrategy: + """Represents how to install optional Python dependencies for the running app.""" + + kind: str + can_auto_install: bool + manual_instructions: str + auto_install_command: list[str] | None = None + reason_unavailable: str | None = None + + +def _in_venv() -> bool: + if os.environ.get("VIRTUAL_ENV"): + return True + base_prefix = getattr(sys, "base_prefix", sys.prefix) + return sys.prefix != base_prefix + + +def _is_pipx() -> bool: + pipx_override = os.environ.get("SQLIT_MOCK_PIPX", "").strip().lower() + if pipx_override in {"1", "true", "yes", "pipx"}: + return True + if pipx_override in {"0", "false", "no", "pip", "unknown"}: + return False + + exe = sys.executable.lower() + return "pipx" in exe or "/pipx/venvs/" in exe or "\\pipx\\venvs\\" in exe + + +def _is_unknown_install() -> bool: + """Check if we should mock an unknown installation method (e.g., uvx).""" + return os.environ.get("SQLIT_MOCK_PIPX", "").strip().lower() == "unknown" + + +def _pep668_externally_managed() -> bool: + if _in_venv(): + return False + + candidates: list[str] = [] + for key in ("stdlib", "platstdlib"): + try: + value = sysconfig.get_path(key) + except Exception: + value = None + if value: + candidates.append(value) + + for stdlib_path in candidates: + marker = Path(stdlib_path) / "EXTERNALLY-MANAGED" + if marker.exists(): + return True + return False + + +def _pip_available() -> bool: + return importlib.util.find_spec("pip") is not None + + +def _user_site_enabled() -> bool: + # site.ENABLE_USER_SITE already accounts for PYTHONNOUSERSITE and -s/-S in most cases. + try: + return bool(site.ENABLE_USER_SITE) + except Exception: + return False + + +def _install_paths_writable() -> bool: + try: + paths = sysconfig.get_paths() + except Exception: + return False + + for key in ("purelib", "platlib"): + value = paths.get(key) + if not value: + continue + path = Path(value) + # If the directory doesn't exist, check whether we can create it under its parent. + probe = path if path.exists() else path.parent + if probe.exists() and os.access(probe, os.W_OK): + return True + return False + + +def _format_manual_instructions(package_name: str, reason: str) -> str: + """Format manual installation instructions with rich markup.""" + return ( + f"{reason}\n\n" + f"[bold]Install the driver using your preferred package manager:[/]\n\n" + f" [cyan]pip[/] pip install {package_name}\n" + f" [cyan]pipx[/] pipx inject sqlit-tui {package_name}\n" + f" [cyan]uv[/] uv pip install {package_name}\n" + f" [cyan]uvx[/] uvx --with {package_name} sqlit-tui\n" + f" [cyan]poetry[/] poetry add {package_name}\n" + f" [cyan]pdm[/] pdm add {package_name}\n" + f" [cyan]conda[/] conda install {package_name}" + ) + + +def detect_strategy(*, extra_name: str, package_name: str) -> InstallStrategy: + """Detect the best installation strategy for optional driver dependencies.""" + if _is_unknown_install(): + return InstallStrategy( + kind="unknown", + can_auto_install=False, + manual_instructions=_format_manual_instructions( + package_name, + "Unable to detect how sqlit was installed.", + ), + reason_unavailable="Unable to detect installation method.", + ) + + if _is_pipx(): + cmd = ["pipx", "inject", "sqlit-tui", package_name] + return InstallStrategy( + kind="pipx", + can_auto_install=True, + manual_instructions="pipx inject sqlit-tui " + package_name, + auto_install_command=cmd, + ) + + if _pep668_externally_managed(): + return InstallStrategy( + kind="externally-managed", + can_auto_install=False, + manual_instructions=_format_manual_instructions( + package_name, + "This Python environment is externally managed (PEP 668).", + ), + reason_unavailable="Externally managed Python environment (PEP 668).", + ) + + if not _pip_available(): + return InstallStrategy( + kind="no-pip", + can_auto_install=False, + manual_instructions=_format_manual_instructions( + package_name, + "pip is not available for this Python interpreter.", + ), + reason_unavailable="pip is not available.", + ) + + pip_cmd = [sys.executable, "-m", "pip", "install"] + if _in_venv() or _install_paths_writable(): + cmd = [*pip_cmd, package_name] + return InstallStrategy( + kind="pip", + can_auto_install=True, + manual_instructions=f"{sys.executable} -m pip install {package_name}", + auto_install_command=cmd, + ) + + if _user_site_enabled(): + cmd = [*pip_cmd, "--user", package_name] + return InstallStrategy( + kind="pip-user", + can_auto_install=True, + manual_instructions=f"{sys.executable} -m pip install --user {package_name}", + auto_install_command=cmd, + ) + + return InstallStrategy( + kind="pip-unwritable", + can_auto_install=False, + manual_instructions=_format_manual_instructions( + package_name, + "This Python environment is not writable and user-site installs are disabled.", + ), + reason_unavailable="Python environment not writable and user-site disabled.", + ) diff --git a/sqlit/keymap.py b/sqlit/keymap.py index 3798bcd9..ed31359f 100644 --- a/sqlit/keymap.py +++ b/sqlit/keymap.py @@ -16,8 +16,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING if TYPE_CHECKING: pass @@ -41,6 +41,7 @@ class ActionKeyDef: key: str # The key to press action: str # The action name context: str | None = None # Optional context hint (for documentation) + guard: str | None = None # Guard name (resolved at runtime) class KeymapProvider(ABC): @@ -85,13 +86,9 @@ def get_leader_commands(self) -> list[LeaderCommandDef]: LeaderCommandDef("f", "toggle_fullscreen", "Toggle Maximize", "View"), # Connection LeaderCommandDef("c", "show_connection_picker", "Connect", "Connection"), - LeaderCommandDef( - "x", "disconnect", "Disconnect", "Connection", guard="has_connection" - ), + LeaderCommandDef("x", "disconnect", "Disconnect", "Connection", guard="has_connection"), # Actions - LeaderCommandDef( - "z", "cancel_operation", "Cancel", "Actions", guard="query_executing" - ), + LeaderCommandDef("z", "cancel_operation", "Cancel", "Actions", guard="query_executing"), LeaderCommandDef("t", "change_theme", "Change Theme", "Actions"), LeaderCommandDef("h", "show_help", "Help", "Actions"), LeaderCommandDef("q", "quit", "Quit", "Actions"), @@ -126,13 +123,14 @@ def get_action_keys(self) -> list[ActionKeyDef]: ActionKeyDef("d", "clear_query", "query_normal"), ActionKeyDef("n", "new_query", "query_normal"), ActionKeyDef("h", "show_history", "query_normal"), + ActionKeyDef("y", "copy_context", "query_normal"), # Results ActionKeyDef("v", "view_cell", "results"), - ActionKeyDef("y", "copy_cell", "results"), + ActionKeyDef("y", "copy_context", "results"), ActionKeyDef("Y", "copy_row", "results"), ActionKeyDef("a", "copy_results", "results"), - # Cancel - ActionKeyDef("ctrl+c", "cancel_operation", "global"), + # Cancel (only when query executing) + ActionKeyDef("ctrl+z", "cancel_operation", "global", guard="query_executing"), ] diff --git a/sqlit/mock_settings.py b/sqlit/mock_settings.py new file mode 100644 index 00000000..6d92001f --- /dev/null +++ b/sqlit/mock_settings.py @@ -0,0 +1,275 @@ +"""Helpers for loading mock configuration from settings JSON.""" + +from __future__ import annotations + +import os +from dataclasses import fields +from typing import Any + +from .config import ConnectionConfig +from .db.adapters.base import ColumnInfo +from .mocks import MockDatabaseAdapter, MockProfile, get_mock_profile + + +def apply_mock_environment(settings: dict[str, Any]) -> None: + """Apply environment-based mock settings for driver/install behavior.""" + mock_settings = settings.get("mock") + if not isinstance(mock_settings, dict) or not mock_settings.get("enabled"): + return + + drivers = mock_settings.get("drivers", {}) + if isinstance(drivers, dict): + missing = drivers.get("missing") + if isinstance(missing, list): + os.environ["SQLIT_MOCK_MISSING_DRIVERS"] = ",".join(str(item).strip() for item in missing if str(item).strip()) + elif isinstance(missing, str) and missing.strip(): + os.environ["SQLIT_MOCK_MISSING_DRIVERS"] = missing.strip() + elif missing == []: + os.environ.pop("SQLIT_MOCK_MISSING_DRIVERS", None) + + install_result = str(drivers.get("install_result", "")).strip().lower() + if install_result in {"success", "fail"}: + os.environ["SQLIT_MOCK_INSTALL_RESULT"] = install_result + elif install_result == "real": + os.environ.pop("SQLIT_MOCK_INSTALL_RESULT", None) + + pipx = str(drivers.get("pipx", "")).strip().lower() + if pipx in {"pipx", "pip", "unknown"}: + os.environ["SQLIT_MOCK_PIPX"] = pipx + elif pipx == "auto": + os.environ.pop("SQLIT_MOCK_PIPX", None) + + +def build_mock_profile_from_settings(settings: dict[str, Any]) -> MockProfile | None: + """Build a MockProfile from settings JSON.""" + mock_settings = settings.get("mock") + if not isinstance(mock_settings, dict): + return None + + if not mock_settings.get("enabled"): + return None + + profile_name = str(mock_settings.get("profile") or "settings") + base_profile = get_mock_profile(profile_name) or MockProfile(name=profile_name) + + connections = base_profile.connections + if "connections" in mock_settings: + connections = _parse_connections(mock_settings.get("connections")) + + adapters = dict(base_profile.adapters) + adapters_config = mock_settings.get("adapters") + if isinstance(adapters_config, dict): + for db_type, adapter_config in adapters_config.items(): + if not isinstance(adapter_config, dict): + continue + adapters[str(db_type)] = _build_adapter_from_settings(str(db_type), adapter_config) + + use_default = base_profile.use_default_adapters + if "use_default_adapters" in mock_settings: + use_default = bool(mock_settings.get("use_default_adapters")) + + return MockProfile( + name=profile_name, + connections=connections, + adapters=adapters, + use_default_adapters=use_default, + ) + + +def _parse_connections(raw: Any) -> list[ConnectionConfig]: + if not isinstance(raw, list): + return [] + allowed_fields = {f.name for f in fields(ConnectionConfig)} + connections: list[ConnectionConfig] = [] + for item in raw: + if not isinstance(item, dict): + continue + payload = {k: v for k, v in item.items() if k in allowed_fields} + try: + connections.append(ConnectionConfig(**payload)) + except TypeError: + continue + return connections + + +def _build_adapter_from_settings(db_type: str, config: dict[str, Any]) -> MockDatabaseAdapter: + name = str(config.get("name") or db_type.title()) + default_schema = str(config.get("default_schema") or "") + + connect = config.get("connect") if isinstance(config.get("connect"), dict) else {} + connect_result = str(connect.get("result") or "success") + connect_error = str(connect.get("error_message") or "Connection failed") + required_fields = connect.get("required_fields") + if not isinstance(required_fields, list): + required_fields = [] + required_fields = [str(field) for field in required_fields if str(field).strip()] + allowed = connect.get("allowed") + if not isinstance(allowed, list): + allowed = [] + allowed = [item for item in allowed if isinstance(item, dict)] + auth_error = str(connect.get("auth_error_message") or "Authentication failed") + + tables = _parse_table_list(config.get("tables")) + views = _parse_table_list(config.get("views")) + columns: dict[str, list[ColumnInfo]] = {} + query_results: dict[str, tuple[list[str], list[tuple]]] = {} + + schemas = config.get("schemas") + if isinstance(schemas, dict): + for schema_name, schema_config in schemas.items(): + if not isinstance(schema_config, dict): + continue + schema = str(schema_name) + _ingest_schema(schema, schema_config, tables, views, columns, query_results) + + raw_columns = config.get("columns") + if isinstance(raw_columns, dict): + for key, value in raw_columns.items(): + columns[str(key)] = _parse_columns(value) + + raw_query_results = config.get("query_results") + if isinstance(raw_query_results, dict): + for pattern, result in raw_query_results.items(): + parsed = _parse_query_result(result) + if parsed: + query_results[str(pattern)] = parsed + + default_query_result = None + raw_default = config.get("default_query_result") + if isinstance(raw_default, dict): + parsed_default = _parse_query_result(raw_default) + if parsed_default: + default_query_result = parsed_default + + query_delay = 0.0 + raw_delay = config.get("query_delay") + if isinstance(raw_delay, (int, float)): + query_delay = float(raw_delay) + + return MockDatabaseAdapter( + name=name, + tables=tables, + views=views, + columns=columns, + query_results=query_results, + default_schema=default_schema, + default_query_result=default_query_result, + connect_result=connect_result, + connect_error=connect_error, + required_fields=required_fields, + allowed_connections=allowed, + auth_error=auth_error, + query_delay=query_delay, + ) + + +def _ingest_schema( + schema: str, + schema_config: dict[str, Any], + tables: list[tuple[str, str]], + views: list[tuple[str, str]], + columns: dict[str, list[ColumnInfo]], + query_results: dict[str, tuple[list[str], list[tuple]]], +) -> None: + schema_tables = schema_config.get("tables") + if isinstance(schema_tables, dict): + for table_name, table_config in schema_tables.items(): + if not isinstance(table_config, dict): + continue + table = str(table_name) + tables.append((schema, table)) + cols = _parse_columns(table_config.get("columns")) + if cols: + columns[f"{schema}.{table}"] = cols + rows = _parse_rows(table_config.get("rows")) + if rows and cols: + column_names = [col.name for col in cols] + _add_table_query_results(schema, table, column_names, rows, query_results) + table_query_results = table_config.get("query_results") + if isinstance(table_query_results, dict): + for pattern, result in table_query_results.items(): + parsed = _parse_query_result(result) + if parsed: + query_results[str(pattern)] = parsed + + schema_views = schema_config.get("views") + if isinstance(schema_views, dict): + for view_name, view_config in schema_views.items(): + if not isinstance(view_config, dict): + continue + view = str(view_name) + views.append((schema, view)) + cols = _parse_columns(view_config.get("columns")) + if cols: + columns[f"{schema}.{view}"] = cols + rows = _parse_rows(view_config.get("rows")) + if rows and cols: + column_names = [col.name for col in cols] + _add_table_query_results(schema, view, column_names, rows, query_results) + + +def _parse_table_list(raw: Any) -> list[tuple[str, str]]: + if not isinstance(raw, list): + return [] + tables: list[tuple[str, str]] = [] + for item in raw: + if not isinstance(item, dict): + continue + schema = str(item.get("schema") or "") + name = str(item.get("name") or "") + if name: + tables.append((schema, name)) + return tables + + +def _parse_columns(raw: Any) -> list[ColumnInfo]: + if not isinstance(raw, list): + return [] + columns: list[ColumnInfo] = [] + for item in raw: + if isinstance(item, dict): + name = str(item.get("name") or "") + data_type = str(item.get("type") or "") + if name: + columns.append(ColumnInfo(name=name, data_type=data_type or "text")) + return columns + + +def _parse_rows(raw: Any) -> list[tuple]: + if not isinstance(raw, list): + return [] + rows: list[tuple] = [] + for row in raw: + if isinstance(row, list): + rows.append(tuple(row)) + elif isinstance(row, tuple): + rows.append(row) + return rows + + +def _parse_query_result(raw: Any) -> tuple[list[str], list[tuple]] | None: + if not isinstance(raw, dict): + return None + columns = raw.get("columns") + rows = raw.get("rows") + if not isinstance(columns, list) or not isinstance(rows, list): + return None + column_names = [str(col) for col in columns] + return column_names, _parse_rows(rows) + + +def _add_table_query_results( + schema: str, + table: str, + columns: list[str], + rows: list[tuple], + query_results: dict[str, tuple[list[str], list[tuple]]], +) -> None: + patterns = [ + f'"{schema}"."{table}"', + f"{schema}.{table}", + f'"{table}"', + table, + ] + for pattern in patterns: + query_results.setdefault(pattern, (columns, rows)) diff --git a/sqlit/mocks.py b/sqlit/mocks.py index 39a86612..866eb01b 100644 --- a/sqlit/mocks.py +++ b/sqlit/mocks.py @@ -4,10 +4,13 @@ sqlit --mock=sqlite-demo # Pre-configured SQLite with demo data sqlit --mock=empty # Empty connections, but mock adapters available sqlit --mock=multi-db # Multiple database connections + sqlit --mock=driver-install-success --mock-missing-drivers=postgresql --mock-install=success + sqlit --mock=driver-install-fail --mock-missing-drivers=mysql --mock-install=fail """ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -18,13 +21,13 @@ class MockConnection: """Mock database connection object.""" - def __init__(self): + def __init__(self) -> None: self.closed = False - def close(self): + def close(self) -> None: self.closed = True - def cursor(self): + def cursor(self) -> MockCursor: return MockCursor() @@ -61,6 +64,12 @@ def __init__( query_results: dict[str, tuple[list[str], list[tuple]]] | None = None, default_schema: str = "", default_query_result: tuple[list[str], list[tuple]] | None = None, + connect_result: str = "success", + connect_error: str = "Connection failed", + required_fields: list[str] | None = None, + allowed_connections: list[dict[str, Any]] | None = None, + auth_error: str = "Authentication failed", + query_delay: float = 0.0, ): self._name = name self._tables = tables or [] @@ -72,6 +81,21 @@ def __init__( ["id", "name"], [(1, "Sample Row 1"), (2, "Sample Row 2")], ) + self._connect_result = (connect_result or "success").strip().lower() + self._connect_error = connect_error or "Connection failed" + self._required_fields = required_fields or [] + self._allowed_connections = allowed_connections or [] + self._auth_error = auth_error or "Authentication failed" + # Use provided delay or fall back to environment variable + if query_delay > 0: + self._query_delay = query_delay + else: + import os + env_delay = os.environ.get("SQLIT_MOCK_QUERY_DELAY", "") + try: + self._query_delay = float(env_delay) if env_delay else 0.0 + except ValueError: + self._query_delay = 0.0 @property def name(self) -> str: @@ -90,6 +114,18 @@ def supports_stored_procedures(self) -> bool: return False def connect(self, config: ConnectionConfig) -> Any: + if self._connect_result not in {"success", "ok", "pass"}: + raise Exception(self._connect_error) + + if self._required_fields: + missing = [field for field in self._required_fields if not getattr(config, field, None)] + if missing: + message = self._connect_error or f"Missing required fields: {', '.join(missing)}" + raise Exception(message) + + if self._allowed_connections: + if not any(_matches_connection_rule(config, rule) for rule in self._allowed_connections): + raise Exception(self._auth_error) return MockConnection() def get_databases(self, conn: Any) -> list[str]: @@ -104,6 +140,8 @@ def get_views(self, conn: Any, database: str | None = None) -> list[tuple[str, s def get_columns( self, conn: Any, table: str, database: str | None = None, schema: str | None = None ) -> list[ColumnInfo]: + if schema: + return self._columns.get(f"{schema}.{table}", self._columns.get(table, [])) return self._columns.get(table, []) def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: @@ -112,17 +150,18 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: def quote_identifier(self, name: str) -> str: return f'"{name}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None, schema: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str: if schema: return f'SELECT * FROM "{schema}"."{table}" LIMIT {limit}' return f'SELECT * FROM "{table}" LIMIT {limit}' - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a query and return (columns, rows, truncated).""" + import time + + if self._query_delay > 0: + time.sleep(self._query_delay) + query_lower = query.lower().strip() # Check for specific query results (case-insensitive pattern matching) @@ -147,6 +186,7 @@ def execute_non_query(self, conn: Any, query: str) -> int: # Default Mock Adapters - used when profiles don't define their own # ============================================================================= + def create_default_sqlite_adapter() -> MockDatabaseAdapter: """Create a default SQLite mock adapter with demo data.""" return MockDatabaseAdapter( @@ -262,7 +302,7 @@ def create_default_mysql_adapter() -> MockDatabaseAdapter: # Registry of default adapters by database type -DEFAULT_MOCK_ADAPTERS: dict[str, callable] = { +DEFAULT_MOCK_ADAPTERS: dict[str, Callable[[], MockDatabaseAdapter]] = { "sqlite": create_default_sqlite_adapter, "postgresql": create_default_postgresql_adapter, "mysql": create_default_mysql_adapter, @@ -278,10 +318,18 @@ def get_default_mock_adapter(db_type: str) -> MockDatabaseAdapter: return MockDatabaseAdapter(name=f"Mock{db_type.title()}") +def _matches_connection_rule(config: ConnectionConfig, rule: dict[str, Any]) -> bool: + for key, value in rule.items(): + if getattr(config, key, None) != value: + return False + return True + + # ============================================================================= # Mock Profiles # ============================================================================= + @dataclass class MockProfile: """A mock profile containing connections and adapter configuration.""" @@ -362,11 +410,53 @@ def _create_multi_db_profile() -> MockProfile: ) +def _create_driver_install_success_profile() -> MockProfile: + """Profile intended for demoing the driver install UX.""" + connections = [ + ConnectionConfig( + name="PostgreSQL (missing driver)", + db_type="postgresql", + server="localhost", + port="5432", + database="postgres", + username="user", + ), + ] + return MockProfile( + name="driver-install-success", + connections=connections, + adapters={}, + use_default_adapters=True, + ) + + +def _create_driver_install_fail_profile() -> MockProfile: + """Profile intended for demoing the driver install failure UX.""" + connections = [ + ConnectionConfig( + name="MySQL (missing driver)", + db_type="mysql", + server="localhost", + port="3306", + database="test_sqlit", + username="user", + ), + ] + return MockProfile( + name="driver-install-fail", + connections=connections, + adapters={}, + use_default_adapters=True, + ) + + # Registry of available mock profiles -MOCK_PROFILES: dict[str, callable] = { +MOCK_PROFILES: dict[str, Callable[[], MockProfile]] = { "sqlite-demo": _create_sqlite_demo_profile, "empty": _create_empty_profile, "multi-db": _create_multi_db_profile, + "driver-install-success": _create_driver_install_success_profile, + "driver-install-fail": _create_driver_install_fail_profile, } diff --git a/sqlit/screens.py b/sqlit/screens.py deleted file mode 100644 index 1734423f..00000000 --- a/sqlit/screens.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Modal screens for sqlit. - -This module re-exports from sqlit.ui.screens for backward compatibility. -New code should import directly from sqlit.ui.screens. -""" - -# Re-export everything from the new location for backward compatibility -from .ui.screens import ( - ConfirmScreen, - ConnectionScreen, - DriverSetupScreen, - HelpScreen, - QueryHistoryScreen, - ValueViewScreen, -) - -__all__ = [ - "ConfirmScreen", - "ConnectionScreen", - "DriverSetupScreen", - "HelpScreen", - "QueryHistoryScreen", - "ValueViewScreen", -] diff --git a/sqlit/services/__init__.py b/sqlit/services/__init__.py index 5c710cc0..7a308c59 100644 --- a/sqlit/services/__init__.py +++ b/sqlit/services/__init__.py @@ -7,6 +7,7 @@ - QueryService: Unified query execution with history tracking - ConnectionSession: Connection lifecycle management with cleanup guarantees - DatabaseExecutor: Serialized database operation execution +- CredentialsService: Secure credential storage using OS keyring Protocols: - AdapterProtocol: Interface for database adapters @@ -16,6 +17,14 @@ """ from .cancellable import CancellableQuery +from .credentials import ( + CredentialsService, + KeyringCredentialsService, + PlaintextCredentialsService, + get_credentials_service, + reset_credentials_service, + set_credentials_service, +) from .executor import DatabaseExecutor from .protocols import ( AdapterFactoryProtocol, @@ -40,6 +49,13 @@ "DatabaseExecutor", # Cancellable query "CancellableQuery", + # Credentials service + "CredentialsService", + "KeyringCredentialsService", + "PlaintextCredentialsService", + "get_credentials_service", + "set_credentials_service", + "reset_credentials_service", # Protocols "AdapterProtocol", "AdapterFactoryProtocol", diff --git a/sqlit/services/cancellable.py b/sqlit/services/cancellable.py index e02b00fa..ef698f9e 100644 --- a/sqlit/services/cancellable.py +++ b/sqlit/services/cancellable.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from ..config import ConnectionConfig from ..db import DatabaseAdapter - from .query import QueryResult, NonQueryResult + from .query import NonQueryResult, QueryResult @dataclass @@ -44,8 +44,8 @@ class CancellableQuery: """ sql: str - config: "ConnectionConfig" - adapter: "DatabaseAdapter" + config: ConnectionConfig + adapter: DatabaseAdapter def __post_init__(self) -> None: """Initialize internal state.""" @@ -58,7 +58,7 @@ def __post_init__(self) -> None: def execute( self, max_rows: int | None = None, - ) -> "QueryResult | NonQueryResult": + ) -> QueryResult | NonQueryResult: """Execute the query on a dedicated connection. Creates a new connection, executes the query, and returns the result. @@ -102,9 +102,7 @@ def execute( # Execute query using adapter methods if is_select_query(self.sql): - columns, rows, truncated = self.adapter.execute_query( - self._connection, self.sql, max_rows - ) + columns, rows, truncated = self.adapter.execute_query(self._connection, self.sql, max_rows) return QueryResult( columns=columns, rows=rows, @@ -113,9 +111,7 @@ def execute( ) else: # Non-SELECT query - rows_affected = self.adapter.execute_non_query( - self._connection, self.sql - ) + rows_affected = self.adapter.execute_non_query(self._connection, self.sql) return NonQueryResult(rows_affected=rows_affected) finally: diff --git a/sqlit/services/credentials.py b/sqlit/services/credentials.py new file mode 100644 index 00000000..a9418478 --- /dev/null +++ b/sqlit/services/credentials.py @@ -0,0 +1,369 @@ +"""Credentials service for secure password storage. + +This module provides an abstraction for storing and retrieving credentials +securely. The default implementation uses the OS keyring (macOS Keychain, +Windows Credential Locker, Linux Secret Service). A plaintext fallback +is provided for environments without keyring support (with user consent), +and an in-memory fallback is provided for testing. +""" + +from __future__ import annotations + +import secrets +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from ..stores.base import CONFIG_DIR, JSONFileStore + +if TYPE_CHECKING: + pass + +# Service name used for keyring storage +KEYRING_SERVICE_NAME = "sqlit" + +# Settings key controlling whether plaintext credential storage is allowed. +ALLOW_PLAINTEXT_CREDENTIALS_SETTING = "allow_plaintext_credentials" + + +def is_keyring_usable() -> bool: + """Return True if a usable keyring backend appears to be available.""" + try: + import keyring + except ImportError: + return False + + try: + backend = keyring.get_keyring() + module_name = getattr(backend, "__module__", "") or "" + priority = getattr(backend, "priority", None) + if "keyring.backends.fail" in module_name: + return False + if isinstance(priority, (int, float)) and priority <= 0: + return False + + # Minimal probe: read-only call to surface obvious misconfiguration. + keyring.get_password(KEYRING_SERVICE_NAME, f"probe:{secrets.token_hex(8)}") + return True + except Exception: + return False + +class CredentialsService(ABC): + """Abstract base class for credential storage services.""" + + @abstractmethod + def get_password(self, connection_name: str) -> str | None: + """Retrieve the database password for a connection. + + Args: + connection_name: The unique name of the connection. + + Returns: + The password string, or None if not found. + """ + ... + + @abstractmethod + def set_password(self, connection_name: str, password: str) -> None: + """Store the database password for a connection. + + Args: + connection_name: The unique name of the connection. + password: The password to store. + """ + ... + + @abstractmethod + def delete_password(self, connection_name: str) -> None: + """Delete the database password for a connection. + + Args: + connection_name: The unique name of the connection. + """ + ... + + @abstractmethod + def get_ssh_password(self, connection_name: str) -> str | None: + """Retrieve the SSH password for a connection. + + Args: + connection_name: The unique name of the connection. + + Returns: + The SSH password string, or None if not found. + """ + ... + + @abstractmethod + def set_ssh_password(self, connection_name: str, password: str) -> None: + """Store the SSH password for a connection. + + Args: + connection_name: The unique name of the connection. + password: The SSH password to store. + """ + ... + + @abstractmethod + def delete_ssh_password(self, connection_name: str) -> None: + """Delete the SSH password for a connection. + + Args: + connection_name: The unique name of the connection. + """ + ... + + def rename_connection(self, old_name: str, new_name: str) -> None: + """Rename credentials when a connection is renamed. + + Args: + old_name: The old connection name. + new_name: The new connection name. + """ + # Get existing credentials + db_password = self.get_password(old_name) + ssh_password = self.get_ssh_password(old_name) + + # Store under new name + if db_password: + self.set_password(new_name, db_password) + if ssh_password: + self.set_ssh_password(new_name, ssh_password) + + # Delete old credentials + self.delete_password(old_name) + self.delete_ssh_password(old_name) + + def delete_all_for_connection(self, connection_name: str) -> None: + """Delete all credentials for a connection. + + Args: + connection_name: The unique name of the connection. + """ + self.delete_password(connection_name) + self.delete_ssh_password(connection_name) + + +class KeyringCredentialsService(CredentialsService): + """Credentials service using OS keyring for secure storage. + + This implementation uses the `keyring` library to store passwords + in the OS-provided secure storage: + - macOS: Keychain + - Windows: Credential Locker + - Linux: Secret Service (GNOME Keyring, KDE Wallet, etc.) + + The keyring module is lazy-loaded to avoid import overhead when + not needed. + """ + + def __init__(self) -> None: + self._keyring = None + + def _get_keyring(self): + if self._keyring is None: + import keyring + + self._keyring = keyring + return self._keyring + + def _make_key(self, connection_name: str, key_type: str) -> str: + """Create a unique key for storage. + + Args: + connection_name: The connection name. + key_type: Type of credential ('db' or 'ssh'). + + Returns: + A unique key string. + """ + return f"{connection_name}:{key_type}" + + def get_password(self, connection_name: str) -> str | None: + try: + keyring = self._get_keyring() + key = self._make_key(connection_name, "db") + return keyring.get_password(KEYRING_SERVICE_NAME, key) + except Exception: + return None + + def set_password(self, connection_name: str, password: str) -> None: + if password is None: + self.delete_password(connection_name) + return + try: + keyring = self._get_keyring() + key = self._make_key(connection_name, "db") + keyring.set_password(KEYRING_SERVICE_NAME, key, password) + except Exception: + pass + + def delete_password(self, connection_name: str) -> None: + try: + keyring = self._get_keyring() + key = self._make_key(connection_name, "db") + keyring.delete_password(KEYRING_SERVICE_NAME, key) + except Exception: + pass + + def get_ssh_password(self, connection_name: str) -> str | None: + try: + keyring = self._get_keyring() + key = self._make_key(connection_name, "ssh") + return keyring.get_password(KEYRING_SERVICE_NAME, key) + except Exception: + return None + + def set_ssh_password(self, connection_name: str, password: str) -> None: + if password is None: + self.delete_ssh_password(connection_name) + return + try: + keyring = self._get_keyring() + key = self._make_key(connection_name, "ssh") + keyring.set_password(KEYRING_SERVICE_NAME, key, password) + except Exception: + pass + + def delete_ssh_password(self, connection_name: str) -> None: + try: + keyring = self._get_keyring() + key = self._make_key(connection_name, "ssh") + keyring.delete_password(KEYRING_SERVICE_NAME, key) + except Exception: + pass + + +class PlaintextCredentialsService(CredentialsService): + """Credentials service storing passwords in memory (for testing). + + WARNING: This implementation stores passwords in memory only. + It does NOT persist passwords. Use only for testing. + """ + + def __init__(self) -> None: + self._passwords: dict[str, str] = {} + self._ssh_passwords: dict[str, str] = {} + + def get_password(self, connection_name: str) -> str | None: + return self._passwords.get(connection_name) + + def set_password(self, connection_name: str, password: str) -> None: + if password is not None: + self._passwords[connection_name] = password + else: + self.delete_password(connection_name) + + def delete_password(self, connection_name: str) -> None: + self._passwords.pop(connection_name, None) + + def get_ssh_password(self, connection_name: str) -> str | None: + return self._ssh_passwords.get(connection_name) + + def set_ssh_password(self, connection_name: str, password: str) -> None: + if password is not None: + self._ssh_passwords[connection_name] = password + else: + self.delete_ssh_password(connection_name) + + def delete_ssh_password(self, connection_name: str) -> None: + self._ssh_passwords.pop(connection_name, None) + + +class PlaintextFileCredentialsService(CredentialsService): + """Credentials service storing passwords in a local file. + + WARNING: This stores secrets in plaintext on disk. The credentials file is + created under the config dir with restrictive permissions (0700/0600). + """ + + def __init__(self) -> None: + self._store = JSONFileStore(CONFIG_DIR / "credentials.json") + + def _read_all(self) -> dict[str, str]: + data = self._store._read_json() + return data if isinstance(data, dict) else {} + + def _write_all(self, data: dict[str, str]) -> None: + self._store._write_json(data) + + def _key(self, connection_name: str, kind: str) -> str: + return f"{connection_name}:{kind}" + + def get_password(self, connection_name: str) -> str | None: + return self._read_all().get(self._key(connection_name, "db")) + + def set_password(self, connection_name: str, password: str) -> None: + if password is None: + self.delete_password(connection_name) + return + data = self._read_all() + data[self._key(connection_name, "db")] = password + self._write_all(data) + + def delete_password(self, connection_name: str) -> None: + data = self._read_all() + data.pop(self._key(connection_name, "db"), None) + self._write_all(data) + + def get_ssh_password(self, connection_name: str) -> str | None: + return self._read_all().get(self._key(connection_name, "ssh")) + + def set_ssh_password(self, connection_name: str, password: str) -> None: + if password is None: + self.delete_ssh_password(connection_name) + return + data = self._read_all() + data[self._key(connection_name, "ssh")] = password + self._write_all(data) + + def delete_ssh_password(self, connection_name: str) -> None: + data = self._read_all() + data.pop(self._key(connection_name, "ssh"), None) + self._write_all(data) + + +_credentials_service: CredentialsService | None = None + + +def get_credentials_service() -> CredentialsService: + """Get the global credentials service instance. + + Returns the keyring-based service by default. If keyring isn't usable, + falls back to a plaintext file store if user consent is recorded in + settings; otherwise falls back to an in-memory store (not persisted). + + Returns: + The credentials service instance. + """ + global _credentials_service + if _credentials_service is None: + if is_keyring_usable(): + _credentials_service = KeyringCredentialsService() + else: + from ..stores.settings import load_settings + + settings = load_settings() + allow_plaintext = bool(settings.get(ALLOW_PLAINTEXT_CREDENTIALS_SETTING)) + _credentials_service = PlaintextFileCredentialsService() if allow_plaintext else PlaintextCredentialsService() + return _credentials_service + + +def set_credentials_service(service: CredentialsService | None) -> None: + """Set the global credentials service instance. + + This is primarily useful for testing to inject a mock service. + + Args: + service: The credentials service to use, or None to reset. + """ + global _credentials_service + _credentials_service = service + + +def reset_credentials_service() -> None: + """Reset the credentials service to default. + + Useful for testing to ensure a clean state. + """ + global _credentials_service + _credentials_service = None diff --git a/sqlit/services/executor.py b/sqlit/services/executor.py index 6a9d2f4e..22bbb3dc 100644 --- a/sqlit/services/executor.py +++ b/sqlit/services/executor.py @@ -9,8 +9,9 @@ import asyncio import threading +from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from .session import ConnectionSession @@ -38,7 +39,7 @@ class DatabaseExecutor: session: The ConnectionSession this executor is bound to. """ - def __init__(self, session: "ConnectionSession"): + def __init__(self, session: ConnectionSession): """Initialize the executor. Args: @@ -54,7 +55,7 @@ def __init__(self, session: "ConnectionSession"): self._shutdown = False @property - def session(self) -> "ConnectionSession": + def session(self) -> ConnectionSession: """Get the session this executor is bound to.""" return self._session diff --git a/sqlit/services/installer.py b/sqlit/services/installer.py new file mode 100644 index 00000000..c4cc489f --- /dev/null +++ b/sqlit/services/installer.py @@ -0,0 +1,182 @@ +"""Service for handling automatic package installation.""" + +from __future__ import annotations + +import os +import subprocess +import sys +import threading +import time +from collections.abc import Callable +from typing import Any, Protocol + +from ..db.exceptions import MissingDriverError +from ..install_strategy import detect_strategy + + +class InstallerApp(Protocol): + def push_screen(self, screen: Any, callback: Any = None, wait_for_dismiss: bool = False) -> Any: ... + def pop_screen(self) -> Any: ... + def call_from_thread(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... + def notify(self, message: str, *, severity: str = "information", timeout: float | int | None = None) -> Any: ... + def restart(self) -> Any: ... + + +class Installer: + """Manages the automatic installation of missing drivers.""" + + def __init__(self, app: InstallerApp): + self.app = app + self._active_process: subprocess.Popen[str] | None = None + + def install(self, error: MissingDriverError) -> None: + """Push a loading screen and run installation in a background thread.""" + from ..ui.screens.loading import LoadingScreen + + cancel_event = threading.Event() + self.app.push_screen( + LoadingScreen( + f"Installing {error.driver_name}... (Esc to cancel)", + on_cancel=cancel_event.set, + ) + ) + + def worker() -> None: + result = self._do_install(error, cancel_event) + self.app.call_from_thread(self._on_install_complete, result) + + threading.Thread(target=worker, daemon=True).start() + + def install_in_background( + self, + error: MissingDriverError, + *, + on_complete: Callable[[bool, str, MissingDriverError], None], + ) -> None: + """Run installation in a background thread and report completion on the main thread.""" + + def worker() -> None: + # Reuse the same implementation, but without the modal LoadingScreen. + result = self._do_install(error, threading.Event()) + success, output, err = result + self.app.call_from_thread(on_complete, success, output, err) + + threading.Thread(target=worker, daemon=True).start() + + def _do_install( + self, error: MissingDriverError, cancel_event: threading.Event + ) -> tuple[bool, str, MissingDriverError]: + """ + Synchronous method to be run in a worker thread. + Determines the command and executes it. + """ + mock_install = os.environ.get("SQLIT_MOCK_INSTALL_RESULT", "").strip().lower() + if mock_install in {"success", "ok", "pass"}: + return True, "Mocked success (SQLIT_MOCK_INSTALL_RESULT=success)", error + if mock_install in {"fail", "error"}: + return False, "Mocked failure (SQLIT_MOCK_INSTALL_RESULT=fail)", error + + if os.environ.get("SQLIT_INSTALL_FORCE_FAIL") == "1": + return False, "Forced failure (SQLIT_INSTALL_FORCE_FAIL=1)", error + + strategy = detect_strategy(extra_name=error.extra_name, package_name=error.package_name) + if not strategy.can_auto_install or not strategy.auto_install_command: + reason = strategy.reason_unavailable or "Automatic installation is not available." + return False, f"{reason}\n\n{strategy.manual_instructions}".strip(), error + + command = strategy.auto_install_command + cwd: str | None = None + + if cancel_event.is_set(): + return False, "Installation cancelled by user.", error + + try: + self._active_process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=cwd, + ) + + stdout = "" + stderr = "" + + while True: + if cancel_event.is_set(): + try: + self._active_process.terminate() + self._active_process.wait(timeout=5) + except Exception: + try: + self._active_process.kill() + self._active_process.wait(timeout=5) + except Exception: + pass + try: + stdout, stderr = self._active_process.communicate(timeout=1) + except Exception: + pass + return False, "Installation cancelled by user.", error + + try: + stdout, stderr = self._active_process.communicate(timeout=0.1) + break + except subprocess.TimeoutExpired: + time.sleep(0.05) + continue + + rc = self._active_process.returncode + if rc == 0: + return True, stdout, error + return False, stderr or stdout, error + except FileNotFoundError as e: + return False, str(e), error + finally: + self._active_process = None + + def _on_install_complete(self, result: tuple[bool, str, MissingDriverError]) -> None: + """ + Callback executed on the main thread after installation attempt. + """ + from textual.css.stylesheet import StylesheetParseError + + from ..ui.screens.message import MessageScreen + + success, output, error = result + self.app.pop_screen() # Pop the LoadingScreen + + try: + if success: + restart = getattr(self.app, "restart", None) + self.app.push_screen( + MessageScreen( + "Driver installed", + f"{error.driver_name} installed successfully. Please restart to apply.", + enter_label="Restart", + on_enter=restart if callable(restart) else None, + ) + ) + else: + # Keep the manual instructions in the underlying setup screen. + self.app.push_screen( + MessageScreen( + "Couldn't install automatically", + "Couldn't install automatically, please install manually.", + ) + ) + except StylesheetParseError as e: + # Fallback: avoid crashing the app if the stylesheet can’t be reparsed after install. + try: + details = str(e.args[0]) + except Exception: + details = str(e) + print(f"StylesheetParseError while showing install result:\n{details}", file=sys.stderr) + try: + self.app.notify( + "Installation completed, but UI failed to render result. Please restart sqlit-tui.", + severity="warning", + timeout=10, + ) + except Exception: + pass diff --git a/sqlit/services/protocols.py b/sqlit/services/protocols.py index 4c60ac18..a69e6d49 100644 --- a/sqlit/services/protocols.py +++ b/sqlit/services/protocols.py @@ -20,7 +20,7 @@ class AdapterProtocol(Protocol): to execute queries against a database. """ - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: """Connect to the database. Args: @@ -31,9 +31,7 @@ def connect(self, config: "ConnectionConfig") -> Any: """ ... - def execute_query( - self, conn: Any, query: str, max_rows: int | None = None - ) -> tuple[list[str], list[tuple], bool]: + def execute_query(self, conn: Any, query: str, max_rows: int | None = None) -> tuple[list[str], list[tuple], bool]: """Execute a SELECT-type query. Args: @@ -116,7 +114,7 @@ class TunnelFactoryProtocol(Protocol): SSH tunnels for database connections. """ - def __call__(self, config: "ConnectionConfig") -> tuple[Any, str, int]: + def __call__(self, config: ConnectionConfig) -> tuple[Any, str, int]: """Create an SSH tunnel if enabled in config. Args: @@ -137,7 +135,7 @@ class ConnectionStoreProtocol(Protocol): database connection configurations. """ - def load(self) -> list["ConnectionConfig"]: + def load(self) -> list[ConnectionConfig]: """Load all saved connections. Returns: @@ -145,7 +143,7 @@ def load(self) -> list["ConnectionConfig"]: """ ... - def save(self, connections: list["ConnectionConfig"]) -> None: + def save(self, connections: list[ConnectionConfig]) -> None: """Save connections. Args: diff --git a/sqlit/services/query.py b/sqlit/services/query.py index f9ee1811..afaf66bb 100644 --- a/sqlit/services/query.py +++ b/sqlit/services/query.py @@ -14,9 +14,7 @@ from .protocols import AdapterProtocol, HistoryStoreProtocol # Query types that return result sets (SELECT-like queries) -SELECT_KEYWORDS = frozenset( - ["SELECT", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "PRAGMA"] -) +SELECT_KEYWORDS = frozenset(["SELECT", "WITH", "SHOW", "DESCRIBE", "EXPLAIN", "PRAGMA"]) def is_select_query(query: str) -> bool: @@ -60,7 +58,7 @@ class QueryService: If not provided, uses the default HistoryStore singleton. """ - def __init__(self, history_store: "HistoryStoreProtocol | None" = None): + def __init__(self, history_store: HistoryStoreProtocol | None = None): """Initialize the query service. Args: @@ -71,9 +69,9 @@ def __init__(self, history_store: "HistoryStoreProtocol | None" = None): def execute( self, connection: Any, - adapter: "AdapterProtocol", + adapter: AdapterProtocol, query: str, - config: "ConnectionConfig | None" = None, + config: ConnectionConfig | None = None, max_rows: int | None = None, save_to_history: bool = True, ) -> QueryResult | NonQueryResult: @@ -93,6 +91,7 @@ def execute( Raises: Any exceptions raised by the underlying database driver. """ + result: QueryResult | NonQueryResult if is_select_query(query): columns, rows, truncated = adapter.execute_query(connection, query, max_rows) result = QueryResult( diff --git a/sqlit/services/session.py b/sqlit/services/session.py index aafdafaa..b32b387a 100644 --- a/sqlit/services/session.py +++ b/sqlit/services/session.py @@ -6,8 +6,9 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import replace -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ..config import ConnectionConfig @@ -40,8 +41,8 @@ class ConnectionSession: def __init__( self, connection: Any, - adapter: "DatabaseAdapter", - config: "ConnectionConfig", + adapter: DatabaseAdapter, + config: ConnectionConfig, tunnel: Any | None = None, ): """Initialize a connection session. @@ -57,16 +58,15 @@ def __init__( self._config = config self._tunnel = tunnel self._closed = False - self._executor: "DatabaseExecutor | None" = None + self._executor: DatabaseExecutor | None = None @classmethod def create( cls, - config: "ConnectionConfig", - adapter_factory: Callable[[str], "DatabaseAdapter"] | None = None, - tunnel_factory: Callable[["ConnectionConfig"], tuple[Any, str, int]] - | None = None, - ) -> "ConnectionSession": + config: ConnectionConfig, + adapter_factory: Callable[[str], DatabaseAdapter] | None = None, + tunnel_factory: Callable[[ConnectionConfig], tuple[Any, str, int]] | None = None, + ) -> ConnectionSession: """Create a new connection session. This factory method handles SSH tunnel creation (if enabled) and @@ -112,12 +112,12 @@ def connection(self) -> Any: return self._connection @property - def adapter(self) -> "DatabaseAdapter": + def adapter(self) -> DatabaseAdapter: """Get the database adapter.""" return self._adapter @property - def config(self) -> "ConnectionConfig": + def config(self) -> ConnectionConfig: """Get the connection configuration.""" return self._config @@ -137,7 +137,7 @@ def is_closed(self) -> bool: return self._closed @property - def executor(self) -> "DatabaseExecutor": + def executor(self) -> DatabaseExecutor: """Get or create the database executor for serialized operations. The executor is lazily created on first access. All database operations @@ -197,7 +197,7 @@ def close(self) -> None: self._closed = True - def __enter__(self) -> "ConnectionSession": + def __enter__(self) -> ConnectionSession: """Enter the context manager.""" return self diff --git a/sqlit/state_machine.py b/sqlit/state_machine.py index d8b9d926..148adaf5 100644 --- a/sqlit/state_machine.py +++ b/sqlit/state_machine.py @@ -11,21 +11,24 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING -from .ui.tree_nodes import ConnectionNode, DatabaseNode, FolderNode, SchemaNode, TableNode, ViewNode +from .ui.tree_nodes import ( + ConnectionNode, + DatabaseNode, + FolderNode, + SchemaNode, + TableNode, + ViewNode, +) if TYPE_CHECKING: from .app import SSMSTUI - -# ============================================================ -# Leader Commands Definition -# ============================================================ - -# Guard functions for leader commands (resolved by name from keymap) +# Guards are referenced by name from the keymap. LEADER_GUARDS: dict[str, Callable[[SSMSTUI], bool]] = { "has_connection": lambda app: app.current_connection is not None, "query_executing": lambda app: getattr(app, "_query_executing", False), @@ -86,14 +89,11 @@ def get_leader_binding_actions() -> set[str]: return {cmd.binding_action for cmd in get_leader_commands()} -def get_leader_bindings(): +def get_leader_bindings() -> tuple: """Generate Textual Bindings from leader commands.""" from textual.binding import Binding - return tuple( - Binding(cmd.key, cmd.binding_action, show=False) - for cmd in get_leader_commands() - ) + return tuple(Binding(cmd.key, cmd.binding_action, show=False) for cmd in get_leader_commands()) class ActionResult(Enum): @@ -127,10 +127,8 @@ class ActionSpec: """Specification for an action.""" guard: Callable[[SSMSTUI], bool] | None = None - # Optional display info - if provided, action shows in footer display_key: str | None = None display_label: str | None = None - # Optional help text - if provided, action shows in help help_key: str | None = None help_description: str | None = None @@ -162,16 +160,13 @@ def get_help_entry(self, category: str) -> HelpEntry | None: class State(ABC): """Base class for hierarchical states.""" - # Override in subclasses to set the help category for this state's actions help_category: str | None = None def __init__(self, parent: State | None = None): self.parent = parent self._actions: dict[str, ActionSpec] = {} self._forbidden: set[str] = set() - # Bindings to display (in order) when this state is active self._display_order: list[str] = [] - # Right-side bindings (like leader key) self._right_bindings: list[str] = [] self._setup_actions() @@ -231,26 +226,21 @@ def forbids(self, *action_names: str) -> None: def check_action(self, app: SSMSTUI, action_name: str) -> ActionResult: """Check if action is allowed in this state or ancestors.""" - # Explicit forbid takes precedence if action_name in self._forbidden: return ActionResult.FORBIDDEN - # Check if this state handles the action if action_name in self._actions: spec = self._actions[action_name] if spec.is_allowed(app): return ActionResult.ALLOWED return ActionResult.FORBIDDEN - # Delegate to parent state if self.parent: return self.parent.check_action(app, action_name) return ActionResult.UNHANDLED - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: """Get bindings to display in footer (left, right). Returns bindings from this state and ancestors, with this state's @@ -260,7 +250,6 @@ def get_display_bindings( right: list[DisplayBinding] = [] seen: set[str] = set() - # Collect from this state first for action_name in self._display_order: if action_name in seen: continue @@ -281,7 +270,6 @@ def get_display_bindings( right.append(binding) seen.add(action_name) - # Collect from parent (but don't duplicate) if self.parent: parent_left, parent_right = self.parent.get_display_bindings(app) for binding in parent_left: @@ -301,18 +289,12 @@ def is_active(self, app: SSMSTUI) -> bool: pass -# ============================================================ -# Root State -# ============================================================ - - class RootState(State): """Root state - minimal actions available everywhere.""" help_category = "General" def _setup_actions(self) -> None: - # Actions available everywhere self.allows("quit", help="Quit", help_key="^q") self.allows("show_help", help="Show this help", help_key="?") self.allows("leader_key", help="Commands menu", help_key="") @@ -321,11 +303,6 @@ def is_active(self, app: SSMSTUI) -> bool: return True -# ============================================================ -# Modal Active State -# ============================================================ - - class ModalActiveState(State): """State when a modal screen is active. @@ -335,20 +312,14 @@ class ModalActiveState(State): """ def _setup_actions(self) -> None: - # Modal screens handle their own bindings pass def check_action(self, app: SSMSTUI, action_name: str) -> ActionResult: - # Let critical actions through if action_name in ("quit",): return ActionResult.ALLOWED - # Block everything else - modal handles its own bindings return ActionResult.FORBIDDEN - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: - # Modal screens provide their own footer/UI + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: return [], [] def is_active(self, app: SSMSTUI) -> bool: @@ -357,40 +328,50 @@ def is_active(self, app: SSMSTUI) -> bool: return any(isinstance(screen, ModalScreen) for screen in app.screen_stack[1:]) -# ============================================================ -# Main Screen State (no modal active) -# ============================================================ - - class MainScreenState(State): """Base state for main screen (no modal active).""" help_category = "Navigation" def _setup_actions(self) -> None: - # Navigation (shown in Navigation category) self.allows("focus_explorer", help="Focus Explorer", help_key="e") self.allows("focus_query", help="Focus Query", help_key="q") self.allows("focus_results", help="Focus Results", help_key="r") self.allows("toggle_fullscreen", help="Toggle fullscreen", help_key="f") - # General actions (not shown in Navigation, will be in General) self.allows("show_help") self.allows("change_theme") - self.allows("cancel_operation") # ctrl+c to cancel running operations - # Leader key shown on right side self.allows("leader_key", key="", label="Commands", right=True) def is_active(self, app: SSMSTUI) -> bool: from textual.screen import ModalScreen - return not any( - isinstance(screen, ModalScreen) for screen in app.screen_stack[1:] - ) + if any(isinstance(screen, ModalScreen) for screen in app.screen_stack[1:]): + return False + # Defer to QueryExecutingState if a query is running + return not getattr(app, "_query_executing", False) -# ============================================================ -# Leader Pending State -# ============================================================ +class QueryExecutingState(State): + """State when a query is being executed.""" + + help_category = "Query" + + def _setup_actions(self) -> None: + self.allows("cancel_operation", key="^z", label="Cancel", help="Cancel query") + self.allows("quit") + + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + left: list[DisplayBinding] = [ + DisplayBinding(key="^z", label="Cancel", action="cancel_operation"), + ] + return left, [] + + def is_active(self, app: SSMSTUI) -> bool: + from textual.screen import ModalScreen + + if any(isinstance(screen, ModalScreen) for screen in app.screen_stack[1:]): + return False + return getattr(app, "_query_executing", False) class LeaderPendingState(State): @@ -400,12 +381,9 @@ class LeaderPendingState(State): """ def _setup_actions(self) -> None: - # Leader actions are checked dynamically in check_action - # because the keymap may be swapped for testing pass def check_action(self, app: SSMSTUI, action_name: str) -> ActionResult: - # Check if this is a leader binding action (leader_quit, leader_toggle_explorer, etc.) leader_binding_actions = get_leader_binding_actions() if action_name in leader_binding_actions: leader_commands = get_leader_commands() @@ -414,26 +392,51 @@ def check_action(self, app: SSMSTUI, action_name: str) -> ActionResult: return ActionResult.ALLOWED return ActionResult.FORBIDDEN - # leader_key passes through during pending (to show menu) if action_name == "leader_key": return ActionResult.ALLOWED return ActionResult.FORBIDDEN - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: - # During leader pending, we show a minimal indicator - # The actual menu will appear via LeaderMenuScreen + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: return [], [DisplayBinding(key="...", label="Waiting", action="leader_pending")] def is_active(self, app: SSMSTUI) -> bool: return getattr(app, "_leader_pending", False) -# ============================================================ -# Tree States -# ============================================================ +class TreeFilterActiveState(State): + """State when tree filter is active.""" + + help_category = "Explorer" + + def _setup_actions(self) -> None: + self.allows("tree_filter_close", help="Close filter", help_key="esc") + self.allows("tree_filter_accept", help="Select item", help_key="enter") + self.allows("tree_filter_next", help="Next match", help_key="n/j") + self.allows("tree_filter_prev", help="Previous match", help_key="N/k") + self.allows("quit") + self.forbids( + "focus_explorer", + "focus_query", + "focus_results", + "leader_key", + "new_connection", + "tree_filter", + ) + + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + left: list[DisplayBinding] = [ + DisplayBinding(key="esc", label="Close", action="tree_filter_close"), + DisplayBinding(key="enter", label="Select", action="tree_filter_accept"), + DisplayBinding(key="n/N", label="Next/Prev", action="tree_filter_next"), + ] + return left, [] + + def is_active(self, app: SSMSTUI) -> bool: + return ( + app.object_tree.has_focus + and getattr(app, "_tree_filter_visible", False) + ) class TreeFocusedState(State): @@ -445,9 +448,15 @@ def _setup_actions(self) -> None: self.allows("new_connection", key="n", label="New", help="New connection") self.allows("refresh_tree", key="f", label="Refresh", help="Refresh tree", help_key="R/f") self.allows("collapse_tree", help="Collapse all", help_key="z") + self.allows("tree_cursor_down") # vim j + self.allows("tree_cursor_up") # vim k + self.allows("tree_filter", help="Filter tree", help_key="/") def is_active(self, app: SSMSTUI) -> bool: - return app.object_tree.has_focus + if not app.object_tree.has_focus: + return False + # Defer to TreeFilterActiveState if filter is visible + return not getattr(app, "_tree_filter_visible", False) class TreeOnConnectionState(State): @@ -463,34 +472,27 @@ def can_connect(app: SSMSTUI) -> bool: config = node.data.config if not app.current_connection: return True - return ( - config - and app.current_config - and config.name != app.current_config.name - ) + return bool(config and app.current_config and config.name != app.current_config.name) def is_connected_to_this(app: SSMSTUI) -> bool: node = app.object_tree.cursor_node if not node or not isinstance(node.data, ConnectionNode): return False config = node.data.config - return ( + return bool( app.current_connection is not None and config and app.current_config and config.name == app.current_config.name ) - # Show connect or disconnect based on state self.allows("connect_selected", can_connect, key="enter", label="Connect", help="Connect/Expand/Columns") self.allows("disconnect", is_connected_to_this, key="x", label="Disconnect", help="Disconnect") self.allows("edit_connection", key="e", label="Edit", help="Edit connection") self.allows("delete_connection", key="d", label="Delete", help="Delete connection") self.allows("duplicate_connection", key="D", label="Duplicate", help="Duplicate connection") - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: """Custom display logic for connection node.""" left: list[DisplayBinding] = [] seen: set[str] = set() @@ -504,7 +506,6 @@ def get_display_bindings( and config.name == app.current_config.name ) - # Show either Connect or Disconnect, not both if is_connected: left.append(DisplayBinding(key="x", label="Disconnect", action="disconnect")) seen.add("disconnect") @@ -550,9 +551,7 @@ class TreeOnTableState(State): def _setup_actions(self) -> None: self.allows("select_table", key="s", label="Select TOP 100", help="Select TOP 100 (table/view)") - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] seen: set[str] = set() @@ -577,7 +576,7 @@ def is_active(self, app: SSMSTUI) -> bool: if not app.object_tree.has_focus: return False node = app.object_tree.cursor_node - return node is not None and isinstance(node.data, (TableNode, ViewNode)) + return node is not None and isinstance(node.data, TableNode | ViewNode) class TreeOnFolderState(State): @@ -586,9 +585,7 @@ class TreeOnFolderState(State): def _setup_actions(self) -> None: pass # Just inherits from parent - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] seen: set[str] = set() @@ -611,12 +608,7 @@ def is_active(self, app: SSMSTUI) -> bool: if not app.object_tree.has_focus: return False node = app.object_tree.cursor_node - return node is not None and isinstance(node.data, (FolderNode, DatabaseNode, SchemaNode)) - - -# ============================================================ -# Query States -# ============================================================ + return node is not None and isinstance(node.data, FolderNode | DatabaseNode | SchemaNode) class QueryFocusedState(State): @@ -635,23 +627,14 @@ class QueryNormalModeState(State): help_category = "Query Editor (Normal)" def _setup_actions(self) -> None: - from .widgets import VimMode - self.allows("enter_insert_mode", key="i", label="Insert Mode", help="Enter INSERT mode") self.allows("execute_query", key="enter", label="Execute", help="Execute query") self.allows("clear_query", key="d", label="Clear", help="Clear query") self.allows("new_query", key="n", label="New", help="New query (clear all)") - self.allows( - "show_history", - lambda app: app.current_config is not None, - key="h", - label="History", - help="Query history", - ) + self.allows("copy_context", key="y", label="Copy query", help="Copy current query") + self.allows("show_history", key="h", label="History", help="Query history") - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] seen: set[str] = set() @@ -660,8 +643,10 @@ def get_display_bindings( left.append(DisplayBinding(key="enter", label="Execute", action="execute_query")) seen.add("execute_query") - if app.current_config is not None: - left.append(DisplayBinding(key="h", label="History", action="show_history")) + left.append(DisplayBinding(key="y", label="Copy query", action="copy_context")) + seen.add("copy_context") + + left.append(DisplayBinding(key="h", label="History", action="show_history")) seen.add("show_history") left.append(DisplayBinding(key="d", label="Clear", action="clear_query")) @@ -692,7 +677,7 @@ class QueryInsertModeState(State): def _setup_actions(self) -> None: self.allows("exit_insert_mode", key="esc", label="Normal Mode", help="Exit to NORMAL mode") - self.allows("execute_query_insert", key="f5", label="Execute", help="Execute query (stay INSERT)") + self.allows("execute_query_insert", key="f5 | ^enter", label="Execute", help="Execute query (stay INSERT)") self.allows("autocomplete_accept", help="Accept autocomplete", help_key="tab") self.allows("quit") self.forbids( @@ -703,12 +688,10 @@ def _setup_actions(self) -> None: "show_help", ) - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [ DisplayBinding(key="esc", label="Normal Mode", action="exit_insert_mode"), - DisplayBinding(key="f5", label="Execute", action="execute_query_insert"), + DisplayBinding(key="f5 | ^enter", label="Execute", action="execute_query_insert"), DisplayBinding(key="tab", label="Autocomplete", action="autocomplete_accept"), ] return left, [] @@ -716,12 +699,49 @@ def get_display_bindings( def is_active(self, app: SSMSTUI) -> bool: from .widgets import VimMode - return app.query_input.has_focus and app.vim_mode == VimMode.INSERT + if not app.query_input.has_focus or app.vim_mode != VimMode.INSERT: + return False + # Defer to AutocompleteActiveState if autocomplete is visible + return not getattr(app, "_autocomplete_visible", False) -# ============================================================ -# Results States -# ============================================================ +class AutocompleteActiveState(State): + """Query editor with autocomplete dropdown visible.""" + + help_category = "Query Editor (Insert)" + + def _setup_actions(self) -> None: + self.allows("autocomplete_next", help="Next suggestion", help_key="^j") + self.allows("autocomplete_prev", help="Previous suggestion", help_key="^k") + self.allows("autocomplete_accept", help="Accept autocomplete", help_key="tab") + self.allows("autocomplete_close", help="Close autocomplete", help_key="esc") + self.allows("execute_query_insert") + self.allows("quit") + self.forbids( + "exit_insert_mode", # Escape closes autocomplete, not exits insert mode + "focus_explorer", + "focus_results", + "leader_key", + "new_connection", + "show_help", + ) + + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + left: list[DisplayBinding] = [ + DisplayBinding(key="tab", label="Accept", action="autocomplete_accept"), + DisplayBinding(key="^j/^k", label="Next/Prev", action="autocomplete_next"), + DisplayBinding(key="esc", label="Close", action="autocomplete_close"), + ] + return left, [] + + def is_active(self, app: SSMSTUI) -> bool: + from .widgets import VimMode + + return ( + app.query_input.has_focus + and app.vim_mode == VimMode.INSERT + and getattr(app, "_autocomplete_visible", False) + ) class ResultsFocusedState(State): @@ -731,13 +751,17 @@ class ResultsFocusedState(State): def _setup_actions(self) -> None: self.allows("view_cell", key="v", label="View cell", help="View selected cell") - self.allows("copy_cell", key="y", label="Copy cell", help="Copy selected cell") + self.allows("edit_cell", key="u", label="Update cell", help="Update cell (generate UPDATE)") + self.allows("copy_context", key="y", label="Copy cell", help="Copy selected cell") self.allows("copy_row", key="Y", label="Copy row", help="Copy selected row") self.allows("copy_results", key="a", label="Copy all", help="Copy all results") + self.allows("clear_results", key="x", label="Clear", help="Clear results") + self.allows("results_cursor_left") # vim h + self.allows("results_cursor_down") # vim j + self.allows("results_cursor_up") # vim k + self.allows("results_cursor_right") # vim l - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: left: list[DisplayBinding] = [] seen: set[str] = set() @@ -745,14 +769,16 @@ def get_display_bindings( if is_error: left.append(DisplayBinding(key="v", label="View error", action="view_cell")) - left.append(DisplayBinding(key="y", label="Copy error", action="copy_cell")) + left.append(DisplayBinding(key="y", label="Copy error", action="copy_context")) else: left.append(DisplayBinding(key="v", label="View cell", action="view_cell")) - left.append(DisplayBinding(key="y", label="Copy cell", action="copy_cell")) + left.append(DisplayBinding(key="u", label="Update", action="edit_cell")) + left.append(DisplayBinding(key="y", label="Copy cell", action="copy_context")) left.append(DisplayBinding(key="Y", label="Copy row", action="copy_row")) left.append(DisplayBinding(key="a", label="Copy all", action="copy_results")) + left.append(DisplayBinding(key="x", label="Clear", action="clear_results")) - seen.update(["view_cell", "copy_cell", "copy_row", "copy_results"]) + seen.update(["view_cell", "copy_context", "copy_row", "copy_results", "clear_results"]) right: list[DisplayBinding] = [] if self.parent: @@ -765,59 +791,54 @@ def get_display_bindings( return left, right def is_active(self, app: SSMSTUI) -> bool: - return app.results_table.has_focus - - -# ============================================================ -# State Machine -# ============================================================ + try: + return app.results_table.has_focus + except Exception: + # Results table may not exist yet (Lazy loading) + return False class UIStateMachine: """Hierarchical state machine for UI action validation and binding display.""" - def __init__(self): + def __init__(self) -> None: self.root = RootState() - # Modal state (highest priority, blocks everything) self.modal_active = ModalActiveState(parent=self.root) - # Main screen state (parent of all non-modal states) + self.query_executing = QueryExecutingState(parent=self.root) + self.main_screen = MainScreenState(parent=self.root) - # Leader pending (high priority within main screen) self.leader_pending = LeaderPendingState(parent=self.main_screen) - # Tree hierarchy self.tree_focused = TreeFocusedState(parent=self.main_screen) + self.tree_filter_active = TreeFilterActiveState(parent=self.main_screen) self.tree_on_connection = TreeOnConnectionState(parent=self.tree_focused) self.tree_on_table = TreeOnTableState(parent=self.tree_focused) self.tree_on_folder = TreeOnFolderState(parent=self.tree_focused) - # Query hierarchy self.query_focused = QueryFocusedState(parent=self.main_screen) self.query_normal = QueryNormalModeState(parent=self.query_focused) self.query_insert = QueryInsertModeState(parent=self.query_focused) + self.autocomplete_active = AutocompleteActiveState(parent=self.query_focused) - # Results self.results_focused = ResultsFocusedState(parent=self.main_screen) - # Priority order: most specific states first self._states = [ - self.modal_active, # Highest: blocks when modal open - self.leader_pending, # High: blocks during leader combo - # Tree substates before tree parent + self.modal_active, + self.query_executing, # Before main_screen (more specific when query running) + self.leader_pending, + self.tree_filter_active, # Before tree_focused (more specific when filter active) self.tree_on_connection, self.tree_on_table, self.tree_on_folder, self.tree_focused, - # Query substates before query parent + self.autocomplete_active, # Before query_insert (more specific) self.query_insert, self.query_normal, self.query_focused, - # Results self.results_focused, - # Fallbacks self.main_screen, self.root, ] @@ -833,13 +854,9 @@ def check_action(self, app: SSMSTUI, action_name: str) -> bool: """Check if action is allowed in current state.""" state = self.get_active_state(app) result = state.check_action(app, action_name) - # Only explicitly ALLOWED actions are permitted - # UNHANDLED and FORBIDDEN both block the action return result == ActionResult.ALLOWED - def get_display_bindings( - self, app: SSMSTUI - ) -> tuple[list[DisplayBinding], list[DisplayBinding]]: + def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]: """Get bindings to display in footer for current state.""" state = self.get_active_state(app) return state.get_display_bindings(app) @@ -851,7 +868,6 @@ def get_active_state_name(self, app: SSMSTUI) -> str: def generate_help_text(self) -> str: """Generate help text from all states' help entries.""" - # Collect help entries from all states entries_by_category: dict[str, list[HelpEntry]] = {} for state in self._states: @@ -863,8 +879,7 @@ def generate_help_text(self) -> str: entries_by_category[entry.category].append(entry) entries_by_category["Commands ()"] = [ - HelpEntry(cmd.key, cmd.label, "Commands ()") - for cmd in get_leader_commands() + HelpEntry(f"+{cmd.key}", cmd.label, "Commands ()") for cmd in get_leader_commands() ] category_order = [ @@ -887,7 +902,7 @@ def generate_help_text(self) -> str: lines.append(f"[bold]{category}:[/]") for entry in entries: - key_display = self._format_key_for_help(entry.key).ljust(10) + key_display = self._format_key_for_help(entry.key).ljust(16) lines.append(f" {key_display} {entry.description}") lines.append("") @@ -897,9 +912,33 @@ def generate_help_text(self) -> str: def _format_key_for_help(key: str) -> str: """Format a key for help display, wrapping special keys in angle brackets.""" special_keys = { - "enter", "space", "esc", "escape", "tab", "delete", "backspace", - "up", "down", "left", "right", "home", "end", "pageup", "pagedown", - "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12", + "enter", + "space", + "esc", + "escape", + "tab", + "delete", + "backspace", + "up", + "down", + "left", + "right", + "home", + "end", + "pageup", + "pagedown", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", } if key.lower() in special_keys: diff --git a/sqlit/stores/base.py b/sqlit/stores/base.py index 224b6c1a..6f87021b 100644 --- a/sqlit/stores/base.py +++ b/sqlit/stores/base.py @@ -45,7 +45,7 @@ def _read_json(self) -> Any: if not self._file_path.exists(): return None try: - with open(self._file_path, "r", encoding="utf-8") as f: + with open(self._file_path, encoding="utf-8") as f: return json.load(f) except (json.JSONDecodeError, TypeError): return None diff --git a/sqlit/stores/connections.py b/sqlit/stores/connections.py index c0f9e25a..67fb2ad5 100644 --- a/sqlit/stores/connections.py +++ b/sqlit/stores/connections.py @@ -8,29 +8,49 @@ if TYPE_CHECKING: from ..config import ConnectionConfig + from ..services.credentials import CredentialsService class ConnectionStore(JSONFileStore): """Store for managing saved database connections. - Connections are stored as a JSON array in ~/.sqlit/connections.json + Connections are stored as a JSON array in ~/.sqlit/connections.json. + Passwords are stored separately in the OS keyring via CredentialsService. """ - _instance: "ConnectionStore | None" = None + _instance: ConnectionStore | None = None - def __init__(self): + def __init__(self, credentials_service: CredentialsService | None = None) -> None: super().__init__(CONFIG_DIR / "connections.json") + self._credentials_service = credentials_service + + @property + def credentials_service(self) -> CredentialsService: + """Get the credentials service (lazy-loaded).""" + if self._credentials_service is None: + from ..services.credentials import get_credentials_service + + return get_credentials_service() + return self._credentials_service @classmethod - def get_instance(cls) -> "ConnectionStore": + def get_instance(cls) -> ConnectionStore: """Get the singleton instance.""" if cls._instance is None: cls._instance = cls() return cls._instance - def load_all(self) -> list["ConnectionConfig"]: + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton instance (useful for testing).""" + cls._instance = None + + def load_all(self, load_credentials: bool = True) -> list[ConnectionConfig]: """Load all saved connections. + Connections are loaded from JSON, and passwords are retrieved + from the credentials service (OS keyring). + Returns: List of ConnectionConfig objects, or empty list if none exist. """ @@ -40,19 +60,89 @@ def load_all(self) -> list["ConnectionConfig"]: if data is None: return [] try: - return [ConnectionConfig(**conn) for conn in data] + migrated = [] + for conn in data: + if isinstance(conn, dict) and "host" in conn and "server" not in conn: + conn = {**conn, "server": conn.get("host", "")} + conn.pop("host", None) + migrated.append(conn) + + configs = [] + for conn in migrated: + config = ConnectionConfig(**conn) + if load_credentials: + # Retrieve passwords from credentials service + self._load_credentials(config) + configs.append(config) + return configs except (TypeError, KeyError): return [] - def save_all(self, connections: list["ConnectionConfig"]) -> None: + def _load_credentials(self, config: ConnectionConfig) -> None: + """Load credentials from the credentials service into config. + + Args: + config: ConnectionConfig to populate with credentials. + """ + if config.password is None: + password = self.credentials_service.get_password(config.name) + if password is not None: + config.password = password + + if config.ssh_password is None: + ssh_password = self.credentials_service.get_ssh_password(config.name) + if ssh_password is not None: + config.ssh_password = ssh_password + + def _save_credentials(self, config: ConnectionConfig) -> None: + """Save credentials from config to the credentials service. + + Args: + config: ConnectionConfig containing credentials to save. + + Note: Empty string "" is a valid password (e.g., CockroachDB insecure mode). + Only None means "delete/no password stored". + """ + if config.password is not None: + self.credentials_service.set_password(config.name, config.password) + else: + self.credentials_service.delete_password(config.name) + + if config.ssh_password is not None: + self.credentials_service.set_ssh_password(config.name, config.ssh_password) + else: + self.credentials_service.delete_ssh_password(config.name) + + def _config_to_dict_without_passwords(self, config: ConnectionConfig) -> dict: + """Convert config to dict without password fields. + + Args: + config: ConnectionConfig to convert. + + Returns: + Dict representation with password fields set to None. + None indicates "load from credentials service on next load". + """ + data = vars(config).copy() + data["password"] = None + data["ssh_password"] = None + return data + + def save_all(self, connections: list[ConnectionConfig]) -> None: """Save all connections. + Passwords are stored in the credentials service (OS keyring), + not in the JSON file. + Args: connections: List of ConnectionConfig objects to save. """ - self._write_json([vars(c) for c in connections]) + for config in connections: + self._save_credentials(config) + + self._write_json([self._config_to_dict_without_passwords(c) for c in connections]) - def get_by_name(self, name: str) -> "ConnectionConfig | None": + def get_by_name(self, name: str) -> ConnectionConfig | None: """Get a connection by name. Args: @@ -66,7 +156,7 @@ def get_by_name(self, name: str) -> "ConnectionConfig | None": return conn return None - def add(self, connection: "ConnectionConfig") -> None: + def add(self, connection: ConnectionConfig) -> None: """Add a new connection. Args: @@ -81,7 +171,7 @@ def add(self, connection: "ConnectionConfig") -> None: connections.append(connection) self.save_all(connections) - def update(self, connection: "ConnectionConfig") -> None: + def update(self, connection: ConnectionConfig) -> None: """Update an existing connection. Args: @@ -101,6 +191,8 @@ def update(self, connection: "ConnectionConfig") -> None: def delete(self, name: str) -> bool: """Delete a connection by name. + Also deletes associated credentials from the keyring. + Args: name: Connection name to delete. @@ -111,6 +203,8 @@ def delete(self, name: str) -> bool: original_count = len(connections) connections = [c for c in connections if c.name != name] if len(connections) < original_count: + # Delete credentials from keyring + self.credentials_service.delete_all_for_connection(name) self.save_all(connections) return True return False @@ -124,15 +218,14 @@ def list_names(self) -> list[str]: return [c.name for c in self.load_all()] -# Module-level convenience functions for backward compatibility _store = ConnectionStore() -def load_connections() -> list["ConnectionConfig"]: +def load_connections(load_credentials: bool = True) -> list[ConnectionConfig]: """Load saved connections from config file.""" - return _store.load_all() + return _store.load_all(load_credentials=load_credentials) -def save_connections(connections: list["ConnectionConfig"]) -> None: +def save_connections(connections: list[ConnectionConfig]) -> None: """Save connections to config file.""" _store.save_all(connections) diff --git a/sqlit/stores/history.py b/sqlit/stores/history.py index faef7f44..8fb12694 100644 --- a/sqlit/stores/history.py +++ b/sqlit/stores/history.py @@ -25,7 +25,7 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, data: dict) -> "QueryHistoryEntry": + def from_dict(cls, data: dict) -> QueryHistoryEntry: """Create from dictionary.""" return cls( query=data["query"], @@ -42,13 +42,13 @@ class HistoryStore(JSONFileStore): """ MAX_ENTRIES_PER_CONNECTION = 100 - _instance: "HistoryStore | None" = None + _instance: HistoryStore | None = None - def __init__(self): + def __init__(self) -> None: super().__init__(CONFIG_DIR / "query_history.json") @classmethod - def get_instance(cls) -> "HistoryStore": + def get_instance(cls) -> HistoryStore: """Get the singleton instance.""" if cls._instance is None: cls._instance = cls() @@ -96,8 +96,7 @@ def save_query(self, connection_name: str, query: str) -> None: # Check if query already exists for entry in all_entries: - if (entry.get("connection_name") == connection_name - and entry.get("query", "").strip() == query_stripped): + if entry.get("connection_name") == connection_name and entry.get("query", "").strip() == query_stripped: entry["timestamp"] = now break else: @@ -110,15 +109,11 @@ def save_query(self, connection_name: str, query: str) -> None: all_entries.append(new_entry.to_dict()) # Limit entries per connection - connection_entries = [ - e for e in all_entries if e.get("connection_name") == connection_name - ] - other_entries = [ - e for e in all_entries if e.get("connection_name") != connection_name - ] + connection_entries = [e for e in all_entries if e.get("connection_name") == connection_name] + other_entries = [e for e in all_entries if e.get("connection_name") != connection_name] connection_entries.sort(key=lambda e: e.get("timestamp", ""), reverse=True) - connection_entries = connection_entries[:self.MAX_ENTRIES_PER_CONNECTION] + connection_entries = connection_entries[: self.MAX_ENTRIES_PER_CONNECTION] self._write_json(other_entries + connection_entries) @@ -136,9 +131,9 @@ def delete_entry(self, connection_name: str, timestamp: str) -> bool: original_count = len(all_entries) all_entries = [ - e for e in all_entries - if not (e.get("timestamp") == timestamp - and e.get("connection_name") == connection_name) + e + for e in all_entries + if not (e.get("timestamp") == timestamp and e.get("connection_name") == connection_name) ] if len(all_entries) < original_count: @@ -158,10 +153,7 @@ def clear_for_connection(self, connection_name: str) -> int: all_entries = self._load_all_entries() original_count = len(all_entries) - all_entries = [ - e for e in all_entries - if e.get("connection_name") != connection_name - ] + all_entries = [e for e in all_entries if e.get("connection_name") != connection_name] deleted = original_count - len(all_entries) if deleted > 0: diff --git a/sqlit/stores/settings.py b/sqlit/stores/settings.py index 8e86f5d2..bb9217c8 100644 --- a/sqlit/stores/settings.py +++ b/sqlit/stores/settings.py @@ -2,28 +2,35 @@ from __future__ import annotations +import os +from pathlib import Path from typing import Any from .base import CONFIG_DIR, JSONFileStore +def _resolve_settings_path() -> Path: + override = os.environ.get("SQLIT_SETTINGS_PATH", "").strip() + if override: + return Path(override).expanduser() + return CONFIG_DIR / "settings.json" + + class SettingsStore(JSONFileStore): """Store for managing application settings. Settings are stored as a JSON object in ~/.sqlit/settings.json """ - _instance: "SettingsStore | None" = None + _instance: SettingsStore | None = None - def __init__(self): - super().__init__(CONFIG_DIR / "settings.json") + def __init__(self, file_path: Path | None = None) -> None: + super().__init__(file_path or _resolve_settings_path()) @classmethod - def get_instance(cls) -> "SettingsStore": + def get_instance(cls) -> SettingsStore: """Get the singleton instance.""" - if cls._instance is None: - cls._instance = cls() - return cls._instance + return _get_store() def load_all(self) -> dict[str, Any]: """Load all settings. @@ -83,14 +90,24 @@ def delete(self, key: str) -> bool: # Module-level convenience functions for backward compatibility -_store = SettingsStore() +_store: SettingsStore | None = None +_store_path: Path | None = None + + +def _get_store() -> SettingsStore: + global _store, _store_path + path = _resolve_settings_path() + if _store is None or _store_path != path: + _store = SettingsStore(file_path=path) + _store_path = path + return _store def load_settings() -> dict: """Load app settings from config file.""" - return _store.load_all() + return _get_store().load_all() def save_settings(settings: dict) -> None: """Save app settings to config file.""" - _store.save_all(settings) + _get_store().save_all(settings) diff --git a/sqlit/terminal.py b/sqlit/terminal.py new file mode 100644 index 00000000..a6365e3f --- /dev/null +++ b/sqlit/terminal.py @@ -0,0 +1,76 @@ +"""Terminal detection and command execution utilities.""" + +from __future__ import annotations + +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass +from enum import Enum + + +class TerminalType(Enum): + GNOME = "gnome-terminal" + KONSOLE = "konsole" + XTERM = "xterm" + MACOS = "macos" + WINDOWS = "windows" + NONE = "none" + + +@dataclass +class TerminalResult: + success: bool + terminal: TerminalType + error: str | None = None + + +def detect_terminal() -> TerminalType: + """Detect available terminal emulator.""" + if shutil.which("gnome-terminal"): + return TerminalType.GNOME + if shutil.which("konsole"): + return TerminalType.KONSOLE + if shutil.which("xterm"): + return TerminalType.XTERM + if shutil.which("open") and os.uname().sysname == "Darwin": + return TerminalType.MACOS + if os.name == "nt": + return TerminalType.WINDOWS + return TerminalType.NONE + + +def run_in_terminal(commands: list[str], wait_message: str = "Press Enter to close...") -> TerminalResult: + """Run commands in a new terminal window. + + Returns TerminalResult indicating success/failure and which terminal was used. + """ + terminal = detect_terminal() + full_command = " && ".join(commands) + suffix = f'echo ""; echo "{wait_message}"; read' + + try: + if terminal == TerminalType.GNOME: + subprocess.Popen(["gnome-terminal", "--", "bash", "-c", f"{full_command}; {suffix}"]) + elif terminal == TerminalType.KONSOLE: + subprocess.Popen(["konsole", "-e", "bash", "-c", f"{full_command}; {suffix}"]) + elif terminal == TerminalType.XTERM: + subprocess.Popen(["xterm", "-e", "bash", "-c", f"{full_command}; {suffix}"]) + elif terminal == TerminalType.MACOS: + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: + f.write("#!/bin/bash\n") + f.write(full_command + "\n") + f.write(f'echo ""\necho "{wait_message}"\nread\n') + script_path = f.name + os.chmod(script_path, 0o755) + subprocess.Popen(["open", "-a", "Terminal", script_path]) + elif terminal == TerminalType.WINDOWS: + subprocess.Popen(["cmd", "/c", "start", "cmd", "/k", full_command], shell=True) + else: + return TerminalResult(success=False, terminal=terminal, error="No terminal emulator found") + + return TerminalResult(success=True, terminal=terminal) + + except Exception as e: + return TerminalResult(success=False, terminal=terminal, error=str(e)) diff --git a/sqlit/ui/__init__.py b/sqlit/ui/__init__.py index 5d5caefa..e3ebc938 100644 --- a/sqlit/ui/__init__.py +++ b/sqlit/ui/__init__.py @@ -1,13 +1,7 @@ """UI components for sqlit.""" -from .screens import ( - ConfirmScreen, - ConnectionScreen, - DriverSetupScreen, - HelpScreen, - QueryHistoryScreen, - ValueViewScreen, -) +from importlib import import_module +from typing import TYPE_CHECKING, Any __all__ = [ "ConfirmScreen", @@ -17,3 +11,29 @@ "QueryHistoryScreen", "ValueViewScreen", ] + +_LAZY_ATTRS: dict[str, tuple[str, str]] = { + "ConfirmScreen": ("sqlit.ui.screens.confirm", "ConfirmScreen"), + "ConnectionScreen": ("sqlit.ui.screens.connection", "ConnectionScreen"), + "DriverSetupScreen": ("sqlit.ui.screens.driver_setup", "DriverSetupScreen"), + "HelpScreen": ("sqlit.ui.screens.help", "HelpScreen"), + "QueryHistoryScreen": ("sqlit.ui.screens.query_history", "QueryHistoryScreen"), + "ValueViewScreen": ("sqlit.ui.screens.value_view", "ValueViewScreen"), +} + +if TYPE_CHECKING: + from .screens.confirm import ConfirmScreen + from .screens.connection import ConnectionScreen + from .screens.driver_setup import DriverSetupScreen + from .screens.help import HelpScreen + from .screens.query_history import QueryHistoryScreen + from .screens.value_view import ValueViewScreen + + +def __getattr__(name: str) -> Any: + target = _LAZY_ATTRS.get(name) + if target is None: + raise AttributeError(name) + module_name, attr_name = target + module = import_module(module_name) + return getattr(module, attr_name) diff --git a/sqlit/ui/mixins/__init__.py b/sqlit/ui/mixins/__init__.py index 6253609b..096442f2 100644 --- a/sqlit/ui/mixins/__init__.py +++ b/sqlit/ui/mixins/__init__.py @@ -5,6 +5,7 @@ from .query import QueryMixin from .results import ResultsMixin from .tree import TreeMixin +from .tree_filter import TreeFilterMixin from .ui_navigation import UINavigationMixin __all__ = [ @@ -13,5 +14,6 @@ "QueryMixin", "ResultsMixin", "TreeMixin", + "TreeFilterMixin", "UINavigationMixin", ] diff --git a/sqlit/ui/mixins/autocomplete.py b/sqlit/ui/mixins/autocomplete.py index 058649a9..f53b2502 100644 --- a/sqlit/ui/mixins/autocomplete.py +++ b/sqlit/ui/mixins/autocomplete.py @@ -2,43 +2,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any from textual.timer import Timer from textual.widgets import TextArea from textual.worker import Worker -if TYPE_CHECKING: - from ...config import ConnectionConfig - from ...widgets import VimMode +from ..protocols import AppProtocol -# Spinner frames for loading animation SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] class AutocompleteMixin: """Mixin providing SQL autocomplete functionality.""" - # These attributes are defined in the main app class - current_connection: Any - current_config: "ConnectionConfig | None" - current_adapter: Any - vim_mode: "VimMode" - _schema_cache: dict - _autocomplete_visible: bool - _autocomplete_items: list[str] - _autocomplete_index: int - _autocomplete_filter: str - _autocomplete_just_applied: bool - # Schema indexing state - _schema_indexing: bool - _schema_worker: Worker | None - _schema_spinner_index: int - _schema_spinner_timer: Timer | None - # Table metadata for lazy column loading: {display_name.lower(): (schema, table, database)} - _table_metadata: dict[str, tuple[str, str, str | None]] - # Track in-flight column loading requests - _columns_loading: set[str] + _schema_worker: Worker[Any] | None = None + _schema_spinner_timer: Timer | None = None + _schema_cache: dict[str, Any] = {} + _table_metadata: dict[str, tuple[str, str, str | None]] = {} def _get_word_before_cursor(self, text: str, cursor_pos: int) -> tuple[str, str]: """Get the current word being typed and the context keyword before it.""" @@ -72,7 +53,7 @@ def _get_word_before_cursor(self, text: str, cursor_pos: int) -> tuple[str, str] return current_word, "" - def _get_autocomplete_suggestions(self, word: str, context: str) -> list[str]: + def _get_autocomplete_suggestions(self: AppProtocol, word: str, context: str) -> list[str]: """Get autocomplete suggestions based on context.""" suggestions = [] @@ -82,7 +63,6 @@ def _get_autocomplete_suggestions(self, word: str, context: str) -> list[str]: suggestions = self._schema_cache["procedures"] elif context.startswith("column:"): table_name = context.split(":", 1)[1].lower() - # Lazy load columns if not cached if table_name not in self._schema_cache["columns"]: self._load_columns_for_table(table_name) suggestions = self._schema_cache["columns"].get(table_name, []) @@ -98,20 +78,17 @@ def _get_autocomplete_suggestions(self, word: str, context: str) -> list[str]: return suggestions[:50] - def _load_columns_for_table(self, table_name: str) -> None: + def _load_columns_for_table(self: AppProtocol, table_name: str) -> None: """Lazy load columns for a specific table (async via worker).""" if not self.current_connection or not self.current_adapter: return - # Initialize _columns_loading if not present if not hasattr(self, "_columns_loading") or self._columns_loading is None: self._columns_loading = set() - # Skip if already loading this table if table_name in self._columns_loading: return - # Check if we have metadata for this table metadata = self._table_metadata.get(table_name) if not metadata: return @@ -120,7 +97,6 @@ def _load_columns_for_table(self, table_name: str) -> None: self._columns_loading.add(table_name) def work() -> None: - """Run in worker thread.""" adapter = self.current_adapter connection = self.current_connection if not adapter or not connection: @@ -132,7 +108,6 @@ def work() -> None: except Exception: column_names = [] - # Update cache on main thread self.call_from_thread( self._on_autocomplete_columns_loaded, table_name, @@ -143,17 +118,15 @@ def work() -> None: self.run_worker(work, name=f"load-columns-{table_name}", thread=True, exclusive=False) def _on_autocomplete_columns_loaded( - self, table_name: str, actual_table_name: str, column_names: list[str] + self: AppProtocol, table_name: str, actual_table_name: str, column_names: list[str] ) -> None: """Handle column load completion for autocomplete on main thread.""" self._columns_loading.discard(table_name) self._schema_cache["columns"][table_name] = column_names - # Also cache by actual table name self._schema_cache["columns"][actual_table_name.lower()] = column_names - def _show_autocomplete(self, suggestions: list[str], filter_text: str) -> None: + def _show_autocomplete(self: AppProtocol, suggestions: list[str], filter_text: str) -> None: """Show the autocomplete dropdown with suggestions.""" - from ...widgets import AutocompleteDropdown if not suggestions: self._hide_autocomplete() @@ -168,12 +141,15 @@ def _show_autocomplete(self, suggestions: list[str], filter_text: str) -> None: dropdown.show() self._autocomplete_visible = True - def _hide_autocomplete(self) -> None: + def _hide_autocomplete(self: AppProtocol) -> None: """Hide the autocomplete dropdown.""" - self.autocomplete_dropdown.hide() + try: + self.autocomplete_dropdown.hide() + except Exception: + pass # Widget not mounted yet self._autocomplete_visible = False - def _apply_autocomplete(self) -> None: + def _apply_autocomplete(self: AppProtocol) -> None: """Apply the selected autocomplete suggestion.""" selected = self.autocomplete_dropdown.get_selected() @@ -188,15 +164,11 @@ def _apply_autocomplete(self) -> None: cursor_pos = self._location_to_offset(text, cursor_loc) word_start = cursor_pos - while word_start > 0 and text[word_start - 1] not in " \t\n,()[]": + while word_start > 0 and text[word_start - 1] not in " \t\n,()[].": word_start -= 1 if word_start > 0 and text[word_start - 1] == ".": - new_text = ( - text[:cursor_pos] - + selected[len(text[word_start:cursor_pos]) :] - + text[cursor_pos:] - ) + new_text = text[:cursor_pos] + selected[len(text[word_start:cursor_pos]) :] + text[cursor_pos:] else: new_text = text[:word_start] + selected + text[cursor_pos:] @@ -208,7 +180,7 @@ def _apply_autocomplete(self) -> None: self._hide_autocomplete() - def _location_to_offset(self, text: str, location: tuple) -> int: + def _location_to_offset(self, text: str, location: tuple[int, int]) -> int: """Convert (row, col) location to text offset.""" row, col = location lines = text.split("\n") @@ -216,7 +188,7 @@ def _location_to_offset(self, text: str, location: tuple) -> int: offset += col return min(offset, len(text)) - def _offset_to_location(self, text: str, offset: int) -> tuple: + def _offset_to_location(self, text: str, offset: int) -> tuple[int, int]: """Convert text offset to (row, col) location.""" lines = text.split("\n") current_offset = 0 @@ -226,7 +198,7 @@ def _offset_to_location(self, text: str, offset: int) -> tuple: current_offset += len(line) + 1 return (len(lines) - 1, len(lines[-1]) if lines else 0) - def on_text_area_changed(self, event: TextArea.Changed) -> None: + def on_text_area_changed(self: AppProtocol, event: TextArea.Changed) -> None: """Handle text changes in the query editor for autocomplete.""" from ...widgets import VimMode @@ -264,7 +236,21 @@ def on_text_area_changed(self, event: TextArea.Changed) -> None: else: self._hide_autocomplete() - def on_key(self, event) -> None: + def action_autocomplete_next(self: AppProtocol) -> None: + """Move to next autocomplete suggestion.""" + if self._autocomplete_visible: + self.autocomplete_dropdown.move_selection(1) + + def action_autocomplete_prev(self: AppProtocol) -> None: + """Move to previous autocomplete suggestion.""" + if self._autocomplete_visible: + self.autocomplete_dropdown.move_selection(-1) + + def action_autocomplete_close(self: AppProtocol) -> None: + """Close autocomplete dropdown without exiting insert mode.""" + self._hide_autocomplete() + + def on_key(self: AppProtocol, event: Any) -> None: """Handle key events for autocomplete navigation.""" from ...widgets import VimMode @@ -289,8 +275,10 @@ def on_key(self, event) -> None: event.stop() elif event.key == "escape": self._hide_autocomplete() + event.prevent_default() + event.stop() - def _load_schema_cache(self) -> None: + def _load_schema_cache(self: AppProtocol) -> None: """Load database schema for autocomplete asynchronously.""" if not self.current_connection or not self.current_config or not self.current_adapter: return @@ -318,7 +306,7 @@ def _load_schema_cache(self) -> None: exclusive=True, ) - def _start_schema_spinner(self) -> None: + def _start_schema_spinner(self: AppProtocol) -> None: """Start the schema indexing spinner animation.""" self._schema_indexing = True self._schema_spinner_index = 0 @@ -328,7 +316,7 @@ def _start_schema_spinner(self) -> None: self._schema_spinner_timer.stop() self._schema_spinner_timer = self.set_interval(0.1, self._animate_schema_spinner) - def _stop_schema_spinner(self) -> None: + def _stop_schema_spinner(self: AppProtocol) -> None: """Stop the schema indexing spinner animation.""" self._schema_indexing = False if hasattr(self, "_schema_spinner_timer") and self._schema_spinner_timer is not None: @@ -336,14 +324,14 @@ def _stop_schema_spinner(self) -> None: self._schema_spinner_timer = None self._update_status_bar() - def _animate_schema_spinner(self) -> None: + def _animate_schema_spinner(self: AppProtocol) -> None: """Update schema spinner animation frame.""" if not self._schema_indexing: return self._schema_spinner_index = (self._schema_spinner_index + 1) % len(SPINNER_FRAMES) self._update_status_bar() - def action_cancel_schema_indexing(self) -> None: + def action_cancel_schema_indexing(self: AppProtocol) -> None: """Cancel ongoing schema indexing.""" if hasattr(self, "_schema_worker") and self._schema_worker is not None: self._schema_worker.cancel() @@ -351,7 +339,7 @@ def action_cancel_schema_indexing(self) -> None: self._stop_schema_spinner() self.notify("Schema indexing cancelled") - async def _load_schema_cache_async(self) -> None: + async def _load_schema_cache_async(self: AppProtocol) -> None: """Load database schema asynchronously in a worker thread. Only loads tables, views, and procedures. Columns are loaded lazily. @@ -376,14 +364,13 @@ async def _load_schema_cache_async(self) -> None: try: # Get database list in thread + databases: list[str | None] if adapter.supports_multiple_databases: db = config.database if db and db.lower() not in ("", "master"): databases = [db] else: - all_dbs = await asyncio.to_thread( - adapter.get_databases, connection - ) + all_dbs = await asyncio.to_thread(adapter.get_databases, connection) system_dbs = {"master", "tempdb", "model", "msdb"} databases = [d for d in all_dbs if d.lower() not in system_dbs] else: @@ -392,9 +379,7 @@ async def _load_schema_cache_async(self) -> None: for database in databases: try: # Get tables in thread (NO columns - lazy loaded) - tables = await asyncio.to_thread( - adapter.get_tables, connection, database - ) + tables = await asyncio.to_thread(adapter.get_tables, connection, database) for schema_name, table_name in tables: display_name = adapter.format_table_name(schema_name, table_name) schema_cache["tables"].append(display_name) @@ -407,9 +392,7 @@ async def _load_schema_cache_async(self) -> None: table_metadata[full_name.lower()] = (schema_name, table_name, database) # Get views in thread (NO columns - lazy loaded) - views = await asyncio.to_thread( - adapter.get_views, connection, database - ) + views = await asyncio.to_thread(adapter.get_views, connection, database) for schema_name, view_name in views: display_name = adapter.format_table_name(schema_name, view_name) schema_cache["views"].append(display_name) @@ -422,9 +405,7 @@ async def _load_schema_cache_async(self) -> None: table_metadata[full_name.lower()] = (schema_name, view_name, database) if adapter.supports_stored_procedures: - procedures = await asyncio.to_thread( - adapter.get_procedures, connection, database - ) + procedures = await asyncio.to_thread(adapter.get_procedures, connection, database) schema_cache["procedures"].extend(procedures) except Exception: @@ -443,7 +424,7 @@ async def _load_schema_cache_async(self) -> None: finally: self._stop_schema_spinner() - def _update_schema_cache(self, schema_cache: dict, table_metadata: dict | None = None) -> None: + def _update_schema_cache(self: AppProtocol, schema_cache: dict, table_metadata: dict | None = None) -> None: """Update the schema cache (called on main thread).""" self._schema_cache = schema_cache if table_metadata is not None: diff --git a/sqlit/ui/mixins/connection.py b/sqlit/ui/mixins/connection.py index eb655e7b..9927bbba 100644 --- a/sqlit/ui/mixins/connection.py +++ b/sqlit/ui/mixins/connection.py @@ -2,53 +2,122 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable - -from textual.widgets import Static +from typing import TYPE_CHECKING, Any +from ..protocols import AppProtocol from ..tree_nodes import ConnectionNode if TYPE_CHECKING: from ...config import ConnectionConfig - from ...services import ConnectionSession + from ...db import DatabaseAdapter -class ConnectionMixin: - """Mixin providing connection management functionality. +def _needs_db_password(config: ConnectionConfig) -> bool: + """Check if the connection needs a database password prompt. - Attributes: - _session_factory: Optional factory for creating ConnectionSession. - Set this in tests to inject a mock session factory. - Defaults to ConnectionSession.create when None. + Returns True if password is None (not set) and the database type uses passwords. + Note: Empty string "" means explicitly set to empty (no prompt needed). """ + from ...db.providers import is_file_based - # These attributes are defined in the main app class - connections: list - current_connection: Any - current_config: "ConnectionConfig | None" - current_adapter: Any - current_ssh_tunnel: Any - _session: "ConnectionSession | None" + # File-based databases (SQLite, DuckDB) don't need passwords + if is_file_based(config.db_type): + return False - # DI seam for testing - set to override session creation - _session_factory: Callable[["ConnectionConfig"], "ConnectionSession"] | None = None + # Check if password is not set (None means prompt needed) + return config.password is None - def connect_to_server(self, config: "ConnectionConfig") -> None: - """Connect to a database (async, non-blocking).""" - from ...services import ConnectionSession - # Check for pyodbc only if it's a SQL Server connection - try: - import pyodbc - PYODBC_AVAILABLE = True - except ImportError: - PYODBC_AVAILABLE = False +def _needs_ssh_password(config: ConnectionConfig) -> bool: + """Check if the connection needs an SSH password prompt. + + Returns True if SSH is enabled with password auth and password is None (not set). + Note: Empty string "" means explicitly set to empty (no prompt needed). + """ + if not config.ssh_enabled: + return False + + if config.ssh_auth_type != "password": + return False + + return config.ssh_password is None + + +class ConnectionMixin: + """Mixin providing connection management functionality.""" + + current_config: ConnectionConfig | None = None + current_adapter: DatabaseAdapter | None = None - if config.db_type == "mssql" and not PYODBC_AVAILABLE: - self.notify("pyodbc not installed. Run: pip install pyodbc", severity="error") + def _populate_credentials_if_missing(self: AppProtocol, config: ConnectionConfig) -> None: + """Populate missing credentials from the credentials service.""" + if config.password is not None and config.ssh_password is not None: return + from ...services.credentials import get_credentials_service + + service = get_credentials_service() + if config.password is None: + password = service.get_password(config.name) + if password is not None: + config.password = password + if config.ssh_password is None: + ssh_password = service.get_ssh_password(config.name) + if ssh_password is not None: + config.ssh_password = ssh_password + + def connect_to_server(self: AppProtocol, config: ConnectionConfig) -> None: + """Connect to a database (async, non-blocking). + + If the connection requires a password that is not stored (empty), + the user will be prompted to enter the password before connecting. + """ + from dataclasses import replace + + from ..screens import PasswordInputScreen + + self._populate_credentials_if_missing(config) + + if _needs_ssh_password(config): + + def on_ssh_password(password: str | None) -> None: + if password is None: + return + temp_config = replace(config, ssh_password=password) + self._connect_with_db_password_check(temp_config) + + self.push_screen( + PasswordInputScreen(config.name, password_type="ssh"), + on_ssh_password, + ) + return + + self._connect_with_db_password_check(config) + + def _connect_with_db_password_check(self: AppProtocol, config: ConnectionConfig) -> None: + """Check for database password and prompt if needed, then connect.""" + from dataclasses import replace + + from ..screens import PasswordInputScreen + + if _needs_db_password(config): + + def on_db_password(password: str | None) -> None: + if password is None: + return + temp_config = replace(config, password=password) + self._do_connect(temp_config) + + self.push_screen( + PasswordInputScreen(config.name, password_type="database"), + on_db_password, + ) + return + + self._do_connect(config) + + def _do_connect(self: AppProtocol, config: ConnectionConfig) -> None: + from ...services import ConnectionSession - # Close any existing session first if hasattr(self, "_session") and self._session: self._session.close() self._session = None @@ -58,18 +127,14 @@ def connect_to_server(self, config: "ConnectionConfig") -> None: self.current_ssh_tunnel = None self.refresh_tree() - # Reset connection failed state self._connection_failed = False - # Use injected factory or default create_session = self._session_factory or ConnectionSession.create - def work() -> "ConnectionSession": - """Create connection in worker thread.""" + def work() -> ConnectionSession: return create_session(config) - def on_success(session: "ConnectionSession") -> None: - """Handle successful connection on main thread.""" + def on_success(session: ConnectionSession) -> None: self._connection_failed = False self._session = session self.current_connection = session.connection @@ -82,15 +147,84 @@ def on_success(session: "ConnectionSession") -> None: self._update_status_bar() def on_error(error: Exception) -> None: - """Handle connection failure on main thread.""" - from ..screens import ErrorScreen + from ...config import save_connections + from ...db.exceptions import MissingDriverError, MissingODBCDriverError + from ...terminal import run_in_terminal + from ..screens import ConfirmScreen, DriverSetupScreen, ErrorScreen, MessageScreen self._connection_failed = True self._update_status_bar() - self.push_screen(ErrorScreen("Connection Failed", str(error))) + + if isinstance(error, MissingDriverError): + from ...services.installer import Installer + from ..screens import PackageSetupScreen + + self.push_screen( + PackageSetupScreen(error, on_install=lambda err: Installer(self).install(err)), + ) + elif isinstance(error, MissingODBCDriverError): + + def on_confirm(confirmed: bool | None) -> None: + if confirmed is not True: + self.push_screen( + MessageScreen( + "Missing ODBC driver", + ( + "SQL Server requires an ODBC driver.\n\n" + "Open connection settings (Advanced) to configure drivers." + ), + ) + ) + return + + def on_driver_result(result: Any) -> None: + if not result: + return + action = result[0] + if action == "select": + driver = result[1] + config.driver = driver + for i, c in enumerate(self.connections): + if c.name == config.name: + self.connections[i] = config + break + save_connections(self.connections) + self.call_later(lambda: self.connect_to_server(config)) + return + if action == "install": + commands = result[1] + res = run_in_terminal(commands) + if res.success: + self.push_screen( + MessageScreen( + "Driver install", + "Installation started in a new terminal.\n\nPlease restart to apply.", + ) + ) + else: + self.push_screen( + MessageScreen( + "Couldn't install automatically", + "Couldn't install automatically, please install manually.", + ), + lambda _=None: self.push_screen( + DriverSetupScreen(error.installed_drivers), on_driver_result + ), + ) + + self.push_screen(DriverSetupScreen(error.installed_drivers), on_driver_result) + + self.push_screen( + ConfirmScreen( + "Missing ODBC driver", + "SQL Server requires an ODBC driver.\n\nOpen driver setup now?", + ), + on_confirm, + ) + else: + self.push_screen(ErrorScreen("Connection Failed", str(error))) def do_work() -> None: - """Worker function with error handling.""" try: session = work() self.call_from_thread(on_success, session) @@ -99,20 +233,17 @@ def do_work() -> None: self.run_worker(do_work, name=f"connect-{config.name}", thread=True, exclusive=True) - def _disconnect_silent(self) -> None: - """Disconnect from current database without notification.""" - # Use session's close method for proper cleanup + def _disconnect_silent(self: AppProtocol) -> None: if hasattr(self, "_session") and self._session: self._session.close() self._session = None - # Clear instance variables self.current_connection = None self.current_config = None self.current_adapter = None self.current_ssh_tunnel = None - def action_disconnect(self) -> None: + def action_disconnect(self: AppProtocol) -> None: """Disconnect from current database.""" if self.current_connection: self._disconnect_silent() @@ -122,15 +253,13 @@ def action_disconnect(self) -> None: self.refresh_tree() self.notify("Disconnected") - def action_new_connection(self) -> None: - """Show new connection dialog.""" + def action_new_connection(self: AppProtocol) -> None: from ..screens import ConnectionScreen self._set_connection_screen_footer() self.push_screen(ConnectionScreen(), self._wrap_connection_result) - def action_edit_connection(self) -> None: - """Edit the selected connection.""" + def action_edit_connection(self: AppProtocol) -> None: from ..screens import ConnectionScreen node = self.object_tree.cursor_node @@ -143,12 +272,9 @@ def action_edit_connection(self) -> None: return self._set_connection_screen_footer() - self.push_screen( - ConnectionScreen(data.config, editing=True), self._wrap_connection_result - ) + self.push_screen(ConnectionScreen(data.config, editing=True), self._wrap_connection_result) - def _set_connection_screen_footer(self) -> None: - """Set footer bindings for connection screen.""" + def _set_connection_screen_footer(self: AppProtocol) -> None: from ...widgets import ContextFooter try: @@ -157,14 +283,18 @@ def _set_connection_screen_footer(self) -> None: return footer.set_bindings([], []) - def _wrap_connection_result(self, result: tuple | None) -> None: - """Wrapper to restore footer after connection dialog.""" + def _wrap_connection_result(self: AppProtocol, result: tuple | None) -> None: self._update_footer_bindings() self.handle_connection_result(result) - def handle_connection_result(self, result: tuple | None) -> None: - """Handle result from connection dialog.""" - from ...config import save_connections + def handle_connection_result(self: AppProtocol, result: tuple | None) -> None: + from ...config import load_settings, save_connections, save_settings + from ...services.credentials import ( + ALLOW_PLAINTEXT_CREDENTIALS_SETTING, + is_keyring_usable, + reset_credentials_service, + ) + from ..screens import ConfirmScreen if not result: return @@ -172,14 +302,64 @@ def handle_connection_result(self, result: tuple | None) -> None: action, config = result if action == "save": - self.connections = [c for c in self.connections if c.name != config.name] - self.connections.append(config) - save_connections(self.connections) - self.refresh_tree() - self.notify(f"Connection '{config.name}' saved") + def do_save(with_config) -> None: # noqa: ANN001 + self.connections = [c for c in self.connections if c.name != with_config.name] + self.connections.append(with_config) + if getattr(self, "_mock_profile", None): + self.notify("Mock mode: connection changes are not persisted") + else: + save_connections(self.connections) + self.refresh_tree() + self.notify(f"Connection '{with_config.name}' saved") + + needs_password_persist = bool(getattr(config, "password", "") or getattr(config, "ssh_password", "")) + if not getattr(self, "_mock_profile", None) and needs_password_persist and not is_keyring_usable(): + settings = load_settings() + allow_plaintext = settings.get(ALLOW_PLAINTEXT_CREDENTIALS_SETTING) + + if allow_plaintext is True: + reset_credentials_service() + do_save(config) + return + + if allow_plaintext is False: + config.password = "" + config.ssh_password = "" + do_save(config) + self.notify("Keyring unavailable: passwords will be prompted when needed", severity="warning") + return + + def on_confirm(confirmed: bool | None) -> None: + settings2 = load_settings() + if confirmed is True: + settings2[ALLOW_PLAINTEXT_CREDENTIALS_SETTING] = True + save_settings(settings2) + reset_credentials_service() + do_save(config) + self.notify("Saved passwords as plaintext in ~/.sqlit/ (0600)", severity="warning") + return + + settings2[ALLOW_PLAINTEXT_CREDENTIALS_SETTING] = False + save_settings(settings2) + config.password = "" + config.ssh_password = "" + do_save(config) + self.notify("Passwords were not saved (keyring unavailable)", severity="warning") + + self.push_screen( + ConfirmScreen( + "Keyring isn't available", + "Save passwords as plaintext in ~/.sqlit/ (protected directory)?", + yes_label="Yes", + no_label="No", + ), + on_confirm, + ) + return + + do_save(config) - def action_duplicate_connection(self) -> None: - """Duplicate the selected connection.""" + def action_duplicate_connection(self: AppProtocol) -> None: from dataclasses import replace from ..screens import ConnectionScreen @@ -206,12 +386,9 @@ def action_duplicate_connection(self) -> None: duplicated = replace(config, name=new_name) self._set_connection_screen_footer() - self.push_screen( - ConnectionScreen(duplicated, editing=False), self._wrap_connection_result - ) + self.push_screen(ConnectionScreen(duplicated, editing=False), self._wrap_connection_result) - def action_delete_connection(self) -> None: - """Delete the selected connection.""" + def action_delete_connection(self: AppProtocol) -> None: from ..screens import ConfirmScreen node = self.object_tree.cursor_node @@ -234,17 +411,32 @@ def action_delete_connection(self) -> None: lambda confirmed: self._do_delete_connection(config) if confirmed else None, ) - def _do_delete_connection(self, config: "ConnectionConfig") -> None: - """Actually delete the connection after confirmation.""" + def _do_delete_connection(self: AppProtocol, config: ConnectionConfig) -> None: from ...config import save_connections self.connections = [c for c in self.connections if c.name != config.name] - save_connections(self.connections) + if getattr(self, "_mock_profile", None): + self.notify("Mock mode: connection changes are not persisted") + else: + save_connections(self.connections) self.refresh_tree() self.notify(f"Connection '{config.name}' deleted") - def action_connect_selected(self) -> None: - """Connect to the selected connection.""" + def _handle_install_confirmation(self: AppProtocol, confirmed: bool | None, error: Any) -> None: + from ...db.adapters.base import _create_driver_import_error_hint + from ...services.installer import Installer + from ..screens import ErrorScreen + + if confirmed is True: + installer = Installer(self) # self is the App instance + self.call_next(installer.install, error) # Schedule the async install method + elif confirmed is False: + hint = _create_driver_import_error_hint(error.driver_name, error.extra_name, error.package_name) + self.push_screen(ErrorScreen("Manual Installation Required", hint)) + else: + return + + def action_connect_selected(self: AppProtocol) -> None: node = self.object_tree.cursor_node if not node or not node.data: @@ -259,8 +451,7 @@ def action_connect_selected(self) -> None: self._disconnect_silent() self.connect_to_server(config) - def action_show_connection_picker(self) -> None: - """Show connection picker dialog.""" + def action_show_connection_picker(self: AppProtocol) -> None: from ..screens import ConnectionPickerScreen self.push_screen( @@ -268,14 +459,12 @@ def action_show_connection_picker(self) -> None: self._handle_connection_picker_result, ) - def _handle_connection_picker_result(self, result: str | None) -> None: - """Handle connection picker selection.""" + def _handle_connection_picker_result(self: AppProtocol, result: str | None) -> None: if result is None: return config = next((c for c in self.connections if c.name == result), None) if config: - # Select the connection node in the tree for node in self.object_tree.root.children: if isinstance(node.data, ConnectionNode) and node.data.config.name == result: self.object_tree.select_node(node) diff --git a/sqlit/ui/mixins/protocols.py b/sqlit/ui/mixins/protocols.py new file mode 100644 index 00000000..97fb817a --- /dev/null +++ b/sqlit/ui/mixins/protocols.py @@ -0,0 +1,94 @@ +"""Protocol definitions for mixin classes. + +These protocols define the attributes and methods that mixins expect +to be available on the host App class. + +Note: mixins must not inherit from Protocol at runtime (can cause metaclass +conflicts with Textual's App metaclass on newer Python versions). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from textual.widgets import DataTable, TextArea, Tree + + from ...config import ConnectionConfig + from ...services import ConnectionSession + + +class AppProtocol(Protocol): + """Protocol defining what mixins expect from the main App class.""" + + # Widget attributes (from Textual App) + object_tree: Tree + query_input: TextArea + results_table: DataTable + autocomplete_dropdown: Any # AutocompleteDropdown widget + + # Connection state + connections: list[ConnectionConfig] + current_connection: Any # database connection object + current_config: ConnectionConfig | None + current_adapter: Any # DatabaseAdapter instance + _session: ConnectionSession | None + + # UI state + _expanded_paths: set[str] + _loading_nodes: set[str] + _leader_pending: bool + screen_stack: list[Any] + vim_mode: Any # VimMode enum + + # Result state + _last_result_columns: list[str] + _last_result_rows: list[tuple] + _last_result_row_count: int + _internal_clipboard: str + + # Textual App methods + def notify( + self, + message: str, + *, + title: str = "", + severity: str = "information", + timeout: float = 2.0, + ) -> None: ... + + def call_later(self, callback: Any, *args: Any, **kwargs: Any) -> Any: ... + + def call_from_thread(self, callback: Any, *args: Any, **kwargs: Any) -> Any: ... + + def run_worker( + self, + work: Any, + *, + name: str = "", + group: str = "default", + description: str = "", + exit_on_error: bool = True, + exclusive: bool = False, + ) -> Any: ... + + def push_screen(self, screen: Any) -> Any: ... + + def pop_screen(self) -> Any: ... + + def action_quit(self) -> None: ... + + def copy_to_clipboard(self, text: str) -> None: ... + + def set_interval(self, interval: float, callback: Any, *args: Any, **kwargs: Any) -> Any: ... + + # App-specific methods + def _disconnect_silent(self) -> None: ... + + def connect_to_server(self, config: Any) -> Any: ... + + def _update_footer_bindings(self) -> None: ... + + def action_execute_query(self) -> None: ... + + def _update_status_bar(self) -> None: ... diff --git a/sqlit/ui/mixins/query.py b/sqlit/ui/mixins/query.py index 05d9c642..d321ed18 100644 --- a/sqlit/ui/mixins/query.py +++ b/sqlit/ui/mixins/query.py @@ -6,15 +6,14 @@ from rich.markup import escape as escape_markup from textual.timer import Timer -from textual.widgets import DataTable, TextArea from textual.worker import Worker +from ..protocols import AppProtocol +from ...utils import format_duration_ms + if TYPE_CHECKING: - from ...config import ConnectionConfig - from ...services import CancellableQuery, QueryService - from ...widgets import VimMode + from ...services import QueryService -# Spinner frames for loading animation SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] @@ -27,33 +26,44 @@ class QueryMixin: Defaults to a new QueryService() when None. """ - # These attributes are defined in the main app class - current_connection: Any - current_config: "ConnectionConfig | None" - current_adapter: Any - vim_mode: "VimMode" - _last_result_columns: list[str] - _last_result_rows: list[tuple] - _last_result_row_count: int - _query_worker: Worker | None - _query_executing: bool - _query_start_time: float - _spinner_index: int - _spinner_timer: Timer | None - _cancellable_query: "CancellableQuery | None" - - # DI seam for testing - set to override query service - _query_service: "QueryService | None" = None - - def action_execute_query(self) -> None: + _query_service: QueryService | None = None + + _query_worker: Worker[Any] | None = None + _schema_worker: Worker[Any] | None = None + _cancellable_query: Any | None = None + _spinner_timer: Timer | None = None + _query_cursor_cache: dict[str, tuple[int, int]] | None = None # query text -> cursor (row, col) + + def action_execute_query(self: AppProtocol) -> None: """Execute the current query.""" self._execute_query_common(keep_insert_mode=False) - def action_execute_query_insert(self) -> None: + def action_execute_query_insert(self: AppProtocol) -> None: """Execute query in INSERT mode without leaving it.""" self._execute_query_common(keep_insert_mode=True) - def _execute_query_common(self, keep_insert_mode: bool) -> None: + def action_copy_query(self: AppProtocol) -> None: + """Copy the current query to clipboard.""" + from ...widgets import flash_widget + + query = self.query_input.text.strip() + if not query: + self.notify("Query is empty", severity="warning") + return + self._copy_text(query) + flash_widget(self.query_input) + + def action_copy_context(self: AppProtocol) -> None: + """Copy based on current focus (query or results).""" + if self.query_input.has_focus: + self.action_copy_query() + return + if self.results_table.has_focus: + self.action_copy_cell() + return + self.notify("Nothing to copy", severity="warning") + + def _execute_query_common(self: AppProtocol, keep_insert_mode: bool) -> None: """Common query execution logic.""" if not self.current_connection or not self.current_adapter: self.notify("Connect to a server to execute queries", severity="warning") @@ -65,25 +75,18 @@ def _execute_query_common(self, keep_insert_mode: bool) -> None: self.notify("No query to execute", severity="warning") return - # Cancel any existing query worker if hasattr(self, "_query_worker") and self._query_worker is not None: self._query_worker.cancel() - self.results_table.clear(columns=True) - self.results_table.add_column("Status") - self.results_table.add_row("Executing query...") - - # Start spinner animation self._start_query_spinner() - # Run query in background thread self._query_worker = self.run_worker( self._run_query_async(query, keep_insert_mode), name="query_execution", exclusive=True, ) - def _start_query_spinner(self) -> None: + def _start_query_spinner(self: AppProtocol) -> None: """Start the query execution spinner animation.""" import time @@ -91,12 +94,11 @@ def _start_query_spinner(self) -> None: self._query_start_time = time.perf_counter() self._spinner_index = 0 self._update_status_bar() - # Start timer to animate spinner if hasattr(self, "_spinner_timer") and self._spinner_timer is not None: self._spinner_timer.stop() - self._spinner_timer = self.set_interval(0.1, self._animate_spinner) + self._spinner_timer = self.set_interval(1 / 30, self._animate_spinner) # 30fps - def _stop_query_spinner(self) -> None: + def _stop_query_spinner(self: AppProtocol) -> None: """Stop the query execution spinner animation.""" self._query_executing = False if hasattr(self, "_spinner_timer") and self._spinner_timer is not None: @@ -104,14 +106,14 @@ def _stop_query_spinner(self) -> None: self._spinner_timer = None self._update_status_bar() - def _animate_spinner(self) -> None: + def _animate_spinner(self: AppProtocol) -> None: """Update spinner animation frame.""" if not self._query_executing: return self._spinner_index = (self._spinner_index + 1) % len(SPINNER_FRAMES) self._update_status_bar() - async def _run_query_async(self, query: str, keep_insert_mode: bool) -> None: + async def _run_query_async(self: AppProtocol, query: str, keep_insert_mode: bool) -> None: """Run query asynchronously using a cancellable dedicated connection.""" import asyncio import time @@ -126,7 +128,7 @@ async def _run_query_async(self, query: str, keep_insert_mode: bool) -> None: self._stop_query_spinner() return - # Create cancellable query with dedicated connection + # Dedicated connection enables cancellation by closing it. cancellable = CancellableQuery( sql=query, config=config, @@ -134,11 +136,9 @@ async def _run_query_async(self, query: str, keep_insert_mode: bool) -> None: ) self._cancellable_query = cancellable - # Use injected service or default (for history saving) service = self._query_service or QueryService() try: - # Execute on dedicated connection (cancellable via connection close) max_fetch_rows = 10000 start_time = time.perf_counter() @@ -148,14 +148,10 @@ async def _run_query_async(self, query: str, keep_insert_mode: bool) -> None: ) elapsed_ms = (time.perf_counter() - start_time) * 1000 - # Save to history after successful execution service._save_to_history(config.name, query) - # Update UI (we're back on main thread after await) if isinstance(result, QueryResult): - self._display_query_results( - result.columns, result.rows, result.row_count, result.truncated, elapsed_ms - ) + self._display_query_results(result.columns, result.rows, result.row_count, result.truncated, elapsed_ms) else: self._display_non_query_result(result.rows_affected, elapsed_ms) @@ -163,22 +159,19 @@ async def _run_query_async(self, query: str, keep_insert_mode: bool) -> None: self._restore_insert_mode() except RuntimeError as e: - # Query was cancelled if "cancelled" in str(e).lower(): pass # Already handled by action_cancel_query else: self._display_query_error(str(e)) except Exception as e: - # Don't show error if query was cancelled if not cancellable.is_cancelled: self._display_query_error(str(e)) finally: self._cancellable_query = None - # Always stop the spinner when done self._stop_query_spinner() def _display_query_results( - self, columns: list[str], rows: list[tuple], row_count: int, truncated: bool, elapsed_ms: float + self: AppProtocol, columns: list[str], rows: list[tuple], row_count: int, truncated: bool, elapsed_ms: float ) -> None: """Display query results in the results table (called on main thread).""" self._last_result_columns = columns @@ -186,42 +179,45 @@ def _display_query_results( self._last_result_row_count = row_count self.results_table.clear(columns=True) + self.results_table.show_header = True self.results_table.add_columns(*columns) for row in rows[:1000]: str_row = tuple(escape_markup(str(v)) if v is not None else "NULL" for v in row) self.results_table.add_row(*str_row) - time_str = f"{elapsed_ms:.0f}ms" if elapsed_ms >= 1 else f"{elapsed_ms:.2f}ms" + time_str = format_duration_ms(elapsed_ms) if truncated: self.notify(f"Query returned {row_count}+ rows in {time_str} (truncated)", severity="warning") else: self.notify(f"Query returned {row_count} rows in {time_str}") - def _display_non_query_result(self, affected: int, elapsed_ms: float) -> None: + def _display_non_query_result(self: AppProtocol, affected: int, elapsed_ms: float) -> None: """Display non-query result (called on main thread).""" self._last_result_columns = ["Result"] self._last_result_rows = [(f"{affected} row(s) affected",)] self._last_result_row_count = 1 self.results_table.clear(columns=True) + self.results_table.show_header = True self.results_table.add_column("Result") self.results_table.add_row(f"{affected} row(s) affected") - time_str = f"{elapsed_ms:.0f}ms" if elapsed_ms >= 1 else f"{elapsed_ms:.2f}ms" + time_str = format_duration_ms(elapsed_ms) self.notify(f"Query executed: {affected} row(s) affected in {time_str}") - def _display_query_error(self, error_message: str) -> None: + def _display_query_error(self: AppProtocol, error_message: str) -> None: """Display query error (called on main thread).""" self._last_result_columns = ["Error"] self._last_result_rows = [(error_message,)] self._last_result_row_count = 1 self.results_table.clear(columns=True) + self.results_table.show_header = True self.results_table.add_column("Error") self.results_table.add_row(escape_markup(error_message)) self.notify(f"Query error: {error_message}", severity="error") - def _restore_insert_mode(self) -> None: + def _restore_insert_mode(self: AppProtocol) -> None: """Restore INSERT mode after query execution (called on main thread).""" from ...widgets import VimMode @@ -231,31 +227,29 @@ def _restore_insert_mode(self) -> None: self._update_footer_bindings() self._update_status_bar() - def action_cancel_query(self) -> None: + def action_cancel_query(self: AppProtocol) -> None: """Cancel the currently running query.""" if not getattr(self, "_query_executing", False): self.notify("No query running") return - # Cancel the cancellable query (closes dedicated connection) if hasattr(self, "_cancellable_query") and self._cancellable_query is not None: self._cancellable_query.cancel() - # Also cancel the worker if hasattr(self, "_query_worker") and self._query_worker is not None: self._query_worker.cancel() self._query_worker = None self._stop_query_spinner() - # Update results table to show cancelled state self.results_table.clear(columns=True) + self.results_table.show_header = True self.results_table.add_column("Status") self.results_table.add_row("Query cancelled") self.notify("Query cancelled", severity="warning") - def action_cancel_operation(self) -> None: + def action_cancel_operation(self: AppProtocol) -> None: """Cancel any running operation (query or schema indexing).""" cancelled = False @@ -272,6 +266,7 @@ def action_cancel_operation(self) -> None: # Update results table to show cancelled state self.results_table.clear(columns=True) + self.results_table.show_header = True self.results_table.add_column("Status") self.results_table.add_row("Query cancelled") cancelled = True @@ -289,19 +284,20 @@ def action_cancel_operation(self) -> None: else: self.notify("No operation running") - def action_clear_query(self) -> None: + def action_clear_query(self: AppProtocol) -> None: """Clear the query input.""" self.query_input.text = "" - def action_new_query(self) -> None: + def action_new_query(self: AppProtocol) -> None: """Start a new query (clear input and results).""" self.query_input.text = "" self.results_table.clear(columns=True) + self.results_table.show_header = False - def action_show_history(self) -> None: + def action_show_history(self: AppProtocol) -> None: """Show query history for the current connection.""" if not self.current_config: - self.notify("Not connected to a database", severity="warning") + self.notify("Not connected", severity="warning") return from ...config import load_query_history @@ -313,20 +309,42 @@ def action_show_history(self) -> None: self._handle_history_result, ) - def _handle_history_result(self, result) -> None: + def _handle_history_result(self: AppProtocol, result: Any) -> None: """Handle the result from the history screen.""" if result is None: return action, data = result if action == "select": + # Initialize cursor cache if needed + if self._query_cursor_cache is None: + self._query_cursor_cache = {} + + # Save current query's cursor position before switching + current_query = self.query_input.text + if current_query: + self._query_cursor_cache[current_query] = self.query_input.cursor_location + + # Set new query text self.query_input.text = data + + # Restore cursor position if we have it cached, otherwise go to end + if data in self._query_cursor_cache: + self.query_input.cursor_location = self._query_cursor_cache[data] + else: + # Move cursor to end of query + lines = data.split("\n") + last_line = len(lines) - 1 + last_col = len(lines[-1]) if lines else 0 + self.query_input.cursor_location = (last_line, last_col) elif action == "delete": self._delete_history_entry(data) self.action_show_history() - def _delete_history_entry(self, timestamp: str) -> None: + def _delete_history_entry(self: AppProtocol, timestamp: str) -> None: """Delete a specific history entry by timestamp.""" from ...config import delete_query_from_history + if not self.current_config: + return delete_query_from_history(self.current_config.name, timestamp) diff --git a/sqlit/ui/mixins/results.py b/sqlit/ui/mixins/results.py index e0624c45..0f696e5e 100644 --- a/sqlit/ui/mixins/results.py +++ b/sqlit/ui/mixins/results.py @@ -2,26 +2,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - from textual.widgets import DataTable -if TYPE_CHECKING: - from ...config import ConnectionConfig +from ..protocols import AppProtocol class ResultsMixin: """Mixin providing results handling functionality.""" - # These attributes are defined in the main app class - current_connection: Any - current_config: "ConnectionConfig | None" - _last_result_columns: list[str] - _last_result_rows: list[tuple] - _last_result_row_count: int - _internal_clipboard: str - - def _copy_text(self, text: str) -> bool: + def _copy_text(self: AppProtocol, text: str) -> bool: """Copy text to clipboard if possible, otherwise store internally.""" self._internal_clipboard = text @@ -34,18 +23,20 @@ def _copy_text(self, text: str) -> bool: # Fallback to system clipboard via pyperclip (requires platform support). try: - import pyperclip # type: ignore + import pyperclip # noqa: F401 # type: ignore pyperclip.copy(text) return True except Exception: return False - def _flash_table_yank(self, table: DataTable, scope: str) -> None: + def _flash_table_yank(self: AppProtocol, table: DataTable, scope: str) -> None: """Briefly flash the yanked cell(s) to confirm a copy action.""" + from ...widgets import flash_widget + previous_cursor_type = getattr(table, "cursor_type", "cell") css_class = "flash-cell" - target_cursor_type = "cell" + target_cursor_type: str = "cell" if scope == "row": css_class = "flash-row" @@ -55,26 +46,21 @@ def _flash_table_yank(self, table: DataTable, scope: str) -> None: target_cursor_type = previous_cursor_type try: - table.cursor_type = target_cursor_type + table.cursor_type = target_cursor_type # type: ignore[assignment] except Exception: pass - table.add_class(css_class) - - def _clear() -> None: + def restore_cursor() -> None: try: - table.remove_class(css_class) - try: - table.cursor_type = previous_cursor_type - except Exception: - pass + table.cursor_type = previous_cursor_type # type: ignore[assignment] except Exception: pass - table.set_timer(0.15, _clear) + flash_widget(table, css_class, on_complete=restore_cursor) def _format_tsv(self, columns: list[str], rows: list[tuple]) -> str: """Format columns and rows as TSV.""" + def fmt(value: object) -> str: if value is None: return "NULL" @@ -87,7 +73,7 @@ def fmt(value: object) -> str: lines.append("\t".join(fmt(v) for v in row)) return "\n".join(lines) - def action_view_cell(self) -> None: + def action_view_cell(self: AppProtocol) -> None: """View the full value of the selected cell.""" from ..screens import ValueViewScreen @@ -99,13 +85,9 @@ def action_view_cell(self) -> None: value = table.get_cell_at(table.cursor_coordinate) except Exception: return - self.push_screen( - ValueViewScreen( - str(value) if value is not None else "NULL", title="Cell Value" - ) - ) + self.push_screen(ValueViewScreen(str(value) if value is not None else "NULL", title="Cell Value")) - def action_copy_cell(self) -> None: + def action_copy_cell(self: AppProtocol) -> None: """Copy the selected cell to clipboard (or internal clipboard).""" table = self.results_table if table.row_count <= 0: @@ -118,7 +100,7 @@ def action_copy_cell(self) -> None: self._copy_text(str(value) if value is not None else "NULL") self._flash_table_yank(table, "cell") - def action_copy_row(self) -> None: + def action_copy_row(self: AppProtocol) -> None: """Copy the selected row to clipboard (TSV).""" table = self.results_table if table.row_count <= 0: @@ -133,7 +115,7 @@ def action_copy_row(self) -> None: self._copy_text(text) self._flash_table_yank(table, "row") - def action_copy_results(self) -> None: + def action_copy_results(self: AppProtocol) -> None: """Copy the entire results (last query) to clipboard (TSV).""" if not self._last_result_columns and not self._last_result_rows: self.notify("No results", severity="warning") @@ -142,3 +124,130 @@ def action_copy_results(self) -> None: text = self._format_tsv(self._last_result_columns, self._last_result_rows) self._copy_text(text) self._flash_table_yank(self.results_table, "all") + + def action_results_cursor_left(self: AppProtocol) -> None: + """Move results cursor left (vim h).""" + if self.results_table.has_focus: + self.results_table.action_cursor_left() + + def action_results_cursor_down(self: AppProtocol) -> None: + """Move results cursor down (vim j).""" + if self.results_table.has_focus: + self.results_table.action_cursor_down() + + def action_results_cursor_up(self: AppProtocol) -> None: + """Move results cursor up (vim k).""" + if self.results_table.has_focus: + self.results_table.action_cursor_up() + + def action_results_cursor_right(self: AppProtocol) -> None: + """Move results cursor right (vim l).""" + if self.results_table.has_focus: + self.results_table.action_cursor_right() + + def action_clear_results(self: AppProtocol) -> None: + """Clear the results table.""" + self.results_table.clear(columns=True) + self.results_table.show_header = False + self._last_result_columns = [] + self._last_result_rows = [] + self._last_result_row_count = 0 + + def action_edit_cell(self: AppProtocol) -> None: + """Generate an UPDATE query for the selected cell and enter insert mode.""" + table = self.results_table + if table.row_count <= 0: + self.notify("No results", severity="warning") + return + + if not self._last_result_columns: + self.notify("No column info", severity="warning") + return + + try: + cursor_row, cursor_col = table.cursor_coordinate + value = table.get_cell_at(table.cursor_coordinate) + row_values = table.get_row_at(cursor_row) + except Exception: + return + + # Get column name + if cursor_col >= len(self._last_result_columns): + return + column_name = self._last_result_columns[cursor_col] + + # Check if this column is a primary key - don't allow editing PKs + if hasattr(self, "_last_query_table") and self._last_query_table: + for col in self._last_query_table.get("columns", []): + if col.name == column_name and col.is_primary_key: + self.notify("Cannot edit primary key column", severity="warning") + return + + # Format value for SQL + def sql_value(v: object) -> str: + if v is None: + return "NULL" + if isinstance(v, bool): + return "TRUE" if v else "FALSE" + if isinstance(v, int | float): + return str(v) + # String - escape single quotes + return "'" + str(v).replace("'", "''") + "'" + + # Get table name and primary key columns + table_name = "" + pk_column_names: set[str] = set() + + if hasattr(self, "_last_query_table") and self._last_query_table: + table_info = self._last_query_table + table_name = table_info["name"] + # Get PK columns from column info + for col in table_info.get("columns", []): + if col.is_primary_key: + pk_column_names.add(col.name) + + # Build WHERE clause - prefer PK columns, fall back to all columns + where_parts = [] + for i, col in enumerate(self._last_result_columns): + if i < len(row_values): + # If we have PK info, only use PK columns; otherwise use all columns + if pk_column_names and col not in pk_column_names: + continue + val = row_values[i] + if val is None: + where_parts.append(f"{col} IS NULL") + else: + where_parts.append(f"{col} = {sql_value(val)}") + + # If no where parts (no PKs matched result columns), fall back to all columns + if not where_parts: + for i, col in enumerate(self._last_result_columns): + if i < len(row_values): + val = row_values[i] + if val is None: + where_parts.append(f"{col} IS NULL") + else: + where_parts.append(f"{col} = {sql_value(val)}") + + where_clause = " AND ".join(where_parts) + + # Generate UPDATE query with empty placeholder for the new value + query = f"UPDATE {table_name} SET {column_name} = '' WHERE {where_clause};" + + # Find position inside the empty quotes (after "SET column = '") + set_prefix = f"SET {column_name} = '" + cursor_pos = query.find(set_prefix) + len(set_prefix) + + # Set query and switch to insert mode + self.query_input.text = query + self.query_input.focus() + + # Position cursor inside the empty quotes + self.query_input.cursor_location = (0, cursor_pos) + + # Enter insert mode + from ...widgets import VimMode + self.vim_mode = VimMode.INSERT + self.query_input.read_only = False + self._update_status_bar() + self._update_footer_bindings() diff --git a/sqlit/ui/mixins/tree.py b/sqlit/ui/mixins/tree.py index 9cec92a1..bd10a313 100644 --- a/sqlit/ui/mixins/tree.py +++ b/sqlit/ui/mixins/tree.py @@ -7,6 +7,7 @@ from rich.markup import escape as escape_markup from textual.widgets import Tree +from ..protocols import AppProtocol from ..tree_nodes import ( ColumnNode, ConnectionNode, @@ -20,22 +21,12 @@ ) if TYPE_CHECKING: - from ...config import ConnectionConfig - from ...services import ConnectionSession + pass class TreeMixin: """Mixin providing tree/explorer functionality.""" - # These attributes are defined in the main app class - connections: list - current_connection: Any - current_config: "ConnectionConfig | None" - current_adapter: Any - _expanded_paths: set[str] - _session: "ConnectionSession | None" - _loading_nodes: set - def _db_type_badge(self, db_type: str) -> str: """Get short badge for database type.""" badge_map = { @@ -51,7 +42,7 @@ def _db_type_badge(self, db_type: str) -> str: } return badge_map.get(db_type, db_type.upper() if db_type else "DB") - def refresh_tree(self) -> None: + def refresh_tree(self: AppProtocol) -> None: """Refresh the explorer tree.""" self.object_tree.clear() self.object_tree.root.expand() @@ -60,28 +51,35 @@ def refresh_tree(self) -> None: display_info = escape_markup(conn.get_display_info()) db_type_label = self._db_type_badge(conn.db_type) escaped_name = escape_markup(conn.name) - node = self.object_tree.root.add( - f"[dim]{escaped_name}[/dim] [{db_type_label}] ({display_info})" + # Check if this is the connected server + is_connected = ( + self.current_config is not None + and conn.name == self.current_config.name ) + if is_connected: + label = f"[#4ADE80]* {escaped_name}[/] [{db_type_label}] ({display_info})" + else: + label = f"[dim]{escaped_name}[/dim] [{db_type_label}] ({display_info})" + node = self.object_tree.root.add(label) node.data = ConnectionNode(config=conn) node.allow_expand = True if self.current_connection and self.current_config: self.populate_connected_tree() - def populate_connected_tree(self) -> None: + def populate_connected_tree(self: AppProtocol) -> None: """Populate tree with database objects when connected.""" if not self.current_connection or not self.current_config or not self.current_adapter: return adapter = self.current_adapter - def get_conn_label(config, connected=False): + def get_conn_label(config: Any, connected: Any = False) -> str: display_info = escape_markup(config.get_display_info()) db_type_label = self._db_type_badge(config.db_type) escaped_name = escape_markup(config.name) if connected: - name = f"[green]{escaped_name}[/green]" + name = f"[#4ADE80]* {escaped_name}[/]" else: name = escaped_name return f"{name} [{db_type_label}] ({display_info})" @@ -95,9 +93,7 @@ def get_conn_label(config, connected=False): break if not active_node: - active_node = self.object_tree.root.add( - get_conn_label(self.current_config, connected=True) - ) + active_node = self.object_tree.root.add(get_conn_label(self.current_config, connected=True)) active_node.data = ConnectionNode(config=self.current_config) active_node.remove_children() @@ -130,7 +126,7 @@ def get_conn_label(config, connected=False): except Exception as e: self.notify(f"Error loading objects: {e}", severity="error") - def _add_database_object_nodes(self, parent_node, database: str | None) -> None: + def _add_database_object_nodes(self: AppProtocol, parent_node: Any, database: str | None) -> None: """Add Tables, Views, and optionally Stored Procedures nodes.""" tables_node = parent_node.add("Tables") tables_node.data = FolderNode(folder_type="tables", database=database) @@ -145,7 +141,7 @@ def _add_database_object_nodes(self, parent_node, database: str | None) -> None: procs_node.data = FolderNode(folder_type="procedures", database=database) procs_node.allow_expand = True - def _get_node_path(self, node) -> str: + def _get_node_path(self, node: Any) -> str: """Get a unique path string for a tree node.""" parts = [] current = node @@ -159,13 +155,13 @@ def _get_node_path(self, node) -> str: parts.append(f"folder:{data.folder_type}") elif isinstance(data, SchemaNode): parts.append(f"schema:{data.schema}") - elif isinstance(data, (TableNode, ViewNode)): + elif isinstance(data, TableNode | ViewNode): node_type = "table" if isinstance(data, TableNode) else "view" parts.append(f"{node_type}:{data.schema}.{data.name}") current = current.parent return "/".join(reversed(parts)) - def _restore_subtree_expansion(self, node) -> None: + def _restore_subtree_expansion(self: AppProtocol, node: Any) -> None: """Recursively expand nodes that should be expanded.""" for child in node.children: if child.data: @@ -174,13 +170,13 @@ def _restore_subtree_expansion(self, node) -> None: child.expand() self._restore_subtree_expansion(child) - def _save_expanded_state(self) -> None: + def _save_expanded_state(self: AppProtocol) -> None: """Save which nodes are expanded.""" from ...config import load_settings, save_settings expanded = [] - def collect_expanded(node): + def collect_expanded(node: Any) -> None: if node.is_expanded and node.data: path = self._get_node_path(node) if path: @@ -195,11 +191,11 @@ def collect_expanded(node): settings["expanded_nodes"] = expanded save_settings(settings) - def on_tree_node_collapsed(self, event: Tree.NodeCollapsed) -> None: + def on_tree_node_collapsed(self: AppProtocol, event: Tree.NodeCollapsed) -> None: """Save state when a node is collapsed.""" self.call_later(self._save_expanded_state) - def on_tree_node_expanded(self, event: Tree.NodeExpanded) -> None: + def on_tree_node_expanded(self: AppProtocol, event: Tree.NodeExpanded) -> None: """Load child objects when a node is expanded.""" node = event.node @@ -229,7 +225,7 @@ def on_tree_node_expanded(self, event: Tree.NodeExpanded) -> None: return # Already loading this node # Handle table/view column expansion - if isinstance(data, (TableNode, ViewNode)): + if isinstance(data, TableNode | ViewNode): self._loading_nodes.add(node_path) loading_node = node.add_leaf("[dim italic]Loading...[/]") loading_node.data = LoadingNode() @@ -244,7 +240,7 @@ def on_tree_node_expanded(self, event: Tree.NodeExpanded) -> None: self._load_folder_async(node, data) return - def _load_columns_async(self, node, data: TableNode | ViewNode) -> None: + def _load_columns_async(self: AppProtocol, node: Any, data: TableNode | ViewNode) -> None: """Spawn worker to load columns for a table/view.""" db_name = data.database schema_name = data.schema @@ -267,7 +263,9 @@ def work() -> None: self.run_worker(work, name=f"load-columns-{obj_name}", thread=True, exclusive=False) - def _on_columns_loaded(self, node, db_name: str | None, schema_name: str, obj_name: str, columns: list) -> None: + def _on_columns_loaded( + self: AppProtocol, node: Any, db_name: str | None, schema_name: str, obj_name: str, columns: list + ) -> None: """Handle column load completion on main thread.""" node_path = self._get_node_path(node) self._loading_nodes.discard(node_path) @@ -282,7 +280,7 @@ def _on_columns_loaded(self, node, db_name: str | None, schema_name: str, obj_na child = node.add_leaf(f"[dim]{col_name}[/] [italic dim]{col_type}[/]") child.data = ColumnNode(database=db_name, schema=schema_name, table=obj_name, name=col.name) - def _load_folder_async(self, node, data: FolderNode) -> None: + def _load_folder_async(self: AppProtocol, node: Any, data: FolderNode) -> None: """Spawn worker to load folder contents (tables/views/procedures).""" folder_type = data.folder_type db_name = data.database @@ -302,7 +300,7 @@ def work() -> None: items = [("view", s, v) for s, v in adapter.get_views(conn, db_name)] elif folder_type == "procedures": if adapter.supports_stored_procedures: - items = [("procedure", p) for p in adapter.get_procedures(conn, db_name)] + items = [("procedure", "", p) for p in adapter.get_procedures(conn, db_name)] else: items = [] else: @@ -315,7 +313,7 @@ def work() -> None: self.run_worker(work, name=f"load-folder-{folder_type}", thread=True, exclusive=False) - def _on_folder_loaded(self, node, db_name: str | None, folder_type: str, items: list) -> None: + def _on_folder_loaded(self: AppProtocol, node: Any, db_name: str | None, folder_type: str, items: list) -> None: """Handle folder load completion on main thread.""" node_path = self._get_node_path(node) self._loading_nodes.discard(node_path) @@ -334,15 +332,15 @@ def _on_folder_loaded(self, node, db_name: str | None, folder_type: str, items: else: for item in items: if item[0] == "procedure": - child = node.add(escape_markup(item[1])) - child.data = ProcedureNode(database=db_name, name=item[1]) + child = node.add(escape_markup(item[2])) + child.data = ProcedureNode(database=db_name, name=item[2]) def _add_schema_grouped_items( self, - node, + node: Any, db_name: str | None, folder_type: str, - items: list, + items: list[Any], default_schema: str, ) -> None: """Add tables/views grouped by schema.""" @@ -359,7 +357,7 @@ def schema_sort_key(schema: str) -> tuple[int, str]: sorted_schemas = sorted(by_schema.keys(), key=schema_sort_key) has_multiple_schemas = len(sorted_schemas) > 1 - schema_nodes: dict[str, any] = {} + schema_nodes: dict[str, Any] = {} for schema in sorted_schemas: schema_items = by_schema[schema] @@ -372,7 +370,9 @@ def schema_sort_key(schema: str) -> tuple[int, str]: display_name = schema if schema else default_schema escaped_name = escape_markup(display_name) schema_node = node.add(f"[dim]\\[{escaped_name}][/]") - schema_node.data = SchemaNode(database=db_name, schema=schema or default_schema, folder_type=folder_type) + schema_node.data = SchemaNode( + database=db_name, schema=schema or default_schema, folder_type=folder_type + ) schema_node.allow_expand = True schema_nodes[schema] = schema_node parent = schema_nodes[schema] @@ -386,7 +386,7 @@ def schema_sort_key(schema: str) -> tuple[int, str]: child.data = ViewNode(database=db_name, schema=schema_name, name=obj_name) child.allow_expand = True - def _on_tree_load_error(self, node, error_message: str) -> None: + def _on_tree_load_error(self: AppProtocol, node: Any, error_message: str) -> None: """Handle tree load error on main thread.""" node_path = self._get_node_path(node) self._loading_nodes.discard(node_path) @@ -397,7 +397,7 @@ def _on_tree_load_error(self, node, error_message: str) -> None: self.notify(escape_markup(error_message), severity="error") - def on_tree_node_selected(self, event: Tree.NodeSelected) -> None: + def on_tree_node_selected(self: AppProtocol, event: Tree.NodeSelected) -> None: """Handle tree node selection (double-click/enter).""" node = event.node if not node.data: @@ -413,18 +413,19 @@ def on_tree_node_selected(self, event: Tree.NodeSelected) -> None: self._disconnect_silent() self.connect_to_server(config) - def on_tree_node_highlighted(self, event: Tree.NodeHighlighted) -> None: + def on_tree_node_highlighted(self: AppProtocol, event: Tree.NodeHighlighted) -> None: """Update footer when tree selection changes.""" self._update_footer_bindings() - def action_refresh_tree(self) -> None: + def action_refresh_tree(self: AppProtocol) -> None: """Refresh the explorer.""" self.refresh_tree() self.notify("Refreshed") - def action_collapse_tree(self) -> None: + def action_collapse_tree(self: AppProtocol) -> None: """Collapse all nodes in the explorer.""" - def collapse_all(node): + + def collapse_all(node: Any) -> None: for child in node.children: collapse_all(child) child.collapse() @@ -433,9 +434,19 @@ def collapse_all(node): self._expanded_paths.clear() self._save_expanded_state() - def action_select_table(self) -> None: + def action_tree_cursor_down(self: AppProtocol) -> None: + """Move tree cursor down (vim j).""" + if self.object_tree.has_focus: + self.object_tree.action_cursor_down() + + def action_tree_cursor_up(self: AppProtocol) -> None: + """Move tree cursor up (vim k).""" + if self.object_tree.has_focus: + self.object_tree.action_cursor_up() + + def action_select_table(self: AppProtocol) -> None: """Generate and execute SELECT query for selected table/view.""" - if not self.current_adapter: + if not self.current_adapter or not self._session: return node = self.object_tree.cursor_node @@ -444,8 +455,22 @@ def action_select_table(self) -> None: return data = node.data - if not isinstance(data, (TableNode, ViewNode)): + if not isinstance(data, TableNode | ViewNode): return + # Store table info for edit_cell action + try: + columns = self._session.adapter.get_columns( + self._session.connection, data.name, data.database, data.schema + ) + self._last_query_table = { + "database": data.database, + "schema": data.schema, + "name": data.name, + "columns": columns, + } + except Exception: + self._last_query_table = None + self.query_input.text = self.current_adapter.build_select_query(data.name, 100, data.database, data.schema) self.action_execute_query() diff --git a/sqlit/ui/mixins/tree_filter.py b/sqlit/ui/mixins/tree_filter.py new file mode 100644 index 00000000..2c235ce7 --- /dev/null +++ b/sqlit/ui/mixins/tree_filter.py @@ -0,0 +1,313 @@ +"""Tree filter mixin for SSMSTUI.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from rich.markup import escape as escape_markup + +from ...utils import fuzzy_match, highlight_matches +from ..protocols import AppProtocol +from ..tree_nodes import ( + ColumnNode, + ConnectionNode, + DatabaseNode, + FolderNode, + ProcedureNode, + SchemaNode, + TableNode, + ViewNode, +) + +if TYPE_CHECKING: + pass + + +class TreeFilterMixin: + """Mixin providing tree filter functionality.""" + + _tree_filter_visible: bool = False + _tree_filter_text: str = "" + _tree_filter_matches: list[Any] = [] + _tree_filter_match_index: int = 0 + _tree_original_labels: dict[int, str] = {} + + def action_tree_filter(self: AppProtocol) -> None: + """Open the tree filter.""" + if not self.object_tree.has_focus: + self.object_tree.focus() + + self._tree_filter_visible = True + self._tree_filter_text = "" + self._tree_filter_matches = [] + self._tree_filter_match_index = 0 + self._tree_original_labels = {} + + self.tree_filter_input.show() + self._update_tree_filter() + self._update_footer_bindings() + + def action_tree_filter_close(self: AppProtocol) -> None: + """Close the tree filter and restore tree.""" + self._tree_filter_visible = False + self._tree_filter_text = "" + self.tree_filter_input.hide() + self._restore_tree_labels() + self._show_all_tree_nodes() + self._update_footer_bindings() + + def action_tree_filter_accept(self: AppProtocol) -> None: + """Accept current filter selection and close.""" + self.action_tree_filter_close() + + def action_tree_filter_next(self: AppProtocol) -> None: + """Move to next filter match.""" + if not self._tree_filter_matches: + return + self._tree_filter_match_index = (self._tree_filter_match_index + 1) % len( + self._tree_filter_matches + ) + self._jump_to_current_match() + + def action_tree_filter_prev(self: AppProtocol) -> None: + """Move to previous filter match.""" + if not self._tree_filter_matches: + return + self._tree_filter_match_index = (self._tree_filter_match_index - 1) % len( + self._tree_filter_matches + ) + self._jump_to_current_match() + + def _jump_to_current_match(self: AppProtocol) -> None: + """Jump to the current match in the tree.""" + if not self._tree_filter_matches: + return + node = self._tree_filter_matches[self._tree_filter_match_index] + # Expand ancestors to make node visible + self._expand_ancestors(node) + # Select the node + self.object_tree.select_node(node) + + def _expand_ancestors(self: AppProtocol, node: Any) -> None: + """Expand all ancestor nodes to make a node visible.""" + ancestors = [] + current = node.parent + while current and current != self.object_tree.root: + ancestors.append(current) + current = current.parent + # Expand from root down + for ancestor in reversed(ancestors): + ancestor.expand() + + def on_key(self: AppProtocol, event: Any) -> None: + """Handle key events when tree filter is active.""" + if not self._tree_filter_visible: + # Pass to next mixin in chain (e.g., AutocompleteMixin) + super().on_key(event) # type: ignore[misc] + return + + key = event.key + + # Handle backspace + if key == "backspace": + if self._tree_filter_text: + self._tree_filter_text = self._tree_filter_text[:-1] + self._update_tree_filter() + else: + # Exit filter when backspacing with no text + self.action_tree_filter_close() + event.prevent_default() + event.stop() + return + + # Handle printable characters + if len(key) == 1 and key.isprintable(): + self._tree_filter_text += key + self._update_tree_filter() + event.prevent_default() + event.stop() + return + + # Pass unhandled keys to next mixin + super().on_key(event) # type: ignore[misc] + + def _update_tree_filter(self: AppProtocol) -> None: + """Update the tree based on current filter text.""" + self._restore_tree_labels() + total = self._count_all_nodes() + + if not self._tree_filter_text: + self._show_all_tree_nodes() + self._tree_filter_matches = [] + self.tree_filter_input.set_filter("", 0, total) + return + + # Find all matching nodes + matches: list[Any] = [] + self._find_matching_nodes(self.object_tree.root, matches) + + self._tree_filter_matches = matches + self._tree_filter_match_index = 0 + + # Hide non-matching nodes and highlight matches + self._apply_filter_to_tree() + + # Update filter display + self.tree_filter_input.set_filter( + self._tree_filter_text, len(matches), total + ) + + # Jump to first match + if matches: + self._jump_to_current_match() + + def _find_matching_nodes( + self: AppProtocol, node: Any, matches: list + ) -> bool: + """Recursively find nodes matching the filter. + + Returns True if this node or any descendant matches. + """ + node_matches = False + has_matching_child = False + + # Check children first + for child in node.children: + if self._find_matching_nodes(child, matches): + has_matching_child = True + + # Get node label text for matching + label_text = self._get_node_label_text(node) + if label_text: + matched, indices = fuzzy_match(self._tree_filter_text, label_text) + if matched: + node_matches = True + matches.append(node) + # Store original label and apply highlighting + self._tree_original_labels[id(node)] = str(node.label) + highlighted = highlight_matches( + escape_markup(label_text), indices, style="bold #FFFF00" + ) + # Preserve any existing markup prefix (like icons, colors) + node.set_label(self._rebuild_label_with_highlight(node, highlighted)) + + return node_matches or has_matching_child + + def _get_node_label_text(self, node: Any) -> str: + """Get the plain text label for a node.""" + data = node.data + if data is None: + return "" + + if isinstance(data, ConnectionNode): + return data.config.name + elif isinstance(data, DatabaseNode): + return data.name + elif isinstance(data, FolderNode): + return data.folder_type + elif isinstance(data, SchemaNode): + return data.schema + elif isinstance(data, TableNode): + return data.name + elif isinstance(data, ViewNode): + return data.name + elif isinstance(data, ProcedureNode): + return data.name + elif isinstance(data, ColumnNode): + return data.name + return "" + + def _rebuild_label_with_highlight(self, node: Any, highlighted_text: str) -> str: + """Rebuild the node label with highlighted text.""" + data = node.data + if data is None: + return highlighted_text + + # For simple nodes, just return highlighted text + if isinstance(data, TableNode | ViewNode | ProcedureNode): + return highlighted_text + elif isinstance(data, ColumnNode): + # Columns show "name type" format - only highlight name + return highlighted_text + elif isinstance(data, ConnectionNode): + # Connections have format "[color]* name[/] [TYPE] (info)" + # Just replace the name portion + return highlighted_text + elif isinstance(data, DatabaseNode | SchemaNode | FolderNode): + return highlighted_text + + return highlighted_text + + def _apply_filter_to_tree(self: AppProtocol) -> None: + """Hide nodes that don't match and aren't ancestors of matches.""" + match_ids = {id(n) for n in self._tree_filter_matches} + ancestor_ids = set() + + # Collect all ancestor IDs + for node in self._tree_filter_matches: + current = node.parent + while current and current != self.object_tree.root: + ancestor_ids.add(id(current)) + current = current.parent + + # Hide non-matching, non-ancestor nodes + self._set_node_visibility( + self.object_tree.root, match_ids, ancestor_ids, visible=True + ) + + def _set_node_visibility( + self: AppProtocol, + node: Any, + match_ids: set, + ancestor_ids: set, + visible: bool, + ) -> None: + """Recursively set node visibility.""" + for child in node.children: + child_id = id(child) + is_match = child_id in match_ids + is_ancestor = child_id in ancestor_ids + should_show = is_match or is_ancestor or not self._tree_filter_text + + # Use display style to hide/show + # Note: Textual Tree doesn't have per-node visibility, + # so we'll dim non-matching nodes instead + if not should_show and self._tree_filter_text: + # Dim non-matching nodes + original = self._tree_original_labels.get(child_id, str(child.label)) + if child_id not in self._tree_original_labels: + self._tree_original_labels[child_id] = original + child.set_label(f"[dim]{escape_markup(self._get_node_label_text(child))}[/]") + + self._set_node_visibility(child, match_ids, ancestor_ids, should_show) + + def _show_all_tree_nodes(self: AppProtocol) -> None: + """Show all tree nodes (remove filter dimming).""" + # Labels are restored by _restore_tree_labels + pass + + def _restore_tree_labels(self: AppProtocol) -> None: + """Restore original labels for all modified nodes.""" + def restore_node(node: Any) -> None: + node_id = id(node) + if node_id in self._tree_original_labels: + node.set_label(self._tree_original_labels[node_id]) + for child in node.children: + restore_node(child) + + restore_node(self.object_tree.root) + self._tree_original_labels = {} + + def _count_all_nodes(self: AppProtocol) -> int: + """Count all searchable nodes in the tree.""" + count = 0 + + def count_nodes(node: Any) -> None: + nonlocal count + if node.data and self._get_node_label_text(node): + count += 1 + for child in node.children: + count_nodes(child) + + count_nodes(self.object_tree.root) + return count diff --git a/sqlit/ui/mixins/ui_navigation.py b/sqlit/ui/mixins/ui_navigation.py index 3549cb98..87aa1f62 100644 --- a/sqlit/ui/mixins/ui_navigation.py +++ b/sqlit/ui/mixins/ui_navigation.py @@ -4,29 +4,20 @@ from typing import TYPE_CHECKING, Any -from textual.widgets import DataTable, Static, TextArea, Tree +from textual.timer import Timer +from ..protocols import AppProtocol if TYPE_CHECKING: - from ...widgets import VimMode + pass class UINavigationMixin: """Mixin providing UI navigation and vim mode functionality.""" - # These attributes are defined in the main app class - vim_mode: "VimMode" - current_connection: Any - current_config: Any - _fullscreen_mode: str - _last_notification: str - _last_notification_severity: str - _last_notification_time: str - _notification_timer: Any - _notification_history: list - _leader_timer: Any - _leader_pending: bool - - def _set_fullscreen_mode(self, mode: str) -> None: + _notification_timer: Timer | None = None + _leader_timer: Timer | None = None + + def _set_fullscreen_mode(self: AppProtocol, mode: str) -> None: """Set fullscreen mode: none|explorer|query|results.""" self._fullscreen_mode = mode self.screen.remove_class("results-fullscreen") @@ -40,12 +31,12 @@ def _set_fullscreen_mode(self, mode: str) -> None: elif mode == "explorer": self.screen.add_class("explorer-fullscreen") - def _update_section_labels(self) -> None: + def _update_section_labels(self: AppProtocol) -> None: """Update section labels to highlight the active pane.""" try: - label_explorer = self.query_one("#label-explorer", Static) - label_query = self.query_one("#label-query", Static) - label_results = self.query_one("#label-results", Static) + pane_explorer = self.query_one("#sidebar") + pane_query = self.query_one("#query-area") + pane_results = self.query_one("#results-area") except Exception: return @@ -69,17 +60,66 @@ def _update_section_labels(self) -> None: # Only update labels if a pane is focused (don't clear when dialogs are open) if active_pane: - label_explorer.remove_class("active") - label_query.remove_class("active") - label_results.remove_class("active") - if active_pane == "explorer": - label_explorer.add_class("active") - elif active_pane == "query": - label_query.add_class("active") - elif active_pane == "results": - label_results.add_class("active") - - def action_focus_explorer(self) -> None: + self._last_active_pane = active_pane + + # Update active-pane class based on dialog state + # When dialog is open, remove active-pane class (border reverts to default) + # but title text will stay primary via explicit markup in _sync_active_pane_title + dialog_open = bool(getattr(self, "_dialog_open", False)) + pane_explorer.remove_class("active-pane") + pane_query.remove_class("active-pane") + pane_results.remove_class("active-pane") + + if not dialog_open: + last_active = getattr(self, "_last_active_pane", None) + if last_active == "explorer": + pane_explorer.add_class("active-pane") + elif last_active == "query": + pane_query.add_class("active-pane") + elif last_active == "results": + pane_results.add_class("active-pane") + + self._sync_active_pane_title() + + def _sync_active_pane_title(self: AppProtocol) -> None: + """Adjust pane title color when dialogs are open. + + Keybinding hints [e], [q], [r] are: + - White by default (inactive pane) + - Primary when pane is selected + - White when dialog is open (keybindings disabled) + + The pane title (Explorer, Query, Results) uses CSS border-title-color: + - $border (white) for inactive panes + - $primary for active pane (via .active-pane class) + """ + try: + pane_explorer = self.query_one("#sidebar") + pane_query = self.query_one("#query-area") + pane_results = self.query_one("#results-area") + except Exception: + return + + dialog_open = bool(getattr(self, "_dialog_open", False)) + active_pane = getattr(self, "_last_active_pane", None) + + def set_title(pane: Any, key: str, label: str, *, active: bool) -> None: + if active and dialog_open: + # Active pane with dialog: key matches border (disabled), title stays primary + # Border reverts to default (active-pane class removed) + pane.border_title = f"[$border]\\[{key}][/] [$primary]{label}[/]" + elif active: + # Active pane, no dialog: both key and title primary + pane.border_title = f"[$primary]\\[{key}] {label}[/]" + else: + # Inactive pane: key and title match border color via CSS + pane.border_title = f"\\[{key}] {label}" + + set_title(pane_explorer, "e", "Explorer", active=active_pane == "explorer") + set_title(pane_query, "q", "Query", active=active_pane == "query") + set_title(pane_results, "r", "Results", active=active_pane == "results") + + def action_focus_explorer(self: AppProtocol) -> None: """Focus the Explorer pane.""" if self._fullscreen_mode != "none": self._set_fullscreen_mode("none") @@ -92,7 +132,7 @@ def action_focus_explorer(self) -> None: if self.object_tree.root.children: self.object_tree.cursor_line = 0 - def action_focus_query(self) -> None: + def action_focus_query(self: AppProtocol) -> None: """Focus the Query pane (in NORMAL mode).""" from ...widgets import VimMode @@ -103,13 +143,17 @@ def action_focus_query(self) -> None: self.query_input.focus() self._update_status_bar() - def action_focus_results(self) -> None: + def action_focus_results(self: AppProtocol) -> None: """Focus the Results pane.""" if self._fullscreen_mode != "none": self._set_fullscreen_mode("none") - self.results_table.focus() + try: + self.results_table.focus() + except Exception: + # Results table may not exist yet (Lazy loading) + pass - def action_enter_insert_mode(self) -> None: + def action_enter_insert_mode(self: AppProtocol) -> None: """Enter INSERT mode for query editing.""" from ...widgets import VimMode @@ -119,7 +163,7 @@ def action_enter_insert_mode(self) -> None: self._update_status_bar() self._update_footer_bindings() - def action_exit_insert_mode(self) -> None: + def action_exit_insert_mode(self: AppProtocol) -> None: """Exit INSERT mode, return to NORMAL mode.""" from ...widgets import VimMode @@ -130,17 +174,23 @@ def action_exit_insert_mode(self) -> None: self._update_status_bar() self._update_footer_bindings() - def _update_status_bar(self) -> None: + def _update_status_bar(self: AppProtocol) -> None: """Update status bar with connection and vim mode info.""" from ...widgets import VimMode from .query import SPINNER_FRAMES - status = self.status_bar - if getattr(self, "_connection_failed", False): + try: + status = self.status_bar + except Exception: + return + # Hide connection info while query is executing + if getattr(self, "_query_executing", False): + conn_info = "" + elif getattr(self, "_connection_failed", False): conn_info = "[#ff6b6b]Connection failed[/]" elif self.current_config: display_info = self.current_config.get_display_info() - conn_info = f"[#90EE90]Connected to {self.current_config.name}[/] ({display_info})" + conn_info = f"[#4ADE80]Connected to {self.current_config.name}[/] ({display_info})" else: conn_info = "Not connected" @@ -157,18 +207,17 @@ def _update_status_bar(self) -> None: if getattr(self, "_query_executing", False): import time + from ...utils import format_duration_ms + spinner_idx = getattr(self, "_spinner_index", 0) spinner = SPINNER_FRAMES[spinner_idx % len(SPINNER_FRAMES)] start_time = getattr(self, "_query_start_time", None) if start_time: elapsed_ms = (time.perf_counter() - start_time) * 1000 - if elapsed_ms >= 1000: - elapsed_str = f"{elapsed_ms / 1000:.1f}s" - else: - elapsed_str = f"{elapsed_ms:.0f}ms" - status_parts.append(f"[bold yellow]{spinner} Executing [{elapsed_str}][/] [dim]z to cancel[/]") + elapsed_str = format_duration_ms(elapsed_ms, always_seconds=True) + status_parts.append(f"[bold yellow]{spinner} Executing [{elapsed_str}][/] [dim]^z to cancel[/]") else: - status_parts.append(f"[bold yellow]{spinner} Executing[/] [dim]z to cancel[/]") + status_parts.append(f"[bold yellow]{spinner} Executing[/] [dim]^z to cancel[/]") status_str = " ".join(status_parts) if status_str: @@ -190,11 +239,21 @@ def _update_status_bar(self) -> None: notification = getattr(self, "_last_notification", "") timestamp = getattr(self, "_last_notification_time", "") severity = getattr(self, "_last_notification_severity", "information") + launch_ms = getattr(self, "_launch_ms", None) + show_launch = ( + getattr(self, "_debug_mode", False) + and isinstance(launch_ms, (int, float)) + and not self.current_config + and not getattr(self, "_connection_failed", False) + ) + launch_str = f"[dim]Launched in {launch_ms:.0f}ms[/]" if show_launch else "" + launch_plain = f"Launched in {launch_ms:.0f}ms" if show_launch else "" if notification: # Normal/warning notifications on right side import re - left_plain = re.sub(r'\[.*?\]', '', left_content) + + left_plain = re.sub(r"\[.*?\]", "", left_content) time_prefix = f"[dim]{timestamp}[/] " if timestamp else "" if severity == "warning": @@ -214,16 +273,31 @@ def _update_status_bar(self) -> None: status.update(f"{left_content}{' ' * gap}{notif_str}") else: status.update(f"{left_content} {notif_str}") + elif launch_str: + import re + + left_plain = re.sub(r"\[.*?\]", "", left_content) + try: + total_width = self.size.width - 2 + except Exception: + total_width = 80 + + gap = total_width - len(left_plain) - len(launch_plain) + if gap > 2: + status.update(f"{left_content}{' ' * gap}{launch_str}") + else: + status.update(f"{left_content} {launch_str}") else: status.update(left_content) def notify( - self, + self: AppProtocol, message: str, *, title: str = "", severity: str = "information", - timeout: float = 3.0, + timeout: float | None = None, + markup: bool = True, ) -> None: """Show notification in status bar (takes over full bar temporarily). @@ -257,7 +331,7 @@ def notify( self._last_notification_time = timestamp self._update_status_bar() - def _show_error_in_results(self, message: str, timestamp: str) -> None: + def _show_error_in_results(self: AppProtocol, message: str, timestamp: str) -> None: """Display error message in the results table.""" import textwrap @@ -272,11 +346,12 @@ def _show_error_in_results(self, message: str, timestamp: str) -> None: self._last_result_row_count = 1 self.results_table.clear(columns=True) + self.results_table.show_header = True self.results_table.add_column("Error") self.results_table.add_row(wrapped) self._update_footer_bindings() - def action_toggle_explorer(self) -> None: + def action_toggle_explorer(self: AppProtocol) -> None: """Toggle the visibility of the explorer sidebar.""" if self._fullscreen_mode != "none": self._set_fullscreen_mode("none") @@ -292,7 +367,7 @@ def action_toggle_explorer(self) -> None: self.query_input.focus() self.screen.add_class("explorer-hidden") - def action_change_theme(self) -> None: + def action_change_theme(self: AppProtocol) -> None: """Open the theme selection dialog.""" from ..screens import ThemeScreen @@ -302,7 +377,7 @@ def on_theme_selected(theme: str | None) -> None: self.push_screen(ThemeScreen(self.theme), on_theme_selected) - def action_toggle_fullscreen(self) -> None: + def action_toggle_fullscreen(self: AppProtocol) -> None: """Toggle fullscreen for the currently focused pane.""" if self.object_tree.has_focus: target = "explorer" @@ -328,7 +403,7 @@ def action_toggle_fullscreen(self) -> None: self._update_section_labels() self._update_footer_bindings() - def _update_footer_bindings(self) -> None: + def _update_footer_bindings(self: AppProtocol) -> None: """Update footer with context-appropriate bindings from the state machine.""" from ...widgets import ContextFooter, KeyBinding @@ -339,23 +414,19 @@ def _update_footer_bindings(self) -> None: left_display, right_display = self._state_machine.get_display_bindings(self) - left_bindings = [ - KeyBinding(b.key, b.label, b.action) for b in left_display - ] - right_bindings = [ - KeyBinding(b.key, b.label, b.action) for b in right_display - ] + left_bindings = [KeyBinding(b.key, b.label, b.action) for b in left_display] + right_bindings = [KeyBinding(b.key, b.label, b.action) for b in right_display] footer.set_bindings(left_bindings, right_bindings) - def action_show_help(self) -> None: + def action_show_help(self: AppProtocol) -> None: """Show help with all keybindings.""" from ..screens import HelpScreen help_text = self._state_machine.generate_help_text() self.push_screen(HelpScreen(help_text)) - def action_leader_key(self) -> None: + def action_leader_key(self: AppProtocol) -> None: """Handle leader key (space) press - show command menu after delay.""" from ...widgets import VimMode @@ -369,7 +440,7 @@ def action_leader_key(self) -> None: self._leader_pending = True - def show_menu(): + def show_menu() -> None: if getattr(self, "_leader_pending", False): self._leader_pending = False self._show_leader_menu() @@ -377,14 +448,14 @@ def show_menu(): # Show menu after 200ms delay self._leader_timer = self.set_timer(0.2, show_menu) - def _cancel_leader_pending(self) -> None: + def _cancel_leader_pending(self: AppProtocol) -> None: """Cancel leader pending state and timer.""" self._leader_pending = False if hasattr(self, "_leader_timer") and self._leader_timer is not None: self._leader_timer.stop() self._leader_timer = None - def _execute_leader_command(self, action: str) -> None: + def _execute_leader_command(self: AppProtocol, action: str) -> None: """Execute a leader command by action name. Also clears leader pending state - this is the single place @@ -398,7 +469,7 @@ def _execute_leader_command(self, action: str) -> None: if action_method: action_method() - def _show_leader_menu(self) -> None: + def _show_leader_menu(self: AppProtocol) -> None: """Display the leader menu.""" from textual.screen import ModalScreen @@ -409,47 +480,54 @@ def _show_leader_menu(self) -> None: self.push_screen(LeaderMenuScreen(), self._handle_leader_result) - def _handle_leader_result(self, result: str | None) -> None: + def _handle_leader_result(self: AppProtocol, result: str | None) -> None: """Handle result from leader menu.""" self._update_footer_bindings() if result: self._execute_leader_command(result) - def action_leader_toggle_explorer(self) -> None: + def action_leader_toggle_explorer(self: AppProtocol) -> None: self._execute_leader_command("toggle_explorer") - def action_leader_toggle_fullscreen(self) -> None: + def action_leader_toggle_fullscreen(self: AppProtocol) -> None: self._execute_leader_command("toggle_fullscreen") - def action_leader_show_connection_picker(self) -> None: + def action_leader_show_connection_picker(self: AppProtocol) -> None: self._execute_leader_command("show_connection_picker") - def action_leader_disconnect(self) -> None: + def action_leader_disconnect(self: AppProtocol) -> None: self._execute_leader_command("disconnect") - def action_leader_cancel_operation(self) -> None: + def action_leader_cancel_operation(self: AppProtocol) -> None: self._execute_leader_command("cancel_operation") - def action_leader_change_theme(self) -> None: + def action_leader_change_theme(self: AppProtocol) -> None: self._execute_leader_command("change_theme") - def action_leader_show_help(self) -> None: + def action_leader_show_help(self: AppProtocol) -> None: self._execute_leader_command("show_help") - def action_leader_quit(self) -> None: + def action_leader_quit(self: AppProtocol) -> None: self._execute_leader_command("quit") - def on_descendant_focus(self, event) -> None: + def on_descendant_focus(self: AppProtocol, event: Any) -> None: """Handle focus changes to update section labels and footer.""" from ...widgets import VimMode self._update_section_labels() - if not self.query_input.has_focus and self.vim_mode == VimMode.INSERT: + try: + has_query_focus = self.query_input.has_focus + except Exception: + has_query_focus = False + if not has_query_focus and self.vim_mode == VimMode.INSERT: self.vim_mode = VimMode.NORMAL - self.query_input.read_only = True + try: + self.query_input.read_only = True + except Exception: + pass self._update_footer_bindings() self._update_status_bar() - def on_descendant_blur(self, event) -> None: + def on_descendant_blur(self: AppProtocol, event: Any) -> None: """Handle blur to update section labels.""" self.call_later(self._update_section_labels) diff --git a/sqlit/ui/protocols.py b/sqlit/ui/protocols.py new file mode 100644 index 00000000..dbe9e57a --- /dev/null +++ b/sqlit/ui/protocols.py @@ -0,0 +1,525 @@ +"""Protocol definitions for mixin type safety.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload + +if TYPE_CHECKING: + from textual.screen import Screen + from textual.timer import Timer + from textual.widget import Widget + from textual.widgets import DataTable, Static, TextArea, Tree + from textual.worker import Worker + + from ..config import ConnectionConfig + from ..db import DatabaseAdapter + from ..services import ConnectionSession, QueryService + from ..widgets import VimMode + +QueryType = TypeVar("QueryType", bound="Widget") + + +class AppProtocol(Protocol): + """Protocol defining what mixins expect from the App class. + + This protocol captures the interface that mixin classes depend on, + allowing proper type checking without creating inheritance conflicts. + Mixins should use `self: AppProtocol` in method signatures. + """ + + # === Textual App methods === + + def notify( + self, + message: str, + *, + title: str = "", + severity: str = "information", + timeout: float | None = None, + markup: bool = True, + ) -> None: + """Show notification.""" + ... + + def push_screen( + self, + screen: Screen[Any] | str, + callback: Callable[[Any], None] | Callable[[Any], Awaitable[None]] | None = None, + wait_for_dismiss: bool = False, + ) -> Any: + """Push a screen onto the screen stack.""" + ... + + def pop_screen(self) -> Any: + """Pop the current screen from the screen stack.""" + ... + + def run_worker( + self, + work: Any, + name: str | None = "", + group: str = "default", + description: str = "", + exit_on_error: bool = True, + start: bool = True, + exclusive: bool = False, + thread: bool = False, + ) -> Worker[Any]: + """Run work in a worker thread/task.""" + ... + + def call_later(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> bool: + """Schedule a callback to run later on the main thread.""" + ... + + def call_from_thread(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Call a function from a worker thread on the main thread.""" + ... + + def set_timer( + self, + delay: float, + callback: Callable[[], None] | None = None, + *, + name: str | None = None, + pause: bool = False, + ) -> Timer: + """Set a timer to call a function after a delay.""" + ... + + def set_interval( + self, + interval: float, + callback: Callable[[], None] | None = None, + *, + name: str | None = None, + repeat: int = 0, + pause: bool = False, + ) -> Timer: + """Set an interval timer.""" + ... + + @overload + def query_one(self, selector: str) -> Widget: ... + + @overload + def query_one(self, selector: type[QueryType]) -> QueryType: ... + + @overload + def query_one(self, selector: str, expect_type: type[QueryType]) -> QueryType: ... + + def query_one(self, selector: Any, expect_type: Any = None) -> Any: + """Query for a single widget.""" + ... + + def copy_to_clipboard(self, text: str) -> None: + """Copy text to clipboard.""" + ... + + def exit(self, result: Any = None, return_code: int = 0, message: Any | None = None) -> None: + """Exit the application.""" + ... + + # === Screen-related attributes === + + @property + def screen(self) -> Screen[Any]: + """The current screen.""" + ... + + @property + def screen_stack(self) -> list[Screen[Any]]: + """The screen stack.""" + ... + + @property + def focused(self) -> Any: + """The currently focused widget.""" + ... + + @property + def size(self) -> Any: + """The size of the terminal.""" + ... + + @property + def theme(self) -> str: + """The current theme name.""" + ... + + @theme.setter + def theme(self, value: str) -> None: + """Set the theme.""" + ... + + # === SSMSTUI widget properties === + + @property + def object_tree(self) -> Tree[Any]: + """The explorer tree widget.""" + ... + + @property + def query_input(self) -> TextArea: + """The query input text area.""" + ... + + @property + def results_table(self) -> DataTable: + """The results table widget.""" + ... + + @property + def status_bar(self) -> Static: + """The status bar widget.""" + ... + + @property + def autocomplete_dropdown(self) -> Any: + """The autocomplete dropdown widget.""" + ... + + # === Connection state === + + connections: list[ConnectionConfig] + current_connection: Any + current_config: ConnectionConfig | None + current_adapter: DatabaseAdapter | None + current_ssh_tunnel: Any + + # === Session state === + + _session: ConnectionSession | None + _session_factory: Callable[[ConnectionConfig], ConnectionSession] | None + _connection_failed: bool + + # === Vim mode state === + + vim_mode: VimMode + + # === Tree state === + + _expanded_paths: set[str] + _loading_nodes: set[str] + + # === Query execution state === + + _query_worker: Worker[Any] | None + _query_executing: bool + _query_start_time: float + _spinner_index: int + _spinner_timer: Timer | None + _cancellable_query: Any + + # === Schema cache state === + + _schema_cache: dict[str, Any] + _schema_indexing: bool + _schema_worker: Worker[Any] | None + _schema_spinner_index: int + _schema_spinner_timer: Timer | None + _table_metadata: dict[str, tuple[str, str, str | None]] + _columns_loading: set[str] + + # === Autocomplete state === + + _autocomplete_filter: str + _autocomplete_just_applied: bool + _autocomplete_visible: bool + + # === Results state === + + _last_result_columns: list[str] + _last_result_rows: list[tuple[Any, ...]] + _last_result_row_count: int + _internal_clipboard: str + _last_query_table: dict[str, Any] | None + + # === UI state === + + _fullscreen_mode: str + _last_notification: str + _last_notification_severity: str + _last_notification_time: str + _notification_timer: Timer | None + _notification_history: list[tuple[str, str, str]] + _leader_timer: Timer | None + _leader_pending: bool + _state_machine: Any + _mock_profile: Any + + # === SSMSTUI methods that mixins call on each other === + + def refresh_tree(self) -> None: + """Refresh the explorer tree.""" + ... + + def populate_connected_tree(self) -> None: + """Populate tree with database objects when connected.""" + ... + + def _update_status_bar(self) -> None: + """Update status bar with connection and vim mode info.""" + ... + + def _update_footer_bindings(self) -> None: + """Update footer with context-appropriate bindings.""" + ... + + def _hide_autocomplete(self) -> None: + """Hide the autocomplete dropdown.""" + ... + + def _load_schema_cache(self) -> None: + """Load database schema for autocomplete asynchronously.""" + ... + + def _stop_schema_spinner(self) -> None: + """Stop the schema indexing spinner animation.""" + ... + + def _disconnect_silent(self) -> None: + """Disconnect from current database without notification.""" + ... + + def connect_to_server(self, config: ConnectionConfig) -> None: + """Connect to a database (async, non-blocking).""" + ... + + def action_execute_query(self) -> None: + """Execute the current query.""" + ... + + # === Tree Mixin methods === + + def _db_type_badge(self, db_type: str) -> str: + """Get short badge for database type.""" + ... + + def _add_database_object_nodes(self, node: Any, database: str | None) -> None: + """Add database object nodes (tables, views, etc.) to a folder.""" + ... + + def _restore_subtree_expansion(self, node: Any) -> None: + """Restore expansion state for a subtree.""" + ... + + def _get_node_path(self, node: Any) -> str: + """Get unique path for a node.""" + ... + + def _save_expanded_state(self) -> None: + """Save currently expanded nodes.""" + ... + + def _load_columns_async(self, node: Any, data: Any) -> None: + """Load columns for a table or view asynchronously.""" + ... + + def _load_folder_async(self, node: Any, data: Any) -> None: + """Load folder content asynchronously.""" + ... + + def _on_columns_loaded( + self, node: Any, db_name: str | None, schema_name: str, obj_name: str, columns: list[Any] + ) -> None: + """Handle columns loaded event.""" + ... + + def _on_tree_load_error(self, node: Any, error_message: str) -> None: + """Handle tree loading error.""" + ... + + def _on_folder_loaded(self, node: Any, db_name: str | None, folder_type: str, items: list[Any]) -> None: + """Handle folder content loaded event.""" + ... + + def _add_schema_grouped_items( + self, node: Any, db_name: str | None, folder_type: str, items: list[Any], default_schema: str + ) -> None: + """Add items grouped by schema.""" + ... + + # === Autocomplete Mixin methods === + + def _load_columns_for_table(self, table_name: str) -> None: + """Load columns for a table for autocomplete.""" + ... + + def _on_autocomplete_columns_loaded(self, table_name: str, actual_table_name: str, column_names: list[str]) -> None: + """Handle columns loaded for autocomplete.""" + ... + + def _location_to_offset(self, text: str, location: tuple[int, int]) -> int: + """Convert row/col location to string offset.""" + ... + + def _offset_to_location(self, text: str, offset: int) -> tuple[int, int]: + """Convert string offset to row/col location.""" + ... + + def _get_word_before_cursor(self, text: str, cursor_pos: int) -> tuple[str, str]: + """Get the word before the cursor.""" + ... + + def _get_autocomplete_suggestions(self, word: str, context: str) -> list[str]: + """Get autocomplete suggestions.""" + ... + + def _show_autocomplete(self, suggestions: list[str], filter_text: str) -> None: + """Show the autocomplete dropdown.""" + ... + + def _apply_autocomplete(self) -> None: + """Apply the selected autocomplete suggestion.""" + ... + + def _start_schema_spinner(self) -> None: + """Start the schema indexing spinner.""" + ... + + def _load_schema_cache_async(self) -> Awaitable[None]: + """Load schema cache asynchronously.""" + ... + + def _animate_schema_spinner(self) -> None: + """Animate the schema spinner.""" + ... + + def _update_schema_cache( + self, schema_cache: dict[str, Any], table_metadata: dict[str, tuple[str, str, str | None]] | None = None + ) -> None: + """Update the schema cache.""" + ... + + # === Results Mixin methods === + + def _copy_text(self, text: str) -> bool: + """Copy text to clipboard.""" + ... + + def _flash_table_yank(self, table: DataTable, scope: str) -> None: + """Flash the table to indicate copy.""" + ... + + def _format_tsv(self, columns: list[str], rows: list[tuple[Any, ...]]) -> str: + """Format results as TSV.""" + ... + + # === Query Mixin methods === + + @property + def _query_service(self) -> QueryService | None: + """The query execution service.""" + ... + + def _execute_query_common(self, keep_insert_mode: bool) -> None: + """Common logic for executing queries.""" + ... + + def _start_query_spinner(self) -> None: + """Start the query execution spinner.""" + ... + + def _run_query_async(self, query: str, keep_insert_mode: bool) -> Awaitable[None]: + """Run a query asynchronously.""" + ... + + def _animate_spinner(self) -> None: + """Animate the query spinner.""" + ... + + def _display_query_error(self, error_message: str) -> None: + """Display a query error.""" + ... + + def _stop_query_spinner(self) -> None: + """Stop the query spinner.""" + ... + + def _display_query_results( + self, columns: list[str], rows: list[tuple[Any, ...]], row_count: int, truncated: bool, elapsed_ms: float + ) -> None: + """Display query results in the table.""" + ... + + def _display_non_query_result(self, affected: int, elapsed_ms: float) -> None: + """Display non-query result (rows affected).""" + ... + + def _restore_insert_mode(self) -> None: + """Restore insert mode if it was active.""" + ... + + def _handle_history_result(self, result: Any) -> None: + """Handle result from history screen.""" + ... + + def _delete_history_entry(self, timestamp: str) -> None: + """Delete a history entry.""" + ... + + def action_show_history(self) -> None: + """Show query history.""" + ... + + # === Connection Mixin methods === + + def _handle_install_confirmation(self, confirmed: bool, error: Any) -> None: + """Handle driver installation confirmation.""" + ... + + def _set_connection_screen_footer(self) -> None: + """Set footer for connection screen.""" + ... + + def _wrap_connection_result(self, result: tuple[Any, ...] | None) -> None: + """Wrap connection result.""" + ... + + def call_next(self, *args: Any, **kwargs: Any) -> None: + """Call next handler (Installer pattern).""" + ... + + def handle_connection_result(self, result: tuple[Any, ...] | None) -> None: + """Handle connection result.""" + ... + + def _do_delete_connection(self, config: ConnectionConfig) -> None: + """Delete connection.""" + ... + + def _handle_connection_picker_result(self, result: str | None) -> None: + """Handle connection picker result.""" + ... + + # === UI Navigation Mixin methods === + + def _set_fullscreen_mode(self, mode: str) -> None: + """Set fullscreen mode.""" + ... + + def _update_section_labels(self) -> None: + """Update section labels.""" + ... + + def _show_error_in_results(self, message: str, timestamp: str) -> None: + """Show error in results table.""" + ... + + def _show_leader_menu(self) -> None: + """Show leader menu.""" + ... + + def _cancel_leader_pending(self) -> None: + """Cancel leader pending state.""" + ... + + def _handle_leader_result(self, result: str | None) -> None: + """Handle leader menu result.""" + ... + + def _execute_leader_command(self, action: str) -> None: + """Execute leader command.""" + ... diff --git a/sqlit/ui/screens/__init__.py b/sqlit/ui/screens/__init__.py index d353af47..6448bc01 100644 --- a/sqlit/ui/screens/__init__.py +++ b/sqlit/ui/screens/__init__.py @@ -1,15 +1,7 @@ """Modal screens for sqlit.""" -from .confirm import ConfirmScreen -from .connection import ConnectionScreen -from .connection_picker import ConnectionPickerScreen -from .driver_setup import DriverSetupScreen -from .error import ErrorScreen -from .help import HelpScreen -from .leader_menu import LeaderMenuScreen -from .query_history import QueryHistoryScreen -from .theme import ThemeScreen -from .value_view import ValueViewScreen +from importlib import import_module +from typing import TYPE_CHECKING, Any __all__ = [ "ConfirmScreen", @@ -19,7 +11,50 @@ "ErrorScreen", "HelpScreen", "LeaderMenuScreen", + "MessageScreen", + "PackageSetupScreen", + "PasswordInputScreen", "QueryHistoryScreen", "ThemeScreen", "ValueViewScreen", ] + +_LAZY_ATTRS: dict[str, tuple[str, str]] = { + "ConfirmScreen": ("sqlit.ui.screens.confirm", "ConfirmScreen"), + "ConnectionScreen": ("sqlit.ui.screens.connection", "ConnectionScreen"), + "ConnectionPickerScreen": ("sqlit.ui.screens.connection_picker", "ConnectionPickerScreen"), + "DriverSetupScreen": ("sqlit.ui.screens.driver_setup", "DriverSetupScreen"), + "ErrorScreen": ("sqlit.ui.screens.error", "ErrorScreen"), + "HelpScreen": ("sqlit.ui.screens.help", "HelpScreen"), + "LeaderMenuScreen": ("sqlit.ui.screens.leader_menu", "LeaderMenuScreen"), + "MessageScreen": ("sqlit.ui.screens.message", "MessageScreen"), + "PackageSetupScreen": ("sqlit.ui.screens.package_setup", "PackageSetupScreen"), + "PasswordInputScreen": ("sqlit.ui.screens.password_input", "PasswordInputScreen"), + "QueryHistoryScreen": ("sqlit.ui.screens.query_history", "QueryHistoryScreen"), + "ThemeScreen": ("sqlit.ui.screens.theme", "ThemeScreen"), + "ValueViewScreen": ("sqlit.ui.screens.value_view", "ValueViewScreen"), +} + +if TYPE_CHECKING: + from .confirm import ConfirmScreen + from .connection import ConnectionScreen + from .connection_picker import ConnectionPickerScreen + from .driver_setup import DriverSetupScreen + from .error import ErrorScreen + from .help import HelpScreen + from .leader_menu import LeaderMenuScreen + from .message import MessageScreen + from .package_setup import PackageSetupScreen + from .password_input import PasswordInputScreen + from .query_history import QueryHistoryScreen + from .theme import ThemeScreen + from .value_view import ValueViewScreen + + +def __getattr__(name: str) -> Any: + target = _LAZY_ATTRS.get(name) + if target is None: + raise AttributeError(name) + module_name, attr_name = target + module = import_module(module_name) + return getattr(module, attr_name) diff --git a/sqlit/ui/screens/confirm.py b/sqlit/ui/screens/confirm.py index d5e9a9e6..95a99368 100644 --- a/sqlit/ui/screens/confirm.py +++ b/sqlit/ui/screens/confirm.py @@ -5,7 +5,7 @@ from textual.app import ComposeResult from textual.binding import Binding from textual.screen import ModalScreen -from textual.widgets import OptionList +from textual.widgets import OptionList, Static from textual.widgets.option_list import Option from ...widgets import Dialog @@ -15,10 +15,10 @@ class ConfirmScreen(ModalScreen): """Modal screen for confirmation dialogs.""" BINDINGS = [ - Binding("y", "confirm", "Yes"), - Binding("n", "cancel", "No"), - Binding("escape", "cancel", "Cancel"), - Binding("enter", "select_option", "Select"), + Binding("y", "yes", "Yes", show=False), + Binding("n", "no", "No", show=False), + Binding("escape", "cancel", "Cancel", show=False), + Binding("enter", "select_option", "Select", show=False), ] CSS = """ @@ -31,6 +31,11 @@ class ConfirmScreen(ModalScreen): width: 36; } + #confirm-description { + margin-bottom: 1; + color: $text-muted; + } + #confirm-list { height: auto; border: none; @@ -41,16 +46,28 @@ class ConfirmScreen(ModalScreen): } """ - def __init__(self, title: str): + def __init__( + self, + title: str, + description: str | None = None, + *, + yes_label: str = "Yes", + no_label: str = "No", + ): super().__init__() self.title_text = title + self.description = description + self.yes_label = yes_label + self.no_label = no_label def compose(self) -> ComposeResult: - shortcuts = [("Yes", "Y"), ("No", "N"), ("Cancel", "")] + shortcuts: list[tuple[str, str]] = [("Yes", "y"), ("No", "n")] with Dialog(id="confirm-dialog", title=self.title_text, shortcuts=shortcuts): + if self.description: + yield Static(self.description, id="confirm-description") option_list = OptionList( - Option("Yes", id="yes"), - Option("No", id="no"), + Option(self.yes_label, id="yes"), + Option(self.no_label, id="no"), id="confirm-list", ) yield option_list @@ -58,18 +75,33 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.query_one("#confirm-list", OptionList).focus() - def on_option_list_option_selected(self, event) -> None: - self.dismiss(event.option.id == "yes") + def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: + if event.option.id == "yes": + self.dismiss(True) + elif event.option.id == "no": + self.dismiss(False) def action_select_option(self) -> None: option_list = self.query_one("#confirm-list", OptionList) if option_list.highlighted is not None: - self.dismiss( - option_list.get_option_at_index(option_list.highlighted).id == "yes" - ) + option_id = option_list.get_option_at_index(option_list.highlighted).id + if option_id == "yes": + self.dismiss(True) + elif option_id == "no": + self.dismiss(False) - def action_confirm(self) -> None: + def action_yes(self) -> None: self.dismiss(True) - def action_cancel(self) -> None: + def action_no(self) -> None: self.dismiss(False) + + def action_cancel(self) -> None: + # Escape cancels without selecting Yes/No. + self.dismiss(None) + + def check_action(self, action: str, parameters: tuple) -> bool | None: + # Prevent underlying screens from receiving actions when another modal is on top. + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/ui/screens/connection.py b/sqlit/ui/screens/connection.py index 92dfe730..7b716861 100644 --- a/sqlit/ui/screens/connection.py +++ b/sqlit/ui/screens/connection.py @@ -2,11 +2,21 @@ from __future__ import annotations +import json +import os +import tempfile +from pathlib import Path +from typing import Any + +from rich.markup import escape from textual.app import ComposeResult from textual.binding import Binding from textual.containers import Container, Horizontal +from textual.events import ScreenResume, ScreenSuspend from textual.screen import ModalScreen +from textual.timer import Timer from textual.widgets import ( + Button, Input, OptionList, Select, @@ -18,14 +28,24 @@ from textual.widgets.option_list import Option from ...config import ( - AUTH_TYPE_LABELS, - AuthType, ConnectionConfig, - DATABASE_TYPE_LABELS, DatabaseType, + get_database_type_labels, +) +from ...db import ( + create_ssh_tunnel, + get_adapter, + get_connection_schema, + has_advanced_auth, + is_file_based, + supports_ssh, +) +from ...fields import ( + FieldDefinition, + FieldGroup, + FieldType, + schema_to_field_definitions, ) -from ...db import create_ssh_tunnel, get_adapter, get_connection_schema, has_advanced_auth, is_file_based, supports_ssh -from ...fields import FieldDefinition, FieldGroup, FieldType, schema_to_field_definitions from ...validation import ValidationState, validate_connection_form from ...widgets import Dialog @@ -33,10 +53,15 @@ class ConnectionScreen(ModalScreen): """Modal screen for adding/editing a connection.""" + AUTO_FOCUS = "#conn-name" + + _INSTALL_SPINNER_FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + BINDINGS = [ Binding("escape", "cancel", "Cancel"), Binding("ctrl+s", "save", "Save", priority=True), Binding("ctrl+t", "test_connection", "Test", priority=True), + Binding("ctrl+i", "install_driver", "Install driver", show=False, priority=True), Binding("tab", "next_field", "Next field", priority=True), Binding("shift+tab", "prev_field", "Previous field", priority=True), Binding("down", "focus_tab_content", "Focus content", show=False), @@ -50,13 +75,13 @@ class ConnectionScreen(ModalScreen): #connection-dialog { width: 62; - height: auto; + height: 85%; max-height: 85%; border: solid $primary; background: $surface; padding: 1; border-title-align: left; - border-title-color: $text-muted; + border-title-color: $primary; border-title-background: $surface; border-title-style: bold; border-subtitle-align: right; @@ -73,10 +98,18 @@ class ConnectionScreen(ModalScreen): margin-bottom: 0; } + #btn-odbc-setup { + width: auto; + border: solid $primary; + background: transparent; + color: $primary; + margin-top: 1; + } + .field-container { position: relative; height: auto; - border: solid $primary-darken-2; + border: solid $panel; background: $surface; padding: 0; margin-top: 0; @@ -92,14 +125,17 @@ class ConnectionScreen(ModalScreen): .field-container.invalid { border: solid $error; + border-title-color: $error; } .field-container.focused { border: solid $primary; + border-title-color: $primary; } .field-container.invalid.focused { border: solid $error; + border-title-color: $error; } .field-container Input { @@ -127,20 +163,21 @@ class ConnectionScreen(ModalScreen): } #connection-tabs { - height: auto; + height: 1fr; } TabbedContent { - height: auto; + height: 1fr; } TabbedContent > ContentSwitcher { - height: auto; + height: 1fr; } TabPane { - height: auto; + height: 1fr; min-height: 18; + overflow-y: auto; } Tab:disabled { @@ -206,6 +243,10 @@ class ConnectionScreen(ModalScreen): margin-top: 0; } + #test-status.success { + color: $success; + } + #test-error { height: 6; border: solid $primary-darken-2; @@ -218,30 +259,68 @@ class ConnectionScreen(ModalScreen): } """ - def __init__(self, config: ConnectionConfig | None = None, editing: bool = False): + def __init__( + self, + config: ConnectionConfig | None = None, + editing: bool = False, + *, + prefill_values: dict[str, Any] | None = None, + post_install_message: str | None = None, + ): super().__init__() self.config = config self.editing = editing - self._field_widgets: dict[str, Input | OptionList] = {} + self._prefill_values = prefill_values or {} + self._post_install_message = post_install_message + self._field_widgets: dict[str, Input | OptionList | Select[str]] = {} self._field_definitions: dict[str, FieldDefinition] = {} self._current_db_type: DatabaseType = self._get_initial_db_type() self._last_test_error: str = "" self._last_test_ok: bool | None = None self._focused_container_id: str | None = None self.validation_state: ValidationState = ValidationState() + self._saved_dialog_subtitle: str | None = None + self._missing_driver_error: Any = None # Stores MissingDriverError if driver is missing + self._install_in_progress: bool = False + self._install_spinner_timer: Timer | None = None + self._install_spinner_index: int = 0 + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + return super().check_action(action, parameters) + + def on_screen_suspend(self, event: ScreenSuspend) -> None: + try: + dialog = self.query_one("#connection-dialog", Dialog) + self._saved_dialog_subtitle = dialog.border_subtitle + dialog.border_subtitle = "" + except Exception: + pass + + def on_screen_resume(self, event: ScreenResume) -> None: + try: + dialog = self.query_one("#connection-dialog", Dialog) + if self._saved_dialog_subtitle is not None: + dialog.border_subtitle = self._saved_dialog_subtitle + except Exception: + pass def _get_initial_db_type(self) -> DatabaseType: - """Get the initial database type from config.""" + prefill_db_type = self._prefill_values.get("db_type") + if isinstance(prefill_db_type, str) and prefill_db_type: + try: + return DatabaseType(prefill_db_type) + except Exception: + pass if self.config: return self.config.get_db_type() - return DatabaseType.MSSQL + return DatabaseType.MSSQL # type: ignore[attr-defined, no-any-return] - def _get_adapter_for_type(self, db_type: DatabaseType): - """Get the adapter instance for a database type.""" + def _get_adapter_for_type(self, db_type: DatabaseType) -> Any: return get_adapter(db_type.value) def _get_field_groups_for_type(self, db_type: DatabaseType, tab: str | None = None) -> list[FieldGroup]: - """Get field groups for a database type from the schema registry.""" schema = get_connection_schema(db_type.value) definitions = schema_to_field_definitions(schema) if tab: @@ -249,13 +328,11 @@ def _get_field_groups_for_type(self, db_type: DatabaseType, tab: str | None = No return [FieldGroup(name="connection", fields=definitions)] def _get_field_value(self, field_name: str) -> str: - """Get the current value of a field from config or default.""" if self.config and hasattr(self.config, field_name): return getattr(self.config, field_name) or "" return "" def _get_current_form_values(self) -> dict: - """Get all current form values as a dictionary.""" values = {} for name, widget in self._field_widgets.items(): if isinstance(widget, Input): @@ -275,14 +352,11 @@ def _get_current_form_values(self) -> dict: return values def _create_field_widget(self, field_def: FieldDefinition, group_name: str) -> ComposeResult: - """Create widgets for a field definition.""" field_id = f"field-{field_def.name}" container_id = f"container-{field_def.name}" - # Determine initial visibility initial_visible = True if field_def.visible_when: - # Use config values for initial visibility check initial_values = {} if self.config: for attr in ["auth_type", "driver", "server", "port", "database", "username", "password", "file_path"]: @@ -334,8 +408,6 @@ def _create_field_widget(self, field_def: FieldDefinition, group_name: str) -> C yield Static("", id=f"error-{field_def.name}", classes="error-text hidden") def _create_field_group(self, group: FieldGroup) -> ComposeResult: - """Create widgets for a field group.""" - # Group fields by row_group row_groups: dict[str | None, list[FieldDefinition]] = {} for field_def in group.fields: row_key = field_def.row_group @@ -346,22 +418,18 @@ def _create_field_group(self, group: FieldGroup) -> ComposeResult: with Container(classes="field-group"): for row_key, fields in row_groups.items(): if row_key is None: - # Single field, not in a row for field_def in fields: yield from self._create_field_widget(field_def, group.name) else: - # Multiple fields in a horizontal row with Horizontal(classes="field-row"): for field_def in fields: width_class = "field-flex" if field_def.width == "flex" else "field-fixed" with Container(classes=width_class): yield from self._create_field_widget(field_def, group.name) - def _split_groups_by_advanced( - self, groups: list[FieldGroup] - ) -> tuple[list[FieldGroup], list[FieldGroup]]: - general: list[FieldGroup] = [] - advanced: list[FieldGroup] = [] + def _split_groups_by_advanced(self, groups: list[FieldGroup]) -> tuple[list[FieldGroup], list[FieldGroup]]: + general = [] + advanced = [] for group in groups: general_fields = [f for f in group.fields if not f.advanced] advanced_fields = [f for f in group.fields if f.advanced] @@ -384,7 +452,6 @@ def _split_groups_by_advanced( return general, advanced def _set_advanced_tab_enabled(self, enabled: bool) -> None: - """Enable/disable the Advanced tab (disabled tabs are struck through).""" try: tabs = self.query_one("#connection-tabs", TabbedContent) advanced_pane = self.query_one("#tab-advanced", TabPane) @@ -406,7 +473,6 @@ def _set_advanced_tab_enabled(self, enabled: bool) -> None: pass def _update_ssh_tab_enabled(self, db_type: DatabaseType) -> None: - """Enable/disable the SSH tab based on database type.""" try: tabs = self.query_one("#connection-tabs", TabbedContent) ssh_pane = self.query_one("#tab-ssh", TabPane) @@ -429,19 +495,185 @@ def _update_ssh_tab_enabled(self, db_type: DatabaseType) -> None: except Exception: pass + def _check_driver_availability(self, db_type: DatabaseType) -> None: + from ...db.exceptions import MissingDriverError + + self._missing_driver_error = None + try: + adapter = get_adapter(db_type.value) + adapter.ensure_driver_available() + except MissingDriverError as e: + self._missing_driver_error = e + + self._update_driver_status_ui() + + def _update_driver_status_ui(self) -> None: + try: + test_status = self.query_one("#test-status", Static) + dialog = self.query_one("#connection-dialog", Dialog) + except Exception: + return + + try: + test_status.remove_class("success") + except Exception: + pass + + if self._install_in_progress and self._missing_driver_error: + error = self._missing_driver_error + spinner = self._INSTALL_SPINNER_FRAMES[self._install_spinner_index % len(self._INSTALL_SPINNER_FRAMES)] + test_status.update( + f"[yellow]⚠ Missing driver:[/] {error.package_name}\n" + f"[dim]{spinner} Installing…[/]" + ) + dialog.border_subtitle = "[bold]Installing…[/] Cancel " + return + + if self._missing_driver_error: + error = self._missing_driver_error + from ...install_strategy import detect_strategy + + strategy = detect_strategy(extra_name=error.extra_name, package_name=error.package_name) + if strategy.can_auto_install: + install_cmd = strategy.manual_instructions.split("\n")[0].strip() + test_status.update( + f"[yellow]⚠ Missing driver:[/] {error.package_name}\n" + f"[dim]Install with:[/] {escape(install_cmd)}" + ) + dialog.border_subtitle = "[bold]Install ^i[/] Cancel " + else: + # For unknown install methods, show reason and hint to press ^i for details + reason = strategy.reason_unavailable or "Auto-install not available" + test_status.update( + f"[yellow]⚠ Missing driver:[/] {error.package_name}\n" + f"[dim]{escape(reason)} Press ^i for install instructions.[/]" + ) + dialog.border_subtitle = "[bold]Help ^i[/] Cancel " + else: + if self._post_install_message and (self._last_test_ok is None or self._last_test_ok): + test_status.update(f"✓ {self._post_install_message}") + try: + test_status.add_class("success") + except Exception: + pass + else: + if not self._last_test_ok and self._last_test_ok is not None: + pass + else: + test_status.update("") + dialog.border_subtitle = "[bold]Test ^t[/] Save ^s Cancel " + + def _tick_install_spinner(self) -> None: + self._install_spinner_index += 1 + self._update_driver_status_ui() + + def _get_restart_cache_path(self) -> Path: + return Path(tempfile.gettempdir()) / "sqlit-driver-install-restore.json" + + def _write_restart_cache(self, *, post_install_message: str | None = None) -> None: + try: + values = self._get_current_form_values() + values["name"] = self.query_one("#conn-name", Input).value + db_type = self.query_one("#dbtype-select", Select).value + values["db_type"] = str(db_type) if db_type is not None else "" + try: + tabs = self.query_one("#connection-tabs", TabbedContent) + active_tab = tabs.active + except Exception: + active_tab = "tab-general" + + payload = { + "version": 1, + "editing": bool(self.editing), + "original_name": getattr(self.config, "name", None) if self.editing and self.config else None, + "active_tab": active_tab, + "values": values, + "post_install_message": post_install_message, + } + self._get_restart_cache_path().write_text(json.dumps(payload), encoding="utf-8") + except Exception: + # Best-effort; don't block installation due to caching failure. + pass + + def _clear_restart_cache(self) -> None: + try: + self._get_restart_cache_path().unlink(missing_ok=True) + except Exception: + pass + + def _start_missing_driver_install(self, error: Any) -> None: + from ...db.exceptions import MissingDriverError + from ...services.installer import Installer + + if not isinstance(error, MissingDriverError): + return + if self._install_in_progress: + return + + self._install_in_progress = True + self._install_spinner_index = 0 + self._post_install_message = None + self._update_driver_status_ui() + if self._install_spinner_timer is None: + self._install_spinner_timer = self.set_interval(0.12, self._tick_install_spinner) + + # Cache the form state so we can restore after restart. + self._write_restart_cache() + + def on_complete(success: bool, output: str, err: MissingDriverError) -> None: + self._on_missing_driver_install_complete(success, output, err) + + Installer(self.app).install_in_background(error, on_complete=on_complete) + + def _stop_install_spinner(self) -> None: + if self._install_spinner_timer is not None: + try: + self._install_spinner_timer.stop() + except Exception: + pass + self._install_spinner_timer = None + + def _on_missing_driver_install_complete(self, success: bool, output: str, error: Any) -> None: + from ..screens import MessageScreen + + self._stop_install_spinner() + self._install_in_progress = False + + if success: + self._check_driver_availability(self._current_db_type) + self._post_install_message = "Successfully installed driver" + self._update_driver_status_ui() + + self._write_restart_cache(post_install_message=self._post_install_message) + + if os.environ.get("SQLIT_DISABLE_RESTART") == "1": + self._clear_restart_cache() + return + + restart = getattr(self.app, "restart", None) + if callable(restart): + restart() + return + + self._clear_restart_cache() + self._update_driver_status_ui() + self.app.push_screen( + MessageScreen( + "Couldn't install automatically", + "Couldn't install automatically, please install manually.", + ) + ) + def compose(self) -> ComposeResult: title = "Edit Connection" if self.editing else "New Connection" db_type = self._get_initial_db_type() - shortcuts = [("Test", "^T"), ("Save", "^S"), ("Cancel", "")] + shortcuts = [("Test", "^t"), ("Save", "^s"), ("Cancel", "")] with Dialog(id="connection-dialog", title=title, shortcuts=shortcuts): - - with TabbedContent(id="connection-tabs"): + with TabbedContent(id="connection-tabs", initial="tab-general"): with TabPane("General", id="tab-general"): - name_container = Container( - id="container-name", classes="field-container" - ) + name_container = Container(id="container-name", classes="field-container") name_container.border_title = "Name" with name_container: yield Input( @@ -452,13 +684,12 @@ def compose(self) -> ComposeResult: yield Static("", id="error-name", classes="error-text hidden") db_types = list(DatabaseType) - dbtype_container = Container( - id="container-dbtype", classes="field-container" - ) + labels = get_database_type_labels() + dbtype_container = Container(id="container-dbtype", classes="field-container") dbtype_container.border_title = "Database Type" with dbtype_container: yield Select( - options=[(DATABASE_TYPE_LABELS[dt], dt.value) for dt in db_types], + options=[(labels[dt], dt.value) for dt in db_types], value=db_type.value, allow_blank=False, compact=True, @@ -467,20 +698,18 @@ def compose(self) -> ComposeResult: with Container(id="dynamic-fields-general"): field_groups = self._get_field_groups_for_type(db_type, tab="general") - general_groups, _advanced_groups = self._split_groups_by_advanced( - field_groups - ) + general_groups, _advanced_groups = self._split_groups_by_advanced(field_groups) for group in general_groups: yield from self._create_field_group(group) with TabPane("Advanced", id="tab-advanced"): with Container(id="dynamic-fields-advanced"): field_groups = self._get_field_groups_for_type(db_type, tab="general") - _general_groups, advanced_groups = self._split_groups_by_advanced( - field_groups - ) + _general_groups, advanced_groups = self._split_groups_by_advanced(field_groups) for group in advanced_groups: yield from self._create_field_group(group) + with Container(id="mssql-driver-setup", classes="hidden"): + yield Button("ODBC driver setup…", id="btn-odbc-setup") with TabPane("SSH", id="tab-ssh"): with Container(id="dynamic-fields-ssh"): @@ -492,18 +721,69 @@ def compose(self) -> ComposeResult: yield TextArea("", id="test-error", read_only=True, classes="hidden") def on_mount(self) -> None: - self.query_one("#conn-name", Input).focus() - - # Set initial values for select fields + self.call_after_refresh(self._ensure_initial_tab) self._set_initial_select_values() + self._apply_prefill_values() self._update_field_visibility() self._validate_name_unique() field_groups = self._get_field_groups_for_type(self._current_db_type, tab="general") _general, advanced = self._split_groups_by_advanced(field_groups) self._set_advanced_tab_enabled(bool(advanced)) self._update_ssh_tab_enabled(self._current_db_type) + self._update_mssql_driver_setup_visibility(self._current_db_type) + self._check_driver_availability(self._current_db_type) + + if self._post_install_message and not self._missing_driver_error: + self._update_driver_status_ui() + + def _ensure_initial_tab(self) -> None: + try: + tabs = self.query_one("#connection-tabs", TabbedContent) + except Exception: + return + tabs.active = "tab-general" + + def _apply_prefill_values(self) -> None: + if not self._prefill_values: + return + + values = self._prefill_values.get("values") if "values" in self._prefill_values else self._prefill_values + if not isinstance(values, dict): + return - def on_descendant_focus(self, event) -> None: + name_value = values.get("name") + if isinstance(name_value, str): + try: + self.query_one("#conn-name", Input).value = name_value + except Exception: + pass + + for field_name, widget in self._field_widgets.items(): + value = values.get(field_name) + if value is None: + continue + if isinstance(widget, Input): + widget.value = str(value) + elif isinstance(widget, Select): + widget.value = str(value) + elif isinstance(widget, OptionList): + try: + for idx, opt in enumerate(widget.options): + if getattr(opt, "id", None) == value: + widget.highlighted = idx + break + except Exception: + pass + + active_tab = self._prefill_values.get("active_tab") + if isinstance(active_tab, str) and active_tab: + try: + tabs = self.query_one("#connection-tabs", TabbedContent) + tabs.active = active_tab + except Exception: + pass + + def on_descendant_focus(self, event: Any) -> None: focused = self.focused if focused is None: return @@ -523,9 +803,7 @@ def on_descendant_focus(self, event) -> None: if self._focused_container_id and self._focused_container_id != container_id: try: - self.query_one( - f"#{self._focused_container_id}", Container - ).remove_class("focused") + self.query_one(f"#{self._focused_container_id}", Container).remove_class("focused") except Exception: pass @@ -536,7 +814,6 @@ def on_descendant_focus(self, event) -> None: pass def _set_initial_select_values(self) -> None: - """Set initial highlighted values for select fields based on config.""" for name, widget in self._field_widgets.items(): if isinstance(widget, OptionList): field_def = self._field_definitions.get(name) @@ -557,7 +834,6 @@ def _set_initial_select_values(self) -> None: widget.value = value def _rebuild_dynamic_fields(self, db_type: DatabaseType) -> None: - """Rebuild the dynamic fields for a new database type.""" self._current_db_type = db_type self._field_widgets.clear() self._field_definitions.clear() @@ -570,9 +846,7 @@ def _rebuild_dynamic_fields(self, db_type: DatabaseType) -> None: ssh_container.remove_children() field_groups = self._get_field_groups_for_type(db_type, tab="general") - general_groups, advanced_groups = self._split_groups_by_advanced( - field_groups - ) + general_groups, advanced_groups = self._split_groups_by_advanced(field_groups) self._set_advanced_tab_enabled(bool(advanced_groups)) for group in general_groups: for widget in self._create_field_group_widgets(group): @@ -587,7 +861,6 @@ def _rebuild_dynamic_fields(self, db_type: DatabaseType) -> None: ssh_container.mount(widget) def _create_field_group_widgets(self, group: FieldGroup) -> list: - """Create widget instances for a field group (for mounting).""" widgets = [] row_groups: dict[str | None, list[FieldDefinition]] = {} @@ -618,7 +891,6 @@ def _create_field_group_widgets(self, group: FieldGroup) -> list: return widgets def _create_field_widget_instances(self, field_def: FieldDefinition, group_name: str) -> list: - """Create widget instances for a field (for mounting).""" widgets = [] field_id = f"field-{field_def.name}" container_id = f"container-{field_def.name}" @@ -644,18 +916,14 @@ def _create_field_widget_instances(self, field_def: FieldDefinition, group_name: self._field_widgets[field_def.name] = select self._field_definitions[field_def.name] = field_def container.compose_add_child(select) - container.compose_add_child( - Static("", id=f"error-{field_def.name}", classes="error-text hidden") - ) + container.compose_add_child(Static("", id=f"error-{field_def.name}", classes="error-text hidden")) elif field_def.field_type == FieldType.SELECT: options = [Option(opt.label, id=opt.value) for opt in field_def.options] option_list = OptionList(*options, id=field_id, classes="select-field") self._field_widgets[field_def.name] = option_list self._field_definitions[field_def.name] = field_def container.compose_add_child(option_list) - container.compose_add_child( - Static("", id=f"error-{field_def.name}", classes="error-text hidden") - ) + container.compose_add_child(Static("", id=f"error-{field_def.name}", classes="error-text hidden")) else: value = self._get_field_value(field_def.name) or field_def.default input_widget = Input( @@ -667,9 +935,7 @@ def _create_field_widget_instances(self, field_def: FieldDefinition, group_name: self._field_widgets[field_def.name] = input_widget self._field_definitions[field_def.name] = field_def container.compose_add_child(input_widget) - container.compose_add_child( - Static("", id=f"error-{field_def.name}", classes="error-text hidden") - ) + container.compose_add_child(Static("", id=f"error-{field_def.name}", classes="error-text hidden")) widgets.append(container) return widgets @@ -684,15 +950,16 @@ def on_select_changed(self, event: Select.Changed) -> None: self._rebuild_dynamic_fields(db_type) self._set_initial_select_values() self._update_field_visibility() - self._focus_first_required() + self._focus_first_visible_field() self._update_ssh_tab_enabled(db_type) + self._update_mssql_driver_setup_visibility(db_type) + self._check_driver_availability(db_type) return if event.select.id and str(event.select.id).startswith("field-"): self._update_field_visibility() - def on_option_list_option_highlighted(self, event) -> None: - # A select field changed - update visibility of dependent fields + def on_option_list_option_highlighted(self, event: OptionList.OptionHighlighted) -> None: if event.option_list.id and event.option_list.id.startswith("field-"): self._update_field_visibility() @@ -701,7 +968,6 @@ def on_input_changed(self, event: Input.Changed) -> None: self._validate_name_unique() def _update_field_visibility(self) -> None: - """Update visibility of fields based on current form values.""" values = self._get_current_form_values() for name, field_def in self._field_definitions.items(): @@ -715,9 +981,7 @@ def _update_field_visibility(self) -> None: container.add_class("hidden") def _get_focusable_fields(self) -> list: - """Get list of focusable fields in order, based on active tab. - - Note: Tab bar is intentionally excluded from focusable fields. + """Tab bar is intentionally excluded from focusable fields. Users can switch tabs by clicking or using keyboard shortcuts, but Tab key should cycle through form fields only. """ @@ -731,8 +995,13 @@ def _get_focusable_fields(self) -> list: if active_tab == "tab-ssh": ssh_fields = [ - "ssh_enabled", "ssh_host", "ssh_port", "ssh_username", - "ssh_auth_type", "ssh_key_path", "ssh_password" + "ssh_enabled", + "ssh_host", + "ssh_port", + "ssh_username", + "ssh_auth_type", + "ssh_key_path", + "ssh_password", ] for field in ssh_fields: try: @@ -745,11 +1014,16 @@ def _get_focusable_fields(self) -> list: return fields if active_tab == "tab-general": - fields.extend([ - self.query_one("#conn-name", Input), - self.query_one("#dbtype-select", Select), - ]) - for name, widget in self._field_widgets.items(): + fields.extend( + [ + self.query_one("#conn-name", Input), + self.query_one("#dbtype-select", Select), + ] + ) + for name in self._field_definitions: + widget = self._field_widgets.get(name) + if widget is None: + continue if name.startswith("ssh_"): continue field_def = self._field_definitions.get(name) @@ -763,7 +1037,10 @@ def _get_focusable_fields(self) -> list: pass elif active_tab == "tab-advanced": - for name, widget in self._field_widgets.items(): + for name in self._field_definitions: + widget = self._field_widgets.get(name) + if widget is None: + continue field_def = self._field_definitions.get(name) if field_def and field_def.advanced: try: @@ -772,9 +1049,99 @@ def _get_focusable_fields(self) -> list: fields.append(widget) except Exception: pass + try: + container = self.query_one("#mssql-driver-setup", Container) + if "hidden" not in container.classes: + fields.append(self.query_one("#btn-odbc-setup", Button)) + except Exception: + pass return fields + def _update_mssql_driver_setup_visibility(self, db_type: DatabaseType) -> None: + try: + container = self.query_one("#mssql-driver-setup", Container) + except Exception: + return + if db_type.value == "mssql": + container.remove_class("hidden") + else: + container.add_class("hidden") + + def _set_select_field_value(self, field_name: str, value: str) -> None: + widget = self._field_widgets.get(field_name) + field_def = self._field_definitions.get(field_name) + if not isinstance(widget, OptionList) or not field_def or not field_def.options: + return + for i, opt in enumerate(field_def.options): + if opt.value == value: + widget.highlighted = i + return + + def _open_odbc_driver_setup(self, installed_drivers: list[str] | None = None) -> None: + from ...db.exceptions import MissingDriverError + from ...drivers import get_installed_drivers + from ...terminal import run_in_terminal + from ..screens import DriverSetupScreen, MessageScreen + + try: + get_adapter("mssql").ensure_driver_available() + except MissingDriverError as e: + self._prompt_install_missing_driver(e) + return + + installed = installed_drivers if installed_drivers is not None else get_installed_drivers() + + def on_result(result: Any) -> None: + if not result: + return + action = result[0] + if action == "select": + driver = result[1] + self._set_select_field_value("driver", driver) + return + if action == "install": + commands = result[1] + res = run_in_terminal(commands) + if res.success: + self.app.push_screen( + MessageScreen( + "Driver install", + "Installation started in a new terminal.\n\nPlease restart to apply.", + ) + ) + else: + + def reopen(_: Any = None) -> None: + self._open_odbc_driver_setup(installed_drivers=installed) + + self.app.push_screen( + MessageScreen( + "Couldn't install automatically", + "Couldn't install automatically, please install manually.", + ), + reopen, + ) + + self.app.push_screen(DriverSetupScreen(installed), on_result) + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "btn-odbc-setup": + self._open_odbc_driver_setup() + + def action_open_odbc_setup(self) -> None: + if self._current_db_type.value != "mssql": + return + self._open_odbc_driver_setup() + + def action_install_driver(self) -> None: + if self._install_in_progress: + return + if self._missing_driver_error: + self._prompt_install_missing_driver(self._missing_driver_error) + elif self._current_db_type.value == "mssql": + self._open_odbc_driver_setup() + def _clear_field_error(self, name: str) -> None: try: container = self.query_one(f"#container-{name}", Container) @@ -830,22 +1197,18 @@ def _clear_tab_errors(self) -> None: pass def _apply_validation_to_ui(self) -> None: - """Apply validation_state to UI elements.""" self._clear_tab_errors() self.validation_state.tab_errors.clear() - # Clear all field errors first self._clear_field_error("name") for field_name in self._field_definitions: self._clear_field_error(field_name) for ssh_field in ["ssh_host", "ssh_username", "ssh_key_path"]: self._clear_field_error(ssh_field) - # Apply errors from validation state for field_name, message in self.validation_state.errors.items(): self._set_field_error(field_name, message) - # Determine which tab to mark if field_name == "name": self._set_tab_error("tab-general") self.validation_state.add_tab_error("tab-general") @@ -865,10 +1228,14 @@ def _apply_validation_to_ui(self) -> None: self.validation_state.add_tab_error("tab-general") def _get_existing_names(self) -> set[str]: - """Get set of existing connection names.""" try: connections = getattr(self.app, "connections", []) or [] - return {getattr(c, "name", None) for c in connections} - {None} + names: set[str] = set() + for conn in connections: + name = getattr(conn, "name", None) + if isinstance(name, str) and name: + names.add(name) + return names except Exception: return set() @@ -877,7 +1244,7 @@ def _validate_name_unique(self) -> None: name = self.query_one("#conn-name", Input).value.strip() if not name: return - existing = [] + existing: list[Any] = [] try: existing = getattr(self.app, "connections", []) or [] except Exception: @@ -890,24 +1257,65 @@ def _validate_name_unique(self) -> None: def _focus_first_required(self) -> None: values = self._get_current_form_values() - for field_name, field_def in self._field_definitions.items(): - if not field_def.required: - continue - is_visible = True - if field_def.visible_when: - is_visible = bool(field_def.visible_when(values)) + ordered_fields = list(self._field_definitions.keys()) + + def is_visible(field_def: FieldDefinition) -> bool: + if field_def.visible_when and not bool(field_def.visible_when(values)): + return False if field_def.advanced and not self._show_advanced: - is_visible = False - if not is_visible: + return False + return True + + def is_missing(widget: Any) -> bool: + if isinstance(widget, Input): + return not widget.value.strip() + if isinstance(widget, OptionList): + return widget.highlighted is None + if isinstance(widget, Select): + return widget.value in (None, "") + return False + + for field_name in ordered_fields: + field_def = self._field_definitions.get(field_name) + if not field_def or not field_def.required: + continue + if not is_visible(field_def): continue widget = self._field_widgets.get(field_name) - if isinstance(widget, Input) and not widget.value.strip(): - widget.focus() - return - if isinstance(widget, OptionList) and widget.highlighted is None: + if widget is None: + continue + if is_missing(widget): widget.focus() return + for field_name in ordered_fields: + field_def = self._field_definitions.get(field_name) + if not field_def or not is_visible(field_def): + continue + widget = self._field_widgets.get(field_name) + if widget is None: + continue + widget.focus() + return + + def _focus_first_visible_field(self) -> None: + values = self._get_current_form_values() + ordered_fields = list(self._field_definitions.keys()) + + for field_name in ordered_fields: + field_def = self._field_definitions.get(field_name) + if not field_def: + continue + if field_def.visible_when and not bool(field_def.visible_when(values)): + continue + if field_def.advanced and not self._show_advanced: + continue + widget = self._field_widgets.get(field_name) + if widget is None: + continue + widget.focus() + return + def action_next_field(self) -> None: from textual.widgets import Tabs @@ -950,10 +1358,8 @@ def action_prev_field(self) -> None: fields[-1].focus() def action_focus_tab_content(self) -> None: - """Focus the first field of the active tab when pressing down on tab bar.""" from textual.widgets import Tabs - # Only handle if tab bar is focused try: tabs_widget = self.query_one("#connection-tabs", TabbedContent) tab_bar = tabs_widget.query_one(Tabs) @@ -967,7 +1373,6 @@ def action_focus_tab_content(self) -> None: if active_tab == "tab-general": self.query_one("#conn-name", Input).focus() elif active_tab == "tab-advanced": - # Focus first visible advanced field for name, widget in self._field_widgets.items(): field_def = self._field_definitions.get(name) if field_def and field_def.advanced: @@ -984,20 +1389,17 @@ def action_focus_tab_content(self) -> None: ssh_widget.focus() def _get_config(self) -> ConnectionConfig | None: - """Build a ConnectionConfig from the current form values.""" name_input = self.query_one("#conn-name", Input) name = name_input.value.strip() - # Get selected database type db_type_value = self.query_one("#dbtype-select", Select).value try: db_type = DatabaseType(str(db_type_value)) except Exception: - db_type = DatabaseType.MSSQL + db_type = DatabaseType.MSSQL # type: ignore[attr-defined] values = self._get_current_form_values() - # Name suggestion if not name: suggestion = "" if is_file_based(db_type.value): @@ -1010,7 +1412,6 @@ def _get_config(self) -> ConnectionConfig | None: name_input.value = suggestion name = suggestion - # Run validation editing_name = self.config.name if self.editing and self.config else None self.validation_state = validate_connection_form( name=name, @@ -1021,11 +1422,9 @@ def _get_config(self) -> ConnectionConfig | None: editing_name=editing_name, ) - # Apply validation to UI self._apply_validation_to_ui() if not self.validation_state.is_valid(): - # Focus the first field with an error for field_name in self.validation_state.errors: if field_name == "name": name_input.focus() @@ -1037,20 +1436,18 @@ def _get_config(self) -> ConnectionConfig | None: pass return None - # Build config config_kwargs = { "name": name, "db_type": db_type.value, } - # Add all field values to config for field_name, value in values.items(): if not field_name.startswith("ssh_"): config_kwargs[field_name] = value if has_advanced_auth(db_type.value): auth_type = values.get("auth_type", "sql") - config_kwargs["trusted_connection"] = (auth_type == "windows") + config_kwargs["trusted_connection"] = auth_type == "windows" if supports_ssh(db_type.value): config_kwargs["ssh_enabled"] = values.get("ssh_enabled") == "enabled" @@ -1064,37 +1461,103 @@ def _get_config(self) -> ConnectionConfig | None: return ConnectionConfig(**config_kwargs) def _get_package_install_hint(self, db_type: str) -> str | None: - """Get pip install command for missing database packages.""" - hints = { - "postgresql": "pip install psycopg2-binary", - "mysql": "pip install mysql-connector-python", - "oracle": "pip install oracledb", - "mariadb": "pip install mariadb", - "duckdb": "pip install duckdb", - "cockroachdb": "pip install psycopg2-binary", - "turso": "pip install libsql-client", - } - return hints.get(db_type) + try: + adapter = get_adapter(db_type) + return adapter.install_hint + except (ValueError, ImportError): + return None + + def _prompt_install_missing_driver(self, error: Exception) -> None: + from ...db.exceptions import MissingDriverError + from ...install_strategy import detect_strategy + from ..screens import ConfirmScreen, MessageScreen + + if not isinstance(error, MissingDriverError): + return + + if self._install_in_progress: + return + + strategy = detect_strategy(extra_name=error.extra_name, package_name=error.package_name) + if not strategy.can_auto_install: + self.app.push_screen( + MessageScreen( + "Manual installation required", + strategy.manual_instructions, + ) + ) + return + + self.app.push_screen( + ConfirmScreen( + "Install missing driver?", + f"Missing package: {error.package_name}", + yes_label="Yes", + no_label="No", + ), + lambda confirmed: self._start_missing_driver_install(error) if confirmed else None, + ) def action_test_connection(self) -> None: from dataclasses import replace + from ...db.exceptions import MissingDriverError, MissingODBCDriverError + from ...db.providers import is_file_based + from .password_input import PasswordInputScreen + + if self._missing_driver_error: + self._prompt_install_missing_driver(self._missing_driver_error) + return + config = self._get_config() if not config: return + if config.ssh_enabled and config.ssh_auth_type == "password" and config.ssh_password is None: + + def on_ssh_password(password: str | None) -> None: + if password is None: + return + temp_config = replace(config, ssh_password=password) + self._test_with_config(temp_config) + + self.app.push_screen( + PasswordInputScreen(config.name, password_type="ssh"), + on_ssh_password, + ) + return + + if not is_file_based(config.db_type) and config.password is None: + + def on_db_password(password: str | None) -> None: + if password is None: + return + temp_config = replace(config, password=password) + self._test_with_config(temp_config) + + self.app.push_screen( + PasswordInputScreen(config.name, password_type="database"), + on_db_password, + ) + return + + self._test_with_config(config) + + def _test_with_config(self, config) -> None: + from dataclasses import replace + + from ...db.exceptions import MissingDriverError, MissingODBCDriverError + self.query_one("#test-error", TextArea).add_class("hidden") self.query_one("#test-status", Static).update("Testing…") self._last_test_ok = None self._last_test_error = "" tunnel = None - # Check if we're in mock mode mock_profile = getattr(self.app, "_mock_profile", None) try: if mock_profile: - # Use mock adapter and skip SSH tunnel in mock mode adapter = mock_profile.get_adapter(config.db_type) connect_config = config else: @@ -1108,7 +1571,6 @@ def action_test_connection(self) -> None: conn = adapter.connect(connect_config) conn.close() - # Close tunnel after test if tunnel: tunnel.stop() try: @@ -1119,22 +1581,15 @@ def action_test_connection(self) -> None: pass self._last_test_ok = True self.query_one("#test-status", Static).update("Last test: OK") - except ModuleNotFoundError as e: - hint = self._get_package_install_hint(config.db_type) - if hint: - self.query_one("#test-status", Static).update(f"Last test: failed (missing package)") - err = self.query_one("#test-error", TextArea) - err.text = f"{e}\n\nInstall with:\n {hint}" - err.remove_class("hidden") - self._last_test_error = err.text - else: - self.query_one("#test-status", Static).update("Last test: failed") - err = self.query_one("#test-error", TextArea) - err.text = f"{e}" - err.remove_class("hidden") - self._last_test_error = err.text + except MissingDriverError as e: + self._prompt_install_missing_driver(e) self._last_test_ok = False - except ImportError as e: + self.query_one("#test-status", Static).update("Last test: failed (missing driver)") + except MissingODBCDriverError as e: + self._open_odbc_driver_setup(e.installed_drivers) + self._last_test_ok = False + self.query_one("#test-status", Static).update("Last test: failed (missing ODBC driver)") + except (ModuleNotFoundError, ImportError) as e: hint = self._get_package_install_hint(config.db_type) if hint: self.query_one("#test-status", Static).update("Last test: failed (missing package)") @@ -1163,7 +1618,6 @@ def action_test_connection(self) -> None: err.remove_class("hidden") self._last_test_error = err.text finally: - # Ensure tunnel is closed on any failure if tunnel: try: tunnel.stop() @@ -1172,10 +1626,26 @@ def action_test_connection(self) -> None: def action_save(self) -> None: config = self._get_config() - if config: + if not config: + return + + from ...db.exceptions import MissingDriverError + + if getattr(self.app, "_mock_profile", None): self.dismiss(("save", config)) + return + + try: + get_adapter(config.db_type).ensure_driver_available() + except MissingDriverError as e: + self._prompt_install_missing_driver(e) + return + + self.dismiss(("save", config)) def action_cancel(self) -> None: + if self._install_in_progress: + return self.dismiss(None) @property diff --git a/sqlit/ui/screens/connection_picker.py b/sqlit/ui/screens/connection_picker.py index 4c3be23b..c71e1a94 100644 --- a/sqlit/ui/screens/connection_picker.py +++ b/sqlit/ui/screens/connection_picker.py @@ -9,48 +9,10 @@ from textual.widgets import OptionList from textual.widgets.option_list import Option +from ...utils import fuzzy_match, highlight_matches from ...widgets import Dialog -def fuzzy_match(pattern: str, text: str) -> tuple[bool, list[int]]: - """Check if pattern fuzzy matches text and return matched indices. - - Returns (matches, indices) where indices are positions in text that matched. - """ - if not pattern: - return True, [] - - pattern = pattern.lower() - text_lower = text.lower() - - pattern_idx = 0 - indices = [] - - for i, char in enumerate(text_lower): - if pattern_idx < len(pattern) and char == pattern[pattern_idx]: - indices.append(i) - pattern_idx += 1 - - return pattern_idx == len(pattern), indices - - -def highlight_matches(text: str, indices: list[int]) -> str: - """Highlight matched characters in text.""" - if not indices: - return text - - result = [] - idx_set = set(indices) - - for i, char in enumerate(text): - if i in idx_set: - result.append(f"[bold yellow]{char}[/]") - else: - result.append(char) - - return "".join(result) - - class ConnectionPickerScreen(ModalScreen): """Modal screen for selecting a connection with fuzzy search.""" @@ -120,9 +82,7 @@ def _build_options(self, pattern: str) -> list[Option]: display = highlight_matches(conn.name, indices) db_type = conn.db_type.upper() if conn.db_type else "DB" info = conn.get_display_info() - options.append( - Option(f"{display} [{db_type}] [dim]({info})[/]", id=conn.name) - ) + options.append(Option(f"{display} [{db_type}] [dim]({info})[/]", id=conn.name)) return options def on_mount(self) -> None: @@ -192,7 +152,7 @@ def action_select(self) -> None: self.dismiss(None) - def on_option_list_option_selected(self, event) -> None: + def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: if event.option_list.id == "picker-list": if event.option: self.dismiss(event.option.id) diff --git a/sqlit/ui/screens/driver_setup.py b/sqlit/ui/screens/driver_setup.py index 931ef6a1..47374d9e 100644 --- a/sqlit/ui/screens/driver_setup.py +++ b/sqlit/ui/screens/driver_setup.py @@ -5,7 +5,7 @@ from textual.app import ComposeResult from textual.binding import Binding from textual.screen import ModalScreen -from textual.widgets import OptionList, Static +from textual.widgets import OptionList, Static, TextArea from textual.widgets.option_list import Option from ...widgets import Dialog @@ -18,6 +18,7 @@ class DriverSetupScreen(ModalScreen): Binding("escape", "cancel", "Cancel"), Binding("enter", "select", "Select"), Binding("i", "install_driver", "Install"), + Binding("y", "yank", "Yank"), ] CSS = """ @@ -58,6 +59,7 @@ def __init__(self, installed_drivers: list[str] | None = None): super().__init__() self.installed_drivers = installed_drivers or [] self._install_commands: list[str] = [] + self._install_script: str = "" def compose(self) -> ComposeResult: from ...drivers import SUPPORTED_DRIVERS, get_install_commands, get_os_info @@ -67,10 +69,10 @@ def compose(self) -> ComposeResult: if has_drivers: title = "Select ODBC Driver" - shortcuts = [("Select", ""), ("Cancel", "")] + shortcuts = [("Select", ""), ("Yank", "y"), ("Cancel", "")] else: title = "No ODBC Driver Found" - shortcuts = [("Select", ""), ("Install", "I"), ("Cancel", "")] + shortcuts = [("Select", ""), ("Install", "I"), ("Yank", "y"), ("Cancel", "")] with Dialog(id="driver-dialog", title=title, shortcuts=shortcuts): if has_drivers: @@ -89,7 +91,7 @@ def compose(self) -> ComposeResult: options = [] if has_drivers: for driver in self.installed_drivers: - options.append(Option(f"[green]{driver}[/]", id=driver)) + options.append(Option(f"[#4ADE80]{driver}[/]", id=driver)) else: for driver in SUPPORTED_DRIVERS[:3]: # Show top 3 options options.append(Option(f"[dim]{driver}[/] (not installed)", id=driver)) @@ -101,10 +103,12 @@ def compose(self) -> ComposeResult: install_info = get_install_commands() if install_info: self._install_commands = install_info.commands - commands_text = "\n".join(install_info.commands) - yield Static( - f"[bold]{install_info.description}:[/]\n\n{commands_text}", + self._install_script = "\n".join(install_info.commands).strip() + yield TextArea( + f"{install_info.description}:\n\n{self._install_script}\n", id="install-commands", + read_only=True, + language="bash", ) def on_mount(self) -> None: @@ -116,7 +120,7 @@ def action_select(self) -> None: option = option_list.get_option_at_index(option_list.highlighted) self.dismiss(("select", option.id)) - def on_option_list_option_selected(self, event) -> None: + def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: self.dismiss(("select", event.option.id)) def action_install_driver(self) -> None: @@ -126,6 +130,7 @@ def action_install_driver(self) -> None: return from ...drivers import get_os_info + os_type, _ = get_os_info() # On Windows, just show instructions @@ -136,8 +141,16 @@ def action_install_driver(self) -> None: ) return - self.notify("Installing driver... This may ask for your password.", timeout=5) self.dismiss(("install", self._install_commands)) + def action_yank(self) -> None: + from ...widgets import flash_widget + + script = self._install_script.strip() + if not script: + return + self.app.copy_to_clipboard(script) + flash_widget(self.query_one("#install-commands", TextArea)) + def action_cancel(self) -> None: self.dismiss(None) diff --git a/sqlit/ui/screens/error.py b/sqlit/ui/screens/error.py index 8c644d49..17996f9b 100644 --- a/sqlit/ui/screens/error.py +++ b/sqlit/ui/screens/error.py @@ -2,8 +2,6 @@ from __future__ import annotations -import textwrap - from textual.app import ComposeResult from textual.binding import Binding from textual.screen import ModalScreen @@ -31,16 +29,14 @@ class ErrorScreen(ModalScreen): width: 60; max-width: 80%; border: solid $error; + border-title-color: $error; border-subtitle-color: $error; + color: $error; } #error-message { padding: 1; } - - #error-message.flash { - background: $error 50%; - } """ def __init__(self, title: str, message: str): @@ -50,16 +46,20 @@ def __init__(self, title: str, message: str): def compose(self) -> ComposeResult: shortcuts = [("Copy", "y"), ("Close", "")] - wrapped = textwrap.fill(self.message, width=56) with Dialog(id="error-dialog", title=self.title_text, shortcuts=shortcuts): - yield Static(wrapped, id="error-message") + yield Static(self.message, id="error-message") def action_close(self) -> None: self.dismiss() + def check_action(self, action: str, parameters: tuple) -> bool | None: + # Prevent underlying screens from receiving actions when another modal is on top. + if self.app.screen is not self: + return False + return super().check_action(action, parameters) + def action_copy_message(self) -> None: + from ...widgets import flash_widget + self.app.copy_to_clipboard(self.message) - # Flash the message to indicate copy - msg = self.query_one("#error-message", Static) - msg.add_class("flash") - self.set_timer(0.15, lambda: msg.remove_class("flash")) + flash_widget(self.query_one("#error-message", Static)) diff --git a/sqlit/ui/screens/help.py b/sqlit/ui/screens/help.py index fe706c10..5f8f0790 100644 --- a/sqlit/ui/screens/help.py +++ b/sqlit/ui/screens/help.py @@ -27,15 +27,18 @@ class HelpScreen(ModalScreen): } #help-dialog { - width: 90; - max-width: 90%; - max-height: 90%; + width: 60; + max-width: 70%; + max-height: 80%; } #help-scroll { height: auto; + max-height: 100%; background: $surface; border: none; + scrollbar-gutter: stable; + color: white; } """ @@ -48,5 +51,5 @@ def compose(self) -> ComposeResult: with VerticalScroll(id="help-scroll"): yield Static(self.help_text) - def action_dismiss(self) -> None: + def action_dismiss(self) -> None: # type: ignore[override] self.dismiss(None) diff --git a/sqlit/ui/screens/leader_menu.py b/sqlit/ui/screens/leader_menu.py index e2d39cc2..f341bed1 100644 --- a/sqlit/ui/screens/leader_menu.py +++ b/sqlit/ui/screens/leader_menu.py @@ -2,11 +2,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any, cast + from textual.app import ComposeResult from textual.binding import Binding from textual.screen import ModalScreen from textual.widgets import Static +from ...widgets import Dialog + +if TYPE_CHECKING: + from ...app import SSMSTUI + class LeaderMenuScreen(ModalScreen): """Modal screen showing leader key commands.""" @@ -24,17 +31,18 @@ class LeaderMenuScreen(ModalScreen): } #leader-menu { + max-width: 35; + margin: 0; + border: solid $primary; + } + + #leader-menu-content { width: auto; height: auto; - max-width: 50; - background: $surface; - border: solid $primary; - padding: 1; - margin: 1 2; } """ - def __init__(self): + def __init__(self) -> None: super().__init__() from ...state_machine import get_leader_commands @@ -50,6 +58,7 @@ def compose(self) -> ComposeResult: lines = [] leader_commands = get_leader_commands() + app = cast("SSMSTUI", self.app) categories: dict[str, list] = {} for cmd in leader_commands: @@ -60,31 +69,34 @@ def compose(self) -> ComposeResult: for category, commands in categories.items(): lines.append(f"[bold $text-muted]{category}[/]") for cmd in commands: - if cmd.is_allowed(self.app): + if cmd.is_allowed(app): lines.append(f" [bold $warning]{cmd.key}[/] {cmd.label}") lines.append("") - lines.append("[$primary]Close: [/]") + # Remove trailing empty line + if lines and lines[-1] == "": + lines.pop() content = "\n".join(lines) - yield Static(content, id="leader-menu") + with Dialog(id="leader-menu", shortcuts=[("Close", "esc")]): + yield Static(content, id="leader-menu-content") - def action_dismiss(self) -> None: + def action_dismiss(self) -> None: # type: ignore[override] self.dismiss(None) def _run_and_dismiss(self, action_name: str) -> None: """Run an app action and dismiss the menu.""" self.dismiss(action_name) - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: """Handle cmd_* actions dynamically from leader commands.""" if name.startswith("action_cmd_"): - action = name[len("action_cmd_"):] + action = name[len("action_cmd_") :] if action in self._cmd_actions: cmd = self._cmd_actions[action] - def handler(): - if cmd.is_allowed(self.app): + def handler() -> None: + if cmd.is_allowed(cast("SSMSTUI", self.app)): self._run_and_dismiss(cmd.action) return handler diff --git a/sqlit/ui/screens/loading.py b/sqlit/ui/screens/loading.py new file mode 100644 index 00000000..15e9d17a --- /dev/null +++ b/sqlit/ui/screens/loading.py @@ -0,0 +1,58 @@ +"""A simple modal loading screen.""" + +from __future__ import annotations + +from collections.abc import Callable + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Center, Vertical +from textual.screen import ModalScreen +from textual.widgets import Label +from textual.widgets._loading_indicator import LoadingIndicator + + +class LoadingScreen(ModalScreen[None]): + """Screen to display a loading message with a spinner.""" + + BINDINGS = [ + Binding("escape", "cancel", "Cancel", show=False), + ] + + def __init__(self, message: str, *, on_cancel: Callable[[], None] | None = None): + super().__init__() + self.message = message + self._on_cancel = on_cancel + self._cancel_requested = False + + def compose(self) -> ComposeResult: + yield Vertical( + Center(LoadingIndicator(), classes="spinner-container"), + Center(Label(self.message, id="loading-message")), + classes="loading-dialog", + ) + + CSS = """ + LoadingScreen { + align: center middle; + } + + .loading-dialog { + background: $surface; + padding: 1 2; + width: auto; + height: auto; + border: solid $primary; + } + """ + + def action_cancel(self) -> None: + if self._cancel_requested: + return + self._cancel_requested = True + if self._on_cancel is not None: + self._on_cancel() + try: + self.query_one("#loading-message", Label).update("Cancelling...") + except Exception: + pass diff --git a/sqlit/ui/screens/message.py b/sqlit/ui/screens/message.py new file mode 100644 index 00000000..5287ee47 --- /dev/null +++ b/sqlit/ui/screens/message.py @@ -0,0 +1,75 @@ +"""A simple modal message screen (no buttons).""" + +from __future__ import annotations + +from collections.abc import Callable + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.screen import ModalScreen +from textual.widgets import Static + +from ...widgets import Dialog + + +class MessageScreen(ModalScreen): + """Modal screen that shows a message and closes via keyboard.""" + + BINDINGS = [ + Binding("enter", "primary", "Continue", show=False), + Binding("escape", "close", "Close", show=False), + ] + + CSS = """ + MessageScreen { + align: center middle; + background: transparent; + } + + #message-dialog { + width: auto; + min-width: 70; + max-width: 95%; + border: solid $primary; + border-subtitle-color: $primary; + } + + #message-content { + padding: 1 2; + color: $text; + } + """ + + def __init__( + self, + title: str, + message: str, + *, + enter_label: str = "Continue", + on_enter: Callable[[], None] | None = None, + ): + super().__init__() + self._title = title + self.message = message + self._enter_label = enter_label + self._on_enter = on_enter + + def compose(self) -> ComposeResult: + shortcuts = [(self._enter_label, "")] + with Dialog(id="message-dialog", title=self._title, shortcuts=shortcuts): + yield Static(self.message, id="message-content") + + def action_primary(self) -> None: + if self._on_enter is not None: + self._on_enter() + return + self.dismiss() + + def action_close(self) -> None: + self.dismiss() + + def check_action(self, action: str, parameters: tuple) -> bool | None: + # Prevent underlying screens from receiving actions when another modal is on top. + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/ui/screens/package_setup.py b/sqlit/ui/screens/package_setup.py new file mode 100644 index 00000000..3e49ee67 --- /dev/null +++ b/sqlit/ui/screens/package_setup.py @@ -0,0 +1,110 @@ +"""Package setup screen for missing Python drivers.""" + +from __future__ import annotations + +from collections.abc import Callable + +from rich.markup import escape +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import VerticalScroll +from textual.screen import ModalScreen +from textual.widgets import Static + +from ...db.exceptions import MissingDriverError +from ...install_strategy import detect_strategy +from ...widgets import Dialog + + +class PackageSetupScreen(ModalScreen): + """Screen that shows install instructions for a missing Python package.""" + + BINDINGS = [ + Binding("i", "install", "Install"), + Binding("y", "yank", "Yank"), + Binding("escape", "cancel", "Cancel"), + ] + + CSS = """ + PackageSetupScreen { + align: center middle; + background: transparent; + } + + #package-dialog { + width: 80; + height: auto; + max-height: 90%; + } + + #package-message { + margin-bottom: 1; + } + + #package-scroll { + height: auto; + max-height: 12; + border: solid $primary-darken-2; + background: $surface-darken-1; + padding: 1; + margin-top: 1; + overflow-y: auto; + } + """ + + def __init__(self, error: MissingDriverError, *, on_install: Callable[[MissingDriverError], None]): + super().__init__() + self.error = error + self._on_install = on_install + self._instructions_text = "" + self._can_auto_install = True + + def compose(self) -> ComposeResult: + strategy = detect_strategy(extra_name=self.error.extra_name, package_name=self.error.package_name) + self._can_auto_install = strategy.can_auto_install + self._instructions_text = strategy.manual_instructions.strip() + "\n" + + shortcuts = [("Yank", "y"), ("Cancel", "")] + if self._can_auto_install: + shortcuts.insert(0, ("Install", "i")) + with Dialog(id="package-dialog", title="Missing package", shortcuts=shortcuts): + yield Static( + f"This connection requires the [bold]{self.error.driver_name}[/] driver.\n" + f"Package: [bold]{self.error.package_name}[/]", + id="package-message", + ) + + with VerticalScroll(id="package-scroll"): + yield Static(escape(self._instructions_text.strip()), id="package-script") + + def on_mount(self) -> None: + self.query_one("#package-scroll", VerticalScroll).focus() + + def action_install(self) -> None: + if not self._can_auto_install: + try: + self.app.notify( + "Automatic installation isn't available for this Python environment.", + severity="warning", + timeout=6, + ) + except Exception: + pass + return + self._on_install(self.error) + + def action_yank(self) -> None: + from ...widgets import flash_widget + + self.app.copy_to_clipboard(self._instructions_text.strip()) + flash_widget(self.query_one("#package-script", Static)) + + def action_cancel(self) -> None: + self.dismiss(None) + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + if action == "install" and not self._can_auto_install: + return False + return super().check_action(action, parameters) diff --git a/sqlit/ui/screens/password_input.py b/sqlit/ui/screens/password_input.py new file mode 100644 index 00000000..ff743452 --- /dev/null +++ b/sqlit/ui/screens/password_input.py @@ -0,0 +1,148 @@ +"""Password input dialog screen.""" + +from __future__ import annotations + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.screen import ModalScreen +from textual.widgets import Input, Static + +from ...widgets import Dialog + + +class PasswordInputScreen(ModalScreen): + """Modal screen for password input. + + This screen prompts the user to enter a password when connecting + to a database that has no stored password. + """ + + BINDINGS = [ + Binding("escape", "cancel", "Cancel"), + Binding("enter", "submit", "Submit", show=False), + ] + + CSS = """ + PasswordInputScreen { + align: center middle; + background: transparent; + } + + #password-dialog { + width: 50; + height: auto; + max-height: 12; + } + + #password-description { + margin-bottom: 1; + color: $text-muted; + height: auto; + } + + #password-container { + border: solid $panel; + background: $surface; + padding: 0; + margin-top: 0; + height: 3; + border-title-align: left; + border-title-color: $text-muted; + border-title-background: $surface; + border-title-style: none; + } + + #password-container.focused { + border: solid $primary; + border-title-color: $primary; + } + + #password-container Input { + border: none; + height: 1; + padding: 0; + background: $surface; + } + + #password-container Input:focus { + border: none; + background-tint: $foreground 5%; + } + """ + + def __init__( + self, + connection_name: str, + *, + title: str = "Password Required", + description: str | None = None, + password_type: str = "database", + ): + """Initialize the password input screen. + + Args: + connection_name: The name of the connection requiring the password. + title: The dialog title. + description: Optional description text. + password_type: Type of password ("database" or "ssh"). + """ + super().__init__() + self.connection_name = connection_name + self.title_text = title + self.password_type = password_type + if description: + self.description = description + else: + if password_type == "ssh": + self.description = f"Enter SSH password for '{connection_name}':" + else: + self.description = f"Enter password for '{connection_name}':" + + def compose(self) -> ComposeResult: + shortcuts: list[tuple[str, str]] = [("Submit", ""), ("Cancel", "")] + with Dialog(id="password-dialog", title=self.title_text, shortcuts=shortcuts): + yield Static(self.description, id="password-description") + from textual.containers import Container + + container = Container(id="password-container") + container.border_title = "Password" + with container: + yield Input( + value="", + placeholder="", + id="password-input", + password=True, + ) + + def on_mount(self) -> None: + self.query_one("#password-input", Input).focus() + + def on_input_submitted(self, event: Input.Submitted) -> None: + if event.input.id == "password-input": + self.dismiss(event.value) + + def on_descendant_focus(self, event) -> None: + try: + container = self.query_one("#password-container") + container.add_class("focused") + except Exception: + pass + + def on_descendant_blur(self, event) -> None: + try: + container = self.query_one("#password-container") + container.remove_class("focused") + except Exception: + pass + + def action_submit(self) -> None: + password = self.query_one("#password-input", Input).value + self.dismiss(password) + + def action_cancel(self) -> None: + self.dismiss(None) + + def check_action(self, action: str, parameters: tuple) -> bool | None: + if self.app.screen is not self: + return False + return super().check_action(action, parameters) diff --git a/sqlit/ui/screens/query_history.py b/sqlit/ui/screens/query_history.py index 399ee0b6..1dc5dd0a 100644 --- a/sqlit/ui/screens/query_history.py +++ b/sqlit/ui/screens/query_history.py @@ -101,9 +101,7 @@ def compose(self) -> ComposeResult: if len(entry.query) > 60: query_preview += "..." - options.append( - Option(f"[dim]{time_str}[/] {query_preview}", id=entry.timestamp) - ) + options.append(Option(f"[dim]{time_str}[/] {query_preview}", id=entry.timestamp)) yield OptionList(*options, id="history-list") else: @@ -121,7 +119,7 @@ def on_mount(self) -> None: except Exception: pass - def on_option_list_option_highlighted(self, event) -> None: + def on_option_list_option_highlighted(self, event: OptionList.OptionHighlighted) -> None: if event.option_list.id == "history-list": idx = event.option_list.highlighted if idx is not None: @@ -147,7 +145,7 @@ def action_select(self) -> None: except Exception: self.dismiss(None) - def on_option_list_option_selected(self, event) -> None: + def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: if event.option_list.id == "history-list": idx = event.option_list.highlighted if idx is not None and idx < len(self.history): diff --git a/sqlit/ui/screens/theme.py b/sqlit/ui/screens/theme.py index 9f477f92..2cfe0987 100644 --- a/sqlit/ui/screens/theme.py +++ b/sqlit/ui/screens/theme.py @@ -10,16 +10,25 @@ from ...widgets import Dialog -THEMES = [ - ("textual-dark", "Textual Dark"), - ("textual-light", "Textual Light"), - ("nord", "Nord"), - ("gruvbox", "Gruvbox"), - ("tokyo-night", "Tokyo Night"), - ("solarized-light", "Solarized Light"), - ("catppuccin-mocha", "Catppuccin Mocha"), - ("dracula", "Dracula"), -] +THEME_LABELS = { + "sqlit": "Sqlit", + "sqlit-light": "Sqlit Light", + "textual-dark": "Textual Dark", + "textual-light": "Textual Light", + "nord": "Nord", + "gruvbox": "Gruvbox", + "tokyo-night": "Tokyo Night", + "solarized-light": "Solarized Light", + "solarized-dark": "Solarized Dark", + "monokai": "Monokai", + "flexoki": "Flexoki", + "catppuccin-latte": "Catppuccin Latte", + "rose-pine": "Rose Pine", + "rose-pine-moon": "Rose Pine Moon", + "rose-pine-dawn": "Rose Pine Dawn", + "catppuccin-mocha": "Catppuccin Mocha", + "dracula": "Dracula", +} class ThemeScreen(ModalScreen[str | None]): @@ -54,12 +63,32 @@ class ThemeScreen(ModalScreen[str | None]): def __init__(self, current_theme: str): super().__init__() self.current_theme = current_theme + self._theme_ids: list[str] = [] + + def _build_theme_list(self) -> list[tuple[str, str]]: + available = set(self.app.available_themes) + available.discard("textual-ansi") + ordered: list[tuple[str, str]] = [] + seen: set[str] = set() + + for theme_id, theme_name in THEME_LABELS.items(): + if theme_id in available: + ordered.append((theme_id, theme_name)) + seen.add(theme_id) + + for theme_id in sorted(available - seen): + theme_name = " ".join(part.capitalize() for part in theme_id.split("-")) + ordered.append((theme_id, theme_name)) + + return ordered def compose(self) -> ComposeResult: shortcuts = [("Select", ""), ("Cancel", "")] with Dialog(id="theme-dialog", title="Select Theme", shortcuts=shortcuts): options = [] - for theme_id, theme_name in THEMES: + themes = self._build_theme_list() + self._theme_ids = [theme_id for theme_id, _ in themes] + for theme_id, theme_name in themes: prefix = "> " if theme_id == self.current_theme else " " options.append(Option(f"{prefix}{theme_name}", id=theme_id)) yield OptionList(*options, id="theme-list") @@ -68,12 +97,12 @@ def on_mount(self) -> None: option_list = self.query_one("#theme-list", OptionList) option_list.focus() # Highlight current theme - for i, (theme_id, _) in enumerate(THEMES): + for i, theme_id in enumerate(self._theme_ids): if theme_id == self.current_theme: option_list.highlighted = i break - def on_option_list_option_selected(self, event) -> None: + def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None: self.dismiss(event.option.id) def action_select_option(self) -> None: diff --git a/sqlit/ui/screens/value_view.py b/sqlit/ui/screens/value_view.py index 51fdcee8..648d490e 100644 --- a/sqlit/ui/screens/value_view.py +++ b/sqlit/ui/screens/value_view.py @@ -42,10 +42,6 @@ class ValueViewScreen(ModalScreen): width: auto; height: auto; } - - #value-text.flash-copy { - background: $success; - } """ def __init__(self, value: str, title: str = "Value"): @@ -84,19 +80,15 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.query_one("#value-scroll").focus() - def action_dismiss(self) -> None: + def action_dismiss(self) -> None: # type: ignore[override] self.dismiss(None) def action_copy(self) -> None: + from ...widgets import flash_widget + copied = getattr(self.app, "_copy_text", None) if callable(copied): copied(self.value) - self._flash_copy() + flash_widget(self.query_one("#value-text")) else: self.notify("Copy unavailable", timeout=2) - - def _flash_copy(self) -> None: - """Flash the text background to indicate copy.""" - text_area = self.query_one("#value-text") - text_area.add_class("flash-copy") - self.set_timer(0.15, lambda: text_area.remove_class("flash-copy")) diff --git a/sqlit/utils.py b/sqlit/utils.py new file mode 100644 index 00000000..65915748 --- /dev/null +++ b/sqlit/utils.py @@ -0,0 +1,76 @@ +"""Utility functions for sqlit.""" + +from __future__ import annotations + + +def fuzzy_match(pattern: str, text: str) -> tuple[bool, list[int]]: + """Check if pattern fuzzy matches text and return matched indices. + + Args: + pattern: The search pattern (e.g., "usrtbl" to match "users_table") + text: The text to search in + + Returns: + Tuple of (matches, indices) where indices are positions in text that matched. + """ + if not pattern: + return True, [] + + pattern = pattern.lower() + text_lower = text.lower() + + pattern_idx = 0 + indices = [] + + for i, char in enumerate(text_lower): + if pattern_idx < len(pattern) and char == pattern[pattern_idx]: + indices.append(i) + pattern_idx += 1 + + return pattern_idx == len(pattern), indices + + +def highlight_matches(text: str, indices: list[int], style: str = "bold yellow") -> str: + """Highlight matched characters in text using Rich markup. + + Args: + text: The original text + indices: List of character indices to highlight + style: Rich style string for highlighting (default: "bold yellow") + + Returns: + Text with Rich markup highlighting the matched characters. + """ + if not indices: + return text + + result = [] + idx_set = set(indices) + + for i, char in enumerate(text): + if i in idx_set: + result.append(f"[{style}]{char}[/]") + else: + result.append(char) + + return "".join(result) + + +def format_duration_ms(ms: float, *, always_seconds: bool = False) -> str: + """Format milliseconds into a human-readable duration string. + + Args: + ms: Duration in milliseconds + always_seconds: If True, always format as seconds (e.g., "0.00s") + + Returns: + Formatted duration string + """ + if always_seconds: + return f"{ms / 1000:.2f}s" + if ms >= 1000: + return f"{ms / 1000:.2f}s" + elif ms >= 1: + return f"{ms:.0f}ms" + else: + return f"{ms:.2f}ms" diff --git a/sqlit/validation.py b/sqlit/validation.py index 9b54d601..562eb323 100644 --- a/sqlit/validation.py +++ b/sqlit/validation.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from .db.schema import is_file_based +from .db.providers import is_file_based if TYPE_CHECKING: from .fields import FieldDefinition @@ -46,7 +46,7 @@ def validate_connection_form( name: str, db_type: str, values: dict, - field_definitions: dict[str, "FieldDefinition"], + field_definitions: dict[str, FieldDefinition], existing_names: set[str], editing_name: str | None = None, ) -> ValidationState: @@ -65,11 +65,9 @@ def validate_connection_form( """ state = ValidationState() - # Validate name uniqueness if name in existing_names and name != editing_name: state.add_error("name", "Name already exists.") - # Validate required fields for field_name, field_def in field_definitions.items(): if not field_def.required: continue @@ -81,7 +79,6 @@ def validate_connection_form( if is_visible and not values.get(field_name): state.add_error(field_name, "Required.") - # File path validation for file-based databases if is_file_based(db_type): fp = values.get("file_path", "").strip() if not fp: @@ -89,7 +86,6 @@ def validate_connection_form( elif not Path(fp).exists(): state.add_error("file_path", "File not found.") - # SSH validation ssh_enabled = values.get("ssh_enabled") == "enabled" if ssh_enabled: if not values.get("ssh_host"): diff --git a/sqlit/widgets.py b/sqlit/widgets.py index a48ed4f0..6ad3ad89 100644 --- a/sqlit/widgets.py +++ b/sqlit/widgets.py @@ -3,11 +3,41 @@ from __future__ import annotations from enum import Enum +from typing import TYPE_CHECKING, Any from textual.app import ComposeResult from textual.containers import Container, Horizontal from textual.widgets import Static +if TYPE_CHECKING: + from collections.abc import Callable + + from textual.widget import Widget + + +def flash_widget( + widget: Widget, + css_class: str = "flash", + duration: float = 0.15, + on_complete: Callable[[], None] | None = None, +) -> None: + """Flash a widget by temporarily adding a CSS class. + + Args: + widget: The widget to flash. + css_class: The CSS class to add (default: "flash"). + duration: How long to show the flash in seconds (default: 0.15). + on_complete: Optional callback to run after flash completes. + """ + widget.add_class(css_class) + + def cleanup() -> None: + widget.remove_class(css_class) + if on_complete: + on_complete() + + widget.set_timer(duration, cleanup) + class VimMode(Enum): """Vim editing modes.""" @@ -33,8 +63,8 @@ class ContextFooter(Horizontal): ContextFooter { height: 1; dock: bottom; - background: $surface; - color: $primary; + background: $footer-background; + color: $footer-key-foreground; padding: 0 1; } @@ -50,7 +80,7 @@ class ContextFooter(Horizontal): } """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._left_bindings: list[KeyBinding] = [] self._right_bindings: list[KeyBinding] = [] @@ -67,6 +97,7 @@ def set_bindings(self, left: list[KeyBinding], right: list[KeyBinding]) -> None: def _rebuild(self) -> None: """Rebuild the footer content with left and right sections.""" + def format_binding(binding: KeyBinding) -> str: if binding.disabled: return f"[$text-muted strike]{binding.label}: {binding.key}[/]" @@ -87,8 +118,9 @@ class Dialog(Container): DEFAULT_CSS = """ Dialog { - border: solid $primary; + border: round $primary; background: $surface; + color: $primary; padding: 1; height: auto; max-height: 85%; @@ -97,7 +129,7 @@ class Dialog(Container): scrollbar-visibility: hidden; border-title-align: left; - border-title-color: $text-muted; + border-title-color: $primary; border-title-background: $surface; border-title-style: bold; @@ -112,8 +144,8 @@ def __init__( self, title: str | None = None, shortcuts: list[tuple[str, str]] | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Initialize the dialog. Args: @@ -125,12 +157,80 @@ def __init__( if title is not None: self.border_title = title if shortcuts: - subtitle = " ".join( - f"{action}: [bold]{key}[/]" for action, key in shortcuts + # Use a visible separator. Border subtitles can collapse regular spaces, + # so we use non-breaking spaces to preserve padding around the separator. + def format_key(key: str) -> str: + # Wrap key in <> if not already wrapped + if key.startswith("<") and key.endswith(">"): + return key + return f"<{key}>" + + subtitle = "\u00a0·\u00a0".join( + f"{action}: [bold]{format_key(key)}[/]" for action, key in shortcuts ) self.border_subtitle = subtitle +class TreeFilterInput(Static): + """Filter input widget for the explorer tree.""" + + DEFAULT_CSS = """ + TreeFilterInput { + width: 100%; + height: 1; + background: $surface; + display: none; + padding: 0 1; + } + + TreeFilterInput.visible { + display: block; + } + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__("", *args, **kwargs) + self.filter_text: str = "" + self.match_count: int = 0 + self.total_count: int = 0 + + def set_filter(self, text: str, match_count: int = 0, total_count: int = 0) -> None: + """Set the filter text and match count.""" + self.filter_text = text + self.match_count = match_count + self.total_count = total_count + self._rebuild() + + def clear(self) -> None: + """Clear the filter.""" + self.filter_text = "" + self.match_count = 0 + self.total_count = 0 + self._rebuild() + + def _rebuild(self) -> None: + """Rebuild the display.""" + if not self.filter_text: + self.update("[dim]/[/] ") + else: + count_text = f"[dim]{self.match_count}/{self.total_count}[/]" + self.update(f"[dim]/[/] {self.filter_text} {count_text}") + + def show(self) -> None: + """Show the filter input.""" + self.add_class("visible") + self._rebuild() + + def hide(self) -> None: + """Hide the filter input.""" + self.remove_class("visible") + + @property + def is_visible(self) -> bool: + """Check if filter is visible.""" + return "visible" in self.classes + + class AutocompleteDropdown(Static): """Dropdown widget for SQL autocomplete suggestions.""" @@ -143,7 +243,7 @@ class AutocompleteDropdown(Static): height: auto; max-height: 10; background: $surface; - border: solid $primary; + border: round $border; padding: 0; display: none; } @@ -162,7 +262,7 @@ class AutocompleteDropdown(Static): } """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__("", *args, **kwargs) self.items: list[str] = [] self.filtered_items: list[str] = [] @@ -175,9 +275,7 @@ def set_items(self, items: list[str], filter_text: str = "") -> None: self.filter_text = filter_text.lower() if self.filter_text: - self.filtered_items = [ - item for item in items if item.lower().startswith(self.filter_text) - ] + self.filtered_items = [item for item in items if item.lower().startswith(self.filter_text)] else: self.filtered_items = items[:20] diff --git a/tests/conftest.py b/tests/conftest.py index 03636688..91fc8d8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -import shutil import socket import sqlite3 import subprocess @@ -13,13 +12,16 @@ import pytest +_TEST_CONFIG_DIR = Path(tempfile.mkdtemp(prefix="sqlit-test-config-")) +os.environ.setdefault("SQLIT_CONFIG_DIR", str(_TEST_CONFIG_DIR)) + def is_port_open(host: str, port: int, timeout: float = 1.0) -> bool: """Check if a TCP port is open.""" try: with socket.create_connection((host, port), timeout=timeout): return True - except (OSError, socket.timeout): + except (TimeoutError, OSError): return False @@ -42,9 +44,9 @@ def run_cli(*args: str, check: bool = True) -> subprocess.CompletedProcess: text=True, ) if check and result.returncode != 0: - # Ignore RuntimeWarning about module import order stderr_clean = "\n".join( - line for line in result.stderr.split("\n") + line + for line in result.stderr.split("\n") if "RuntimeWarning" not in line and "unpredictable behaviour" not in line ).strip() if stderr_clean: @@ -60,11 +62,6 @@ def cleanup_connection(name: str) -> None: pass -# ============================================================================= -# SQLite Fixtures -# ============================================================================= - - @pytest.fixture(scope="function") def sqlite_db_path(tmp_path: Path) -> Path: """Create a temporary SQLite database file path.""" @@ -77,7 +74,6 @@ def sqlite_db(sqlite_db_path: Path) -> Path: conn = sqlite3.connect(sqlite_db_path) cursor = conn.cursor() - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INTEGER PRIMARY KEY, @@ -95,13 +91,11 @@ def sqlite_db(sqlite_db_path: Path) -> Path: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.executemany( "INSERT INTO test_users (id, name, email) VALUES (?, ?, ?)", [ @@ -131,28 +125,23 @@ def sqlite_connection(sqlite_db: Path) -> str: """Create a sqlit CLI connection for SQLite and clean up after test.""" connection_name = f"test_sqlite_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "sqlite", - "--file-path", str(sqlite_db), + "connections", + "add", + "sqlite", + "--name", + connection_name, + "--file-path", + str(sqlite_db), ) yield connection_name - # Cleanup cleanup_connection(connection_name) -# ============================================================================= -# SQL Server Fixtures -# ============================================================================= - -# SQL Server connection settings for Docker MSSQL_HOST = os.environ.get("MSSQL_HOST", "localhost") MSSQL_PORT = int(os.environ.get("MSSQL_PORT", "1433")) MSSQL_USER = os.environ.get("MSSQL_USER", "sa") @@ -171,7 +160,6 @@ def mssql_server_ready() -> bool: if not mssql_available(): return False - # Wait a bit for SQL Server to be fully ready time.sleep(2) return True @@ -194,7 +182,6 @@ def mssql_db(mssql_server_ready: bool) -> str: driver = drivers[0] - # Connect to master to create test database conn_str = ( f"DRIVER={{{driver}}};" f"SERVER={MSSQL_HOST},{MSSQL_PORT};" @@ -209,18 +196,15 @@ def mssql_db(mssql_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() - # Drop test database if it exists (separate statements to avoid "connection busy" errors) cursor.execute(f"SELECT name FROM sys.databases WHERE name = '{MSSQL_DATABASE}'") if cursor.fetchone(): cursor.execute(f"ALTER DATABASE [{MSSQL_DATABASE}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE") cursor.execute(f"DROP DATABASE [{MSSQL_DATABASE}]") - # Create test database cursor.execute(f"CREATE DATABASE [{MSSQL_DATABASE}]") cursor.close() conn.close() - # Connect to test database and create schema conn_str = ( f"DRIVER={{{driver}}};" f"SERVER={MSSQL_HOST},{MSSQL_PORT};" @@ -232,7 +216,6 @@ def mssql_db(mssql_server_ready: bool) -> str: conn = pyodbc.connect(conn_str, timeout=10) cursor = conn.cursor() - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INT PRIMARY KEY, @@ -250,13 +233,11 @@ def mssql_db(mssql_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Create test stored procedure cursor.execute(""" CREATE PROCEDURE sp_test_get_users AS @@ -265,7 +246,6 @@ def mssql_db(mssql_server_ready: bool) -> str: END """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -289,7 +269,6 @@ def mssql_db(mssql_server_ready: bool) -> str: yield MSSQL_DATABASE - # Cleanup: drop test database try: conn = pyodbc.connect( f"DRIVER={{{driver}}};" @@ -302,7 +281,6 @@ def mssql_db(mssql_server_ready: bool) -> str: ) conn.autocommit = True cursor = conn.cursor() - # Execute each statement separately to avoid "connection busy" errors cursor.execute(f"SELECT name FROM sys.databases WHERE name = '{MSSQL_DATABASE}'") if cursor.fetchone(): cursor.execute(f"ALTER DATABASE [{MSSQL_DATABASE}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE") @@ -318,32 +296,31 @@ def mssql_connection(mssql_db: str) -> str: """Create a sqlit CLI connection for SQL Server and clean up after test.""" connection_name = f"test_mssql_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "mssql", - "--server", f"{MSSQL_HOST},{MSSQL_PORT}" if MSSQL_PORT != 1433 else MSSQL_HOST, - "--database", mssql_db, - "--auth-type", "sql", - "--username", MSSQL_USER, - "--password", MSSQL_PASSWORD, + "connections", + "add", + "mssql", + "--name", + connection_name, + "--server", + f"{MSSQL_HOST},{MSSQL_PORT}" if MSSQL_PORT != 1433 else MSSQL_HOST, + "--database", + mssql_db, + "--auth-type", + "sql", + "--username", + MSSQL_USER, + "--password", + MSSQL_PASSWORD, ) yield connection_name - # Cleanup cleanup_connection(connection_name) -# ============================================================================= -# PostgreSQL Fixtures -# ============================================================================= - -# PostgreSQL connection settings for Docker POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost") POSTGRES_PORT = int(os.environ.get("POSTGRES_PORT", "5432")) POSTGRES_USER = os.environ.get("POSTGRES_USER", "testuser") @@ -362,7 +339,6 @@ def postgres_server_ready() -> bool: if not postgres_available(): return False - # Wait a bit for PostgreSQL to be fully ready time.sleep(1) return True @@ -390,12 +366,10 @@ def postgres_db(postgres_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() - # Drop tables if they exist and recreate cursor.execute("DROP TABLE IF EXISTS test_users CASCADE") cursor.execute("DROP TABLE IF EXISTS test_products CASCADE") cursor.execute("DROP VIEW IF EXISTS test_user_emails") - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INTEGER PRIMARY KEY, @@ -413,13 +387,11 @@ def postgres_db(postgres_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -441,7 +413,6 @@ def postgres_db(postgres_server_ready: bool) -> str: yield POSTGRES_DATABASE - # Cleanup: drop test tables try: conn = psycopg2.connect( host=POSTGRES_HOST, @@ -466,32 +437,31 @@ def postgres_connection(postgres_db: str) -> str: """Create a sqlit CLI connection for PostgreSQL and clean up after test.""" connection_name = f"test_postgres_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "postgresql", - "--server", POSTGRES_HOST, - "--port", str(POSTGRES_PORT), - "--database", postgres_db, - "--username", POSTGRES_USER, - "--password", POSTGRES_PASSWORD, + "connections", + "add", + "postgresql", + "--name", + connection_name, + "--server", + POSTGRES_HOST, + "--port", + str(POSTGRES_PORT), + "--database", + postgres_db, + "--username", + POSTGRES_USER, + "--password", + POSTGRES_PASSWORD, ) yield connection_name - # Cleanup cleanup_connection(connection_name) -# ============================================================================= -# MySQL Fixtures -# ============================================================================= - -# MySQL connection settings for Docker # Note: We use root user because MySQL's testuser only has localhost access inside the container MYSQL_HOST = os.environ.get("MYSQL_HOST", "localhost") MYSQL_PORT = int(os.environ.get("MYSQL_PORT", "3306")) @@ -511,7 +481,6 @@ def mysql_server_ready() -> bool: if not mysql_available(): return False - # Wait a bit for MySQL to be fully ready time.sleep(1) return True @@ -538,12 +507,10 @@ def mysql_db(mysql_server_ready: bool) -> str: ) cursor = conn.cursor() - # Drop tables if they exist and recreate cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INT PRIMARY KEY, @@ -561,13 +528,11 @@ def mysql_db(mysql_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -590,7 +555,6 @@ def mysql_db(mysql_server_ready: bool) -> str: yield MYSQL_DATABASE - # Cleanup: drop test tables try: conn = mysql.connector.connect( host=MYSQL_HOST, @@ -615,24 +579,28 @@ def mysql_connection(mysql_db: str) -> str: """Create a sqlit CLI connection for MySQL and clean up after test.""" connection_name = f"test_mysql_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "mysql", - "--server", MYSQL_HOST, - "--port", str(MYSQL_PORT), - "--database", mysql_db, - "--username", MYSQL_USER, - "--password", MYSQL_PASSWORD, + "connections", + "add", + "mysql", + "--name", + connection_name, + "--server", + MYSQL_HOST, + "--port", + str(MYSQL_PORT), + "--database", + mysql_db, + "--username", + MYSQL_USER, + "--password", + MYSQL_PASSWORD, ) yield connection_name - # Cleanup cleanup_connection(connection_name) @@ -659,7 +627,6 @@ def oracle_server_ready() -> bool: if not oracle_available(): return False - # Wait a bit for Oracle to be fully ready time.sleep(2) return True @@ -684,20 +651,18 @@ def oracle_db(oracle_server_ready: bool) -> str: ) cursor = conn.cursor() - # Drop tables if they exist (Oracle doesn't have IF EXISTS, use exception handling) + # Oracle lacks `DROP TABLE IF EXISTS`; ignore "does not exist" errors. for table in ["test_users", "test_products"]: try: cursor.execute(f"DROP TABLE {table} CASCADE CONSTRAINTS") except oracledb.DatabaseError: pass # Table doesn't exist - # Drop view if exists try: cursor.execute("DROP VIEW test_user_emails") except oracledb.DatabaseError: pass - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id NUMBER PRIMARY KEY, @@ -715,13 +680,11 @@ def oracle_db(oracle_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com') """) @@ -750,7 +713,6 @@ def oracle_db(oracle_server_ready: bool) -> str: yield ORACLE_SERVICE - # Cleanup: drop test tables try: conn = oracledb.connect( user=ORACLE_USER, @@ -778,24 +740,28 @@ def oracle_connection(oracle_db: str) -> str: """Create a sqlit CLI connection for Oracle and clean up after test.""" connection_name = f"test_oracle_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "oracle", - "--server", ORACLE_HOST, - "--port", str(ORACLE_PORT), - "--database", oracle_db, - "--username", ORACLE_USER, - "--password", ORACLE_PASSWORD, + "connections", + "add", + "oracle", + "--name", + connection_name, + "--server", + ORACLE_HOST, + "--port", + str(ORACLE_PORT), + "--database", + oracle_db, + "--username", + ORACLE_USER, + "--password", + ORACLE_PASSWORD, ) yield connection_name - # Cleanup cleanup_connection(connection_name) @@ -803,7 +769,6 @@ def oracle_connection(oracle_db: str) -> str: # MariaDB Fixtures # ============================================================================= -# MariaDB connection settings for Docker # Note: Using 127.0.0.1 instead of localhost to force TCP connection (localhost uses Unix socket) MARIADB_HOST = os.environ.get("MARIADB_HOST", "127.0.0.1") MARIADB_PORT = int(os.environ.get("MARIADB_PORT", "3307")) @@ -823,7 +788,6 @@ def mariadb_server_ready() -> bool: if not mariadb_available(): return False - # Wait a bit for MariaDB to be fully ready time.sleep(1) return True @@ -850,12 +814,10 @@ def mariadb_db(mariadb_server_ready: bool) -> str: ) cursor = conn.cursor() - # Drop tables if they exist and recreate cursor.execute("DROP TABLE IF EXISTS test_users") cursor.execute("DROP TABLE IF EXISTS test_products") cursor.execute("DROP VIEW IF EXISTS test_user_emails") - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INT PRIMARY KEY, @@ -873,13 +835,11 @@ def mariadb_db(mariadb_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -902,7 +862,6 @@ def mariadb_db(mariadb_server_ready: bool) -> str: yield MARIADB_DATABASE - # Cleanup: drop test tables try: conn = mariadb.connect( host=MARIADB_HOST, @@ -927,24 +886,28 @@ def mariadb_connection(mariadb_db: str) -> str: """Create a sqlit CLI connection for MariaDB and clean up after test.""" connection_name = f"test_mariadb_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "mariadb", - "--server", MARIADB_HOST, - "--port", str(MARIADB_PORT), - "--database", mariadb_db, - "--username", MARIADB_USER, - "--password", MARIADB_PASSWORD, + "connections", + "add", + "mariadb", + "--name", + connection_name, + "--server", + MARIADB_HOST, + "--port", + str(MARIADB_PORT), + "--database", + mariadb_db, + "--username", + MARIADB_USER, + "--password", + MARIADB_PASSWORD, ) yield connection_name - # Cleanup cleanup_connection(connection_name) @@ -969,7 +932,6 @@ def duckdb_db(duckdb_db_path: Path) -> Path: conn = duckdb.connect(str(duckdb_db_path)) - # Create test tables conn.execute(""" CREATE TABLE test_users ( id INTEGER PRIMARY KEY, @@ -987,13 +949,11 @@ def duckdb_db(duckdb_db_path: Path) -> Path: ) """) - # Create test view conn.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data conn.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -1018,20 +978,20 @@ def duckdb_connection(duckdb_db: Path) -> str: """Create a sqlit CLI connection for DuckDB and clean up after test.""" connection_name = f"test_duckdb_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "duckdb", - "--file-path", str(duckdb_db), + "connections", + "add", + "duckdb", + "--name", + connection_name, + "--file-path", + str(duckdb_db), ) yield connection_name - # Cleanup cleanup_connection(connection_name) @@ -1058,7 +1018,6 @@ def cockroachdb_server_ready() -> bool: if not cockroachdb_available(): return False - # Wait a bit for CockroachDB to be fully ready time.sleep(2) return True @@ -1075,7 +1034,6 @@ def cockroachdb_db(cockroachdb_server_ready: bool) -> str: pytest.skip("psycopg2 is not installed") try: - # Connect to default database first conn = psycopg2.connect( host=COCKROACHDB_HOST, port=COCKROACHDB_PORT, @@ -1087,12 +1045,11 @@ def cockroachdb_db(cockroachdb_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() - # Create test database if it doesn't exist + # Database creation requires a connection to an existing DB (e.g. `defaultdb`). cursor.execute(f"DROP DATABASE IF EXISTS {COCKROACHDB_DATABASE}") cursor.execute(f"CREATE DATABASE {COCKROACHDB_DATABASE}") conn.close() - # Connect to test database conn = psycopg2.connect( host=COCKROACHDB_HOST, port=COCKROACHDB_PORT, @@ -1104,7 +1061,6 @@ def cockroachdb_db(cockroachdb_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INT PRIMARY KEY, @@ -1122,13 +1078,11 @@ def cockroachdb_db(cockroachdb_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -1150,7 +1104,6 @@ def cockroachdb_db(cockroachdb_server_ready: bool) -> str: yield COCKROACHDB_DATABASE - # Cleanup: drop test database try: conn = psycopg2.connect( host=COCKROACHDB_HOST, @@ -1173,20 +1126,23 @@ def cockroachdb_connection(cockroachdb_db: str) -> str: """Create a sqlit CLI connection for CockroachDB and clean up after test.""" connection_name = f"test_cockroachdb_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection args = [ - "connection", "create", - "--name", connection_name, - "--db-type", "cockroachdb", - "--server", COCKROACHDB_HOST, - "--port", str(COCKROACHDB_PORT), - "--database", cockroachdb_db, - "--username", COCKROACHDB_USER, + "connections", + "add", + "cockroachdb", + "--name", + connection_name, + "--server", + COCKROACHDB_HOST, + "--port", + str(COCKROACHDB_PORT), + "--database", + cockroachdb_db, + "--username", + COCKROACHDB_USER, ] - # Only add password if it's set if COCKROACHDB_PASSWORD: args.extend(["--password", COCKROACHDB_PASSWORD]) else: @@ -1196,7 +1152,6 @@ def cockroachdb_connection(cockroachdb_db: str) -> str: yield connection_name - # Cleanup cleanup_connection(connection_name) @@ -1220,7 +1175,6 @@ def turso_server_ready() -> bool: if not turso_available(): return False - # Wait a bit for libsql-server to be fully ready time.sleep(1) return True @@ -1241,12 +1195,10 @@ def turso_db(turso_server_ready: bool) -> str: try: client = create_client_sync(turso_url) - # Drop tables if they exist and recreate client.execute("DROP TABLE IF EXISTS test_users") client.execute("DROP TABLE IF EXISTS test_products") client.execute("DROP VIEW IF EXISTS test_user_emails") - # Create test tables client.execute(""" CREATE TABLE test_users ( id INTEGER PRIMARY KEY, @@ -1264,13 +1216,11 @@ def turso_db(turso_server_ready: bool) -> str: ) """) - # Create test view client.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data client.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -1292,7 +1242,6 @@ def turso_db(turso_server_ready: bool) -> str: yield turso_url - # Cleanup: drop test tables try: client = create_client_sync(turso_url) client.execute("DROP TABLE IF EXISTS test_users") @@ -1308,21 +1257,157 @@ def turso_connection(turso_db: str) -> str: """Create a sqlit CLI connection for Turso and clean up after test.""" connection_name = f"test_turso_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection (no auth token needed for local libsql-server) run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "turso", - "--server", turso_db, - "--password", "", # No auth token for local server + "connections", + "add", + "turso", + "--name", + connection_name, + "--server", + turso_db, + "--password", + "", # No auth token for local server. ) yield connection_name - # Cleanup + cleanup_connection(connection_name) + + +# ============================================================================= +# D1 Fixtures +# ============================================================================= + +# D1 connection settings for Docker (miniflare) +D1_HOST = os.environ.get("D1_HOST", "localhost") +D1_PORT = int(os.environ.get("D1_PORT", "8787")) +D1_ACCOUNT_ID = "test-account" +D1_DATABASE = "test-d1" +D1_API_TOKEN = "test-token" +os.environ["D1_API_BASE_URL"] = f"http://{D1_HOST}:{D1_PORT}" + + +def d1_available() -> bool: + """Check if D1 (miniflare) is available.""" + return is_port_open(D1_HOST, D1_PORT) + + +@pytest.fixture(scope="session") +def d1_server_ready() -> bool: + """Check if D1 is ready and return True/False.""" + if not d1_available(): + return False + time.sleep(1) + return True + + +@pytest.fixture(scope="function") +def d1_db(d1_server_ready: bool) -> str: + """Set up D1 test database.""" + if not d1_server_ready: + pytest.skip("D1 (miniflare) is not available") + + from sqlit.db.adapters.d1 import D1Adapter + + adapter = D1Adapter() + config = { + "name": "d1-temp-setup", + "db_type": "d1", + "server": D1_ACCOUNT_ID, + "password": D1_API_TOKEN, + "database": D1_DATABASE, + } + from sqlit.config import ConnectionConfig + + conn_config = ConnectionConfig(**config) + try: + conn = adapter.connect(conn_config) + + adapter.execute_non_query(conn, "DROP TABLE IF EXISTS test_users") + adapter.execute_non_query(conn, "DROP TABLE IF EXISTS test_products") + adapter.execute_non_query(conn, "DROP VIEW IF EXISTS test_user_emails") + + adapter.execute_non_query( + conn, + """ + CREATE TABLE test_users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ) + """, + ) + adapter.execute_non_query( + conn, + """ + CREATE TABLE test_products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + price REAL NOT NULL, + stock INTEGER DEFAULT 0 + ) + """, + ) + adapter.execute_non_query( + conn, + """ + CREATE VIEW test_user_emails AS + SELECT id, name, email FROM test_users WHERE email IS NOT NULL + """, + ) + adapter.execute_non_query( + conn, + "INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com')", + ) + adapter.execute_non_query( + conn, + "INSERT INTO test_users (id, name, email) VALUES (2, 'Bob', 'bob@example.com')", + ) + adapter.execute_non_query( + conn, + "INSERT INTO test_users (id, name, email) VALUES (3, 'Charlie', 'charlie@example.com')", + ) + adapter.execute_non_query( + conn, + "INSERT INTO test_products (id, name, price, stock) VALUES (1, 'Widget', 9.99, 100)", + ) + adapter.execute_non_query( + conn, + "INSERT INTO test_products (id, name, price, stock) VALUES (2, 'Gadget', 19.99, 50)", + ) + adapter.execute_non_query( + conn, + "INSERT INTO test_products (id, name, price, stock) VALUES (3, 'Gizmo', 29.99, 25)", + ) + except Exception as e: + pytest.skip(f"Failed to setup D1 database: {e}") + + yield D1_DATABASE + + +@pytest.fixture(scope="function") +def d1_connection(d1_db: str) -> str: + """Create a sqlit CLI connection for D1 and clean up after test.""" + connection_name = f"test_d1_{os.getpid()}" + cleanup_connection(connection_name) + + run_cli( + "connections", + "add", + "d1", + "--name", + connection_name, + "--host", + D1_ACCOUNT_ID, + "--password", + D1_API_TOKEN, + "--database", + d1_db, + ) + + yield connection_name cleanup_connection(connection_name) @@ -1365,13 +1450,9 @@ def ssh_postgres_db(ssh_server_ready: bool) -> str: except ImportError: pytest.skip("psycopg2 is not installed") - # Connect directly to PostgreSQL to set up test data # postgres-ssh container is accessible on port 5433 pg_host = os.environ.get("SSH_DIRECT_PG_HOST", "localhost") pg_port = int(os.environ.get("SSH_DIRECT_PG_PORT", "5433")) - pg_user = POSTGRES_USER - pg_password = POSTGRES_PASSWORD - pg_database = POSTGRES_DATABASE try: conn = psycopg2.connect( @@ -1385,12 +1466,10 @@ def ssh_postgres_db(ssh_server_ready: bool) -> str: conn.autocommit = True cursor = conn.cursor() - # Drop tables if they exist and recreate cursor.execute("DROP TABLE IF EXISTS test_users CASCADE") cursor.execute("DROP TABLE IF EXISTS test_products CASCADE") cursor.execute("DROP VIEW IF EXISTS test_user_emails") - # Create test tables cursor.execute(""" CREATE TABLE test_users ( id INTEGER PRIMARY KEY, @@ -1408,13 +1487,11 @@ def ssh_postgres_db(ssh_server_ready: bool) -> str: ) """) - # Create test view cursor.execute(""" CREATE VIEW test_user_emails AS SELECT id, name, email FROM test_users WHERE email IS NOT NULL """) - # Insert test data cursor.execute(""" INSERT INTO test_users (id, name, email) VALUES (1, 'Alice', 'alice@example.com'), @@ -1436,7 +1513,6 @@ def ssh_postgres_db(ssh_server_ready: bool) -> str: yield POSTGRES_DATABASE - # Cleanup: drop test tables try: conn = psycopg2.connect( host=pg_host, @@ -1461,30 +1537,39 @@ def ssh_connection(ssh_postgres_db: str) -> str: """Create a sqlit CLI connection for PostgreSQL via SSH tunnel.""" connection_name = f"test_ssh_{os.getpid()}" - # Clean up any existing connection with this name cleanup_connection(connection_name) - # Create the connection with SSH tunnel enabled run_cli( - "connection", "create", - "--name", connection_name, - "--db-type", "postgresql", - "--server", SSH_REMOTE_DB_HOST, - "--port", str(SSH_REMOTE_DB_PORT), - "--database", ssh_postgres_db, - "--username", POSTGRES_USER, - "--password", POSTGRES_PASSWORD, + "connections", + "add", + "postgresql", + "--name", + connection_name, + "--server", + SSH_REMOTE_DB_HOST, + "--port", + str(SSH_REMOTE_DB_PORT), + "--database", + ssh_postgres_db, + "--username", + POSTGRES_USER, + "--password", + POSTGRES_PASSWORD, "--ssh-enabled", - "--ssh-host", SSH_HOST, - "--ssh-port", str(SSH_PORT), - "--ssh-username", SSH_USER, - "--ssh-auth-type", "password", - "--ssh-password", SSH_PASSWORD, + "--ssh-host", + SSH_HOST, + "--ssh-port", + str(SSH_PORT), + "--ssh-username", + SSH_USER, + "--ssh-auth-type", + "password", + "--ssh-password", + SSH_PASSWORD, ) yield connection_name - # Cleanup cleanup_connection(connection_name) diff --git a/tests/fixtures/d1/Dockerfile b/tests/fixtures/d1/Dockerfile new file mode 100644 index 00000000..5b870c3e --- /dev/null +++ b/tests/fixtures/d1/Dockerfile @@ -0,0 +1,13 @@ +FROM node:20-slim + +WORKDIR /app + +# Install wrangler, which includes miniflare v3 +RUN npm install -g wrangler + +# Expose the default wrangler dev port +EXPOSE 8787 + +# Command to run wrangler dev. This will automatically pick up the wrangler.toml +# from the working directory and start the local development server. +CMD ["wrangler", "dev"] diff --git a/tests/fixtures/d1/index.js b/tests/fixtures/d1/index.js new file mode 100644 index 00000000..b8e6a560 --- /dev/null +++ b/tests/fixtures/d1/index.js @@ -0,0 +1,72 @@ +export default { + async fetch(request, env) { + const url = new URL(request.url); + const path = url.pathname.replace(/\/+$/, ""); + + const listMatch = path.match(/^\/client\/v4\/accounts\/([^/]+)\/d1\/database$/); + if (listMatch && request.method === "GET") { + return Response.json({ + success: true, + errors: [], + messages: [], + result: [{ name: "test-d1", uuid: "test-d1" }], + }); + } + + const execMatch = path.match( + /^\/client\/v4\/accounts\/([^/]+)\/d1\/database\/([^/]+)\/execute$/ + ); + if (execMatch && request.method === "POST") { + let body; + try { + body = await request.json(); + } catch { + return Response.json( + { success: false, errors: [{ message: "Invalid JSON" }], result: null }, + { status: 400 } + ); + } + + const sql = typeof body?.sql === "string" ? body.sql : ""; + if (!sql) { + return Response.json( + { success: false, errors: [{ message: "Missing sql" }], result: null }, + { status: 400 } + ); + } + + try { + const statement = env.DB.prepare(sql); + const isRead = + /^\s*(select|pragma|with|explain)\b/i.test(sql) || + /^\s*values\b/i.test(sql); + + const result = isRead ? await statement.all() : await statement.run(); + + return Response.json({ + success: true, + errors: [], + messages: [], + result: [ + { + results: Array.isArray(result?.results) ? result.results : [], + meta: typeof result?.meta === "object" && result.meta ? result.meta : {}, + }, + ], + }); + } catch (e) { + return Response.json( + { + success: false, + errors: [{ message: String(e?.message || e) }], + messages: [], + result: null, + }, + { status: 400 } + ); + } + } + + return new Response("Hello from the test worker!"); + }, +}; diff --git a/tests/fixtures/d1/wrangler.toml b/tests/fixtures/d1/wrangler.toml new file mode 100644 index 00000000..c6d62bff --- /dev/null +++ b/tests/fixtures/d1/wrangler.toml @@ -0,0 +1,10 @@ +# wrangler.toml +# This file is used by miniflare to configure the D1 database for integration tests. + +name = "test-worker" +main = "index.js" +compatibility_date = "2023-10-30" + +d1_databases = [ + { binding = "DB", database_name = "test-d1", database_id = "test-d1" } +] diff --git a/tests/integration/drivers/README.md b/tests/integration/drivers/README.md new file mode 100644 index 00000000..2796f9ff --- /dev/null +++ b/tests/integration/drivers/README.md @@ -0,0 +1,36 @@ +# ODBC Driver Installation Tests + +Validates that `sqlit/drivers.py` installation commands work on real Linux distributions. + +## Run + +```bash +./run_tests.sh # all distros +./run_tests.sh ubuntu # single distro (ubuntu|debian|rocky|fedora|alpine|opensuse|arch) +./run_ui_tests.sh # UI screenshots (svg/png) for driver setup flow +./run_tests.sh --clean # remove containers/images +``` + +## What It Does + +1. Spins up SQL Server 2022 container +2. For each distro: runs `get_install_commands()`, executes them, verifies driver works +3. Tests actual connection to SQL Server + +## Requirements + +- Docker + Docker Compose +- ~10GB disk space +- ~10 minutes (all distros) + +## Not Covered + +macOS and Windows cannot be containerized - test manually. + +## UI Screenshots + +`run_ui_tests.sh` runs a headless Textual flow that captures screenshots into: + +- `tests/integration/drivers/artifacts/` + +It saves `.svg` screenshots and (if `rsvg-convert` is available) also writes `.png` versions. diff --git a/tests/integration/drivers/artifacts/.gitkeep b/tests/integration/drivers/artifacts/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/drivers/artifacts/mssql-01-connection.png b/tests/integration/drivers/artifacts/mssql-01-connection.png new file mode 100644 index 00000000..66321c4b Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-01-connection.png differ diff --git a/tests/integration/drivers/artifacts/mssql-01-connection.svg b/tests/integration/drivers/artifacts/mssql-01-connection.svg new file mode 100644 index 00000000..ee4f67e2 --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-01-connection.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + + + + New Connection ─────────────────────────────────────────── + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Name ───────────────────────────────────────────────── +mssql-driver-ui-flow +──────────────────────────────────────────────────────── + Database Type ──────────────────────────────────────── +SQL Server +────────────────────────────────────────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + Server ──────────────────────────────────── Port  +mssql1433───────────────────────────── +───────────────────────────────────────────────────── + Database ───────────────────────────────────────────── +master +──────────────────────────────────────────────────────── + Authentication ─────────────────────────────────────── +SQL Server Authentication +──────────────────────────────────────────────────────── + Username ──────────────── Password ──────────────── +saTestPassword123! +────────────────────────────────────────────────────── + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + + + + + + +Not connected + + + + diff --git a/tests/integration/drivers/artifacts/mssql-02-advanced.png b/tests/integration/drivers/artifacts/mssql-02-advanced.png new file mode 100644 index 00000000..cbef6002 Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-02-advanced.png differ diff --git a/tests/integration/drivers/artifacts/mssql-02-advanced.svg b/tests/integration/drivers/artifacts/mssql-02-advanced.svg new file mode 100644 index 00000000..54b36974 --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-02-advanced.svg @@ -0,0 +1,222 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + +────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + ODBC driver setup…  +───────────────────────────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected + + + + diff --git a/tests/integration/drivers/artifacts/mssql-03-driver-setup-empty.png b/tests/integration/drivers/artifacts/mssql-03-driver-setup-empty.png new file mode 100644 index 00000000..e1a6fc1d Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-03-driver-setup-empty.png differ diff --git a/tests/integration/drivers/artifacts/mssql-03-driver-setup-empty.svg b/tests/integration/drivers/artifacts/mssql-03-driver-setup-empty.svg new file mode 100644 index 00000000..1aeba36c --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-03-driver-setup-empty.svg @@ -0,0 +1,230 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + No ODBC Driver Found ─────────────────────────────────────────────────────── + +Detected OS: debian 13 +You need an ODBC driver to connect to SQL Server. + +────────────────────────────────────────────────────────────────────────── +ODBC Driver 18 for SQL Server (not installed) +ODBC Driver 17 for SQL Server (not installed) +ODBC Driver 13 for SQL Server (not installed) +──────────────────────────────────────────────────────────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔──────────────────── + +Install on Debian 13:  + +curl -sSL -O  +https://packages.microsoft.com/config/debian/13/packages-microsoft-pr +od.deb  +sudo dpkg -i packages-microsoft-prod.deb  +rm packages-microsoft-prod.deb ▄▄ +sudo apt-get update  + +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +───────────────────── Select: <enter> · Install: I · Yank: y · Cancel: <esc> + + + + +──────────────────────────────────────────────────────────── + +Not connected + + + + diff --git a/tests/integration/drivers/artifacts/mssql-04-install-message.png b/tests/integration/drivers/artifacts/mssql-04-install-message.png new file mode 100644 index 00000000..692d7bcb Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-04-install-message.png differ diff --git a/tests/integration/drivers/artifacts/mssql-04-install-message.svg b/tests/integration/drivers/artifacts/mssql-04-install-message.svg new file mode 100644 index 00000000..2649a8d0 --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-04-install-message.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + + Couldn't install automatically ───────────────────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +───────────────────────── +Couldn't install automatically, please install manually. + + +──────────────────────────────────────────────── Continue: <enter> + + + + + + + + + + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            17:46:13 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/mssql-05-back-to-connection.png b/tests/integration/drivers/artifacts/mssql-05-back-to-connection.png new file mode 100644 index 00000000..5d665c26 Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-05-back-to-connection.png differ diff --git a/tests/integration/drivers/artifacts/mssql-05-back-to-connection.svg b/tests/integration/drivers/artifacts/mssql-05-back-to-connection.svg new file mode 100644 index 00000000..b4401aa5 --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-05-back-to-connection.svg @@ -0,0 +1,227 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + ODBC driver setup… ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +───────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            17:09:29 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/mssql-05-back-to-setup.svg b/tests/integration/drivers/artifacts/mssql-05-back-to-setup.svg new file mode 100644 index 00000000..744bc20c --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-05-back-to-setup.svg @@ -0,0 +1,231 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + No ODBC Driver Found ─────────────────────────────────────────────────────── + +Detected OS: debian 13 +You need an ODBC driver to connect to SQL Server. + +────────────────────────────────────────────────────────────────────────── +ODBC Driver 18 for SQL Server (not installed) +ODBC Driver 17 for SQL Server (not installed) +ODBC Driver 13 for SQL Server (not installed) +──────────────────────────────────────────────────────────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔──────────────────── + +Install on Debian 13:  + +curl -sSL -O  +https://packages.microsoft.com/config/debian/13/packages-microsoft-pr +od.deb  +sudo dpkg -i packages-microsoft-prod.deb  +rm packages-microsoft-prod.deb ▄▄ +sudo apt-get update  + +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +───────────────────── Select: <enter> · Install: I · Yank: y · Cancel: <esc> + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            17:46:13 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/mssql-06-back-to-connection.svg b/tests/integration/drivers/artifacts/mssql-06-back-to-connection.svg new file mode 100644 index 00000000..bf90deed --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-06-back-to-connection.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + +────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + ODBC driver setup…  +───────────────────────────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            17:46:13 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/mssql-06-driver-setup-installed.png b/tests/integration/drivers/artifacts/mssql-06-driver-setup-installed.png new file mode 100644 index 00000000..b911fae0 Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-06-driver-setup-installed.png differ diff --git a/tests/integration/drivers/artifacts/mssql-06-driver-setup-installed.svg b/tests/integration/drivers/artifacts/mssql-06-driver-setup-installed.svg new file mode 100644 index 00000000..e7a5f1fc --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-06-driver-setup-installed.svg @@ -0,0 +1,225 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + Select ODBC Driver ───────────────────────────────────────────────────────── +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +Found 1 installed driver(s): +──────────────────── +────────────────────────────────────────────────────────────────────────── +ODBC Driver 18 for SQL Server +────────────────────────────────────────────────────────────────────────── + + +──────────────────────────────────────────── Select: <enter> · Cancel: <esc> + + + + + + + + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            16:59:45 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/mssql-07-driver-selected.png b/tests/integration/drivers/artifacts/mssql-07-driver-selected.png new file mode 100644 index 00000000..8791b1bf Binary files /dev/null and b/tests/integration/drivers/artifacts/mssql-07-driver-selected.png differ diff --git a/tests/integration/drivers/artifacts/mssql-07-driver-selected.svg b/tests/integration/drivers/artifacts/mssql-07-driver-selected.svg new file mode 100644 index 00000000..bd9466f1 --- /dev/null +++ b/tests/integration/drivers/artifacts/mssql-07-driver-selected.svg @@ -0,0 +1,227 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + ODBC driver setup… ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +───────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            16:59:45 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-01-advanced.png b/tests/integration/drivers/artifacts/terminal-found/mssql-01-advanced.png new file mode 100644 index 00000000..17cd2b07 Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-01-advanced.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-01-advanced.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-01-advanced.svg new file mode 100644 index 00000000..bf90deed --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-01-advanced.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + +────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + ODBC driver setup…  +───────────────────────────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            17:46:13 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-02-driver-setup-empty.png b/tests/integration/drivers/artifacts/terminal-found/mssql-02-driver-setup-empty.png new file mode 100644 index 00000000..1b36517e Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-02-driver-setup-empty.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-02-driver-setup-empty.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-02-driver-setup-empty.svg new file mode 100644 index 00000000..744bc20c --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-02-driver-setup-empty.svg @@ -0,0 +1,231 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + No ODBC Driver Found ─────────────────────────────────────────────────────── + +Detected OS: debian 13 +You need an ODBC driver to connect to SQL Server. + +────────────────────────────────────────────────────────────────────────── +ODBC Driver 18 for SQL Server (not installed) +ODBC Driver 17 for SQL Server (not installed) +ODBC Driver 13 for SQL Server (not installed) +──────────────────────────────────────────────────────────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔──────────────────── + +Install on Debian 13:  + +curl -sSL -O  +https://packages.microsoft.com/config/debian/13/packages-microsoft-pr +od.deb  +sudo dpkg -i packages-microsoft-prod.deb  +rm packages-microsoft-prod.deb ▄▄ +sudo apt-get update  + +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +───────────────────── Select: <enter> · Install: I · Yank: y · Cancel: <esc> + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            17:46:13 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-03-install-message-terminal-found.png b/tests/integration/drivers/artifacts/terminal-found/mssql-03-install-message-terminal-found.png new file mode 100644 index 00000000..50685f35 Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-03-install-message-terminal-found.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-03-install-message-terminal-found.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-03-install-message-terminal-found.svg new file mode 100644 index 00000000..bf8fee0c --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-03-install-message-terminal-found.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + Driver install ─────────────────────────────────────────────────── +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + +Installation started in a new terminal.───────────────────────── + +Please restart to apply. + + +──────────────────────────────────────────────── Continue: <enter> + + + + + + + + + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            17:46:15 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-04-back-to-connection.png b/tests/integration/drivers/artifacts/terminal-found/mssql-04-back-to-connection.png new file mode 100644 index 00000000..19a8f095 Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-04-back-to-connection.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-04-back-to-connection.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-04-back-to-connection.svg new file mode 100644 index 00000000..f1443599 --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-04-back-to-connection.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + +────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + ODBC driver setup…  +───────────────────────────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            17:46:15 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-06-driver-setup-installed.png b/tests/integration/drivers/artifacts/terminal-found/mssql-06-driver-setup-installed.png new file mode 100644 index 00000000..8d692720 Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-06-driver-setup-installed.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-06-driver-setup-installed.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-06-driver-setup-installed.svg new file mode 100644 index 00000000..4d3e8782 --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-06-driver-setup-installed.svg @@ -0,0 +1,225 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + Select ODBC Driver ───────────────────────────────────────────────────────── +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +Found 1 installed driver(s): +──────────────────── +────────────────────────────────────────────────────────────────────────── +ODBC Driver 18 for SQL Server +────────────────────────────────────────────────────────────────────────── + + +──────────────────────────────────────────── Select: <enter> · Cancel: <esc> + + + + + + + + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            17:09:30 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-selected.png b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-selected.png new file mode 100644 index 00000000..7225537a Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-selected.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-selected.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-selected.svg new file mode 100644 index 00000000..98e7cbf6 --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-selected.svg @@ -0,0 +1,227 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── +▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + ODBC driver setup… ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +───────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            17:09:30 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-setup-installed.png b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-setup-installed.png new file mode 100644 index 00000000..53022aec Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-setup-installed.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-setup-installed.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-setup-installed.svg new file mode 100644 index 00000000..b16e269b --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-07-driver-setup-installed.svg @@ -0,0 +1,225 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + Select ODBC Driver ───────────────────────────────────────────────────────── +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ +Found 1 installed driver(s): +──────────────────── +────────────────────────────────────────────────────────────────────────── +ODBC Driver 18 for SQL Server +────────────────────────────────────────────────────────────────────────── + + +────────────────────────────────── Select: <enter> · Yank: y · Cancel: <esc> + + + + + + + + + + + +──────────────────────────────────────────────────────────── + +Not connected                                            17:46:15 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-08-driver-selected.png b/tests/integration/drivers/artifacts/terminal-found/mssql-08-driver-selected.png new file mode 100644 index 00000000..19a8f095 Binary files /dev/null and b/tests/integration/drivers/artifacts/terminal-found/mssql-08-driver-selected.png differ diff --git a/tests/integration/drivers/artifacts/terminal-found/mssql-08-driver-selected.svg b/tests/integration/drivers/artifacts/terminal-found/mssql-08-driver-selected.svg new file mode 100644 index 00000000..f1443599 --- /dev/null +++ b/tests/integration/drivers/artifacts/terminal-found/mssql-08-driver-selected.svg @@ -0,0 +1,224 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + sqlit + + + + + + + + + + +[E] Explorer[q] Query + + New Connection ───────────────────────────────────────────▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ + +GeneralAdvancedSSH +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + Driver ───────────────────────────────────────────── +ODBC Driver 18 for SQL Server +ODBC Driver 17 for SQL Server +ODBC Driver 13 for SQL Server +ODBC Driver 11 for SQL Server +SQL Server Native Client 11.0 +SQL Server +────────────────────────────────────────────────────── + +────────────────────▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + ODBC driver setup…  +───────────────────────────────────────────────── + + + + + + + + + + + + + + + + + +────────────────────── Test: ^T · Save: ^S · Cancel: <esc> + +Not connected                                            17:46:15 Installing driver... This may ask for your password. + + + + diff --git a/tests/integration/drivers/docker-compose.ui.yml b/tests/integration/drivers/docker-compose.ui.yml new file mode 100644 index 00000000..a5e3002c --- /dev/null +++ b/tests/integration/drivers/docker-compose.ui.yml @@ -0,0 +1,10 @@ +services: + ui-test: + build: + context: ../../.. + dockerfile: tests/integration/drivers/dockerfiles/ui.Dockerfile + container_name: sqlit-test-odbc-driver-ui + environment: + SQLIT_TEST_SCREENSHOTS_DIR: /artifacts + volumes: + - ./artifacts:/artifacts diff --git a/tests/integration/drivers/dockerfiles/alpine.Dockerfile b/tests/integration/drivers/dockerfiles/alpine.Dockerfile new file mode 100644 index 00000000..fb4a8f78 --- /dev/null +++ b/tests/integration/drivers/dockerfiles/alpine.Dockerfile @@ -0,0 +1,33 @@ +FROM alpine:3.20 + +# Install Python and basic tools +RUN apk add --no-cache \ + python3 \ + py3-pip \ + curl \ + bash \ + unixodbc \ + unixodbc-dev \ + g++ \ + python3-dev + +# Create virtual environment +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install Python dependencies +RUN pip install --upgrade pip && \ + pip install pyodbc + +# Copy the sqlit package +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ + +# Install sqlit in development mode +RUN pip install -e . + +# Copy test script +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +CMD ["python3", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/arch.Dockerfile b/tests/integration/drivers/dockerfiles/arch.Dockerfile new file mode 100644 index 00000000..b3c94269 --- /dev/null +++ b/tests/integration/drivers/dockerfiles/arch.Dockerfile @@ -0,0 +1,32 @@ +FROM archlinux:latest + +RUN pacman -Syu --noconfirm && pacman -S --noconfirm \ + python \ + python-pip \ + python-virtualenv \ + curl \ + base-devel \ + git \ + unixodbc + +RUN useradd -m builder && echo "builder ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers + +USER builder +WORKDIR /home/builder +RUN git clone https://aur.archlinux.org/yay.git && cd yay && makepkg -si --noconfirm && cd .. && rm -rf yay + +USER root +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --upgrade pip && pip install pyodbc + +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ +RUN pip install -e . + +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +USER builder +CMD ["sudo", "-E", "/opt/venv/bin/python", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/debian.Dockerfile b/tests/integration/drivers/dockerfiles/debian.Dockerfile new file mode 100644 index 00000000..f4773cf7 --- /dev/null +++ b/tests/integration/drivers/dockerfiles/debian.Dockerfile @@ -0,0 +1,27 @@ +FROM debian:12 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + python3-venv \ + curl \ + gnupg \ + unixodbc-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --upgrade pip && pip install pyodbc + +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ + +RUN pip install -e . + +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +CMD ["python3", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/fedora.Dockerfile b/tests/integration/drivers/dockerfiles/fedora.Dockerfile new file mode 100644 index 00000000..a395cebd --- /dev/null +++ b/tests/integration/drivers/dockerfiles/fedora.Dockerfile @@ -0,0 +1,25 @@ +FROM fedora:40 + +RUN dnf install -y \ + python3 \ + python3-pip \ + curl \ + unixODBC-devel \ + gcc \ + python3-devel \ + && dnf clean all + +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --upgrade pip && pip install pyodbc + +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ + +RUN pip install -e . + +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +CMD ["python3", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/opensuse.Dockerfile b/tests/integration/drivers/dockerfiles/opensuse.Dockerfile new file mode 100644 index 00000000..26b1cf5a --- /dev/null +++ b/tests/integration/drivers/dockerfiles/opensuse.Dockerfile @@ -0,0 +1,25 @@ +FROM opensuse/leap:15 + +RUN zypper refresh && zypper install -y \ + python311 \ + python311-pip \ + python311-devel \ + curl \ + unixODBC-devel \ + gcc \ + && zypper clean -a + +RUN python3.11 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --upgrade pip && pip install pyodbc + +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ + +RUN pip install -e . + +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +CMD ["python3", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/rocky.Dockerfile b/tests/integration/drivers/dockerfiles/rocky.Dockerfile new file mode 100644 index 00000000..cd1f3de1 --- /dev/null +++ b/tests/integration/drivers/dockerfiles/rocky.Dockerfile @@ -0,0 +1,26 @@ +FROM rockylinux:9 + +RUN dnf install -y epel-release && \ + dnf config-manager --set-enabled crb && \ + dnf install -y \ + python3.11 \ + python3.11-pip \ + python3.11-devel \ + unixODBC-devel \ + gcc \ + && dnf clean all + +RUN python3.11 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --upgrade pip && pip install pyodbc + +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ + +RUN pip install -e . + +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +CMD ["python3", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/ubuntu.Dockerfile b/tests/integration/drivers/dockerfiles/ubuntu.Dockerfile new file mode 100644 index 00000000..444c952e --- /dev/null +++ b/tests/integration/drivers/dockerfiles/ubuntu.Dockerfile @@ -0,0 +1,27 @@ +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + python3-venv \ + curl \ + gnupg \ + unixodbc-dev \ + && rm -rf /var/lib/apt/lists/* + +RUN python3 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +RUN pip install --upgrade pip && pip install pyodbc + +WORKDIR /app +COPY sqlit/ /app/sqlit/ +COPY pyproject.toml README.md /app/ + +RUN pip install -e . + +COPY tests/integration/drivers/test_driver_install.py /app/test_driver_install.py + +CMD ["python3", "/app/test_driver_install.py"] diff --git a/tests/integration/drivers/dockerfiles/ui.Dockerfile b/tests/integration/drivers/dockerfiles/ui.Dockerfile new file mode 100644 index 00000000..20e03121 --- /dev/null +++ b/tests/integration/drivers/dockerfiles/ui.Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.12-slim + +WORKDIR /app + +COPY . /app + +RUN apt-get update \ + && apt-get install -y --no-install-recommends librsvg2-bin \ + && rm -rf /var/lib/apt/lists/* + +RUN python -m pip install --upgrade pip \ + && pip install -e ".[test]" + +CMD ["python", "tests/integration/drivers/test_driver_setup_ui_flow.py"] diff --git a/tests/integration/drivers/run_tests.sh b/tests/integration/drivers/run_tests.sh new file mode 100755 index 00000000..c48e8d6b --- /dev/null +++ b/tests/integration/drivers/run_tests.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -e + +cd "$(dirname "$0")" + +start_mssql() { + echo "Starting SQL Server..." + docker-compose up -d mssql + + echo "Waiting for SQL Server to be healthy..." + timeout=120 + while [ $timeout -gt 0 ]; do + if docker-compose ps mssql | grep -q "healthy"; then + return 0 + fi + echo " Waiting... ($timeout seconds remaining)" + sleep 5 + timeout=$((timeout - 5)) + done + + echo "ERROR: SQL Server failed to start" + docker-compose logs mssql + exit 1 +} + +case "${1:-all}" in + --clean) + docker-compose down -v --rmi local 2>/dev/null || true + ;; + all) + start_mssql + docker-compose up --build \ + test-ubuntu \ + test-debian \ + test-rocky \ + test-fedora \ + test-alpine \ + test-opensuse \ + test-arch + ;; + ubuntu|debian|rocky|fedora|alpine|opensuse|arch) + start_mssql + docker-compose up --build "test-$1" + ;; + *) + echo "Usage: $0 [all|ubuntu|debian|rocky|fedora|alpine|opensuse|arch|--clean]" + exit 1 + ;; +esac diff --git a/tests/integration/drivers/run_ui_tests.sh b/tests/integration/drivers/run_ui_tests.sh new file mode 100755 index 00000000..84189781 --- /dev/null +++ b/tests/integration/drivers/run_ui_tests.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -euo pipefail + +cd "$(dirname "$0")" + +docker compose -f docker-compose.ui.yml up --build --abort-on-container-exit --exit-code-from ui-test ui-test diff --git a/tests/integration/drivers/test_driver_install.py b/tests/integration/drivers/test_driver_install.py new file mode 100644 index 00000000..ee3ce409 --- /dev/null +++ b/tests/integration/drivers/test_driver_install.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +"""Integration test for ODBC driver installation.""" + +from __future__ import annotations + +import importlib +import os +import subprocess +import sys +import time + + +def log(message: str, level: str = "INFO") -> None: + distro = os.environ.get("DISTRO_NAME", "unknown") + print(f"[{distro}] [{level}] {message}", flush=True) + + +def run_command(command: str) -> tuple[int, str, str]: + log(f"Running: {command}") + result = subprocess.run(command, shell=True, capture_output=True, text=True) + if result.stdout: + for line in result.stdout.strip().split("\n"): + log(f" stdout: {line}") + if result.stderr: + for line in result.stderr.strip().split("\n"): + log(f" stderr: {line}") + return result.returncode, result.stdout, result.stderr + + +def test_no_driver_initially() -> bool: + log("Step 1: Checking that no ODBC driver is installed initially...") + + from sqlit.drivers import get_installed_drivers + + drivers = get_installed_drivers() + if drivers: + log(f"Unexpected: Found drivers already installed: {drivers}", "WARN") + return True + + log("Confirmed: No SQL Server ODBC drivers found initially") + return True + + +def test_get_install_commands() -> list[str] | None: + log("Step 2: Getting installation commands for this OS...") + + from sqlit.drivers import get_install_commands, get_os_info + + os_type, os_version = get_os_info() + log(f"Detected OS: {os_type} {os_version}") + + install_cmd = get_install_commands() + if not install_cmd: + log(f"No installation commands available for {os_type}", "ERROR") + return None + + log(f"Installation method: {install_cmd.description}") + for warning in install_cmd.warnings: + log(warning, "WARN") + + log(f"Commands to execute ({len(install_cmd.commands)}):") + for i, cmd in enumerate(install_cmd.commands, 1): + log(f" {i}. {cmd}") + + return install_cmd.commands + + +def test_execute_install_commands(commands: list[str]) -> bool: + log("Step 3: Executing installation commands...") + + for i, command in enumerate(commands, 1): + # Strip sudo since we're running as root in the container + if command.startswith("sudo "): + command = command[5:] + command = command.replace(" sudo ", " ") + + # Make AUR helpers non-interactive for testing + # Also run as non-root user since yay doesn't like running as root + if command.startswith("yay -S "): + command = command.replace("yay -S ", "yay -S --noconfirm ") + command = f"su - builder -c '{command}'" + + log(f"Executing command {i}/{len(commands)}") + exit_code, _, _ = run_command(command) + + if exit_code != 0: + if "|| true" in command or "2>/dev/null" in command: + log(f"Command {i} failed but was optional, continuing...") + continue + log(f"Command {i} failed with exit code {exit_code}", "ERROR") + return False + + log("All installation commands completed successfully") + return True + + +def test_driver_installed() -> str | None: + log("Step 4: Verifying driver is now installed...") + + import sqlit.drivers + + importlib.reload(sqlit.drivers) + + from sqlit.drivers import get_best_driver, get_installed_drivers + + drivers = get_installed_drivers() + if not drivers: + log("No ODBC drivers found after installation", "ERROR") + return None + + log(f"Found installed drivers: {drivers}") + best = get_best_driver() + log(f"Best available driver: {best}") + return best + + +def test_connection(driver: str) -> bool: + log("Step 5: Testing connection to SQL Server...") + + host = os.environ.get("MSSQL_HOST", "localhost") + port = os.environ.get("MSSQL_PORT", "1433") + user = os.environ.get("MSSQL_USER", "sa") + password = os.environ.get("MSSQL_PASSWORD") + + if not password: + log("MSSQL_PASSWORD environment variable not set", "ERROR") + return False + + log(f"Connecting to {host}:{port} as {user}...") + + try: + import pyodbc + + conn_str = ( + f"DRIVER={{{driver}}};" + f"SERVER={host},{port};" + f"UID={user};" + f"PWD={password};" + f"TrustServerCertificate=yes;" + ) + + max_retries = 5 + for attempt in range(1, max_retries + 1): + try: + log(f"Connection attempt {attempt}/{max_retries}...") + conn = pyodbc.connect(conn_str, timeout=10) + cursor = conn.cursor() + cursor.execute("SELECT @@VERSION") + version = cursor.fetchone()[0] + log("Connected successfully!") + log(f"SQL Server version: {version[:80]}...") + cursor.close() + conn.close() + return True + except pyodbc.Error as e: + log(f"Connection attempt {attempt} failed: {e}", "WARN") + if attempt < max_retries: + time.sleep(5) + + log("All connection attempts failed", "ERROR") + return False + + except Exception as e: + log(f"Connection test failed with exception: {e}", "ERROR") + return False + + +def main() -> int: + log("=" * 60) + log("ODBC Driver Installation Integration Test") + log("=" * 60) + + distro = os.environ.get("DISTRO_NAME", "unknown") + log(f"Testing on: {distro}") + + if not test_no_driver_initially(): + return 1 + + commands = test_get_install_commands() + if commands is None: + return 1 + + if not test_execute_install_commands(commands): + return 1 + + driver = test_driver_installed() + if driver is None: + return 1 + + if not test_connection(driver): + return 1 + + log("=" * 60) + log("ALL TESTS PASSED", "SUCCESS") + log("=" * 60) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/integration/drivers/test_driver_setup_ui_flow.py b/tests/integration/drivers/test_driver_setup_ui_flow.py new file mode 100644 index 00000000..cf6a41a1 --- /dev/null +++ b/tests/integration/drivers/test_driver_setup_ui_flow.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import os +import subprocess +import tempfile +import time +from pathlib import Path + + +def _clean_screenshots_dir(outdir: Path) -> None: + resolved = outdir.resolve() + if resolved == Path("/"): + raise AssertionError("Refusing to clean screenshots in '/'") + if not outdir.exists(): + return + for path in outdir.rglob("*"): + if path.is_file() and path.suffix.lower() in (".svg", ".png"): + path.unlink(missing_ok=True) + + +def _maybe_screenshot(app, name: str) -> None: + outdir = os.environ.get("SQLIT_TEST_SCREENSHOTS_DIR") + if not outdir: + return + Path(outdir).mkdir(parents=True, exist_ok=True) + safe = "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in name) + app.save_screenshot(path=outdir, filename=f"{safe}.svg") + + +async def _wait_for(pilot, predicate, timeout_s: float, label: str) -> None: + start = time.monotonic() + while time.monotonic() - start < timeout_s: + if predicate(): + return + await pilot.pause(0.1) + app = getattr(pilot, "app", None) + screen_name = getattr(getattr(app, "screen", None), "__class__", type("x", (), {})).__name__ if app else "unknown" + raise AssertionError(f"Timed out waiting for: {label} (current screen: {screen_name})") + + +async def main() -> None: + os.environ.setdefault("SQLIT_CONFIG_DIR", tempfile.mkdtemp(prefix="sqlit-test-config-")) + outdir = os.environ.get("SQLIT_TEST_SCREENSHOTS_DIR") + if outdir: + _clean_screenshots_dir(Path(outdir)) + + import sqlit.terminal as terminal_module + import sqlit.ui.screens.connection as connection_screen_module + from sqlit.app import SSMSTUI + from sqlit.config import ConnectionConfig + from sqlit.ui.screens.connection import ConnectionScreen + + class _DummyAdapter: + def ensure_driver_available(self) -> None: + return + + config = ConnectionConfig( + name="mssql-driver-ui-flow", + db_type="mssql", + server="mssql", + port="1433", + database="master", + username="sa", + password="TestPassword123!", + ) + + app = SSMSTUI() + async with app.run_test(size=(120, 40)) as pilot: + # Avoid requiring pyodbc in this UI-only test. + connection_screen_module.get_adapter = lambda _db_type: _DummyAdapter() # type: ignore[assignment] + + app.push_screen(ConnectionScreen(config)) + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-01-connection") + + screen = app.screen + tabs = screen.query_one("#connection-tabs") + tabs.active = "tab-advanced" + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-02-advanced") + + # Open driver setup (no drivers installed). + screen._open_odbc_driver_setup(installed_drivers=[]) + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "DriverSetupScreen", 5, "DriverSetupScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-03-driver-setup-empty") + + # Trigger "Install" to show post-action message (run_in_terminal will likely fail in headless env). + await pilot.press("i") + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "MessageScreen", 5, "MessageScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-04-install-message") + await pilot.press("enter") + + # After acknowledging, we return to the original setup screen (manual instructions remain there). + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "DriverSetupScreen", 5, "DriverSetupScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-05-back-to-setup") + + await pilot.press("escape") + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "ConnectionScreen", 5, "ConnectionScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-06-back-to-connection") + + # Simulate terminal found (successful run_in_terminal) and capture the success-path UI. + os.environ["SQLIT_TEST_SCREENSHOTS_DIR"] = str( + Path(os.environ["SQLIT_TEST_SCREENSHOTS_DIR"]) / "terminal-found" + ) + + terminal_module.run_in_terminal = ( # type: ignore[assignment] + lambda _commands, wait_message="Press Enter to close...": terminal_module.TerminalResult( + success=True, terminal=terminal_module.TerminalType.XTERM, error=None + ) + ) + + screen = app.screen + tabs = screen.query_one("#connection-tabs") + tabs.active = "tab-advanced" + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-01-advanced") + + screen._open_odbc_driver_setup(installed_drivers=[]) + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "DriverSetupScreen", 5, "DriverSetupScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-02-driver-setup-empty") + + await pilot.press("i") + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "MessageScreen", 5, "MessageScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-03-install-message-terminal-found") + await pilot.press("enter") + + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "ConnectionScreen", 5, "ConnectionScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-04-back-to-connection") + + # Open driver setup (drivers present) and select one. + screen = app.screen + tabs = screen.query_one("#connection-tabs") + tabs.active = "tab-advanced" + await pilot.pause(0.2) + + screen._open_odbc_driver_setup(installed_drivers=["ODBC Driver 18 for SQL Server"]) + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "DriverSetupScreen", 5, "DriverSetupScreen") + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-07-driver-setup-installed") + + await pilot.press("enter") + await _wait_for(pilot, lambda: app.screen.__class__.__name__ == "ConnectionScreen", 5, "ConnectionScreen") + + # Verify the driver field reflects the selection. + screen = app.screen + tabs = screen.query_one("#connection-tabs") + tabs.active = "tab-advanced" + await pilot.pause(0.2) + _maybe_screenshot(app, "mssql-08-driver-selected") + + # Convert SVGs to PNGs inside the container (avoids host permission issues). + outdir = os.environ.get("SQLIT_TEST_SCREENSHOTS_DIR") + if outdir: + subprocess.run( + [ + "bash", + "-lc", + f"command -v rsvg-convert >/dev/null 2>&1 && " + f"find {outdir!s} -name '*.svg' -print0 | " + 'xargs -0 -I{} bash -lc \'rsvg-convert "$1" -o "${1%.svg}.png"\' _ {}', + ], + check=False, + ) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/tests/integration/python_packages/README.md b/tests/integration/python_packages/README.md new file mode 100644 index 00000000..b890966f --- /dev/null +++ b/tests/integration/python_packages/README.md @@ -0,0 +1,26 @@ +# Python Driver Install Flow Tests + +Validates the end-user flow for missing Python DB drivers: + +1. User attempts to save a new connection +2. App prompts to install the missing driver +3. App shows a loading screen while installing +4. App shows success message, or failure with manual install instructions + +## Run + +```bash +./run_tests.sh +``` + +## Screenshots + +This integration test can export Textual screenshots (SVG) from inside the container to the host: + +- Output directory: `tests/integration/python_packages/artifacts/` +- Enable via `SQLIT_TEST_SCREENSHOTS_DIR=/artifacts` (already set in `tests/integration/python_packages/docker-compose.yml`) + +## Requirements + +- Docker + Docker Compose +- Network access (to download driver wheels) diff --git a/tests/integration/python_packages/dockerfiles/debian.Dockerfile b/tests/integration/python_packages/dockerfiles/debian.Dockerfile new file mode 100644 index 00000000..9c59a830 --- /dev/null +++ b/tests/integration/python_packages/dockerfiles/debian.Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.12-slim + +WORKDIR /app + +COPY . /app + +RUN python -m pip install --upgrade pip \ + && pip install -e ".[test]" + +ENV SQLIT_CONFIG_DIR=/tmp/sqlit-config +ENV PYTHONUNBUFFERED=1 + +CMD ["python", "-u", "/app/tests/integration/python_packages/test_package_install_flow.py"] diff --git a/tests/integration/python_packages/run_tests.sh b/tests/integration/python_packages/run_tests.sh new file mode 100755 index 00000000..0b52234c --- /dev/null +++ b/tests/integration/python_packages/run_tests.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -euo pipefail + +cd "$(dirname "$0")" + +docker compose up --build --abort-on-container-exit --exit-code-from test-debian test-debian diff --git a/tests/integration/python_packages/test_package_install_flow.py b/tests/integration/python_packages/test_package_install_flow.py new file mode 100644 index 00000000..20735bf6 --- /dev/null +++ b/tests/integration/python_packages/test_package_install_flow.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import asyncio +import importlib +import os +import tempfile +import time +from pathlib import Path + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[3] + + +def _clean_screenshots_dir(outdir: Path) -> None: + resolved = outdir.resolve() + if resolved == Path("/"): + raise AssertionError("Refusing to clean screenshots in '/'") + if not outdir.exists(): + return + for path in outdir.rglob("*"): + if path.is_file() and path.suffix.lower() in (".svg", ".png"): + path.unlink(missing_ok=True) + + +def _maybe_screenshot(app, name: str) -> None: + outdir = os.environ.get("SQLIT_TEST_SCREENSHOTS_DIR") + if not outdir: + return + Path(outdir).mkdir(parents=True, exist_ok=True) + safe = "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in name) + app.save_screenshot(path=outdir, filename=f"{safe}.svg") + + +def _assert_missing(module_name: str) -> None: + try: + importlib.import_module(module_name) + except Exception: + return + raise AssertionError(f"Expected {module_name} to be missing, but it imported successfully") + + +def _assert_present(module_name: str) -> None: + try: + importlib.import_module(module_name) + except Exception as e: + raise AssertionError(f"Expected {module_name} to be importable, but got: {e}") from e + + +async def _wait_for(pilot, predicate, timeout_s: float, label: str) -> None: + start = time.monotonic() + while time.monotonic() - start < timeout_s: + if predicate(): + return + await pilot.pause(0.1) + app = getattr(pilot, "app", None) + screen_name = getattr(getattr(app, "screen", None), "__class__", type("x", (), {})).__name__ if app else "unknown" + raise AssertionError(f"Timed out waiting for: {label} (current screen: {screen_name})") + + +async def _run_flow(*, force_fail: bool, db_type: str) -> None: + os.environ.setdefault("SQLIT_CONFIG_DIR", tempfile.mkdtemp(prefix="sqlit-test-config-")) + os.environ["SQLIT_INSTALL_PROJECT_ROOT"] = str(_repo_root()) + os.environ["SQLIT_DISABLE_RESTART"] = "1" + + if force_fail: + os.environ["SQLIT_INSTALL_FORCE_FAIL"] = "1" + else: + os.environ.pop("SQLIT_INSTALL_FORCE_FAIL", None) + + from sqlit.app import SSMSTUI + from sqlit.config import ConnectionConfig + from sqlit.ui.screens.connection import ConnectionScreen + + if db_type == "postgresql": + config = ConnectionConfig( + name="pg-install-flow", + db_type="postgresql", + server="localhost", + port="5432", + database="postgres", + username="test", + password="test", + ) + expected_manual = 'pip install "sqlit-tui[postgres]"' + elif db_type == "mysql": + config = ConnectionConfig( + name="mysql-install-flow", + db_type="mysql", + server="localhost", + port="3306", + database="test_sqlit", + username="test", + password="test", + ) + expected_manual = 'pip install "sqlit-tui[mysql]"' + else: + raise AssertionError(f"Unsupported db_type for test: {db_type}") + + app = SSMSTUI() + async with app.run_test(size=(120, 40)) as pilot: + app.push_screen(ConnectionScreen(config)) + await pilot.pause(0.2) + _maybe_screenshot(app, f"{db_type}-01-connection") + + # Attempt to save should show a confirmation dialog + app.screen.action_save() + await _wait_for( + pilot, + lambda: app.screen.__class__.__name__ == "ConfirmScreen", + timeout_s=5, + label="ConfirmScreen", + ) + # Give Textual a render tick so screenshots capture the modal contents + await pilot.pause(0.2) + _maybe_screenshot(app, f"{db_type}-02-confirm") + await pilot.press("y") + + # Return to the ConnectionScreen and show an in-dialog loading indicator. + await _wait_for( + pilot, + lambda: app.screen.__class__.__name__ == "ConnectionScreen", + timeout_s=5, + label="ConnectionScreen (after confirm)", + ) + + await _wait_for( + pilot, + lambda: app.screen.__class__.__name__ in ("MessageScreen", "ConnectionScreen"), + timeout_s=180, + label="MessageScreen or ConnectionScreen", + ) + + # If failure, we get a MessageScreen and then return to the connection screen. + if force_fail: + assert app.screen.__class__.__name__ == "MessageScreen" + await pilot.pause(0.2) + _maybe_screenshot(app, f"{db_type}-04-result") + await pilot.press("enter") + await _wait_for( + pilot, + lambda: app.screen.__class__.__name__ == "ConnectionScreen", + timeout_s=5, + label="ConnectionScreen (after failure)", + ) + await pilot.pause(0.2) + _maybe_screenshot(app, f"{db_type}-05-back-to-setup") + + from textual.widgets import Static + + text = str(app.screen.query_one("#test-status", Static).content) + if expected_manual not in text: + raise AssertionError(f"Expected manual install hint in connection screen, got:\n{text}") + else: + assert app.screen.__class__.__name__ == "ConnectionScreen" + await pilot.pause(0.2) + _maybe_screenshot(app, f"{db_type}-04-result") + + os.environ.pop("SQLIT_INSTALL_FORCE_FAIL", None) + os.environ.pop("SQLIT_DISABLE_RESTART", None) + + +async def main() -> None: + outdir = os.environ.get("SQLIT_TEST_SCREENSHOTS_DIR") + if outdir: + _clean_screenshots_dir(Path(outdir)) + + # Success path: install missing psycopg2 + _assert_missing("psycopg2") + await _run_flow(force_fail=False, db_type="postgresql") + _assert_present("psycopg2") + + # Failure path: forced install failure yields manual instructions + _assert_missing("mysql.connector") + await _run_flow(force_fail=True, db_type="mysql") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/integration/test_database_browsing_flow.py b/tests/integration/test_database_browsing_flow.py new file mode 100644 index 00000000..ca6e1a9e --- /dev/null +++ b/tests/integration/test_database_browsing_flow.py @@ -0,0 +1,496 @@ +"""Integration tests for browsing all databases without pre-selecting one. + +This test module verifies that when connecting with an empty database field: +1. The connection succeeds +2. All databases are visible in the explorer tree +3. Clicking on a database expands to show Tables/Views folders +4. Clicking on Tables shows the tables +5. Queries can be executed successfully (where supported) + +Applicable providers: MySQL, PostgreSQL, MSSQL, MariaDB, CockroachDB +(All providers that support multiple databases) + +Note: PostgreSQL and CockroachDB don't support cross-database queries. +When connected without a database, they connect to a default database (postgres/defaultdb) +and can only query tables in that database. For these providers, we test tree +navigation only. +""" + +from __future__ import annotations + +import os +import tempfile +import time +from typing import Any + +import pytest + +from sqlit.app import SSMSTUI +from sqlit.config import ConnectionConfig +from sqlit.ui.tree_nodes import ( + ConnectionNode, + DatabaseNode, + FolderNode, + LoadingNode, + SchemaNode, + TableNode, +) + +# ============================================================================== +# Test Helpers +# ============================================================================== + + +async def wait_for_condition( + pilot: Any, + condition: callable, + timeout_seconds: float = 10.0, + poll_interval: float = 0.1, + description: str = "", +) -> bool: + """Wait for a condition to become true with timeout.""" + start = time.monotonic() + while time.monotonic() - start < timeout_seconds: + if condition(): + return True + await pilot.pause(poll_interval) + raise AssertionError(f"Timed out waiting for: {description or 'condition'}") + + +def find_node_by_type(node: Any, node_type: type, name: str | None = None) -> Any | None: + """Recursively find a node by its data type and optionally name.""" + if node.data and isinstance(node.data, node_type): + if name is None: + return node + if hasattr(node.data, "name") and node.data.name == name: + return node + for child in node.children: + result = find_node_by_type(child, node_type, name) + if result: + return result + return None + + +def find_database_node(tree_root: Any, db_name: str) -> Any | None: + """Find a database node by name.""" + for child in tree_root.children: + result = find_node_by_type(child, DatabaseNode, db_name) + if result: + return result + return None + + +def find_folder_node(parent: Any, folder_type: str) -> Any | None: + """Find a folder node (Tables, Views, etc.) under a parent.""" + for child in parent.children: + if isinstance(child.data, FolderNode) and child.data.folder_type == folder_type: + return child + return None + + +def has_loading_children(node: Any) -> bool: + """Check if a node has a loading placeholder child.""" + for child in node.children: + if isinstance(child.data, LoadingNode): + return True + return False + + +def has_table_children(node: Any) -> bool: + """Check if a node has TableNode children (directly or under schema folders).""" + for child in node.children: + if isinstance(child.data, TableNode): + return True + # Check under schema folders + if isinstance(child.data, SchemaNode): + for schema_child in child.children: + if isinstance(schema_child.data, TableNode): + return True + return False + + +def find_table_node(parent: Any, table_name: str) -> Any | None: + """Find a table node by name, checking both direct children and schema folders.""" + for child in parent.children: + if isinstance(child.data, TableNode) and child.data.name == table_name: + return child + # Check under schema folders + if isinstance(child.data, SchemaNode): + for schema_child in child.children: + if isinstance(schema_child.data, TableNode) and schema_child.data.name == table_name: + return schema_child + return None + + +# ============================================================================== +# Base Test Class +# ============================================================================== + + +class BaseDatabaseBrowsingTest: + """Base class for database browsing tests.""" + + # Subclasses must set these + DB_TYPE: str = "" + TEST_DATABASE: str = "" # The database containing test tables + SERVER_HOST: str = "localhost" + SERVER_PORT: str = "" + USERNAME: str = "" + PASSWORD: str = "" + + # Whether this provider supports cross-database queries + # MySQL, MariaDB, MSSQL: True + # PostgreSQL, CockroachDB: False (each DB is isolated) + SUPPORTS_CROSS_DB_QUERIES: bool = True + + # Whether this provider can fetch table metadata from other databases + # MySQL, MariaDB, MSSQL: True + # PostgreSQL, CockroachDB: False (can only see tables in connected database) + CAN_FETCH_CROSS_DB_TABLES: bool = True + + @pytest.fixture + def connection_config(self) -> ConnectionConfig: + """Create a connection config with empty database field.""" + return ConnectionConfig( + name=f"test-browse-{self.DB_TYPE}", + db_type=self.DB_TYPE, + server=self.SERVER_HOST, + port=self.SERVER_PORT, + database="", # Empty to browse all databases + username=self.USERNAME, + password=self.PASSWORD, + ) + + @pytest.fixture + def temp_config_dir(self): + """Create a temporary config directory for tests.""" + with tempfile.TemporaryDirectory(prefix="sqlit-test-") as tmpdir: + original = os.environ.get("SQLIT_CONFIG_DIR") + os.environ["SQLIT_CONFIG_DIR"] = tmpdir + yield tmpdir + if original: + os.environ["SQLIT_CONFIG_DIR"] = original + else: + os.environ.pop("SQLIT_CONFIG_DIR", None) + + @pytest.mark.asyncio + async def test_browse_all_databases_and_query(self, connection_config: ConnectionConfig, temp_config_dir: str): + """Test: Connect without database, browse to DB, expand Tables, run query. + + For providers that don't support cross-database queries (PostgreSQL, CockroachDB), + we only test tree navigation, not query execution. + """ + app = SSMSTUI() + + async with app.run_test(size=(120, 40)) as pilot: + # Wait for app to mount + await pilot.pause(0.1) + + # Set connections AFTER mount (on_mount loads from disk, overwriting pre-set values) + app.connections = [connection_config] + app.refresh_tree() + await pilot.pause(0.1) + + # Wait for tree to be populated + await wait_for_condition( + pilot, + lambda: len(app.object_tree.root.children) > 0, + timeout_seconds=5.0, + description="tree to be populated with connections", + ) + + # Step 1: Get the connection node from the tree + cursor_node = app.object_tree.root.children[0] + assert cursor_node is not None + assert isinstance(cursor_node.data, ConnectionNode) + + # Connect to the server + app.connect_to_server(connection_config) + await pilot.pause(0.5) + + # Step 2: Wait for connection and tree population + # The tree should now show the connection with a "Databases" folder + await wait_for_condition( + pilot, + lambda: app.current_connection is not None, + timeout_seconds=15.0, + description="connection to be established", + ) + + # Step 3: Verify database list is shown + # The connected node should have children (Databases folder with databases) + connected_node = None + for child in app.object_tree.root.children: + if isinstance(child.data, ConnectionNode) and child.data.config.name == connection_config.name: + connected_node = child + break + assert connected_node is not None, "Connected node not found" + + # Wait for tree to be populated with databases + await wait_for_condition( + pilot, + lambda: len(connected_node.children) > 0, + timeout_seconds=10.0, + description="tree to be populated", + ) + + # Find the test database node + db_node = find_database_node(app.object_tree.root, self.TEST_DATABASE) + assert db_node is not None, f"Database '{self.TEST_DATABASE}' not found in tree" + + # For providers that can't fetch cross-database tables (PostgreSQL, CockroachDB), + # we can only verify that databases are visible. We can't expand tables from + # other databases because the adapter can only see tables in the connected database. + if not self.CAN_FETCH_CROSS_DB_TABLES: + # Test passes - we verified connection and database visibility + return + + # Step 4: Expand the database node to see Tables/Views + db_node.expand() + await pilot.pause(0.3) + + # Find the Tables folder + tables_folder = find_folder_node(db_node, "tables") + assert tables_folder is not None, "Tables folder not found" + + # Step 5: Expand Tables folder + tables_folder.expand() + await pilot.pause(0.5) + + # Wait for tables to load (not just the loading placeholder) + await wait_for_condition( + pilot, + lambda: not has_loading_children(tables_folder) and len(tables_folder.children) > 0, + timeout_seconds=10.0, + description="tables to be loaded", + ) + + # Step 6: Verify tables are visible + assert has_table_children(tables_folder), "No tables found in Tables folder" + + # Find our test table + table_node = find_table_node(tables_folder, "test_users") + assert table_node is not None, "test_users table not found" + + # Step 7: Execute a query (only for providers that support cross-DB queries) + if self.SUPPORTS_CROSS_DB_QUERIES: + # Use the adapter's build_select_query to get the right syntax + query = app.current_adapter.build_select_query( + "test_users", 100, table_node.data.database, table_node.data.schema + ) + app.query_input.text = query + + # Execute the query + app.action_execute_query() + await pilot.pause(0.5) + + # Wait for query to complete + await wait_for_condition( + pilot, + lambda: not getattr(app, "_query_executing", False), + timeout_seconds=15.0, + description="query to complete", + ) + + # Step 8: Verify results + # Check that we got some results (the test_users table should have data) + assert app._last_result_row_count > 0, "Query returned no results" + assert "name" in [col.lower() for col in app._last_result_columns], "Expected 'name' column in results" + + +# ============================================================================== +# Provider-Specific Tests +# ============================================================================== + + +class TestMySQLDatabaseBrowsing(BaseDatabaseBrowsingTest): + """Test database browsing for MySQL.""" + + DB_TYPE = "mysql" + TEST_DATABASE = os.environ.get("MYSQL_DATABASE", "test_sqlit") + SERVER_HOST = os.environ.get("MYSQL_HOST", "localhost") + SERVER_PORT = os.environ.get("MYSQL_PORT", "3306") + USERNAME = os.environ.get("MYSQL_USER", "root") + PASSWORD = os.environ.get("MYSQL_PASSWORD", "TestPassword123!") + + @pytest.fixture + def connection_config(self) -> ConnectionConfig: + return ConnectionConfig( + name="test-browse-mysql", + db_type="mysql", + server=self.SERVER_HOST, + port=self.SERVER_PORT, + database="", # Empty to browse all databases + username=self.USERNAME, + password=self.PASSWORD, + ) + + @pytest.fixture(autouse=True) + def check_mysql_available(self, mysql_server_ready: bool, mysql_db: str): + """Skip if MySQL is not available.""" + if not mysql_server_ready: + pytest.skip("MySQL is not available") + # Update TEST_DATABASE with the actual database name from fixture + self.TEST_DATABASE = mysql_db + + @pytest.mark.asyncio + async def test_browse_all_databases_and_query(self, connection_config: ConnectionConfig, temp_config_dir: str): + """Test MySQL database browsing with empty database field.""" + await super().test_browse_all_databases_and_query(connection_config, temp_config_dir) + + +class TestPostgreSQLDatabaseBrowsing(BaseDatabaseBrowsingTest): + """Test database browsing for PostgreSQL. + + Note: PostgreSQL doesn't support cross-database queries. Each database is isolated. + When connected without a database, it connects to 'postgres' by default. + We can see all databases but can only query tables in the connected database. + """ + + DB_TYPE = "postgresql" + TEST_DATABASE = os.environ.get("POSTGRES_DATABASE", "test_sqlit") + SERVER_HOST = os.environ.get("POSTGRES_HOST", "localhost") + SERVER_PORT = os.environ.get("POSTGRES_PORT", "5432") + USERNAME = os.environ.get("POSTGRES_USER", "testuser") + PASSWORD = os.environ.get("POSTGRES_PASSWORD", "TestPassword123!") + SUPPORTS_CROSS_DB_QUERIES = False + CAN_FETCH_CROSS_DB_TABLES = False + + @pytest.fixture + def connection_config(self) -> ConnectionConfig: + return ConnectionConfig( + name="test-browse-postgresql", + db_type="postgresql", + server=self.SERVER_HOST, + port=self.SERVER_PORT, + database="", # Empty to browse all databases + username=self.USERNAME, + password=self.PASSWORD, + ) + + @pytest.fixture(autouse=True) + def check_postgres_available(self, postgres_server_ready: bool, postgres_db: str): + """Skip if PostgreSQL is not available.""" + if not postgres_server_ready: + pytest.skip("PostgreSQL is not available") + self.TEST_DATABASE = postgres_db + + @pytest.mark.asyncio + async def test_browse_all_databases_and_query(self, connection_config: ConnectionConfig, temp_config_dir: str): + """Test PostgreSQL database browsing with empty database field.""" + await super().test_browse_all_databases_and_query(connection_config, temp_config_dir) + + +class TestMSSQLDatabaseBrowsing(BaseDatabaseBrowsingTest): + """Test database browsing for SQL Server.""" + + DB_TYPE = "mssql" + TEST_DATABASE = os.environ.get("MSSQL_DATABASE", "test_sqlit") + SERVER_HOST = os.environ.get("MSSQL_HOST", "localhost") + SERVER_PORT = os.environ.get("MSSQL_PORT", "1433") + USERNAME = os.environ.get("MSSQL_USER", "sa") + PASSWORD = os.environ.get("MSSQL_PASSWORD", "TestPassword123!") + + @pytest.fixture + def connection_config(self) -> ConnectionConfig: + server = self.SERVER_HOST + if self.SERVER_PORT and self.SERVER_PORT != "1433": + server = f"{self.SERVER_HOST},{self.SERVER_PORT}" + return ConnectionConfig( + name="test-browse-mssql", + db_type="mssql", + server=server, + port="", # Port included in server for MSSQL + database="", # Empty to browse all databases + username=self.USERNAME, + password=self.PASSWORD, + auth_type="sql", + ) + + @pytest.fixture(autouse=True) + def check_mssql_available(self, mssql_server_ready: bool, mssql_db: str): + """Skip if MSSQL is not available.""" + if not mssql_server_ready: + pytest.skip("SQL Server is not available") + self.TEST_DATABASE = mssql_db + + @pytest.mark.asyncio + async def test_browse_all_databases_and_query(self, connection_config: ConnectionConfig, temp_config_dir: str): + """Test SQL Server database browsing with empty database field.""" + await super().test_browse_all_databases_and_query(connection_config, temp_config_dir) + + +class TestMariaDBDatabaseBrowsing(BaseDatabaseBrowsingTest): + """Test database browsing for MariaDB.""" + + DB_TYPE = "mariadb" + TEST_DATABASE = os.environ.get("MARIADB_DATABASE", "test_sqlit") + SERVER_HOST = os.environ.get("MARIADB_HOST", "127.0.0.1") + SERVER_PORT = os.environ.get("MARIADB_PORT", "3307") + USERNAME = os.environ.get("MARIADB_USER", "root") + PASSWORD = os.environ.get("MARIADB_PASSWORD", "TestPassword123!") + + @pytest.fixture + def connection_config(self) -> ConnectionConfig: + return ConnectionConfig( + name="test-browse-mariadb", + db_type="mariadb", + server=self.SERVER_HOST, + port=self.SERVER_PORT, + database="", # Empty to browse all databases + username=self.USERNAME, + password=self.PASSWORD, + ) + + @pytest.fixture(autouse=True) + def check_mariadb_available(self, mariadb_server_ready: bool, mariadb_db: str): + """Skip if MariaDB is not available.""" + if not mariadb_server_ready: + pytest.skip("MariaDB is not available") + self.TEST_DATABASE = mariadb_db + + @pytest.mark.asyncio + async def test_browse_all_databases_and_query(self, connection_config: ConnectionConfig, temp_config_dir: str): + """Test MariaDB database browsing with empty database field.""" + await super().test_browse_all_databases_and_query(connection_config, temp_config_dir) + + +class TestCockroachDBDatabaseBrowsing(BaseDatabaseBrowsingTest): + """Test database browsing for CockroachDB. + + Note: CockroachDB uses PostgreSQL wire protocol and has the same limitation - + it doesn't support cross-database queries. Each database is isolated. + """ + + DB_TYPE = "cockroachdb" + TEST_DATABASE = os.environ.get("COCKROACHDB_DATABASE", "test_sqlit") + SERVER_HOST = os.environ.get("COCKROACHDB_HOST", "localhost") + SERVER_PORT = os.environ.get("COCKROACHDB_PORT", "26257") + USERNAME = os.environ.get("COCKROACHDB_USER", "root") + PASSWORD = os.environ.get("COCKROACHDB_PASSWORD", "") + SUPPORTS_CROSS_DB_QUERIES = False + CAN_FETCH_CROSS_DB_TABLES = False + + @pytest.fixture + def connection_config(self) -> ConnectionConfig: + return ConnectionConfig( + name="test-browse-cockroachdb", + db_type="cockroachdb", + server=self.SERVER_HOST, + port=self.SERVER_PORT, + database="", # Empty to browse all databases + username=self.USERNAME, + password=self.PASSWORD, + ) + + @pytest.fixture(autouse=True) + def check_cockroachdb_available(self, cockroachdb_server_ready: bool, cockroachdb_db: str): + """Skip if CockroachDB is not available.""" + if not cockroachdb_server_ready: + pytest.skip("CockroachDB is not available") + self.TEST_DATABASE = cockroachdb_db + + @pytest.mark.asyncio + async def test_browse_all_databases_and_query(self, connection_config: ConnectionConfig, temp_config_dir: str): + """Test CockroachDB database browsing with empty database field.""" + await super().test_browse_all_databases_and_query(connection_config, temp_config_dir) diff --git a/tests/test_cockroachdb.py b/tests/test_cockroachdb.py index 8be48201..5f49c05d 100644 --- a/tests/test_cockroachdb.py +++ b/tests/test_cockroachdb.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -26,20 +24,32 @@ def config(self) -> DatabaseTestConfig: def test_create_cockroachdb_connection(self, cockroachdb_db, cli_runner): """Test creating a CockroachDB connection via CLI.""" - from .conftest import COCKROACHDB_HOST, COCKROACHDB_PORT, COCKROACHDB_USER, COCKROACHDB_PASSWORD + from .conftest import ( + COCKROACHDB_HOST, + COCKROACHDB_PASSWORD, + COCKROACHDB_PORT, + COCKROACHDB_USER, + ) connection_name = "test_create_cockroachdb" try: args = [ - "connection", "create", - "--name", connection_name, - "--db-type", "cockroachdb", - "--server", COCKROACHDB_HOST, - "--port", str(COCKROACHDB_PORT), - "--database", cockroachdb_db, - "--username", COCKROACHDB_USER, - "--password", COCKROACHDB_PASSWORD or "", + "connections", + "add", + "cockroachdb", + "--name", + connection_name, + "--server", + COCKROACHDB_HOST, + "--port", + str(COCKROACHDB_PORT), + "--database", + cockroachdb_db, + "--username", + COCKROACHDB_USER, + "--password", + COCKROACHDB_PASSWORD or "", ] result = cli_runner(*args) assert result.returncode == 0 @@ -56,19 +66,31 @@ def test_create_cockroachdb_connection(self, cockroachdb_db, cli_runner): def test_delete_cockroachdb_connection(self, cockroachdb_db, cli_runner): """Test deleting a CockroachDB connection.""" - from .conftest import COCKROACHDB_HOST, COCKROACHDB_PORT, COCKROACHDB_USER, COCKROACHDB_PASSWORD + from .conftest import ( + COCKROACHDB_HOST, + COCKROACHDB_PASSWORD, + COCKROACHDB_PORT, + COCKROACHDB_USER, + ) connection_name = "test_delete_cockroachdb" args = [ - "connection", "create", - "--name", connection_name, - "--db-type", "cockroachdb", - "--server", COCKROACHDB_HOST, - "--port", str(COCKROACHDB_PORT), - "--database", cockroachdb_db, - "--username", COCKROACHDB_USER, - "--password", COCKROACHDB_PASSWORD or "", + "connections", + "add", + "cockroachdb", + "--name", + connection_name, + "--server", + COCKROACHDB_HOST, + "--port", + str(COCKROACHDB_PORT), + "--database", + cockroachdb_db, + "--username", + COCKROACHDB_USER, + "--password", + COCKROACHDB_PASSWORD or "", ] cli_runner(*args) diff --git a/tests/test_credentials_service.py b/tests/test_credentials_service.py new file mode 100644 index 00000000..58c26df0 --- /dev/null +++ b/tests/test_credentials_service.py @@ -0,0 +1,558 @@ +"""Tests for the credentials service.""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from sqlit.config import ConnectionConfig +from sqlit.services.credentials import ( + KEYRING_SERVICE_NAME, + CredentialsService, + KeyringCredentialsService, + PlaintextFileCredentialsService, + PlaintextCredentialsService, + get_credentials_service, + reset_credentials_service, + set_credentials_service, +) + + +class TestPlaintextCredentialsService: + """Tests for PlaintextCredentialsService.""" + + def test_set_and_get_password(self) -> None: + """Test setting and getting database password.""" + service = PlaintextCredentialsService() + service.set_password("test_conn", "my_password") + assert service.get_password("test_conn") == "my_password" + + def test_get_password_not_found(self) -> None: + """Test getting a password that doesn't exist.""" + service = PlaintextCredentialsService() + assert service.get_password("nonexistent") is None + + def test_delete_password(self) -> None: + """Test deleting a password.""" + service = PlaintextCredentialsService() + service.set_password("test_conn", "my_password") + service.delete_password("test_conn") + assert service.get_password("test_conn") is None + + def test_delete_nonexistent_password(self) -> None: + """Test deleting a password that doesn't exist (should not raise).""" + service = PlaintextCredentialsService() + service.delete_password("nonexistent") # Should not raise + + def test_set_and_get_ssh_password(self) -> None: + """Test setting and getting SSH password.""" + service = PlaintextCredentialsService() + service.set_ssh_password("test_conn", "ssh_pass") + assert service.get_ssh_password("test_conn") == "ssh_pass" + + def test_get_ssh_password_not_found(self) -> None: + """Test getting an SSH password that doesn't exist.""" + service = PlaintextCredentialsService() + assert service.get_ssh_password("nonexistent") is None + + def test_delete_ssh_password(self) -> None: + """Test deleting an SSH password.""" + service = PlaintextCredentialsService() + service.set_ssh_password("test_conn", "ssh_pass") + service.delete_ssh_password("test_conn") + assert service.get_ssh_password("test_conn") is None + + def test_set_empty_password_stores_empty(self) -> None: + """Test that setting an empty password stores it (not deletes). + + Empty string means "explicitly set to empty" which is valid for + databases that support passwordless auth (e.g., CockroachDB insecure mode). + """ + service = PlaintextCredentialsService() + service.set_password("test_conn", "password") + service.set_password("test_conn", "") + assert service.get_password("test_conn") == "" + + def test_set_empty_ssh_password_stores_empty(self) -> None: + """Test that setting an empty SSH password stores it (not deletes). + + Empty string means "explicitly set to empty" which is valid for + some SSH configurations. + """ + service = PlaintextCredentialsService() + service.set_ssh_password("test_conn", "password") + service.set_ssh_password("test_conn", "") + assert service.get_ssh_password("test_conn") == "" + + def test_set_none_password_deletes(self) -> None: + """Test that setting None deletes the password.""" + service = PlaintextCredentialsService() + service.set_password("test_conn", "password") + service.set_password("test_conn", None) + assert service.get_password("test_conn") is None + + def test_set_none_ssh_password_deletes(self) -> None: + """Test that setting None deletes the SSH password.""" + service = PlaintextCredentialsService() + service.set_ssh_password("test_conn", "password") + service.set_ssh_password("test_conn", None) + assert service.get_ssh_password("test_conn") is None + + def test_rename_connection(self) -> None: + """Test renaming a connection moves credentials.""" + service = PlaintextCredentialsService() + service.set_password("old_name", "db_pass") + service.set_ssh_password("old_name", "ssh_pass") + + service.rename_connection("old_name", "new_name") + + # Old credentials should be gone + assert service.get_password("old_name") is None + assert service.get_ssh_password("old_name") is None + + # New credentials should exist + assert service.get_password("new_name") == "db_pass" + assert service.get_ssh_password("new_name") == "ssh_pass" + + def test_delete_all_for_connection(self) -> None: + """Test deleting all credentials for a connection.""" + service = PlaintextCredentialsService() + service.set_password("test_conn", "db_pass") + service.set_ssh_password("test_conn", "ssh_pass") + + service.delete_all_for_connection("test_conn") + + assert service.get_password("test_conn") is None + assert service.get_ssh_password("test_conn") is None + + def test_multiple_connections(self) -> None: + """Test storing credentials for multiple connections.""" + service = PlaintextCredentialsService() + service.set_password("conn1", "pass1") + service.set_password("conn2", "pass2") + service.set_ssh_password("conn1", "ssh1") + service.set_ssh_password("conn2", "ssh2") + + assert service.get_password("conn1") == "pass1" + assert service.get_password("conn2") == "pass2" + assert service.get_ssh_password("conn1") == "ssh1" + assert service.get_ssh_password("conn2") == "ssh2" + + +class TestKeyringCredentialsService: + """Tests for KeyringCredentialsService.""" + + def _create_service_with_mock_keyring(self) -> tuple[KeyringCredentialsService, MagicMock]: + """Create a service with a mock keyring injected.""" + service = KeyringCredentialsService() + mock_keyring = MagicMock() + service._keyring = mock_keyring + return service, mock_keyring + + def test_lazy_loading(self) -> None: + """Test that keyring is lazy-loaded.""" + service = KeyringCredentialsService() + assert service._keyring is None + + def test_make_key(self) -> None: + """Test key generation for keyring storage.""" + service = KeyringCredentialsService() + assert service._make_key("my_conn", "db") == "my_conn:db" + assert service._make_key("my_conn", "ssh") == "my_conn:ssh" + + def test_set_password(self) -> None: + """Test setting password via keyring.""" + service, mock_keyring = self._create_service_with_mock_keyring() + + service.set_password("test_conn", "my_password") + + mock_keyring.set_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:db", "my_password" + ) + + def test_get_password(self) -> None: + """Test getting password via keyring.""" + service, mock_keyring = self._create_service_with_mock_keyring() + mock_keyring.get_password.return_value = "stored_password" + + result = service.get_password("test_conn") + + assert result == "stored_password" + mock_keyring.get_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:db" + ) + + def test_delete_password(self) -> None: + """Test deleting password via keyring.""" + service, mock_keyring = self._create_service_with_mock_keyring() + + service.delete_password("test_conn") + + mock_keyring.delete_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:db" + ) + + def test_set_ssh_password(self) -> None: + """Test setting SSH password via keyring.""" + service, mock_keyring = self._create_service_with_mock_keyring() + + service.set_ssh_password("test_conn", "ssh_pass") + + mock_keyring.set_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:ssh", "ssh_pass" + ) + + def test_get_ssh_password(self) -> None: + """Test getting SSH password via keyring.""" + service, mock_keyring = self._create_service_with_mock_keyring() + mock_keyring.get_password.return_value = "ssh_stored" + + result = service.get_ssh_password("test_conn") + + assert result == "ssh_stored" + mock_keyring.get_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:ssh" + ) + + def test_delete_ssh_password(self) -> None: + """Test deleting SSH password via keyring.""" + service, mock_keyring = self._create_service_with_mock_keyring() + + service.delete_ssh_password("test_conn") + + mock_keyring.delete_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:ssh" + ) + + def test_set_empty_password_stores_empty(self) -> None: + """Test that setting empty password stores it (not deletes).""" + service, mock_keyring = self._create_service_with_mock_keyring() + + service.set_password("test_conn", "") + + mock_keyring.set_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:db", "" + ) + + def test_set_none_password_deletes(self) -> None: + """Test that setting None password calls delete.""" + service, mock_keyring = self._create_service_with_mock_keyring() + + service.set_password("test_conn", None) + + mock_keyring.delete_password.assert_called_once_with( + KEYRING_SERVICE_NAME, "test_conn:db" + ) + + def test_keyring_error_returns_none(self) -> None: + """Test that keyring errors return None for get operations.""" + service, mock_keyring = self._create_service_with_mock_keyring() + mock_keyring.get_password.side_effect = Exception("Keyring error") + + result = service.get_password("test_conn") + assert result is None + + def test_keyring_error_on_set_silently_fails(self) -> None: + """Test that keyring errors on set are silently caught.""" + service, mock_keyring = self._create_service_with_mock_keyring() + mock_keyring.set_password.side_effect = Exception("Keyring error") + + # Should not raise + service.set_password("test_conn", "password") + + +class TestGlobalCredentialsService: + """Tests for global credentials service functions.""" + + def teardown_method(self) -> None: + """Reset global service after each test.""" + reset_credentials_service() + + def test_set_and_get_service(self) -> None: + """Test setting and getting the global service.""" + service = PlaintextCredentialsService() + set_credentials_service(service) + assert get_credentials_service() is service + + def test_reset_service(self) -> None: + """Test resetting the global service.""" + service = PlaintextCredentialsService() + set_credentials_service(service) + reset_credentials_service() + + # Should create a new service + new_service = get_credentials_service() + assert new_service is not service + + @patch("sqlit.services.credentials.KeyringCredentialsService") + @patch("sqlit.services.credentials.is_keyring_usable", return_value=True) + def test_default_service_is_keyring(self, _mock_usable: MagicMock, mock_keyring_class: MagicMock) -> None: + """Test that default service is keyring-based.""" + mock_instance = MagicMock() + mock_keyring_class.return_value = mock_instance + + service = get_credentials_service() + + assert service is mock_instance + + @patch("sqlit.services.credentials.is_keyring_usable", return_value=False) + @patch("sqlit.stores.settings.load_settings", return_value={}) + def test_fallback_to_in_memory_when_no_consent( + self, _mock_settings: MagicMock, _mock_usable: MagicMock + ) -> None: + """Test fallback to in-memory plaintext when keyring isn't usable and consent not recorded.""" + reset_credentials_service() + service = get_credentials_service() + assert isinstance(service, PlaintextCredentialsService) + + @patch("sqlit.services.credentials.is_keyring_usable", return_value=False) + @patch("sqlit.stores.settings.load_settings", return_value={"allow_plaintext_credentials": True}) + def test_plaintext_file_when_consent_recorded( + self, _mock_settings: MagicMock, _mock_usable: MagicMock + ) -> None: + """Test fallback to plaintext file store when user consent is recorded.""" + reset_credentials_service() + service = get_credentials_service() + assert isinstance(service, PlaintextFileCredentialsService) + + +def test_plaintext_file_credentials_service_roundtrip(tmp_path, monkeypatch): + monkeypatch.setattr("sqlit.services.credentials.CONFIG_DIR", tmp_path) + service = PlaintextFileCredentialsService() + service.set_password("conn", "dbpass") + service.set_ssh_password("conn", "sshpass") + assert service.get_password("conn") == "dbpass" + assert service.get_ssh_password("conn") == "sshpass" + + +class TestConnectionStoreWithCredentials: + """Integration tests for ConnectionStore with credentials service.""" + + def setup_method(self) -> None: + """Set up test fixtures.""" + self.tmpdir = tempfile.mkdtemp() + self.creds_service = PlaintextCredentialsService() + set_credentials_service(self.creds_service) + + def teardown_method(self) -> None: + """Clean up after tests.""" + reset_credentials_service() + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def _create_store(self) -> "ConnectionStore": + """Create a ConnectionStore with the temp directory.""" + from sqlit.stores.base import JSONFileStore + from sqlit.stores.connections import ConnectionStore + + # Create a subclass that uses our temp path + class TempConnectionStore(ConnectionStore): + def __init__(self, tmpdir: str, creds_service): + # Don't call parent __init__, just set up manually + JSONFileStore.__init__(self, Path(tmpdir) / "connections.json") + self._credentials_service = creds_service + + return TempConnectionStore(self.tmpdir, self.creds_service) + + def test_save_removes_passwords_from_json(self) -> None: + """Test that saving connections removes passwords from JSON file.""" + store = self._create_store() + + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password="secret_password", + ssh_password="ssh_secret", + ) + + store.save_all([config]) + + # Read the JSON file directly + json_path = Path(self.tmpdir) / "connections.json" + with open(json_path) as f: + saved_data = json.load(f) + + # Passwords should be null in JSON (indicating "load from credentials service") + assert saved_data[0]["password"] is None + assert saved_data[0]["ssh_password"] is None + + # But should be in the credentials service + assert self.creds_service.get_password("test_db") == "secret_password" + assert self.creds_service.get_ssh_password("test_db") == "ssh_secret" + + def test_load_restores_passwords_from_credentials_service(self) -> None: + """Test that loading connections restores passwords.""" + # Set up credentials in the service + self.creds_service.set_password("test_db", "secret_password") + self.creds_service.set_ssh_password("test_db", "ssh_secret") + + # Write a config file with null passwords (indicates "load from credentials service") + json_path = Path(self.tmpdir) / "connections.json" + with open(json_path, "w") as f: + json.dump( + [ + { + "name": "test_db", + "db_type": "postgresql", + "server": "localhost", + "username": "user", + "password": None, # null = load from credentials service + "ssh_password": None, # null = load from credentials service + "port": "5432", + "database": "", + "auth_type": "sql", + "driver": "ODBC Driver 18 for SQL Server", + "trusted_connection": False, + "file_path": "", + "ssh_enabled": False, + "ssh_host": "", + "ssh_port": "22", + "ssh_username": "", + "ssh_auth_type": "key", + "ssh_key_path": "", + "supabase_region": "", + "supabase_project_id": "", + } + ], + f, + ) + + store = self._create_store() + loaded = store.load_all() + + assert len(loaded) == 1 + assert loaded[0].password == "secret_password" + assert loaded[0].ssh_password == "ssh_secret" + + def test_delete_removes_credentials(self) -> None: + """Test that deleting a connection removes credentials.""" + store = self._create_store() + + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + password="secret", + ssh_password="ssh_secret", + ) + + store.save_all([config]) + + # Verify credentials exist + assert self.creds_service.get_password("test_db") == "secret" + assert self.creds_service.get_ssh_password("test_db") == "ssh_secret" + + # Delete the connection + store.delete("test_db") + + # Credentials should be gone + assert self.creds_service.get_password("test_db") is None + assert self.creds_service.get_ssh_password("test_db") is None + + def test_empty_password_is_stored(self) -> None: + """Test that empty password is stored (explicitly set to empty). + + Empty string means the user explicitly set an empty password, + which is valid for databases supporting passwordless auth. + None means "not set" which would trigger a prompt. + """ + store = self._create_store() + + # Create config with empty password (explicitly empty, e.g., CockroachDB insecure) + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password="", # Empty = explicitly empty, no prompt + ) + + store.save_all([config]) + + # Load and verify password is still empty + loaded = store.load_all() + assert loaded[0].password == "" + + # Credentials service should have empty string stored + assert self.creds_service.get_password("test_db") == "" + + def test_none_password_means_prompt_on_connect(self) -> None: + """Test that None password means prompt on connect.""" + store = self._create_store() + + # Create config with None password (user wants to be prompted) + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password=None, # None = prompt on connect + ) + + store.save_all([config]) + + # Load and verify password is still None + loaded = store.load_all() + assert loaded[0].password is None + + # Credentials service should not have a password + assert self.creds_service.get_password("test_db") is None + + def test_migration_from_plaintext_preserves_existing_passwords(self) -> None: + """Test that existing plaintext passwords in JSON are preserved during migration.""" + # Write a config file WITH passwords (simulating old format) + json_path = Path(self.tmpdir) / "connections.json" + with open(json_path, "w") as f: + json.dump( + [ + { + "name": "legacy_db", + "db_type": "postgresql", + "server": "localhost", + "username": "user", + "password": "legacy_password", # Old plaintext password + "ssh_password": "legacy_ssh", + "port": "5432", + "database": "", + "auth_type": "sql", + "driver": "ODBC Driver 18 for SQL Server", + "trusted_connection": False, + "file_path": "", + "ssh_enabled": True, + "ssh_host": "bastion", + "ssh_port": "22", + "ssh_username": "user", + "ssh_auth_type": "password", + "ssh_key_path": "", + "supabase_region": "", + "supabase_project_id": "", + } + ], + f, + ) + + store = self._create_store() + loaded = store.load_all() + + # Legacy passwords from JSON should be loaded + assert loaded[0].password == "legacy_password" + assert loaded[0].ssh_password == "legacy_ssh" + + # Re-save to migrate to keyring + store.save_all(loaded) + + # Now passwords should be in keyring + assert self.creds_service.get_password("legacy_db") == "legacy_password" + assert self.creds_service.get_ssh_password("legacy_db") == "legacy_ssh" + + # And JSON should be clean (null indicates load from credentials service) + with open(json_path) as f: + saved_data = json.load(f) + assert saved_data[0]["password"] is None + assert saved_data[0]["ssh_password"] is None diff --git a/tests/test_d1.py b/tests/test_d1.py new file mode 100644 index 00000000..92632a63 --- /dev/null +++ b/tests/test_d1.py @@ -0,0 +1,56 @@ +"""Integration tests for Cloudflare D1 database operations.""" + +from __future__ import annotations + +from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig + + +class TestD1Integration(BaseDatabaseTestsWithLimit): + """Integration tests for Cloudflare D1 database operations via CLI. + + These tests require a running miniflare instance (via Docker). + Tests are skipped if D1 (miniflare) is not available. + """ + + @property + def config(self) -> DatabaseTestConfig: + return DatabaseTestConfig( + db_type="d1", + display_name="Cloudflare D1", + connection_fixture="d1_connection", + db_fixture="d1_db", + create_connection_args=lambda: [], # Uses fixtures + ) + + def test_create_d1_connection(self, d1_db, cli_runner): + """Test creating a D1 connection via CLI.""" + from .conftest import D1_ACCOUNT_ID, D1_API_TOKEN + + connection_name = "test_create_d1" + + try: + # Create connection + result = cli_runner( + "connections", + "add", + "d1", + "--name", + connection_name, + "--host", + D1_ACCOUNT_ID, + "--database", + d1_db, + "--password", + D1_API_TOKEN, + ) + assert result.returncode == 0 + assert "created successfully" in result.stdout + + # Verify it appears in list + result = cli_runner("connection", "list") + assert connection_name in result.stdout + assert "Cloudflare D1" in result.stdout + + finally: + # Cleanup + cli_runner("connection", "delete", connection_name, check=False) diff --git a/tests/test_database_base.py b/tests/test_database_base.py index 5aeb145a..fe953675 100644 --- a/tests/test_database_base.py +++ b/tests/test_database_base.py @@ -11,8 +11,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -import pytest - if TYPE_CHECKING: from collections.abc import Callable @@ -56,8 +54,10 @@ def test_query_select(self, request, cli_runner): connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "SELECT * FROM test_users ORDER BY id", + "-c", + connection, + "-q", + "SELECT * FROM test_users ORDER BY id", ) assert result.returncode == 0 assert "Alice" in result.stdout @@ -70,8 +70,10 @@ def test_query_with_where(self, request, cli_runner): connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "SELECT name, email FROM test_users WHERE id = 1", + "-c", + connection, + "-q", + "SELECT name, email FROM test_users WHERE id = 1", ) assert result.returncode == 0 assert "Alice" in result.stdout @@ -81,11 +83,17 @@ def test_query_with_where(self, request, cli_runner): def test_query_json_format(self, request, cli_runner): """Test query output in JSON format.""" connection = request.getfixturevalue(self.config.connection_fixture) + # Use --limit for databases that don't support LIMIT syntax result = cli_runner( "query", - "-c", connection, - "-q", "SELECT id, name FROM test_users ORDER BY id LIMIT 2", - "--format", "json", + "-c", + connection, + "-q", + "SELECT id, name FROM test_users ORDER BY id", + "--format", + "json", + "--limit", + "2", ) assert result.returncode == 0 @@ -93,30 +101,42 @@ def test_query_json_format(self, request, cli_runner): data = json.loads(result.stdout) assert len(data) == 2 - assert data[0]["name"] == "Alice" - assert data[1]["name"] == "Bob" + # Oracle returns uppercase column names + first_name = data[0].get("name") or data[0].get("NAME") + second_name = data[1].get("name") or data[1].get("NAME") + assert first_name == "Alice" + assert second_name == "Bob" def test_query_csv_format(self, request, cli_runner): """Test query output in CSV format.""" connection = request.getfixturevalue(self.config.connection_fixture) + # Use --limit for databases that don't support LIMIT syntax result = cli_runner( "query", - "-c", connection, - "-q", "SELECT id, name FROM test_users ORDER BY id LIMIT 2", - "--format", "csv", + "-c", + connection, + "-q", + "SELECT id, name FROM test_users ORDER BY id", + "--format", + "csv", + "--limit", + "2", ) assert result.returncode == 0 - assert "id,name" in result.stdout - assert "1,Alice" in result.stdout - assert "2,Bob" in result.stdout + # Oracle may return uppercase column names + assert "id,name" in result.stdout.lower() + assert "Alice" in result.stdout + assert "Bob" in result.stdout def test_query_view(self, request, cli_runner): """Test querying a view.""" connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "SELECT * FROM test_user_emails ORDER BY id", + "-c", + connection, + "-q", + "SELECT * FROM test_user_emails ORDER BY id", ) assert result.returncode == 0 assert "Alice" in result.stdout @@ -127,8 +147,10 @@ def test_query_aggregate(self, request, cli_runner): connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "SELECT COUNT(*) as user_count FROM test_users", + "-c", + connection, + "-q", + "SELECT COUNT(*) as user_count FROM test_users", ) assert result.returncode == 0 assert "3" in result.stdout @@ -138,20 +160,23 @@ def test_query_insert(self, request, cli_runner): connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "INSERT INTO test_users (id, name, email) VALUES (4, 'David', 'david@example.com')", + "-c", + connection, + "-q", + "INSERT INTO test_users (id, name, email) VALUES (4, 'David', 'david@example.com')", ) assert result.returncode == 0 # Verify the insert result = cli_runner( "query", - "-c", connection, - "-q", "SELECT * FROM test_users WHERE id = 4", + "-c", + connection, + "-q", + "SELECT * FROM test_users WHERE id = 4", ) assert "David" in result.stdout - def test_cancellable_query_select(self, request): """Test CancellableQuery execution (used by TUI). @@ -222,12 +247,16 @@ def test_streaming_csv_output(self, request, cli_runner): # No --max-rows to trigger the streaming path for cursor-based adapters result = cli_runner( "query", - "-c", connection, - "-q", "SELECT id, name FROM test_users ORDER BY id", - "--format", "csv", + "-c", + connection, + "-q", + "SELECT id, name FROM test_users ORDER BY id", + "--format", + "csv", ) assert result.returncode == 0 - assert "id,name" in result.stdout + # Oracle may return uppercase column names + assert "id,name" in result.stdout.lower() assert "Alice" in result.stdout def test_streaming_json_output(self, request, cli_runner): @@ -235,14 +264,19 @@ def test_streaming_json_output(self, request, cli_runner): connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "SELECT id, name FROM test_users ORDER BY id", - "--format", "json", + "-c", + connection, + "-q", + "SELECT id, name FROM test_users ORDER BY id", + "--format", + "json", ) assert result.returncode == 0 data = json.loads(result.stdout) assert len(data) == 3 - assert data[0]["name"] == "Alice" + # Oracle returns uppercase column names + first_name = data[0].get("name") or data[0].get("NAME") + assert first_name == "Alice" def test_adapter_interface_compliance(self, request): """Verify adapter implements required interface without relying on cursor. @@ -256,16 +290,16 @@ def test_adapter_interface_compliance(self, request): # Required methods that should work without cursor required_methods = [ - 'connect', - 'execute_query', - 'execute_non_query', - 'get_tables', - 'get_views', - 'get_columns', - 'get_databases', - 'get_procedures', - 'quote_identifier', - 'build_select_query', + "connect", + "execute_query", + "execute_non_query", + "get_tables", + "get_views", + "get_columns", + "get_databases", + "get_procedures", + "quote_identifier", + "build_select_query", ] for method_name in required_methods: @@ -275,9 +309,9 @@ def test_adapter_interface_compliance(self, request): # Required properties required_properties = [ - 'name', - 'supports_multiple_databases', - 'supports_stored_procedures', + "name", + "supports_multiple_databases", + "supports_stored_procedures", ] for prop_name in required_properties: @@ -317,6 +351,44 @@ def test_query_service_execution(self, request): row_values = [str(v) for v in result.rows[0]] assert "Alice" in row_values + def test_primary_key_detection(self, request): + """Test that adapter correctly detects primary key columns. + + This tests that get_columns returns ColumnInfo with is_primary_key=True + for primary key columns. The test_users table has 'id' as PRIMARY KEY. + """ + from sqlit.config import load_connections + from sqlit.db.adapters import get_adapter + from sqlit.services.session import ConnectionSession + + connection_name = request.getfixturevalue(self.config.connection_fixture) + connections = load_connections() + config = next((c for c in connections if c.name == connection_name), None) + assert config is not None, f"Connection {connection_name} not found" + + with ConnectionSession.create(config, get_adapter) as session: + # Get columns for test_users table (has 'id' as PRIMARY KEY) + columns = session.adapter.get_columns( + session.connection, + "test_users", + database=config.database if session.adapter.supports_multiple_databases else None, + ) + + assert len(columns) >= 3, f"Expected at least 3 columns, got {len(columns)}" + + # Find the 'id' column (case-insensitive for Oracle which uppercases) + id_column = next( + (col for col in columns if col.name.lower() == "id"), + None, + ) + assert id_column is not None, f"Column 'id' not found. Columns: {[c.name for c in columns]}" + assert id_column.is_primary_key, f"Column 'id' should be marked as primary key" + + # Non-PK columns should not be marked as primary key + non_pk_columns = [col for col in columns if col.name.lower() != "id"] + for col in non_pk_columns: + assert not col.is_primary_key, f"Column '{col.name}' should NOT be marked as primary key" + class BaseDatabaseTestsWithLimit(BaseDatabaseTests): """Base tests for databases that support LIMIT syntax.""" @@ -326,8 +398,10 @@ def test_query_limit(self, request, cli_runner): connection = request.getfixturevalue(self.config.connection_fixture) result = cli_runner( "query", - "-c", connection, - "-q", "SELECT * FROM test_users ORDER BY id LIMIT 2", + "-c", + connection, + "-q", + "SELECT * FROM test_users ORDER BY id LIMIT 2", ) assert result.returncode == 0 assert "Alice" in result.stdout diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index 7cab948e..e368a18e 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -22,8 +20,8 @@ def config(self) -> DatabaseTestConfig: connection_fixture="duckdb_connection", db_fixture="duckdb_db", create_connection_args=lambda db: [ - "--db-type", "duckdb", - "--file-path", str(db), + "--file-path", + str(db), ], ) @@ -34,10 +32,13 @@ def test_create_duckdb_connection(self, duckdb_db, cli_runner): try: # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "duckdb", - "--file-path", str(duckdb_db), + "connections", + "add", + "duckdb", + "--name", + connection_name, + "--file-path", + str(duckdb_db), ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -55,8 +56,10 @@ def test_query_duckdb_join(self, duckdb_connection, cli_runner): """Test JOIN query on DuckDB.""" result = cli_runner( "query", - "-c", duckdb_connection, - "-q", """ + "-c", + duckdb_connection, + "-q", + """ SELECT u.name, p.name as product, p.price FROM test_users u CROSS JOIN test_products p @@ -71,16 +74,20 @@ def test_query_duckdb_update(self, duckdb_connection, cli_runner): """Test UPDATE statement on DuckDB.""" result = cli_runner( "query", - "-c", duckdb_connection, - "-q", "UPDATE test_users SET name = 'Alicia' WHERE id = 1", + "-c", + duckdb_connection, + "-q", + "UPDATE test_users SET name = 'Alicia' WHERE id = 1", ) assert result.returncode == 0 # Verify the update result = cli_runner( "query", - "-c", duckdb_connection, - "-q", "SELECT name FROM test_users WHERE id = 1", + "-c", + duckdb_connection, + "-q", + "SELECT name FROM test_users WHERE id = 1", ) assert "Alicia" in result.stdout @@ -90,10 +97,13 @@ def test_delete_duckdb_connection(self, duckdb_db, cli_runner): # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "duckdb", - "--file-path", str(duckdb_db), + "connections", + "add", + "duckdb", + "--name", + connection_name, + "--file-path", + str(duckdb_db), ) # Delete it @@ -109,8 +119,10 @@ def test_query_duckdb_invalid_query(self, duckdb_connection, cli_runner): """Test handling of invalid SQL query.""" result = cli_runner( "query", - "-c", duckdb_connection, - "-q", "SELECT * FROM nonexistent_table", + "-c", + duckdb_connection, + "-q", + "SELECT * FROM nonexistent_table", check=False, ) # Should fail gracefully diff --git a/tests/test_install_strategy.py b/tests/test_install_strategy.py new file mode 100644 index 00000000..5ba92ae2 --- /dev/null +++ b/tests/test_install_strategy.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from sqlit.install_strategy import detect_strategy + + +def test_detect_strategy_pipx_override(monkeypatch): + monkeypatch.setenv("SQLIT_MOCK_PIPX", "pipx") + strategy = detect_strategy(extra_name="postgres", package_name="psycopg2-binary") + assert strategy.kind == "pipx" + assert strategy.can_auto_install is True + assert strategy.auto_install_command == ["pipx", "inject", "sqlit-tui", "psycopg2-binary"] + + +def test_detect_strategy_externally_managed_disables_auto_install(monkeypatch, tmp_path): + marker_dir = tmp_path / "stdlib" + marker_dir.mkdir() + (marker_dir / "EXTERNALLY-MANAGED").write_text("managed", encoding="utf-8") + + monkeypatch.delenv("SQLIT_MOCK_PIPX", raising=False) + monkeypatch.setattr("sqlit.install_strategy._in_venv", lambda: False) + monkeypatch.setattr("sqlit.install_strategy.sysconfig.get_path", lambda _key: str(marker_dir)) + + strategy = detect_strategy(extra_name="postgres", package_name="psycopg2-binary") + assert strategy.kind == "externally-managed" + assert strategy.can_auto_install is False + assert "externally managed" in (strategy.reason_unavailable or "").lower() + assert "pipx inject" in strategy.manual_instructions + + +def test_detect_strategy_pip_user_fallback(monkeypatch): + monkeypatch.delenv("SQLIT_MOCK_PIPX", raising=False) + monkeypatch.setattr("sqlit.install_strategy._in_venv", lambda: False) + monkeypatch.setattr("sqlit.install_strategy._pep668_externally_managed", lambda: False) + monkeypatch.setattr("sqlit.install_strategy._pip_available", lambda: True) + monkeypatch.setattr("sqlit.install_strategy._install_paths_writable", lambda: False) + monkeypatch.setattr("sqlit.install_strategy._user_site_enabled", lambda: True) + + strategy = detect_strategy(extra_name="postgres", package_name="psycopg2-binary") + assert strategy.kind == "pip-user" + assert strategy.can_auto_install is True + assert "--user" in (strategy.auto_install_command or []) diff --git a/tests/test_installer_cancel.py b/tests/test_installer_cancel.py new file mode 100644 index 00000000..063f984b --- /dev/null +++ b/tests/test_installer_cancel.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import subprocess +import threading +from unittest.mock import patch + +from sqlit.db.exceptions import MissingDriverError +from sqlit.services.installer import Installer + + +class _FakeProcess: + def __init__(self, started_event: threading.Event): + self.started_event = started_event + self.terminated = False + self.killed = False + self.returncode: int | None = None + self.started_event.set() + + def communicate(self, timeout: float | None = None): # noqa: ANN001 + if self.terminated or self.killed: + self.returncode = -15 if self.terminated else -9 + return "", "" + raise subprocess.TimeoutExpired(cmd="fake", timeout=timeout) + + def terminate(self) -> None: + self.terminated = True + + def kill(self) -> None: + self.killed = True + + def wait(self, timeout: float | None = None) -> int: # noqa: ARG002 + self.returncode = -15 if self.terminated else -9 if self.killed else 0 + return self.returncode + + +def test_installer_cancel_terminates_process(): + installer = Installer(app=object()) + error = MissingDriverError("PostgreSQL", "postgres", "psycopg2-binary") + cancel_event = threading.Event() + + started = threading.Event() + proc_holder: dict[str, _FakeProcess] = {} + + def fake_popen(*args, **kwargs): # noqa: ANN001,ARG001 + proc = _FakeProcess(started) + proc_holder["proc"] = proc + return proc + + with patch("sqlit.services.installer.subprocess.Popen", new=fake_popen): + result_holder: dict[str, tuple[bool, str, MissingDriverError]] = {} + + def run(): + result_holder["result"] = installer._do_install(error, cancel_event) + + thread = threading.Thread(target=run, daemon=True) + thread.start() + + assert started.wait(timeout=1) + cancel_event.set() + thread.join(timeout=5) + assert not thread.is_alive() + + success, output, _ = result_holder["result"] + assert success is False + assert "cancelled" in output.lower() + assert proc_holder["proc"].terminated is True diff --git a/tests/test_mariadb.py b/tests/test_mariadb.py index 63753e16..d5c7e5e3 100644 --- a/tests/test_mariadb.py +++ b/tests/test_mariadb.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -26,21 +24,28 @@ def config(self) -> DatabaseTestConfig: def test_create_mariadb_connection(self, mariadb_db, cli_runner): """Test creating a MariaDB connection via CLI.""" - from .conftest import MARIADB_HOST, MARIADB_PORT, MARIADB_USER, MARIADB_PASSWORD + from .conftest import MARIADB_HOST, MARIADB_PASSWORD, MARIADB_PORT, MARIADB_USER connection_name = "test_create_mariadb" try: # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "mariadb", - "--server", MARIADB_HOST, - "--port", str(MARIADB_PORT), - "--database", mariadb_db, - "--username", MARIADB_USER, - "--password", MARIADB_PASSWORD, + "connections", + "add", + "mariadb", + "--name", + connection_name, + "--server", + MARIADB_HOST, + "--port", + str(MARIADB_PORT), + "--database", + mariadb_db, + "--username", + MARIADB_USER, + "--password", + MARIADB_PASSWORD, ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -56,20 +61,27 @@ def test_create_mariadb_connection(self, mariadb_db, cli_runner): def test_delete_mariadb_connection(self, mariadb_db, cli_runner): """Test deleting a MariaDB connection.""" - from .conftest import MARIADB_HOST, MARIADB_PORT, MARIADB_USER, MARIADB_PASSWORD + from .conftest import MARIADB_HOST, MARIADB_PASSWORD, MARIADB_PORT, MARIADB_USER connection_name = "test_delete_mariadb" # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "mariadb", - "--server", MARIADB_HOST, - "--port", str(MARIADB_PORT), - "--database", mariadb_db, - "--username", MARIADB_USER, - "--password", MARIADB_PASSWORD, + "connections", + "add", + "mariadb", + "--name", + connection_name, + "--server", + MARIADB_HOST, + "--port", + str(MARIADB_PORT), + "--database", + mariadb_db, + "--username", + MARIADB_USER, + "--password", + MARIADB_PASSWORD, ) # Delete it diff --git a/tests/test_mocks.py b/tests/test_mocks.py index 6b50463f..a7f0e407 100644 --- a/tests/test_mocks.py +++ b/tests/test_mocks.py @@ -42,12 +42,12 @@ def test_unknown_table_returns_empty_columns(self): assert adapter.get_columns(MockConnection(), "nonexistent") == [] def test_execute_query_pattern_matching(self): - adapter = MockDatabaseAdapter(query_results={ - "users": (["id", "name"], [(1, "Alice"), (2, "Bob")]), - }) - cols, rows, truncated = adapter.execute_query( - MockConnection(), "SELECT * FROM users WHERE id = 1" + adapter = MockDatabaseAdapter( + query_results={ + "users": (["id", "name"], [(1, "Alice"), (2, "Bob")]), + } ) + cols, rows, truncated = adapter.execute_query(MockConnection(), "SELECT * FROM users WHERE id = 1") assert cols == ["id", "name"] assert len(rows) == 2 assert not truncated @@ -67,12 +67,12 @@ def test_execute_query_returns_default_for_unknown(self): assert rows == [("default",)] def test_execute_query_respects_max_rows(self): - adapter = MockDatabaseAdapter(query_results={ - "users": (["id"], [(1,), (2,), (3,), (4,), (5,)]), - }) - cols, rows, truncated = adapter.execute_query( - MockConnection(), "SELECT * FROM users", max_rows=2 + adapter = MockDatabaseAdapter( + query_results={ + "users": (["id"], [(1,), (2,), (3,), (4,), (5,)]), + } ) + cols, rows, truncated = adapter.execute_query(MockConnection(), "SELECT * FROM users", max_rows=2) assert len(rows) == 2 assert truncated is True @@ -136,19 +136,14 @@ class TestAdapterInterfaceCompliance: def test_method_signatures_match_base(self): import inspect + from sqlit.db.adapters.base import DatabaseAdapter for method_name in ["build_select_query", "execute_query"]: - base_params = list(inspect.signature( - getattr(DatabaseAdapter, method_name) - ).parameters.keys()) - mock_params = list(inspect.signature( - getattr(MockDatabaseAdapter, method_name) - ).parameters.keys()) - - assert base_params == mock_params, ( - f"{method_name}: {mock_params} != {base_params}" - ) + base_params = list(inspect.signature(getattr(DatabaseAdapter, method_name)).parameters.keys()) + mock_params = list(inspect.signature(getattr(MockDatabaseAdapter, method_name)).parameters.keys()) + + assert base_params == mock_params, f"{method_name}: {mock_params} != {base_params}" def test_all_abstract_methods_callable(self): adapter = MockDatabaseAdapter() diff --git a/tests/test_mssql.py b/tests/test_mssql.py index c80f9fb9..20b7596c 100644 --- a/tests/test_mssql.py +++ b/tests/test_mssql.py @@ -2,183 +2,102 @@ from __future__ import annotations -import json +from .test_database_base import BaseDatabaseTests, DatabaseTestConfig -import pytest - -class TestMSSQLIntegration: +class TestMSSQLIntegration(BaseDatabaseTests): """Integration tests for SQL Server database operations via CLI. These tests require a running SQL Server instance (via Docker). Tests are skipped if SQL Server is not available. """ + @property + def config(self) -> DatabaseTestConfig: + return DatabaseTestConfig( + db_type="mssql", + display_name="SQL Server", + connection_fixture="mssql_connection", + db_fixture="mssql_db", + create_connection_args=lambda: [], # Uses fixtures + uses_limit=False, # MSSQL uses TOP instead of LIMIT + ) + def test_create_mssql_connection(self, mssql_db, cli_runner): """Test creating a SQL Server connection via CLI.""" - from .conftest import MSSQL_HOST, MSSQL_PORT, MSSQL_USER, MSSQL_PASSWORD + from .conftest import MSSQL_HOST, MSSQL_PASSWORD, MSSQL_PORT, MSSQL_USER connection_name = "test_create_mssql" try: - # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "mssql", - "--server", f"{MSSQL_HOST},{MSSQL_PORT}" if MSSQL_PORT != 1433 else MSSQL_HOST, - "--database", mssql_db, - "--auth-type", "sql", - "--username", MSSQL_USER, - "--password", MSSQL_PASSWORD, + "connections", + "add", + "mssql", + "--name", + connection_name, + "--server", + f"{MSSQL_HOST},{MSSQL_PORT}" if MSSQL_PORT != 1433 else MSSQL_HOST, + "--database", + mssql_db, + "--auth-type", + "sql", + "--username", + MSSQL_USER, + "--password", + MSSQL_PASSWORD, ) assert result.returncode == 0 assert "created successfully" in result.stdout - # Verify it appears in list result = cli_runner("connection", "list") assert connection_name in result.stdout assert "SQL Server" in result.stdout finally: - # Cleanup cli_runner("connection", "delete", connection_name, check=False) - def test_list_connections_shows_mssql(self, mssql_connection, cli_runner): - """Test that connection list shows SQL Server connections correctly.""" - result = cli_runner("connection", "list") - assert result.returncode == 0 - assert mssql_connection in result.stdout - assert "SQL Server" in result.stdout - - def test_query_mssql_select(self, mssql_connection, cli_runner): - """Test executing SELECT query on SQL Server.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT * FROM test_users ORDER BY id", - ) - assert result.returncode == 0 - assert "Alice" in result.stdout - assert "Bob" in result.stdout - assert "Charlie" in result.stdout - assert "3 row(s) returned" in result.stdout - - def test_query_mssql_with_where(self, mssql_connection, cli_runner): - """Test executing SELECT with WHERE clause on SQL Server.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT name, email FROM test_users WHERE id = 1", - ) - assert result.returncode == 0 - assert "Alice" in result.stdout - assert "alice@example.com" in result.stdout - assert "1 row(s) returned" in result.stdout - def test_query_mssql_top(self, mssql_connection, cli_runner): - """Test SQL Server specific TOP clause.""" + """Test SQL Server specific TOP clause (MSSQL's equivalent of LIMIT).""" result = cli_runner( "query", - "-c", mssql_connection, - "-q", "SELECT TOP 2 * FROM test_users ORDER BY id", + "-c", + mssql_connection, + "-q", + "SELECT TOP 2 * FROM test_users ORDER BY id", ) assert result.returncode == 0 assert "Alice" in result.stdout assert "Bob" in result.stdout assert "2 row(s) returned" in result.stdout - def test_query_mssql_json_format(self, mssql_connection, cli_runner): - """Test query output in JSON format.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT TOP 2 id, name FROM test_users ORDER BY id", - "--format", "json", - ) - assert result.returncode == 0 - - # Parse JSON output (row count message goes to stderr, not stdout) - data = json.loads(result.stdout) - - assert len(data) == 2 - assert data[0]["name"] == "Alice" - assert data[1]["name"] == "Bob" - - def test_query_mssql_csv_format(self, mssql_connection, cli_runner): - """Test query output in CSV format.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT TOP 2 id, name FROM test_users ORDER BY id", - "--format", "csv", - ) - assert result.returncode == 0 - assert "id,name" in result.stdout - assert "1,Alice" in result.stdout - assert "2,Bob" in result.stdout - - def test_query_mssql_view(self, mssql_connection, cli_runner): - """Test querying a view on SQL Server.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT * FROM test_user_emails ORDER BY id", - ) - assert result.returncode == 0 - assert "Alice" in result.stdout - assert "3 row(s) returned" in result.stdout - - def test_query_mssql_aggregate(self, mssql_connection, cli_runner): - """Test aggregate query on SQL Server.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT COUNT(*) as user_count FROM test_users", - ) - assert result.returncode == 0 - assert "3" in result.stdout - - def test_query_mssql_insert(self, mssql_connection, cli_runner): - """Test INSERT statement on SQL Server.""" - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "INSERT INTO test_users (id, name, email) VALUES (4, 'David', 'david@example.com')", - ) - assert result.returncode == 0 - - # Verify the insert - result = cli_runner( - "query", - "-c", mssql_connection, - "-q", "SELECT * FROM test_users WHERE id = 4", - ) - assert "David" in result.stdout - def test_delete_mssql_connection(self, mssql_db, cli_runner): """Test deleting a SQL Server connection.""" - from .conftest import MSSQL_HOST, MSSQL_PORT, MSSQL_USER, MSSQL_PASSWORD + from .conftest import MSSQL_HOST, MSSQL_PASSWORD, MSSQL_PORT, MSSQL_USER connection_name = "test_delete_mssql" - # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "mssql", - "--server", f"{MSSQL_HOST},{MSSQL_PORT}" if MSSQL_PORT != 1433 else MSSQL_HOST, - "--database", mssql_db, - "--auth-type", "sql", - "--username", MSSQL_USER, - "--password", MSSQL_PASSWORD, + "connections", + "add", + "mssql", + "--name", + connection_name, + "--server", + f"{MSSQL_HOST},{MSSQL_PORT}" if MSSQL_PORT != 1433 else MSSQL_HOST, + "--database", + mssql_db, + "--auth-type", + "sql", + "--username", + MSSQL_USER, + "--password", + MSSQL_PASSWORD, ) - # Delete it result = cli_runner("connection", "delete", connection_name) assert result.returncode == 0 assert "deleted successfully" in result.stdout - # Verify it's gone result = cli_runner("connection", "list") assert connection_name not in result.stdout diff --git a/tests/test_mysql.py b/tests/test_mysql.py index a59fafe8..ca061508 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -26,21 +24,28 @@ def config(self) -> DatabaseTestConfig: def test_create_mysql_connection(self, mysql_db, cli_runner): """Test creating a MySQL connection via CLI.""" - from .conftest import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD + from .conftest import MYSQL_HOST, MYSQL_PASSWORD, MYSQL_PORT, MYSQL_USER connection_name = "test_create_mysql" try: # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "mysql", - "--server", MYSQL_HOST, - "--port", str(MYSQL_PORT), - "--database", mysql_db, - "--username", MYSQL_USER, - "--password", MYSQL_PASSWORD, + "connections", + "add", + "mysql", + "--name", + connection_name, + "--server", + MYSQL_HOST, + "--port", + str(MYSQL_PORT), + "--database", + mysql_db, + "--username", + MYSQL_USER, + "--password", + MYSQL_PASSWORD, ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -56,20 +61,27 @@ def test_create_mysql_connection(self, mysql_db, cli_runner): def test_delete_mysql_connection(self, mysql_db, cli_runner): """Test deleting a MySQL connection.""" - from .conftest import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD + from .conftest import MYSQL_HOST, MYSQL_PASSWORD, MYSQL_PORT, MYSQL_USER connection_name = "test_delete_mysql" # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "mysql", - "--server", MYSQL_HOST, - "--port", str(MYSQL_PORT), - "--database", mysql_db, - "--username", MYSQL_USER, - "--password", MYSQL_PASSWORD, + "connections", + "add", + "mysql", + "--name", + connection_name, + "--server", + MYSQL_HOST, + "--port", + str(MYSQL_PORT), + "--database", + mysql_db, + "--username", + MYSQL_USER, + "--password", + MYSQL_PASSWORD, ) # Delete it diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 8d9c8279..c0e021ba 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -2,35 +2,51 @@ from __future__ import annotations -import json +from .test_database_base import BaseDatabaseTests, DatabaseTestConfig -import pytest - -class TestOracleIntegration: +class TestOracleIntegration(BaseDatabaseTests): """Integration tests for Oracle database operations via CLI. These tests require a running Oracle instance (via Docker). Tests are skipped if Oracle is not available. """ + @property + def config(self) -> DatabaseTestConfig: + return DatabaseTestConfig( + db_type="oracle", + display_name="Oracle", + connection_fixture="oracle_connection", + db_fixture="oracle_db", + create_connection_args=lambda: [], # Uses fixtures + uses_limit=False, # Oracle uses FETCH FIRST instead of LIMIT + ) + def test_create_oracle_connection(self, oracle_db, cli_runner): """Test creating an Oracle connection via CLI.""" - from .conftest import ORACLE_HOST, ORACLE_PORT, ORACLE_USER, ORACLE_PASSWORD + from .conftest import ORACLE_HOST, ORACLE_PASSWORD, ORACLE_PORT, ORACLE_USER connection_name = "test_create_oracle" try: # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "oracle", - "--server", ORACLE_HOST, - "--port", str(ORACLE_PORT), - "--database", oracle_db, - "--username", ORACLE_USER, - "--password", ORACLE_PASSWORD, + "connections", + "add", + "oracle", + "--name", + connection_name, + "--server", + ORACLE_HOST, + "--port", + str(ORACLE_PORT), + "--database", + oracle_db, + "--username", + ORACLE_USER, + "--password", + ORACLE_PASSWORD, ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -44,136 +60,43 @@ def test_create_oracle_connection(self, oracle_db, cli_runner): # Cleanup cli_runner("connection", "delete", connection_name, check=False) - def test_list_connections_shows_oracle(self, oracle_connection, cli_runner): - """Test that connection list shows Oracle connections correctly.""" - result = cli_runner("connection", "list") - assert result.returncode == 0 - assert oracle_connection in result.stdout - assert "Oracle" in result.stdout - - def test_query_oracle_select(self, oracle_connection, cli_runner): - """Test executing SELECT query on Oracle.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT * FROM test_users ORDER BY id", - ) - assert result.returncode == 0 - assert "Alice" in result.stdout - assert "Bob" in result.stdout - assert "Charlie" in result.stdout - assert "3 row(s) returned" in result.stdout - - def test_query_oracle_with_where(self, oracle_connection, cli_runner): - """Test executing SELECT with WHERE clause on Oracle.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT name, email FROM test_users WHERE id = 1", - ) - assert result.returncode == 0 - assert "Alice" in result.stdout - assert "alice@example.com" in result.stdout - assert "1 row(s) returned" in result.stdout - def test_query_oracle_fetch_first(self, oracle_connection, cli_runner): - """Test Oracle FETCH FIRST clause.""" + """Test Oracle FETCH FIRST clause (Oracle's equivalent of LIMIT).""" result = cli_runner( "query", - "-c", oracle_connection, - "-q", "SELECT * FROM test_users ORDER BY id FETCH FIRST 2 ROWS ONLY", + "-c", + oracle_connection, + "-q", + "SELECT * FROM test_users ORDER BY id FETCH FIRST 2 ROWS ONLY", ) assert result.returncode == 0 assert "Alice" in result.stdout assert "Bob" in result.stdout assert "2 row(s) returned" in result.stdout - def test_query_oracle_json_format(self, oracle_connection, cli_runner): - """Test query output in JSON format.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT id, name FROM test_users ORDER BY id FETCH FIRST 2 ROWS ONLY", - "--format", "json", - ) - assert result.returncode == 0 - - # Parse JSON output (row count message goes to stderr, not stdout) - data = json.loads(result.stdout) - - assert len(data) == 2 - # Oracle returns column names in uppercase - assert data[0].get("name") == "Alice" or data[0].get("NAME") == "Alice" - assert data[1].get("name") == "Bob" or data[1].get("NAME") == "Bob" - - def test_query_oracle_csv_format(self, oracle_connection, cli_runner): - """Test query output in CSV format.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT id, name FROM test_users ORDER BY id FETCH FIRST 2 ROWS ONLY", - "--format", "csv", - ) - assert result.returncode == 0 - # Oracle may return uppercase column names - assert "id,name" in result.stdout.lower() or "ID,NAME" in result.stdout - assert "Alice" in result.stdout - assert "Bob" in result.stdout - - def test_query_oracle_view(self, oracle_connection, cli_runner): - """Test querying a view on Oracle.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT * FROM test_user_emails ORDER BY id", - ) - assert result.returncode == 0 - assert "Alice" in result.stdout - assert "3 row(s) returned" in result.stdout - - def test_query_oracle_aggregate(self, oracle_connection, cli_runner): - """Test aggregate query on Oracle.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT COUNT(*) as user_count FROM test_users", - ) - assert result.returncode == 0 - assert "3" in result.stdout - - def test_query_oracle_insert(self, oracle_connection, cli_runner): - """Test INSERT statement on Oracle.""" - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "INSERT INTO test_users (id, name, email) VALUES (4, 'David', 'david@example.com')", - ) - assert result.returncode == 0 - - # Verify the insert - result = cli_runner( - "query", - "-c", oracle_connection, - "-q", "SELECT * FROM test_users WHERE id = 4", - ) - assert "David" in result.stdout - def test_delete_oracle_connection(self, oracle_db, cli_runner): """Test deleting an Oracle connection.""" - from .conftest import ORACLE_HOST, ORACLE_PORT, ORACLE_USER, ORACLE_PASSWORD + from .conftest import ORACLE_HOST, ORACLE_PASSWORD, ORACLE_PORT, ORACLE_USER connection_name = "test_delete_oracle" # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "oracle", - "--server", ORACLE_HOST, - "--port", str(ORACLE_PORT), - "--database", oracle_db, - "--username", ORACLE_USER, - "--password", ORACLE_PASSWORD, + "connections", + "add", + "oracle", + "--name", + connection_name, + "--server", + ORACLE_HOST, + "--port", + str(ORACLE_PORT), + "--database", + oracle_db, + "--username", + ORACLE_USER, + "--password", + ORACLE_PASSWORD, ) # Delete it diff --git a/tests/test_password_prompts.py b/tests/test_password_prompts.py new file mode 100644 index 00000000..6a427832 --- /dev/null +++ b/tests/test_password_prompts.py @@ -0,0 +1,444 @@ +"""Tests for password prompt functionality.""" + +from __future__ import annotations + +from dataclasses import replace +from unittest.mock import MagicMock, patch + +import pytest + +from sqlit.commands import _prompt_for_password +from sqlit.config import ConnectionConfig +from sqlit.ui.mixins.connection import _needs_db_password, _needs_ssh_password + + +class TestNeedsDbPassword: + """Test _needs_db_password helper function.""" + + def test_file_based_database_does_not_need_password(self) -> None: + """SQLite and DuckDB don't need passwords.""" + sqlite_config = ConnectionConfig( + name="test", + db_type="sqlite", + file_path="/tmp/test.db", + password="", + ) + assert not _needs_db_password(sqlite_config) + + duckdb_config = ConnectionConfig( + name="test", + db_type="duckdb", + file_path="/tmp/test.duckdb", + password="", + ) + assert not _needs_db_password(duckdb_config) + + def test_server_database_with_none_password_needs_prompt(self) -> None: + """PostgreSQL/MySQL with None password (not set) needs prompt.""" + postgres_config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password=None, + ) + assert _needs_db_password(postgres_config) + + mysql_config = ConnectionConfig( + name="test", + db_type="mysql", + server="localhost", + username="user", + password=None, + ) + assert _needs_db_password(mysql_config) + + def test_server_database_with_empty_password_no_prompt(self) -> None: + """PostgreSQL/MySQL with empty string password (explicitly empty) doesn't need prompt.""" + postgres_config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password="", # Explicitly empty, valid for some DBs + ) + assert not _needs_db_password(postgres_config) + + mysql_config = ConnectionConfig( + name="test", + db_type="mysql", + server="localhost", + username="user", + password="", + ) + assert not _needs_db_password(mysql_config) + + def test_server_database_with_stored_password_does_not_need_prompt(self) -> None: + """Database with stored password doesn't need prompt.""" + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + username="user", + password="stored_password", + ) + assert not _needs_db_password(config) + + def test_mssql_with_none_password_needs_prompt(self) -> None: + """SQL Server with SQL auth and None password needs prompt.""" + config = ConnectionConfig( + name="test", + db_type="mssql", + server="localhost", + username="sa", + password=None, + auth_type="sql", + ) + assert _needs_db_password(config) + + def test_mssql_windows_auth_with_none_password(self) -> None: + """SQL Server with Windows auth and None password still technically needs prompt. + + Note: In practice Windows auth doesn't use the password field, + but the function returns True because password is None. + """ + config = ConnectionConfig( + name="test", + db_type="mssql", + server="localhost", + password=None, + auth_type="windows", + trusted_connection=True, + ) + assert _needs_db_password(config) + + def test_mssql_windows_auth_with_empty_password_no_prompt(self) -> None: + """SQL Server with Windows auth and empty string password doesn't need prompt.""" + config = ConnectionConfig( + name="test", + db_type="mssql", + server="localhost", + password="", # Explicitly empty + auth_type="windows", + trusted_connection=True, + ) + assert not _needs_db_password(config) + + +class TestNeedsSshPassword: + """Test _needs_ssh_password helper function.""" + + def test_ssh_disabled_does_not_need_password(self) -> None: + """Config without SSH doesn't need SSH password.""" + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=False, + ) + assert not _needs_ssh_password(config) + + def test_ssh_key_auth_does_not_need_password(self) -> None: + """SSH with key auth doesn't need password.""" + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=True, + ssh_auth_type="key", + ssh_key_path="~/.ssh/id_rsa", + ssh_password="", + ) + assert not _needs_ssh_password(config) + + def test_ssh_password_auth_with_none_password_needs_prompt(self) -> None: + """SSH with password auth and None password (not set) needs prompt.""" + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="user", + ssh_password=None, + ) + assert _needs_ssh_password(config) + + def test_ssh_password_auth_with_empty_password_no_prompt(self) -> None: + """SSH with password auth and empty string password (explicitly empty) no prompt.""" + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="user", + ssh_password="", # Explicitly empty + ) + assert not _needs_ssh_password(config) + + def test_ssh_password_auth_with_stored_password_does_not_need_prompt(self) -> None: + """SSH with stored password doesn't need prompt.""" + config = ConnectionConfig( + name="test", + db_type="postgresql", + server="localhost", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="user", + ssh_password="stored_password", + ) + assert not _needs_ssh_password(config) + + +class TestCliPromptForPassword: + """Test CLI _prompt_for_password function.""" + + def test_file_based_no_prompt(self) -> None: + """File-based databases don't trigger password prompt.""" + config = ConnectionConfig( + name="test", + db_type="sqlite", + file_path="/tmp/test.db", + ) + + with patch("sqlit.commands.getpass.getpass") as mock_getpass: + result = _prompt_for_password(config) + mock_getpass.assert_not_called() + assert result == config + + @patch("sqlit.commands.getpass.getpass", return_value="test_password") + def test_database_password_prompt(self, mock_getpass: MagicMock) -> None: + """None database password triggers getpass prompt.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, + ) + + result = _prompt_for_password(config) + + mock_getpass.assert_called_once_with("Password for 'mydb': ") + assert result.password == "test_password" + assert result.name == "mydb" + assert result.server == "localhost" + + @patch("sqlit.commands.getpass.getpass") + def test_empty_password_no_prompt(self, mock_getpass: MagicMock) -> None: + """Empty string password (explicitly set) does not trigger prompt.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password="", # Explicitly empty + ) + + result = _prompt_for_password(config) + + mock_getpass.assert_not_called() + assert result == config + assert result.password == "" + + @patch("sqlit.commands.getpass.getpass") + def test_stored_password_no_prompt(self, mock_getpass: MagicMock) -> None: + """Stored database password doesn't trigger prompt.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password="stored_password", + ) + + result = _prompt_for_password(config) + + mock_getpass.assert_not_called() + assert result == config + assert result.password == "stored_password" + + @patch("sqlit.commands.getpass.getpass") + def test_ssh_password_prompt(self, mock_getpass: MagicMock) -> None: + """None SSH password triggers getpass prompt.""" + mock_getpass.side_effect = ["ssh_pass", "db_pass"] + + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="sshuser", + ssh_password=None, + ) + + result = _prompt_for_password(config) + + assert mock_getpass.call_count == 2 + mock_getpass.assert_any_call("SSH password for 'mydb': ") + mock_getpass.assert_any_call("Password for 'mydb': ") + assert result.ssh_password == "ssh_pass" + assert result.password == "db_pass" + + @patch("sqlit.commands.getpass.getpass", return_value="ssh_pass") + def test_ssh_password_only(self, mock_getpass: MagicMock) -> None: + """SSH password prompt without database password.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password="stored_db_password", + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="sshuser", + ssh_password=None, + ) + + result = _prompt_for_password(config) + + mock_getpass.assert_called_once_with("SSH password for 'mydb': ") + assert result.ssh_password == "ssh_pass" + assert result.password == "stored_db_password" + + @patch("sqlit.commands.getpass.getpass", return_value="") + def test_user_enters_empty_password(self, mock_getpass: MagicMock) -> None: + """User can enter empty password (just press Enter) when prompted.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, # None triggers prompt + ) + + result = _prompt_for_password(config) + + mock_getpass.assert_called_once() + assert result.password == "" + + def test_original_config_not_modified(self) -> None: + """Original config object is not modified.""" + original = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, + ) + + with patch("sqlit.commands.getpass.getpass", return_value="new_password"): + result = _prompt_for_password(original) + + # Original should still have None password + assert original.password is None + # Result should have the new password + assert result.password == "new_password" + # They should be different objects + assert result is not original + + +class TestPasswordPromptIntegration: + """Integration tests for the full password prompt flow.""" + + @patch("sqlit.commands.getpass.getpass", return_value="test123") + def test_cli_query_with_none_password(self, mock_getpass: MagicMock) -> None: + """CLI query command prompts for password when config has None password.""" + from sqlit.commands import cmd_query + from sqlit.config import save_connections + + # Create a test connection with None password (not set) + config = ConnectionConfig( + name="test_connection", + db_type="postgresql", + server="localhost", + port="5432", + database="testdb", + username="testuser", + password=None, # None = not set, will prompt + ) + + # Save it + save_connections([config]) + + # Mock arguments + args = MagicMock() + args.connection = "test_connection" + args.database = None + args.query = "SELECT 1" + args.file = None + args.format = "table" + args.limit = 1000 + + # Mock the session factory to avoid actual connection + def mock_session_factory(config): + raise Exception("Connection test - should have prompted for password") + + # This will fail at connection but should have prompted for password + try: + cmd_query(args, session_factory=mock_session_factory) + except Exception: + pass + + # Verify getpass was called + mock_getpass.assert_called_once_with("Password for 'test_connection': ") + + def test_config_with_both_passwords_none(self) -> None: + """Config with both DB and SSH passwords None needs both prompts.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password=None, + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="sshuser", + ssh_password=None, + ) + + assert _needs_db_password(config) + assert _needs_ssh_password(config) + + with patch("sqlit.commands.getpass.getpass") as mock_getpass: + mock_getpass.side_effect = ["ssh_password", "db_password"] + result = _prompt_for_password(config) + + assert result.ssh_password == "ssh_password" + assert result.password == "db_password" + assert mock_getpass.call_count == 2 + + def test_config_with_both_passwords_empty_no_prompt(self) -> None: + """Config with both DB and SSH passwords empty (explicit) doesn't need prompts.""" + config = ConnectionConfig( + name="mydb", + db_type="postgresql", + server="localhost", + username="user", + password="", # Explicitly empty + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="sshuser", + ssh_password="", # Explicitly empty + ) + + assert not _needs_db_password(config) + assert not _needs_ssh_password(config) + + with patch("sqlit.commands.getpass.getpass") as mock_getpass: + result = _prompt_for_password(config) + + mock_getpass.assert_not_called() + assert result.password == "" + assert result.ssh_password == "" diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index ba22257b..3caba62a 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -26,21 +24,33 @@ def config(self) -> DatabaseTestConfig: def test_create_postgres_connection(self, postgres_db, cli_runner): """Test creating a PostgreSQL connection via CLI.""" - from .conftest import POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD + from .conftest import ( + POSTGRES_HOST, + POSTGRES_PASSWORD, + POSTGRES_PORT, + POSTGRES_USER, + ) connection_name = "test_create_postgres" try: # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "postgresql", - "--server", POSTGRES_HOST, - "--port", str(POSTGRES_PORT), - "--database", postgres_db, - "--username", POSTGRES_USER, - "--password", POSTGRES_PASSWORD, + "connections", + "add", + "postgresql", + "--name", + connection_name, + "--server", + POSTGRES_HOST, + "--port", + str(POSTGRES_PORT), + "--database", + postgres_db, + "--username", + POSTGRES_USER, + "--password", + POSTGRES_PASSWORD, ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -56,20 +66,32 @@ def test_create_postgres_connection(self, postgres_db, cli_runner): def test_delete_postgres_connection(self, postgres_db, cli_runner): """Test deleting a PostgreSQL connection.""" - from .conftest import POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD + from .conftest import ( + POSTGRES_HOST, + POSTGRES_PASSWORD, + POSTGRES_PORT, + POSTGRES_USER, + ) connection_name = "test_delete_postgres" # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "postgresql", - "--server", POSTGRES_HOST, - "--port", str(POSTGRES_PORT), - "--database", postgres_db, - "--username", POSTGRES_USER, - "--password", POSTGRES_PASSWORD, + "connections", + "add", + "postgresql", + "--name", + connection_name, + "--server", + POSTGRES_HOST, + "--port", + str(POSTGRES_PORT), + "--database", + postgres_db, + "--username", + POSTGRES_USER, + "--password", + POSTGRES_PASSWORD, ) # Delete it diff --git a/tests/test_schema_capabilities.py b/tests/test_schema_capabilities.py index d93aa941..f65de4c2 100644 --- a/tests/test_schema_capabilities.py +++ b/tests/test_schema_capabilities.py @@ -3,6 +3,7 @@ from sqlit.db.schema import ( get_default_port, get_display_name, + get_supported_db_types, has_advanced_auth, is_file_based, supports_ssh, @@ -32,3 +33,29 @@ def test_unknown_type_returns_fallback(self): class TestGetDisplayName: def test_unknown_type_returns_input(self): assert get_display_name("nonexistent") == "nonexistent" + + +class TestCatalogConsistency: + def test_provider_schema_ids_match_keys(self): + from sqlit.db.providers import PROVIDERS, get_connection_schema + + for db_type, spec in PROVIDERS.items(): + schema = get_connection_schema(db_type) + assert schema.db_type == db_type + + def test_database_type_enum_matches_schema(self): + from sqlit.config import DatabaseType + + assert {t.value for t in DatabaseType} == set(get_supported_db_types()) + + def test_adapter_factory_matches_schema(self): + from sqlit.db.adapters import get_supported_adapter_db_types + + assert set(get_supported_adapter_db_types()) == set(get_supported_db_types()) + + def test_display_names_match_schema(self): + from sqlit.db.providers import get_connection_schema + + for db_type in get_supported_db_types(): + schema = get_connection_schema(db_type) + assert schema.display_name == get_display_name(db_type) diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 1dcff7df..203f5670 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -2,10 +2,6 @@ from __future__ import annotations -import json - -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -20,8 +16,8 @@ def config(self) -> DatabaseTestConfig: connection_fixture="sqlite_connection", db_fixture="sqlite_db", create_connection_args=lambda db: [ - "--db-type", "sqlite", - "--file-path", str(db), + "--file-path", + str(db), ], ) @@ -32,10 +28,13 @@ def test_create_sqlite_connection(self, sqlite_db, cli_runner): try: # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "sqlite", - "--file-path", str(sqlite_db), + "connections", + "add", + "sqlite", + "--name", + connection_name, + "--file-path", + str(sqlite_db), ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -54,8 +53,10 @@ def test_query_sqlite_join(self, sqlite_connection, cli_runner): # This test verifies that complex queries work result = cli_runner( "query", - "-c", sqlite_connection, - "-q", """ + "-c", + sqlite_connection, + "-q", + """ SELECT u.name, p.name as product, p.price FROM test_users u CROSS JOIN test_products p @@ -70,16 +71,20 @@ def test_query_sqlite_update(self, sqlite_connection, cli_runner): """Test UPDATE statement on SQLite.""" result = cli_runner( "query", - "-c", sqlite_connection, - "-q", "UPDATE test_products SET stock = 200 WHERE id = 1", + "-c", + sqlite_connection, + "-q", + "UPDATE test_products SET stock = 200 WHERE id = 1", ) assert result.returncode == 0 # Verify the update result = cli_runner( "query", - "-c", sqlite_connection, - "-q", "SELECT stock FROM test_products WHERE id = 1", + "-c", + sqlite_connection, + "-q", + "SELECT stock FROM test_products WHERE id = 1", ) assert "200" in result.stdout @@ -89,10 +94,13 @@ def test_delete_sqlite_connection(self, sqlite_db, cli_runner): # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "sqlite", - "--file-path", str(sqlite_db), + "connections", + "add", + "sqlite", + "--name", + connection_name, + "--file-path", + str(sqlite_db), ) # Delete it @@ -108,8 +116,10 @@ def test_query_sqlite_invalid_query(self, sqlite_connection, cli_runner): """Test handling of invalid SQL query.""" result = cli_runner( "query", - "-c", sqlite_connection, - "-q", "SELECT * FROM nonexistent_table", + "-c", + sqlite_connection, + "-q", + "SELECT * FROM nonexistent_table", check=False, ) assert result.returncode != 0 diff --git a/tests/test_ssh.py b/tests/test_ssh.py index 566d928d..dee5018f 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -24,8 +24,9 @@ def slow_down_ssh_tests(self): Note: SSH tests may fail when run together rapidly due to SSH server connection limits. Run individually with: pytest tests/test_ssh.py -k """ - time.sleep(1) # Wait before each test + time.sleep(2) # Wait before each test yield + time.sleep(1) # Wait after each test for connections to fully close def test_create_ssh_connection(self, ssh_postgres_db, cli_runner): """Test creating a PostgreSQL connection with SSH tunnel via CLI.""" @@ -45,20 +46,32 @@ def test_create_ssh_connection(self, ssh_postgres_db, cli_runner): try: # Create connection with SSH tunnel result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "postgresql", - "--server", SSH_REMOTE_DB_HOST, - "--port", str(SSH_REMOTE_DB_PORT), - "--database", ssh_postgres_db, - "--username", POSTGRES_USER, - "--password", POSTGRES_PASSWORD, + "connections", + "add", + "postgresql", + "--name", + connection_name, + "--server", + SSH_REMOTE_DB_HOST, + "--port", + str(SSH_REMOTE_DB_PORT), + "--database", + ssh_postgres_db, + "--username", + POSTGRES_USER, + "--password", + POSTGRES_PASSWORD, "--ssh-enabled", - "--ssh-host", SSH_HOST, - "--ssh-port", str(SSH_PORT), - "--ssh-username", SSH_USER, - "--ssh-auth-type", "password", - "--ssh-password", SSH_PASSWORD, + "--ssh-host", + SSH_HOST, + "--ssh-port", + str(SSH_PORT), + "--ssh-username", + SSH_USER, + "--ssh-auth-type", + "password", + "--ssh-password", + SSH_PASSWORD, ) assert result.returncode == 0 assert "created successfully" in result.stdout @@ -76,8 +89,10 @@ def test_query_via_ssh_tunnel(self, ssh_connection, cli_runner): """Test executing SELECT query through SSH tunnel.""" result = cli_runner( "query", - "-c", ssh_connection, - "-q", "SELECT * FROM test_users ORDER BY id", + "-c", + ssh_connection, + "-q", + "SELECT * FROM test_users ORDER BY id", ) assert result.returncode == 0 assert "Alice" in result.stdout @@ -89,8 +104,10 @@ def test_query_with_where_via_ssh(self, ssh_connection, cli_runner): """Test executing SELECT with WHERE clause through SSH tunnel.""" result = cli_runner( "query", - "-c", ssh_connection, - "-q", "SELECT name, email FROM test_users WHERE id = 1", + "-c", + ssh_connection, + "-q", + "SELECT name, email FROM test_users WHERE id = 1", ) assert result.returncode == 0 assert "Alice" in result.stdout @@ -101,9 +118,12 @@ def test_query_json_format_via_ssh(self, ssh_connection, cli_runner): """Test query output in JSON format through SSH tunnel.""" result = cli_runner( "query", - "-c", ssh_connection, - "-q", "SELECT id, name FROM test_users ORDER BY id LIMIT 2", - "--format", "json", + "-c", + ssh_connection, + "-q", + "SELECT id, name FROM test_users ORDER BY id LIMIT 2", + "--format", + "json", ) assert result.returncode == 0 @@ -118,9 +138,12 @@ def test_query_csv_format_via_ssh(self, ssh_connection, cli_runner): """Test query output in CSV format through SSH tunnel.""" result = cli_runner( "query", - "-c", ssh_connection, - "-q", "SELECT id, name FROM test_users ORDER BY id LIMIT 2", - "--format", "csv", + "-c", + ssh_connection, + "-q", + "SELECT id, name FROM test_users ORDER BY id LIMIT 2", + "--format", + "csv", ) assert result.returncode == 0 assert "id,name" in result.stdout @@ -131,8 +154,10 @@ def test_query_aggregate_via_ssh(self, ssh_connection, cli_runner): """Test aggregate query through SSH tunnel.""" result = cli_runner( "query", - "-c", ssh_connection, - "-q", "SELECT COUNT(*) as user_count FROM test_users", + "-c", + ssh_connection, + "-q", + "SELECT COUNT(*) as user_count FROM test_users", ) assert result.returncode == 0 assert "3" in result.stdout @@ -141,16 +166,20 @@ def test_insert_via_ssh(self, ssh_connection, cli_runner): """Test INSERT statement through SSH tunnel.""" result = cli_runner( "query", - "-c", ssh_connection, - "-q", "INSERT INTO test_users (id, name, email) VALUES (4, 'David', 'david@example.com')", + "-c", + ssh_connection, + "-q", + "INSERT INTO test_users (id, name, email) VALUES (4, 'David', 'david@example.com')", ) assert result.returncode == 0 # Verify the insert result = cli_runner( "query", - "-c", ssh_connection, - "-q", "SELECT * FROM test_users WHERE id = 4", + "-c", + ssh_connection, + "-q", + "SELECT * FROM test_users WHERE id = 4", ) assert "David" in result.stdout @@ -171,20 +200,32 @@ def test_delete_ssh_connection(self, ssh_postgres_db, cli_runner): # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "postgresql", - "--server", SSH_REMOTE_DB_HOST, - "--port", str(SSH_REMOTE_DB_PORT), - "--database", ssh_postgres_db, - "--username", POSTGRES_USER, - "--password", POSTGRES_PASSWORD, + "connections", + "add", + "postgresql", + "--name", + connection_name, + "--server", + SSH_REMOTE_DB_HOST, + "--port", + str(SSH_REMOTE_DB_PORT), + "--database", + ssh_postgres_db, + "--username", + POSTGRES_USER, + "--password", + POSTGRES_PASSWORD, "--ssh-enabled", - "--ssh-host", SSH_HOST, - "--ssh-port", str(SSH_PORT), - "--ssh-username", SSH_USER, - "--ssh-auth-type", "password", - "--ssh-password", SSH_PASSWORD, + "--ssh-host", + SSH_HOST, + "--ssh-port", + str(SSH_PORT), + "--ssh-username", + SSH_USER, + "--ssh-auth-type", + "password", + "--ssh-password", + SSH_PASSWORD, ) # Delete it diff --git a/tests/test_turso.py b/tests/test_turso.py index f1c7541e..e585def6 100644 --- a/tests/test_turso.py +++ b/tests/test_turso.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from .test_database_base import BaseDatabaseTestsWithLimit, DatabaseTestConfig @@ -29,32 +27,35 @@ def test_create_turso_connection(self, turso_db, cli_runner): connection_name = "test_create_turso" try: - # Create connection result = cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "turso", - "--server", turso_db, - "--password", "", + "connections", + "add", + "turso", + "--name", + connection_name, + "--server", + turso_db, + "--password", + "", ) assert result.returncode == 0 assert "created successfully" in result.stdout - # Verify it appears in list result = cli_runner("connection", "list") assert connection_name in result.stdout assert "Turso" in result.stdout finally: - # Cleanup cli_runner("connection", "delete", connection_name, check=False) def test_query_turso_join(self, turso_connection, cli_runner): """Test JOIN query on Turso.""" result = cli_runner( "query", - "-c", turso_connection, - "-q", """ + "-c", + turso_connection, + "-q", + """ SELECT u.name, p.name as product, p.price FROM test_users u CROSS JOIN test_products p @@ -69,16 +70,19 @@ def test_query_turso_update(self, turso_connection, cli_runner): """Test UPDATE statement on Turso.""" result = cli_runner( "query", - "-c", turso_connection, - "-q", "UPDATE test_products SET stock = 200 WHERE id = 1", + "-c", + turso_connection, + "-q", + "UPDATE test_products SET stock = 200 WHERE id = 1", ) assert result.returncode == 0 - # Verify the update result = cli_runner( "query", - "-c", turso_connection, - "-q", "SELECT stock FROM test_products WHERE id = 1", + "-c", + turso_connection, + "-q", + "SELECT stock FROM test_products WHERE id = 1", ) assert "200" in result.stdout @@ -86,21 +90,22 @@ def test_delete_turso_connection(self, turso_db, cli_runner): """Test deleting a Turso connection.""" connection_name = "test_delete_turso" - # Create connection first cli_runner( - "connection", "create", - "--name", connection_name, - "--db-type", "turso", - "--server", turso_db, - "--password", "", + "connections", + "add", + "turso", + "--name", + connection_name, + "--server", + turso_db, + "--password", + "", ) - # Delete it result = cli_runner("connection", "delete", connection_name) assert result.returncode == 0 assert "deleted successfully" in result.stdout - # Verify it's gone result = cli_runner("connection", "list") assert connection_name not in result.stdout @@ -108,8 +113,10 @@ def test_query_turso_invalid_query(self, turso_connection, cli_runner): """Test handling of invalid SQL query.""" result = cli_runner( "query", - "-c", turso_connection, - "-q", "SELECT * FROM nonexistent_table", + "-c", + turso_connection, + "-q", + "SELECT * FROM nonexistent_table", check=False, ) assert result.returncode != 0 diff --git a/tests/ui/conftest.py b/tests/ui/conftest.py index c9eb17fb..37da4e93 100644 --- a/tests/ui/conftest.py +++ b/tests/ui/conftest.py @@ -89,10 +89,12 @@ def mock_failing_adapter(): @pytest.fixture def patch_stores(mock_connection_store, mock_settings_store): """Patch all stores with mocks for isolated testing.""" - with patch("sqlit.config.load_connections", mock_connection_store.load_all), \ - patch("sqlit.config.save_connections", mock_connection_store.save_all), \ - patch("sqlit.config.load_settings", mock_settings_store.load_all), \ - patch("sqlit.config.save_settings", mock_settings_store.save_all): + with ( + patch("sqlit.config.load_connections", mock_connection_store.load_all), + patch("sqlit.config.save_connections", mock_connection_store.save_all), + patch("sqlit.config.load_settings", mock_settings_store.load_all), + patch("sqlit.config.save_settings", mock_settings_store.save_all), + ): yield { "connections": mock_connection_store, "settings": mock_settings_store, @@ -102,10 +104,12 @@ def patch_stores(mock_connection_store, mock_settings_store): @pytest.fixture def patch_stores_with_data(mock_connection_store_with_data, mock_settings_store): """Patch stores with sample data.""" - with patch("sqlit.config.load_connections", mock_connection_store_with_data.load_all), \ - patch("sqlit.config.save_connections", mock_connection_store_with_data.save_all), \ - patch("sqlit.config.load_settings", mock_settings_store.load_all), \ - patch("sqlit.config.save_settings", mock_settings_store.save_all): + with ( + patch("sqlit.config.load_connections", mock_connection_store_with_data.load_all), + patch("sqlit.config.save_connections", mock_connection_store_with_data.save_all), + patch("sqlit.config.load_settings", mock_settings_store.load_all), + patch("sqlit.config.save_settings", mock_settings_store.save_all), + ): yield { "connections": mock_connection_store_with_data, "settings": mock_settings_store, diff --git a/tests/ui/explorer/test_markup_escaping.py b/tests/ui/explorer/test_markup_escaping.py index e790a937..98e0655e 100644 --- a/tests/ui/explorer/test_markup_escaping.py +++ b/tests/ui/explorer/test_markup_escaping.py @@ -8,10 +8,8 @@ from dataclasses import dataclass -import pytest from rich.markup import escape as escape_markup -from sqlit.db.adapters.base import ColumnInfo from sqlit.ui.mixins.tree import TreeMixin from sqlit.ui.tree_nodes import SchemaNode, TableNode @@ -25,7 +23,7 @@ class MockColumnInfo: class MockTreeNode: """Mock tree node that tracks added children and their labels.""" - def __init__(self, label: str = "", data: tuple = None, parent: "MockTreeNode | None" = None): + def __init__(self, label: str = "", data: tuple = None, parent: MockTreeNode | None = None): self.label = label self.data = data self.parent = parent @@ -33,13 +31,13 @@ def __init__(self, label: str = "", data: tuple = None, parent: "MockTreeNode | self.allow_expand = False self._labels_added: list[str] = [] - def add(self, label: str) -> "MockTreeNode": + def add(self, label: str) -> MockTreeNode: self._labels_added.append(label) child = MockTreeNode(label, parent=self) self.children.append(child) return child - def add_leaf(self, label: str) -> "MockTreeNode": + def add_leaf(self, label: str) -> MockTreeNode: self._labels_added.append(label) child = MockTreeNode(label, parent=self) self.children.append(child) @@ -246,8 +244,8 @@ def test_procedure_name_with_brackets_escaped(self): parent = MockTreeNode("Procedures", ("folder", "procedures", "db")) items = [ - ("procedure", "sp_get_data[v2]"), - ("procedure", "proc[test]/run"), + ("procedure", "", "sp_get_data[v2]"), + ("procedure", "", "proc[test]/run"), ] mixin._on_folder_loaded(parent, "db", "procedures", items) diff --git a/tests/ui/explorer/test_tree_expansion.py b/tests/ui/explorer/test_tree_expansion.py index 76565614..dc4b66b0 100644 --- a/tests/ui/explorer/test_tree_expansion.py +++ b/tests/ui/explorer/test_tree_expansion.py @@ -4,8 +4,6 @@ from unittest.mock import MagicMock -import pytest - from sqlit.ui.mixins.tree import TreeMixin from sqlit.ui.tree_nodes import FolderNode, LoadingNode, SchemaNode, TableNode @@ -13,7 +11,7 @@ class MockTreeNode: """Mock tree node for testing expansion.""" - def __init__(self, label: str = "", data: tuple = None, parent: "MockTreeNode | None" = None): + def __init__(self, label: str = "", data: tuple = None, parent: MockTreeNode | None = None): self.label = label self.data = data self.parent = parent @@ -21,12 +19,12 @@ def __init__(self, label: str = "", data: tuple = None, parent: "MockTreeNode | self.allow_expand = False self.is_expanded = False - def add(self, label: str) -> "MockTreeNode": + def add(self, label: str) -> MockTreeNode: child = MockTreeNode(label, parent=self) self.children.append(child) return child - def add_leaf(self, label: str) -> "MockTreeNode": + def add_leaf(self, label: str) -> MockTreeNode: return self.add(label) def remove(self): diff --git a/tests/ui/keybindings/conftest.py b/tests/ui/keybindings/conftest.py index fdaf8908..b49c54b1 100644 --- a/tests/ui/keybindings/conftest.py +++ b/tests/ui/keybindings/conftest.py @@ -5,9 +5,9 @@ import pytest from sqlit.keymap import ( + ActionKeyDef, KeymapProvider, LeaderCommandDef, - ActionKeyDef, reset_keymap, ) diff --git a/tests/ui/keybindings/test_contextual.py b/tests/ui/keybindings/test_contextual.py index b1c98435..22c6bc92 100644 --- a/tests/ui/keybindings/test_contextual.py +++ b/tests/ui/keybindings/test_contextual.py @@ -65,10 +65,11 @@ async def test_edit_connection_blocked_when_query_focused(self): mock_connections = MockConnectionStore(connections) mock_settings = MockSettingsStore({"theme": "tokyo-night"}) - with patch("sqlit.app.load_connections", mock_connections.load_all), \ - patch("sqlit.app.load_settings", mock_settings.load_all), \ - patch("sqlit.app.save_settings", mock_settings.save_all): - + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): app = SSMSTUI() async with app.run_test(size=(100, 35)) as pilot: diff --git a/tests/ui/keybindings/test_dialogs.py b/tests/ui/keybindings/test_dialogs.py index 43581578..e4aea1a1 100644 --- a/tests/ui/keybindings/test_dialogs.py +++ b/tests/ui/keybindings/test_dialogs.py @@ -31,9 +31,7 @@ async def test_normal_actions_blocked_when_error_dialog_open(self): await pilot.pause() # Verify dialog is shown - has_error = any( - isinstance(screen, ErrorScreen) for screen in app.screen_stack - ) + has_error = any(isinstance(screen, ErrorScreen) for screen in app.screen_stack) assert has_error # Try to focus query - should be blocked by modal @@ -91,7 +89,6 @@ async def test_error_dialog_copy_action(self): copied_text = {"value": None} async with app.run_test(size=(100, 35)) as pilot: - original_copy = app.copy_to_clipboard def mock_copy(text): copied_text["value"] = text @@ -162,9 +159,12 @@ def capture_result(result): @pytest.mark.asyncio async def test_confirm_dialog_escape_cancels(self): - """Confirm dialog escape key should cancel and close.""" + """Confirm dialog escape key should cancel and close. + + Note: Escape returns None (cancelled) vs False (explicit No). + """ app = SSMSTUI() - result_holder = {"result": None} + result_holder = {"result": "not_called"} def capture_result(result): result_holder["result"] = result @@ -178,7 +178,7 @@ def capture_result(result): await pilot.pause() assert not any(isinstance(screen, ConfirmScreen) for screen in app.screen_stack) - assert result_holder["result"] is False + assert result_holder["result"] is None # Escape returns None (cancelled) @pytest.mark.asyncio async def test_help_dialog_blocks_normal_actions(self): @@ -249,9 +249,7 @@ async def test_leader_commands_blocked_when_dialog_open(self): await pilot.pause() # Only the error dialog should be open, not theme picker - error_count = sum( - 1 for screen in app.screen_stack if isinstance(screen, ErrorScreen) - ) + error_count = sum(1 for screen in app.screen_stack if isinstance(screen, ErrorScreen)) assert error_count == 1 # Screen stack should just have main screen + error dialog assert len(app.screen_stack) == 2 diff --git a/tests/ui/keybindings/test_keymap_provider.py b/tests/ui/keybindings/test_keymap_provider.py index 0a6335c2..2304b284 100644 --- a/tests/ui/keybindings/test_keymap_provider.py +++ b/tests/ui/keybindings/test_keymap_provider.py @@ -3,11 +3,11 @@ from __future__ import annotations from sqlit.keymap import ( - LeaderCommandDef, ActionKeyDef, + LeaderCommandDef, get_keymap, - set_keymap, reset_keymap, + set_keymap, ) from sqlit.state_machine import get_leader_commands diff --git a/tests/ui/keybindings/test_leader.py b/tests/ui/keybindings/test_leader.py index cc4cfc8a..f0e2b5c1 100644 --- a/tests/ui/keybindings/test_leader.py +++ b/tests/ui/keybindings/test_leader.py @@ -27,9 +27,7 @@ async def test_leader_show_connection_picker(self): await pilot.press(leader_key, connection_key) await pilot.pause() - has_picker = any( - isinstance(screen, ConnectionPickerScreen) for screen in app.screen_stack - ) + has_picker = any(isinstance(screen, ConnectionPickerScreen) for screen in app.screen_stack) assert has_picker @pytest.mark.asyncio @@ -47,9 +45,7 @@ async def test_leader_show_help(self): await pilot.press(leader_key, help_key) await pilot.pause() - has_help = any( - isinstance(screen, HelpScreen) for screen in app.screen_stack - ) + has_help = any(isinstance(screen, HelpScreen) for screen in app.screen_stack) assert has_help @pytest.mark.asyncio @@ -65,7 +61,5 @@ async def test_leader_commands_blocked_without_leader_key(self): await pilot.press(connection_key) await pilot.pause() - has_picker = any( - isinstance(screen, ConnectionPickerScreen) for screen in app.screen_stack - ) + has_picker = any(isinstance(screen, ConnectionPickerScreen) for screen in app.screen_stack) assert not has_picker diff --git a/tests/ui/keybindings/test_state_machine.py b/tests/ui/keybindings/test_state_machine.py index 70569c8b..b4f58dbc 100644 --- a/tests/ui/keybindings/test_state_machine.py +++ b/tests/ui/keybindings/test_state_machine.py @@ -2,44 +2,89 @@ from __future__ import annotations -from sqlit.state_machine import UIStateMachine +from sqlit.state_machine import QueryExecutingState, UIStateMachine from sqlit.ui.tree_nodes import ConnectionNode from sqlit.widgets import VimMode -class TestStateMachineActionValidation: - """Test that the state machine correctly validates actions.""" +class MockConfig: + name = "test-conn" + + +class MockNode: + def __init__(self, data=None): + self.data = data + + +class MockWidget: + has_focus = False + cursor_node = None + root = MockNode() - def test_edit_connection_only_allowed_on_connection_node(self): - """edit_connection should only be allowed when tree is on a connection.""" - class MockConfig: - name = "test-conn" +class MockApp: + def __init__(self): + self._leader_pending = False + self._query_executing = False + self._autocomplete_visible = False + self.current_connection = None + self.current_config = None + self.screen_stack = [None] + self._vim_mode = VimMode.NORMAL - class MockNode: - def __init__(self, data=None): - self.data = data + object_tree = MockWidget() + query_input = MockWidget() + results_table = MockWidget() - class MockWidget: - has_focus = False - cursor_node = None - root = MockNode() + @property + def vim_mode(self): + return self._vim_mode - class MockApp: - def __init__(self): - self._leader_pending = False - self.current_connection = None - self.current_config = None - self.screen_stack = [None] - object_tree = MockWidget() - query_input = MockWidget() - results_table = MockWidget() +class TestQueryExecutingState: + """Test that cancel_operation is only allowed when query is executing.""" + + def test_cancel_not_allowed_when_idle(self): + """cancel_operation should be blocked when no query is running.""" + sm = UIStateMachine() + app = MockApp() + app._query_executing = False + + assert sm.check_action(app, "cancel_operation") is False + + def test_cancel_allowed_when_query_executing(self): + """cancel_operation should be allowed when a query is running.""" + sm = UIStateMachine() + app = MockApp() + app._query_executing = True - @property - def vim_mode(self): - return VimMode.NORMAL + assert sm.check_action(app, "cancel_operation") is True + def test_active_state_is_query_executing_when_running(self): + """Active state should be QueryExecutingState when query is running.""" + sm = UIStateMachine() + app = MockApp() + app._query_executing = True + + state = sm.get_active_state(app) + assert isinstance(state, QueryExecutingState) + + def test_footer_shows_cancel_when_executing(self): + """Footer should show cancel binding when query is executing.""" + sm = UIStateMachine() + app = MockApp() + app._query_executing = True + + left, right = sm.get_display_bindings(app) + actions = [b.action for b in left] + assert "cancel_operation" in actions + + +class TestStateMachineActionValidation: + """Test that the state machine correctly validates actions.""" + + def test_edit_connection_only_allowed_on_connection_node(self): + """edit_connection should only be allowed when tree is on a connection.""" sm = UIStateMachine() app = MockApp() diff --git a/tests/ui/mocks.py b/tests/ui/mocks.py index bf123bd7..6f8fe5c2 100644 --- a/tests/ui/mocks.py +++ b/tests/ui/mocks.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from typing import Any -from sqlit.config import ConnectionConfig, DatabaseType +from sqlit.config import ConnectionConfig from sqlit.db.adapters.base import ColumnInfo, DatabaseAdapter @@ -20,7 +20,7 @@ def __init__(self, connections: list[ConnectionConfig] | None = None): self.save_called = False self.last_saved: list[ConnectionConfig] = [] - def load_all(self) -> list[ConnectionConfig]: + def load_all(self, load_credentials: bool = True) -> list[ConnectionConfig]: return self.connections.copy() def save_all(self, connections: list[ConnectionConfig]) -> None: @@ -147,7 +147,7 @@ def supports_multiple_databases(self) -> bool: def supports_stored_procedures(self) -> bool: return False - def connect(self, config: "ConnectionConfig") -> Any: + def connect(self, config: ConnectionConfig) -> Any: if self._should_fail_connect: raise ConnectionError(self._connect_error) self._connected = True @@ -173,9 +173,7 @@ def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: def quote_identifier(self, name: str) -> str: return f'"{name}"' - def build_select_query( - self, table: str, limit: int, database: str | None = None - ) -> str: + def build_select_query(self, table: str, limit: int, database: str | None = None) -> str: return f'SELECT * FROM "{table}" LIMIT {limit}' def execute_query(self, conn: Any, query: str) -> tuple[list[str], list[tuple]]: diff --git a/tests/ui/snapshots/01_default_sqlite_screen.svg b/tests/ui/snapshots/01_default_sqlite_screen.svg index c961da05..d1cea7bc 100644 --- a/tests/ui/snapshots/01_default_sqlite_screen.svg +++ b/tests/ui/snapshots/01_default_sqlite_screen.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/02_edit_mode_postgresql.svg b/tests/ui/snapshots/02_edit_mode_postgresql.svg index a051d923..fae54699 100644 --- a/tests/ui/snapshots/02_edit_mode_postgresql.svg +++ b/tests/ui/snapshots/02_edit_mode_postgresql.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/03_before_cancel.svg b/tests/ui/snapshots/03_before_cancel.svg index c961da05..d1cea7bc 100644 --- a/tests/ui/snapshots/03_before_cancel.svg +++ b/tests/ui/snapshots/03_before_cancel.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/04_empty_name_after_save_attempt.svg b/tests/ui/snapshots/04_empty_name_after_save_attempt.svg index cce2cc15..d7fd2bae 100644 --- a/tests/ui/snapshots/04_empty_name_after_save_attempt.svg +++ b/tests/ui/snapshots/04_empty_name_after_save_attempt.svg @@ -160,7 +160,7 @@ - + diff --git a/tests/ui/snapshots/04_empty_name_before_save.svg b/tests/ui/snapshots/04_empty_name_before_save.svg index c961da05..d1cea7bc 100644 --- a/tests/ui/snapshots/04_empty_name_before_save.svg +++ b/tests/ui/snapshots/04_empty_name_before_save.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/05_after_switch_to_postgresql.svg b/tests/ui/snapshots/05_after_switch_to_postgresql.svg index 41cc1ffa..58df3ebf 100644 --- a/tests/ui/snapshots/05_after_switch_to_postgresql.svg +++ b/tests/ui/snapshots/05_after_switch_to_postgresql.svg @@ -160,7 +160,7 @@ - + diff --git a/tests/ui/snapshots/05_before_switch_to_postgresql.svg b/tests/ui/snapshots/05_before_switch_to_postgresql.svg index c961da05..d1cea7bc 100644 --- a/tests/ui/snapshots/05_before_switch_to_postgresql.svg +++ b/tests/ui/snapshots/05_before_switch_to_postgresql.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/06_sqlite_file_path_field.svg b/tests/ui/snapshots/06_sqlite_file_path_field.svg index c961da05..d1cea7bc 100644 --- a/tests/ui/snapshots/06_sqlite_file_path_field.svg +++ b/tests/ui/snapshots/06_sqlite_file_path_field.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/07_sqlite_filled_form.svg b/tests/ui/snapshots/07_sqlite_filled_form.svg index ff1179c6..ddc6aadb 100644 --- a/tests/ui/snapshots/07_sqlite_filled_form.svg +++ b/tests/ui/snapshots/07_sqlite_filled_form.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/08_postgresql_after_switch.svg b/tests/ui/snapshots/08_postgresql_after_switch.svg index 4ecc40ac..82b1ace1 100644 --- a/tests/ui/snapshots/08_postgresql_after_switch.svg +++ b/tests/ui/snapshots/08_postgresql_after_switch.svg @@ -161,7 +161,7 @@ - + diff --git a/tests/ui/snapshots/08_postgresql_filled_form.svg b/tests/ui/snapshots/08_postgresql_filled_form.svg index 0c84ca87..0ac4db1d 100644 --- a/tests/ui/snapshots/08_postgresql_filled_form.svg +++ b/tests/ui/snapshots/08_postgresql_filled_form.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/09_duplicate_name_validation.svg b/tests/ui/snapshots/09_duplicate_name_validation.svg index 4c19d01a..989791d5 100644 --- a/tests/ui/snapshots/09_duplicate_name_validation.svg +++ b/tests/ui/snapshots/09_duplicate_name_validation.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/10_after_test_connection.svg b/tests/ui/snapshots/10_after_test_connection.svg index 8c292a83..dc6bb2fd 100644 --- a/tests/ui/snapshots/10_after_test_connection.svg +++ b/tests/ui/snapshots/10_after_test_connection.svg @@ -160,7 +160,7 @@ - + diff --git a/tests/ui/snapshots/10_before_test_connection.svg b/tests/ui/snapshots/10_before_test_connection.svg index c6f63bc5..7564e505 100644 --- a/tests/ui/snapshots/10_before_test_connection.svg +++ b/tests/ui/snapshots/10_before_test_connection.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/11_navigation_after_tab.svg b/tests/ui/snapshots/11_navigation_after_tab.svg index 881699b0..18a8e6c0 100644 --- a/tests/ui/snapshots/11_navigation_after_tab.svg +++ b/tests/ui/snapshots/11_navigation_after_tab.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/11_navigation_initial_focus.svg b/tests/ui/snapshots/11_navigation_initial_focus.svg index c961da05..d1cea7bc 100644 --- a/tests/ui/snapshots/11_navigation_initial_focus.svg +++ b/tests/ui/snapshots/11_navigation_initial_focus.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/12_navigation_after_shift_tab.svg b/tests/ui/snapshots/12_navigation_after_shift_tab.svg index 881699b0..18a8e6c0 100644 --- a/tests/ui/snapshots/12_navigation_after_shift_tab.svg +++ b/tests/ui/snapshots/12_navigation_after_shift_tab.svg @@ -159,7 +159,7 @@ - + diff --git a/tests/ui/snapshots/12_navigation_after_two_tabs.svg b/tests/ui/snapshots/12_navigation_after_two_tabs.svg index 70182c48..2f292d35 100644 --- a/tests/ui/snapshots/12_navigation_after_two_tabs.svg +++ b/tests/ui/snapshots/12_navigation_after_two_tabs.svg @@ -160,7 +160,7 @@ - + diff --git a/tests/ui/snapshots/13_ssh_fields_hidden.svg b/tests/ui/snapshots/13_ssh_fields_hidden.svg index 91461bce..7c92dadf 100644 --- a/tests/ui/snapshots/13_ssh_fields_hidden.svg +++ b/tests/ui/snapshots/13_ssh_fields_hidden.svg @@ -160,7 +160,7 @@ - + diff --git a/tests/ui/test_connect_action.py b/tests/ui/test_connect_action.py index 253bd0b9..7c32c1ac 100644 --- a/tests/ui/test_connect_action.py +++ b/tests/ui/test_connect_action.py @@ -24,20 +24,18 @@ async def test_connection_picker_select_highlights_in_tree(self): mock_connections = MockConnectionStore(connections) mock_settings = MockSettingsStore({"theme": "tokyo-night"}) - with patch("sqlit.app.load_connections", mock_connections.load_all), \ - patch("sqlit.app.load_settings", mock_settings.load_all), \ - patch("sqlit.app.save_settings", mock_settings.save_all): - + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): app = SSMSTUI() async with app.run_test(size=(100, 35)) as pilot: app.action_show_connection_picker() await pilot.pause() - picker = next( - (s for s in app.screen_stack if isinstance(s, ConnectionPickerScreen)), - None - ) + picker = next((s for s in app.screen_stack if isinstance(s, ConnectionPickerScreen)), None) assert picker is not None with patch.object(app, "connect_to_server"): @@ -59,20 +57,18 @@ async def test_connection_picker_fuzzy_search_selects_correct_connection(self): mock_connections = MockConnectionStore(connections) mock_settings = MockSettingsStore({"theme": "tokyo-night"}) - with patch("sqlit.app.load_connections", mock_connections.load_all), \ - patch("sqlit.app.load_settings", mock_settings.load_all), \ - patch("sqlit.app.save_settings", mock_settings.save_all): - + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): app = SSMSTUI() async with app.run_test(size=(100, 35)) as pilot: app.action_show_connection_picker() await pilot.pause() - picker = next( - (s for s in app.screen_stack if isinstance(s, ConnectionPickerScreen)), - None - ) + picker = next((s for s in app.screen_stack if isinstance(s, ConnectionPickerScreen)), None) assert picker is not None picker.search_text = "ora" diff --git a/tests/ui/test_connection_screen.py b/tests/ui/test_connection_screen.py index a7c7805e..fabd6955 100644 --- a/tests/ui/test_connection_screen.py +++ b/tests/ui/test_connection_screen.py @@ -11,19 +11,11 @@ def _get_providers_with_advanced_tab() -> set[str]: - return { - db_type - for db_type, schema in get_all_schemas().items() - if any(f.advanced for f in schema.fields) - } + return {db_type for db_type, schema in get_all_schemas().items() if any(f.advanced for f in schema.fields)} def _get_providers_without_advanced_tab() -> set[str]: - return { - db_type - for db_type, schema in get_all_schemas().items() - if not any(f.advanced for f in schema.fields) - } + return {db_type for db_type, schema in get_all_schemas().items() if not any(f.advanced for f in schema.fields)} class TestConnectionScreen: @@ -172,7 +164,7 @@ async def test_sqlite_tab_navigation_excludes_tab_bar(self): config = ConnectionConfig(name="", db_type="sqlite", file_path="") app = ConnectionScreenTestApp(config, editing=False) - async with app.run_test(size=(100, 35)) as pilot: + async with app.run_test(size=(100, 35)) as _pilot: screen = app.screen # Get the list of focusable fields @@ -220,7 +212,6 @@ async def test_tab_key_cycles_through_sqlite_fields(self): await pilot.press("tab") assert screen.focused.id == "conn-name", "Tab should cycle back to conn-name, not go to tab bar" - @pytest.mark.asyncio async def test_shift_tab_from_first_field_goes_to_tab_bar(self): """Pressing Shift+Tab from the first field should focus the tab bar. @@ -254,8 +245,7 @@ async def test_shift_tab_from_first_field_goes_to_tab_bar(self): # Focus should be on the Tabs widget (tab bar) assert isinstance(screen.focused, Tabs), ( - f"Shift+Tab from first field should focus tab bar, " - f"but focused is {type(screen.focused).__name__}" + f"Shift+Tab from first field should focus tab bar, " f"but focused is {type(screen.focused).__name__}" ) @@ -266,7 +256,7 @@ async def test_advanced_tab_enabled(self, db_type): config = ConnectionConfig(name="test", db_type=db_type) app = ConnectionScreenTestApp(config, editing=True) - async with app.run_test(size=(100, 35)) as pilot: + async with app.run_test(size=(100, 35)) as _pilot: screen = app.screen tabs = screen.query_one("#connection-tabs") advanced_pane = screen.query_one("#tab-advanced") @@ -280,7 +270,7 @@ async def test_advanced_tab_disabled(self, db_type): config = ConnectionConfig(name="test", db_type=db_type) app = ConnectionScreenTestApp(config, editing=True) - async with app.run_test(size=(100, 35)) as pilot: + async with app.run_test(size=(100, 35)) as _pilot: screen = app.screen tabs = screen.query_one("#connection-tabs") advanced_pane = screen.query_one("#tab-advanced") diff --git a/tests/ui/test_explorer_toggle.py b/tests/ui/test_explorer_toggle.py index 2034fc08..d33e4b96 100644 --- a/tests/ui/test_explorer_toggle.py +++ b/tests/ui/test_explorer_toggle.py @@ -20,9 +20,7 @@ async def test_leader_menu_blocked_when_dialog_open(self): app._show_leader_menu() await pilot.pause() - has_leader_menu = any( - isinstance(screen, LeaderMenuScreen) for screen in app.screen_stack - ) + has_leader_menu = any(isinstance(screen, LeaderMenuScreen) for screen in app.screen_stack) assert not has_leader_menu @@ -91,6 +89,9 @@ async def test_results_fullscreen_hides_explorer_and_query(self): app = SSMSTUI() async with app.run_test(size=(100, 35)) as pilot: + # Wait for Lazy widget to render the results table + await pilot.pause() + app.action_focus_results() await pilot.pause() diff --git a/tests/ui/test_password_input.py b/tests/ui/test_password_input.py new file mode 100644 index 00000000..8e41b541 --- /dev/null +++ b/tests/ui/test_password_input.py @@ -0,0 +1,392 @@ +"""Tests for password input screen and connection flow.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from textual.widgets import Input + +from sqlit.config import ConnectionConfig +from sqlit.ui.screens.password_input import PasswordInputScreen + + +class TestPasswordInputScreen: + """Test the PasswordInputScreen modal.""" + + @pytest.mark.asyncio + async def test_password_input_screen_renders(self) -> None: + """Password input screen renders with correct title and description.""" + from textual.app import App + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + screen = PasswordInputScreen("test_connection") + app.push_screen(screen) + await pilot.pause() + + # Check that the input field exists and is masked + input_widget = screen.query_one("#password-input", Input) + assert input_widget is not None + assert input_widget.password is True + + @pytest.mark.asyncio + async def test_password_input_submit_with_enter(self) -> None: + """Pressing Enter submits the password.""" + from textual.app import App + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + result = None + + def on_dismiss(password): + nonlocal result + result = password + + screen = PasswordInputScreen("test_connection") + app.push_screen(screen, on_dismiss) + await pilot.pause() + + # Type a password + input_widget = screen.query_one("#password-input", Input) + input_widget.value = "my_secret_password" + await pilot.pause() + + # Press Enter to submit + await pilot.press("enter") + await pilot.pause() + + # Should have dismissed with the password + assert result == "my_secret_password" + + @pytest.mark.asyncio + async def test_password_input_cancel_with_escape(self) -> None: + """Pressing Escape cancels and returns None.""" + from textual.app import App + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + result = "not_set" + + def on_dismiss(password): + nonlocal result + result = password + + screen = PasswordInputScreen("test_connection") + app.push_screen(screen, on_dismiss) + await pilot.pause() + + # Type a password but then cancel + input_widget = screen.query_one("#password-input", Input) + input_widget.value = "my_secret_password" + await pilot.pause() + + # Press Escape to cancel + await pilot.press("escape") + await pilot.pause() + + # Should have dismissed with None + assert result is None + + @pytest.mark.asyncio + async def test_password_input_shows_connection_name(self) -> None: + """Password input shows the connection name in description.""" + from textual.app import App + from textual.widgets import Static + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + screen = PasswordInputScreen("my_database") + app.push_screen(screen) + await pilot.pause() + + description = screen.query_one("#password-description", Static) + assert "my_database" in str(description.render()) + + @pytest.mark.asyncio + async def test_password_input_ssh_type(self) -> None: + """SSH password type shows appropriate message.""" + from textual.app import App + from textual.widgets import Static + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + screen = PasswordInputScreen("test_connection", password_type="ssh") + app.push_screen(screen) + await pilot.pause() + + description = screen.query_one("#password-description", Static) + rendered_text = str(description.render()) + assert "SSH password" in rendered_text + assert "test_connection" in rendered_text + + @pytest.mark.asyncio + async def test_password_input_custom_description(self) -> None: + """Custom description is displayed.""" + from textual.app import App + from textual.widgets import Static + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + screen = PasswordInputScreen( + "test_connection", + description="Please enter your custom password:", + ) + app.push_screen(screen) + await pilot.pause() + + description = screen.query_one("#password-description", Static) + assert "custom password" in str(description.render()) + + @pytest.mark.asyncio + async def test_password_not_visible_in_input(self) -> None: + """Password characters are not visible when typing.""" + from textual.app import App + + class TestApp(App): + pass + + app = TestApp() + async with app.run_test() as pilot: + screen = PasswordInputScreen("test_connection") + app.push_screen(screen) + await pilot.pause() + + input_widget = screen.query_one("#password-input", Input) + input_widget.value = "secret123" + await pilot.pause() + + # The password property should be True, which masks the input + assert input_widget.password is True + + +class TestConnectionPasswordFlow: + """Test the connection flow with password prompts.""" + + @pytest.mark.asyncio + async def test_connect_with_none_password_shows_prompt(self) -> None: + """Connecting with None password shows password input screen.""" + from sqlit.app import SSMSTUI + from sqlit.mocks import get_mock_profile + + mock_profile = get_mock_profile("empty") + app = SSMSTUI(mock_profile=mock_profile) + + async with app.run_test() as pilot: + # Create a connection with None password (not set) + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password=None, # None = prompt needed + ) + app.connections = [config] + + # Trigger connect + app.connect_to_server(config) + await pilot.pause() + + # Should have pushed PasswordInputScreen + assert isinstance(app.screen, PasswordInputScreen) + assert app.screen.connection_name == "test_db" + + @pytest.mark.asyncio + async def test_connect_with_stored_password_no_prompt(self) -> None: + """Connecting with stored password doesn't show prompt.""" + from sqlit.app import SSMSTUI + from sqlit.mocks import get_mock_profile + + mock_profile = get_mock_profile("empty") + app = SSMSTUI(mock_profile=mock_profile) + + async with app.run_test() as pilot: + # Create a connection with stored password + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password="stored_password", + ) + app.connections = [config] + + # Mock the session factory + mock_session = MagicMock() + mock_session.connection = MagicMock() + mock_session.adapter = MagicMock() + mock_session.tunnel = None + mock_session.config = config + + app._session_factory = lambda c: mock_session + + # Trigger connect + app.connect_to_server(config) + await pilot.pause(0.5) # Wait for worker thread + + # Should NOT have pushed PasswordInputScreen + # (the app screen should be the main app, not PasswordInputScreen) + assert not isinstance(app.screen, PasswordInputScreen) + + @pytest.mark.asyncio + async def test_ssh_password_prompt_before_db_password(self) -> None: + """SSH password is prompted before database password.""" + from sqlit.app import SSMSTUI + from sqlit.mocks import get_mock_profile + + mock_profile = get_mock_profile("empty") + app = SSMSTUI(mock_profile=mock_profile) + + async with app.run_test() as pilot: + # Create a connection with SSH enabled and both passwords None (not set) + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password=None, # None = prompt needed + ssh_enabled=True, + ssh_auth_type="password", + ssh_host="bastion.example.com", + ssh_username="sshuser", + ssh_password=None, # None = prompt needed + ) + app.connections = [config] + + # Trigger connect + app.connect_to_server(config) + await pilot.pause() + + # Should have pushed PasswordInputScreen for SSH first + assert isinstance(app.screen, PasswordInputScreen) + assert app.screen.password_type == "ssh" + + @pytest.mark.asyncio + async def test_cancel_password_prompt_aborts_connection(self) -> None: + """Cancelling password prompt aborts the connection.""" + from sqlit.app import SSMSTUI + from sqlit.mocks import get_mock_profile + + mock_profile = get_mock_profile("empty") + app = SSMSTUI(mock_profile=mock_profile) + + async with app.run_test() as pilot: + # Create a connection with None password (not set) + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password=None, # None = prompt needed + ) + app.connections = [config] + + # Trigger connect + app.connect_to_server(config) + await pilot.pause() + + # Should show password prompt + assert isinstance(app.screen, PasswordInputScreen) + + # Cancel the prompt + await pilot.press("escape") + await pilot.pause() + + # Should not have connected (current_connection should be None) + assert app.current_connection is None + + @pytest.mark.asyncio + async def test_password_from_prompt_used_for_connection(self) -> None: + """Password entered in prompt is used for connection.""" + from sqlit.app import SSMSTUI + from sqlit.mocks import get_mock_profile + + mock_profile = get_mock_profile("empty") + app = SSMSTUI(mock_profile=mock_profile) + + # Track what config was used for connection + connection_config = None + + def mock_session_factory(config): + nonlocal connection_config + connection_config = config + mock_session = MagicMock() + mock_session.connection = MagicMock() + mock_session.adapter = MagicMock() + mock_session.tunnel = None + mock_session.config = config + return mock_session + + app._session_factory = mock_session_factory + + async with app.run_test() as pilot: + # Create a connection with None password (not set) + config = ConnectionConfig( + name="test_db", + db_type="postgresql", + server="localhost", + username="user", + password=None, # None = prompt needed + ) + app.connections = [config] + + # Trigger connect + app.connect_to_server(config) + await pilot.pause() + + # Should show password prompt + assert isinstance(app.screen, PasswordInputScreen) + + # Enter password + input_widget = app.screen.query_one("#password-input", Input) + input_widget.value = "entered_password" + await pilot.pause() + + # Submit + await pilot.press("enter") + await pilot.pause(0.5) # Wait for connection worker + + # Check that the connection was made with the entered password + assert connection_config is not None + assert connection_config.password == "entered_password" + # Original config should still have None password + assert config.password is None + + @pytest.mark.asyncio + async def test_file_based_database_no_password_prompt(self) -> None: + """File-based databases (SQLite) don't prompt for password.""" + from sqlit.app import SSMSTUI + from sqlit.mocks import get_mock_profile + + mock_profile = get_mock_profile("sqlite-demo") + app = SSMSTUI(mock_profile=mock_profile) + + async with app.run_test() as pilot: + # Get the SQLite demo connection + if app.connections: + config = app.connections[0] + + # Trigger connect + app.connect_to_server(config) + await pilot.pause(0.5) + + # Should NOT show password prompt (SQLite doesn't need passwords) + assert not isinstance(app.screen, PasswordInputScreen) diff --git a/tests/ui/test_query_history.py b/tests/ui/test_query_history.py new file mode 100644 index 00000000..ee347505 --- /dev/null +++ b/tests/ui/test_query_history.py @@ -0,0 +1,161 @@ +"""UI tests for query history functionality.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from sqlit.app import SSMSTUI + +from .mocks import MockConnectionStore, MockSettingsStore, create_test_connection + + +class TestQueryHistoryCursorMemory: + """Tests for cursor position memory when switching between queries.""" + + @pytest.mark.asyncio + async def test_cursor_position_remembered_when_switching_queries(self): + """Test that cursor position is saved and restored when switching queries via history.""" + connections = [create_test_connection("test-db", "sqlite")] + mock_connections = MockConnectionStore(connections) + mock_settings = MockSettingsStore({"theme": "tokyo-night"}) + + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): + app = SSMSTUI() + + async with app.run_test(size=(100, 35)) as pilot: + # Set first query and position cursor at a specific location + query_a = "SELECT * FROM users" + app.query_input.text = query_a + await pilot.pause() + + # Move cursor to position (0, 7) - after "SELECT " + app.query_input.cursor_location = (0, 7) + await pilot.pause() + + # Verify cursor is at expected position + assert app.query_input.cursor_location == (0, 7) + + # Simulate selecting a different query from history + # This calls _handle_history_result directly + query_b = "SELECT id, name FROM products" + app._handle_history_result(("select", query_b)) + await pilot.pause() + + # Verify query changed + assert app.query_input.text == query_b + + # Move cursor to a different position in query B + app.query_input.cursor_location = (0, 10) + await pilot.pause() + + # Now switch back to query A + app._handle_history_result(("select", query_a)) + await pilot.pause() + + # Verify query A is back + assert app.query_input.text == query_a + + # Verify cursor position is restored to (0, 7) + assert app.query_input.cursor_location == (0, 7) + + @pytest.mark.asyncio + async def test_cursor_position_at_end_for_new_query(self): + """Test that cursor goes to end for a query not previously edited.""" + connections = [create_test_connection("test-db", "sqlite")] + mock_connections = MockConnectionStore(connections) + mock_settings = MockSettingsStore({"theme": "tokyo-night"}) + + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): + app = SSMSTUI() + + async with app.run_test(size=(100, 35)) as pilot: + # Start with empty query + app.query_input.text = "" + await pilot.pause() + + # Select a query from history that was never edited before + new_query = "SELECT * FROM orders" + app._handle_history_result(("select", new_query)) + await pilot.pause() + + # Verify cursor is at end of query + expected_col = len(new_query) + assert app.query_input.cursor_location == (0, expected_col) + + @pytest.mark.asyncio + async def test_cursor_position_for_multiline_query(self): + """Test cursor position memory works for multiline queries.""" + connections = [create_test_connection("test-db", "sqlite")] + mock_connections = MockConnectionStore(connections) + mock_settings = MockSettingsStore({"theme": "tokyo-night"}) + + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): + app = SSMSTUI() + + async with app.run_test(size=(100, 35)) as pilot: + # Set multiline query + query_multiline = "SELECT *\nFROM users\nWHERE id = 1" + app.query_input.text = query_multiline + await pilot.pause() + + # Position cursor on second line (row 1, col 5) - "FROM " + app.query_input.cursor_location = (1, 5) + await pilot.pause() + + # Switch to another query + query_other = "SELECT 1" + app._handle_history_result(("select", query_other)) + await pilot.pause() + + # Switch back + app._handle_history_result(("select", query_multiline)) + await pilot.pause() + + # Verify cursor is restored to (1, 5) + assert app.query_input.cursor_location == (1, 5) + + @pytest.mark.asyncio + async def test_cursor_cache_handles_same_query_text(self): + """Test that identical query text shares cursor position.""" + connections = [create_test_connection("test-db", "sqlite")] + mock_connections = MockConnectionStore(connections) + mock_settings = MockSettingsStore({"theme": "tokyo-night"}) + + with ( + patch("sqlit.app.load_connections", mock_connections.load_all), + patch("sqlit.app.load_settings", mock_settings.load_all), + patch("sqlit.app.save_settings", mock_settings.save_all), + ): + app = SSMSTUI() + + async with app.run_test(size=(100, 35)) as pilot: + # Set query and cursor position + query = "SELECT * FROM users" + app.query_input.text = query + app.query_input.cursor_location = (0, 5) + await pilot.pause() + + # Switch away + app._handle_history_result(("select", "SELECT 1")) + await pilot.pause() + + # Select the same query text again (simulating it appearing twice in history) + app._handle_history_result(("select", query)) + await pilot.pause() + + # Cursor should be at the remembered position + assert app.query_input.cursor_location == (0, 5) diff --git a/tests/ui/test_tree_schema_grouping.py b/tests/ui/test_tree_schema_grouping.py index 99bbc02b..5f79abf6 100644 --- a/tests/ui/test_tree_schema_grouping.py +++ b/tests/ui/test_tree_schema_grouping.py @@ -46,12 +46,12 @@ def __init__(self, label: str = "", data: tuple = None): self.children: list[MockTreeNode] = [] self.allow_expand = False - def add(self, label: str) -> "MockTreeNode": + def add(self, label: str) -> MockTreeNode: child = MockTreeNode(label) self.children.append(child) return child - def add_leaf(self, label: str) -> "MockTreeNode": + def add_leaf(self, label: str) -> MockTreeNode: return self.add(label) @@ -370,8 +370,10 @@ def test_escape_markup_handles_opening_tags(self): name = "[bold]test" escaped = escape_markup(name) # Rich should be able to render this without error - from rich.console import Console from io import StringIO + + from rich.console import Console + console = Console(file=StringIO(), force_terminal=True) # This should not raise an error console.print(f"[dim]{escaped}[/dim]") @@ -394,9 +396,10 @@ def test_escape_markup_handles_complex_names(self): "name/with/slashes", ] - from rich.console import Console from io import StringIO + from rich.console import Console + for name in problematic_names: escaped = escape_markup(name) console = Console(file=StringIO(), force_terminal=True)