Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 147 additions & 62 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,71 @@ def _legacy_join(self, select_stmt: selectable.Select, **kwargs):
def visit_join(self, join, asfrom=True, from_linter=None, **kwargs):
return ""

def _should_use_summarize_for_count_star(self, column, select):
"""Determine if COUNT(*) should use summarize or extend."""
# Check if this is a literal_column (used in aggregations) or a regular column (used in drill-by)
return (
# Case for test_select_count: literal_column("count(*)")
(
hasattr(column, "element")
and isinstance(column.element, sql.expression.ClauseElement)
and not isinstance(column.element, sql.expression.ColumnClause)
)
# Also use summarize if there are where clauses (as in test_select_count)
or (select._whereclause is not None)
)

@staticmethod
def _is_count_star_column(column_name: str) -> bool:
"""Check if the column name is COUNT(*)."""
return (
column_name.upper() == "COUNT(*)"
or column_name == "COUNT(*)"
or column_name.upper() == '"COUNT(*)"'
or column_name == '"COUNT(*)"'
)

def _process_column(
self, column, select, summarize_columns, extend_columns, projection_columns
):
"""Process a single column and update the appropriate collections."""
column_name, column_alias = self._extract_column_name_and_alias(column)
column_alias = self._escape_and_quote_columns(column_alias, True)
has_aggregates = False

# Special handling for COUNT(*) as a column name
if self._is_count_star_column(column_name):
if self._should_use_summarize_for_count_star(column, select):
# This is likely a literal_column or has where clauses, so use summarize
has_aggregates = True
summarize_columns.add(f"{column_alias} = count()")
else:
# This is likely a regular column without where clauses, so use extend
extend_columns.add(f"{column_alias} = count()")
else:
# Do we have aggregate columns?
kql_agg = self._extract_maybe_agg_column_parts(column_name)
if kql_agg:
has_aggregates = True
summarize_columns.add(
self._build_column_projection(kql_agg, column_alias)
)
# No aggregates - do the columns have aliases?
elif column_alias and column_alias != column_name:
extend_columns.add(
self._build_column_projection(column_name, column_alias, True)
)

# Add to projection columns
if column_alias:
projection_columns.append(
self._escape_and_quote_columns(column_alias, True)
)
else:
projection_columns.append(self._escape_and_quote_columns(column_name))

return has_aggregates

def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, str]:
"""Builds the ending part of the query either project or summarize."""
columns = select.inner_columns
Expand All @@ -200,68 +265,51 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, s
summarize_statement = ""
extend_statement = ""
project_statement = ""
has_aggregates = False
# The following is the logic
# With Columns :
# - Do we have a group by clause ? --Yes---> Do we have aggregate columns ? --Yes--> Summarize new column(s)
# | | with by clause
# N N --> Add to projection
# |
# |
# - Do the columns have aliases ? --Yes---> Extend with aliases
# |
# N---> Add to projection

# Process columns if they exist
if columns is not None:
summarize_columns = set()
extend_columns = set()
projection_columns = []
summarize_columns: set[str] = set()
extend_columns: set[str] = set()
projection_columns: list[str] = []
has_aggregates = False

# Process each column (except *)
for column in [c for c in columns if c.name != "*"]:
column_name, column_alias = self._extract_column_name_and_alias(column)
column_alias = self._escape_and_quote_columns(column_alias, True)
# Do we have a group by clause ?
# Do we have aggregate columns ?
kql_agg = self._extract_maybe_agg_column_parts(column_name)
if kql_agg:
has_aggregates = True
summarize_columns.add(
self._build_column_projection(kql_agg, column_alias)
)
# No group by clause
# Do the columns have aliases ?
# Add additional and to handle case where : SELECT column_name as column_name
elif column_alias and column_alias != column_name:
extend_columns.add(
self._build_column_projection(column_name, column_alias, True)
)
if column_alias:
projection_columns.append(
self._escape_and_quote_columns(column_alias, True)
)
else:
projection_columns.append(
self._escape_and_quote_columns(column_name)
)
# group by columns
column_has_agg = self._process_column(
column,
select,
summarize_columns,
extend_columns,
projection_columns,
)
has_aggregates = has_aggregates or column_has_agg

# Process group by columns
by_columns = self._group_by(group_by_cols)
if has_aggregates or bool(
by_columns
): # Summarize can happen with or without aggregate being created

# Build statements
if has_aggregates or bool(by_columns):
summarize_statement = f"| summarize {', '.join(summarize_columns)} "
if by_columns:
summarize_statement = (
f"{summarize_statement} by {', '.join(by_columns)}"
)

if extend_columns:
extend_statement = f"| extend {', '.join(sorted(extend_columns))}"

project_statement = (
f"| project {', '.join(projection_columns)}"
if projection_columns
else ""
)

# Process order by
unwrapped_order_by = self._get_order_by(order_by_cols)
sort_statement = (
f"| order by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else ""
)

return {
"extend": extend_statement,
"summarize": summarize_statement,
Expand Down Expand Up @@ -351,20 +399,27 @@ def replacer(match):
def _escape_and_quote_columns(name: str | None, is_alias=False) -> str:
if name is None:
return ""
if (

result = None

# Special handling for COUNT(*) as a column name
if KustoKqlCompiler._is_count_star_column(name):
if is_alias:
# When it's an alias, we need to quote it
result = '["COUNT(*)"]'
else:
# When it's a function, we need to convert it to count()
result = "count()"
elif (
KustoKqlCompiler._is_kql_function(name)
or KustoKqlCompiler._is_number_literal(name)
) and not is_alias:
return name
if name.startswith('"') and name.endswith('"'):
name = name[1:-1]
# First, check if the name is already wrapped in ["ColumnName"] (escaped format)
if name.startswith('["') and name.endswith('"]'):
return name # Return as is if already properly escaped
# Remove surrounding spaces
# Handle mathematical operations (wrap only the column part before operators)
# Find the position of the first operator or space that separates the column name
if not is_alias:
result = name
elif name.startswith('["') and name.endswith('"]'):
# Return as is if already properly escaped
result = name
elif not is_alias:
# Handle mathematical operations (wrap only the column part before operators)
for operator in ["/", "+", "-", "*"]:
if operator in name:
# Split the name at the first operator and wrap the left part
Expand All @@ -374,10 +429,20 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str:
if col_part.startswith('"') and col_part.endswith('"'):
col_part = col_part[1:-1].strip()
col_part = col_part.replace('"', '\\"')
return f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part
# No operators found, just wrap the entire name
name = name.replace('"', '\\"')
return f'["{name}"]'
result = f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part
break

# If no special case was matched, apply default formatting
if result is None:
# Handle quoted names
if name.startswith('"') and name.endswith('"'):
name = name[1:-1]

# No operators found, just wrap the entire name
name = name.replace('"', '\\"')
result = f'["{name}"]'

return result

@staticmethod
def _sql_to_kql_where(where_clause: str) -> str:
Expand Down Expand Up @@ -565,12 +630,32 @@ def _extract_let_statements(clause) -> tuple[str, list[str]]:
@staticmethod
def _extract_column_name_and_alias(column: Column) -> tuple[str, str | None]:
if hasattr(column, "element"):
column_name = str(column.element)
column_alias = str(column.name)

# Special handling for COUNT(*) as a column name
if KustoKqlCompiler._is_count_star_column(column_name):
return "COUNT(*)", column_alias

return KustoKqlCompiler._convert_quoted_columns(
str(column.element)
), KustoKqlCompiler._convert_quoted_columns(column.name)
column_name
), KustoKqlCompiler._convert_quoted_columns(column_alias)
if hasattr(column, "name"):
return KustoKqlCompiler._convert_quoted_columns(str(column.name)), None
return KustoKqlCompiler._convert_quoted_columns(str(column)), None
column_name = str(column.name)

# Special handling for COUNT(*) as a column name
if KustoKqlCompiler._is_count_star_column(column_name):
return "COUNT(*)", None

return KustoKqlCompiler._convert_quoted_columns(column_name), None

column_str = str(column)

# Special handling for COUNT(*) as a column name
if KustoKqlCompiler._is_count_star_column(column_str):
return "COUNT(*)", None

return KustoKqlCompiler._convert_quoted_columns(column_str), None

@staticmethod
def _build_column_projection(
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,12 @@ def test_escape_and_quote_columns():
KustoKqlCompiler._escape_and_quote_columns("EventInfo_Time / time(1d)")
== '["EventInfo_Time"] / time(1d)'
)
# Test COUNT(*) handling
assert KustoKqlCompiler._escape_and_quote_columns("COUNT(*)") == "count()"
assert (
KustoKqlCompiler._escape_and_quote_columns("COUNT(*)", is_alias=True)
== '["COUNT(*)"]'
)


def test_use_table():
Expand Down Expand Up @@ -436,6 +442,37 @@ def test_select_count():
assert query_compiled == query_expected


def test_drill_by_with_count_star():
"""Test that simulates the drill-by functionality with COUNT(*) as a column name."""
# This test simulates what happens in Superset's drill-by functionality
# where COUNT(*) is used as a column name
kql_query = "logs"

# In drill-by, Superset might use COUNT(*) directly as a column name
column_count = column("COUNT(*)").label("count_value")

query = (
select([column_count])
.select_from(TextAsFrom(text(kql_query), ["*"]).alias("inner_qry"))
.limit(5)
)

query_compiled = str(
query.compile(engine, compile_kwargs={"literal_binds": True})
).replace("\n", "")

# The expected result should properly handle COUNT(*) as a column name
query_expected = (
'let inner_qry = (["logs"]);'
"inner_qry"
'| extend ["count_value"] = count()'
'| project ["count_value"]'
"| take 5"
)

assert query_compiled == query_expected


def test_select_with_let():
kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y"
query = (
Expand Down