Skip to content

Commit ef9c21b

Browse files
authored
fix(driver): validate FROM clause in COUNT query generation (#251)
Fixes COUNT query generation to properly validate FROM clause existence before attempting to create COUNT(*) queries. Addresses upstream bug report where select_with_total raised confusing error for SQL with only ORDER BY clause.
1 parent 720a001 commit ef9c21b

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed

sqlspec/driver/_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,13 @@ def _create_count_query(self, original_sql: "SQL") -> "SQL":
985985
expr = original_sql.expression
986986

987987
if isinstance(expr, exp.Select):
988+
if not expr.args.get("from"):
989+
msg = (
990+
"Cannot create COUNT query: SELECT statement missing FROM clause. "
991+
"COUNT queries require a FROM clause to determine which table to count rows from."
992+
)
993+
raise ImproperConfigurationError(msg)
994+
988995
if expr.args.get("group"):
989996
subquery = expr.subquery(alias="grouped_data")
990997
count_expr = exp.select(exp.Count(this=exp.Star())).from_(subquery)
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""Tests for _create_count_query edge cases and validation.
2+
3+
This module tests COUNT query generation validation, particularly for edge cases
4+
where SELECT statements are missing required clauses (FROM, etc.).
5+
"""
6+
7+
# pyright: reportPrivateUsage=false
8+
9+
import pytest
10+
11+
from sqlspec.core import SQL, StatementConfig
12+
from sqlspec.driver._sync import SyncDriverAdapterBase
13+
from sqlspec.exceptions import ImproperConfigurationError
14+
15+
16+
class MockSyncDriver(SyncDriverAdapterBase):
17+
"""Mock driver for testing _create_count_query method."""
18+
19+
def __init__(self) -> None:
20+
self.statement_config = StatementConfig()
21+
22+
@property
23+
def connection(self):
24+
return None
25+
26+
def _execute_statement(self, *args, **kwargs):
27+
raise NotImplementedError("Mock driver - not implemented")
28+
29+
def _execute_many(self, *args, **kwargs):
30+
raise NotImplementedError("Mock driver - not implemented")
31+
32+
def with_cursor(self, *args, **kwargs):
33+
raise NotImplementedError("Mock driver - not implemented")
34+
35+
def handle_database_exceptions(self, *args, **kwargs):
36+
raise NotImplementedError("Mock driver - not implemented")
37+
38+
def create_connection(self, *args, **kwargs):
39+
raise NotImplementedError("Mock driver - not implemented")
40+
41+
def close_connection(self, *args, **kwargs):
42+
raise NotImplementedError("Mock driver - not implemented")
43+
44+
def begin(self, *args, **kwargs):
45+
raise NotImplementedError("Mock driver - not implemented")
46+
47+
def commit(self, *args, **kwargs):
48+
raise NotImplementedError("Mock driver - not implemented")
49+
50+
def rollback(self, *args, **kwargs):
51+
raise NotImplementedError("Mock driver - not implemented")
52+
53+
def _try_special_handling(self, *args, **kwargs):
54+
raise NotImplementedError("Mock driver - not implemented")
55+
56+
@property
57+
def data_dictionary(self):
58+
raise NotImplementedError("Mock driver - not implemented")
59+
60+
61+
class TestCountQueryValidation:
62+
"""Test COUNT query generation validation."""
63+
64+
def test_count_query_missing_from_clause_with_order_by(self) -> None:
65+
"""Test COUNT query fails with clear error when FROM clause missing (ORDER BY only).
66+
67+
This is the reported bug scenario from upstream.
68+
"""
69+
driver = MockSyncDriver()
70+
sql = driver.prepare_statement(SQL("SELECT * ORDER BY id"), statement_config=driver.statement_config)
71+
sql.compile() # Parse the SQL to populate expression
72+
73+
with pytest.raises(ImproperConfigurationError, match="missing FROM clause"):
74+
driver._create_count_query(sql)
75+
76+
def test_count_query_missing_from_clause_with_where(self) -> None:
77+
"""Test COUNT query fails when only WHERE clause present (no FROM)."""
78+
driver = MockSyncDriver()
79+
sql = driver.prepare_statement(SQL("SELECT * WHERE active = true"), statement_config=driver.statement_config)
80+
sql.compile()
81+
82+
with pytest.raises(ImproperConfigurationError, match="missing FROM clause"):
83+
driver._create_count_query(sql)
84+
85+
def test_count_query_select_star_no_from(self) -> None:
86+
"""Test COUNT query fails for SELECT * without FROM clause."""
87+
driver = MockSyncDriver()
88+
sql = driver.prepare_statement(SQL("SELECT *"), statement_config=driver.statement_config)
89+
sql.compile()
90+
91+
with pytest.raises(ImproperConfigurationError, match="missing FROM clause"):
92+
driver._create_count_query(sql)
93+
94+
def test_count_query_select_columns_no_from(self) -> None:
95+
"""Test COUNT query fails for SELECT columns without FROM clause."""
96+
driver = MockSyncDriver()
97+
sql = driver.prepare_statement(SQL("SELECT id, name"), statement_config=driver.statement_config)
98+
sql.compile()
99+
100+
with pytest.raises(ImproperConfigurationError, match="missing FROM clause"):
101+
driver._create_count_query(sql)
102+
103+
def test_count_query_valid_select_with_from(self) -> None:
104+
"""Test COUNT query succeeds with valid SELECT...FROM."""
105+
driver = MockSyncDriver()
106+
sql = driver.prepare_statement(SQL("SELECT * FROM users ORDER BY id"), statement_config=driver.statement_config)
107+
sql.compile()
108+
109+
count_sql = driver._create_count_query(sql)
110+
111+
count_str = str(count_sql)
112+
assert "COUNT(*)" in count_str.upper()
113+
assert "FROM users" in count_str or "FROM USERS" in count_str.upper()
114+
assert "ORDER BY" not in count_str.upper()
115+
116+
def test_count_query_with_where_and_from(self) -> None:
117+
"""Test COUNT query preserves WHERE clause when FROM present."""
118+
driver = MockSyncDriver()
119+
sql = driver.prepare_statement(
120+
SQL("SELECT * FROM users WHERE active = true ORDER BY id"), statement_config=driver.statement_config
121+
)
122+
sql.compile()
123+
124+
count_sql = driver._create_count_query(sql)
125+
126+
count_str = str(count_sql)
127+
assert "COUNT(*)" in count_str.upper()
128+
assert "FROM users" in count_str or "FROM USERS" in count_str.upper()
129+
assert "WHERE" in count_str.upper()
130+
assert "active" in count_str or "ACTIVE" in count_str.upper()
131+
assert "ORDER BY" not in count_str.upper()
132+
133+
def test_count_query_with_group_by(self) -> None:
134+
"""Test COUNT query wraps grouped query in subquery."""
135+
driver = MockSyncDriver()
136+
sql = driver.prepare_statement(
137+
SQL("SELECT status, COUNT(*) FROM users GROUP BY status"), statement_config=driver.statement_config
138+
)
139+
sql.compile()
140+
141+
count_sql = driver._create_count_query(sql)
142+
143+
count_str = str(count_sql)
144+
assert "COUNT(*)" in count_str.upper()
145+
assert "grouped_data" in count_str.lower()
146+
147+
def test_count_query_removes_limit_offset(self) -> None:
148+
"""Test COUNT query removes LIMIT and OFFSET clauses."""
149+
driver = MockSyncDriver()
150+
sql = driver.prepare_statement(
151+
SQL("SELECT * FROM users ORDER BY id LIMIT 10 OFFSET 20"), statement_config=driver.statement_config
152+
)
153+
sql.compile()
154+
155+
count_sql = driver._create_count_query(sql)
156+
157+
count_str = str(count_sql)
158+
assert "LIMIT" not in count_str.upper()
159+
assert "OFFSET" not in count_str.upper()
160+
assert "ORDER BY" not in count_str.upper()
161+
162+
def test_count_query_with_having(self) -> None:
163+
"""Test COUNT query preserves HAVING clause."""
164+
driver = MockSyncDriver()
165+
sql = driver.prepare_statement(
166+
SQL("SELECT status, COUNT(*) as cnt FROM users GROUP BY status HAVING cnt > 5"),
167+
statement_config=driver.statement_config,
168+
)
169+
sql.compile()
170+
171+
count_sql = driver._create_count_query(sql)
172+
173+
count_str = str(count_sql)
174+
assert "COUNT(*)" in count_str.upper()
175+
176+
177+
class TestCountQueryEdgeCases:
178+
"""Test COUNT query edge cases that previously caused issues."""
179+
180+
def test_complex_select_with_join(self) -> None:
181+
"""Test complex SELECT with JOIN generates correct COUNT."""
182+
driver = MockSyncDriver()
183+
sql = driver.prepare_statement(
184+
SQL("""
185+
SELECT u.id, u.name, o.total
186+
FROM users u
187+
JOIN orders o ON u.id = o.user_id
188+
WHERE u.active = true
189+
AND o.total > 100
190+
ORDER BY o.total DESC
191+
LIMIT 10
192+
"""),
193+
statement_config=driver.statement_config,
194+
)
195+
sql.compile()
196+
197+
count_sql = driver._create_count_query(sql)
198+
199+
count_str = str(count_sql)
200+
assert "COUNT(*)" in count_str.upper()
201+
assert "FROM users" in count_str or "FROM USERS" in count_str.upper()
202+
assert "ORDER BY" not in count_str.upper()
203+
assert "LIMIT" not in count_str.upper()
204+
205+
def test_select_with_subquery_in_from(self) -> None:
206+
"""Test SELECT with subquery in FROM clause."""
207+
driver = MockSyncDriver()
208+
sql = driver.prepare_statement(
209+
SQL("""
210+
SELECT t.id
211+
FROM (SELECT id FROM users WHERE active = true) t
212+
ORDER BY t.id
213+
"""),
214+
statement_config=driver.statement_config,
215+
)
216+
sql.compile()
217+
218+
count_sql = driver._create_count_query(sql)
219+
220+
count_str = str(count_sql)
221+
assert "COUNT(*)" in count_str.upper()
222+
223+
def test_error_message_clarity(self) -> None:
224+
"""Test that error message explains why FROM clause is required."""
225+
driver = MockSyncDriver()
226+
sql = driver.prepare_statement(SQL("SELECT * ORDER BY id"), statement_config=driver.statement_config)
227+
sql.compile()
228+
229+
with pytest.raises(
230+
ImproperConfigurationError,
231+
match="COUNT queries require a FROM clause to determine which table to count rows from",
232+
):
233+
driver._create_count_query(sql)

0 commit comments

Comments
 (0)