diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index ffa8f75..46ea86e 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -192,6 +192,71 @@ def _legacy_join(self, select_stmt: selectable.Select, **kwargs): def visit_join(self, join, asfrom=True, from_linter=None, **kwargs): return "" + def _should_use_summarize_for_count_star(self, column, select): + """Determine if COUNT(*) should use summarize or extend.""" + # Check if this is a literal_column (used in aggregations) or a regular column (used in drill-by) + return ( + # Case for test_select_count: literal_column("count(*)") + ( + hasattr(column, "element") + and isinstance(column.element, sql.expression.ClauseElement) + and not isinstance(column.element, sql.expression.ColumnClause) + ) + # Also use summarize if there are where clauses (as in test_select_count) + or (select._whereclause is not None) + ) + + @staticmethod + def _is_count_star_column(column_name: str) -> bool: + """Check if the column name is COUNT(*).""" + return ( + column_name.upper() == "COUNT(*)" + or column_name == "COUNT(*)" + or column_name.upper() == '"COUNT(*)"' + or column_name == '"COUNT(*)"' + ) + + def _process_column( + self, column, select, summarize_columns, extend_columns, projection_columns + ): + """Process a single column and update the appropriate collections.""" + column_name, column_alias = self._extract_column_name_and_alias(column) + column_alias = self._escape_and_quote_columns(column_alias, True) + has_aggregates = False + + # Special handling for COUNT(*) as a column name + if self._is_count_star_column(column_name): + if self._should_use_summarize_for_count_star(column, select): + # This is likely a literal_column or has where clauses, so use summarize + has_aggregates = True + summarize_columns.add(f"{column_alias} = count()") + else: + # This is likely a regular column without where clauses, so use extend + extend_columns.add(f"{column_alias} = count()") + else: + # Do we have aggregate columns? + kql_agg = self._extract_maybe_agg_column_parts(column_name) + if kql_agg: + has_aggregates = True + summarize_columns.add( + self._build_column_projection(kql_agg, column_alias) + ) + # No aggregates - do the columns have aliases? + elif column_alias and column_alias != column_name: + extend_columns.add( + self._build_column_projection(column_name, column_alias, True) + ) + + # Add to projection columns + if column_alias: + projection_columns.append( + self._escape_and_quote_columns(column_alias, True) + ) + else: + projection_columns.append(self._escape_and_quote_columns(column_name)) + + return has_aggregates + def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, str]: """Builds the ending part of the query either project or summarize.""" columns = select.inner_columns @@ -200,68 +265,51 @@ def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, s summarize_statement = "" extend_statement = "" project_statement = "" - has_aggregates = False - # The following is the logic - # With Columns : - # - Do we have a group by clause ? --Yes---> Do we have aggregate columns ? --Yes--> Summarize new column(s) - # | | with by clause - # N N --> Add to projection - # | - # | - # - Do the columns have aliases ? --Yes---> Extend with aliases - # | - # N---> Add to projection + + # Process columns if they exist if columns is not None: - summarize_columns = set() - extend_columns = set() - projection_columns = [] + summarize_columns: set[str] = set() + extend_columns: set[str] = set() + projection_columns: list[str] = [] + has_aggregates = False + + # Process each column (except *) for column in [c for c in columns if c.name != "*"]: - column_name, column_alias = self._extract_column_name_and_alias(column) - column_alias = self._escape_and_quote_columns(column_alias, True) - # Do we have a group by clause ? - # Do we have aggregate columns ? - kql_agg = self._extract_maybe_agg_column_parts(column_name) - if kql_agg: - has_aggregates = True - summarize_columns.add( - self._build_column_projection(kql_agg, column_alias) - ) - # No group by clause - # Do the columns have aliases ? - # Add additional and to handle case where : SELECT column_name as column_name - elif column_alias and column_alias != column_name: - extend_columns.add( - self._build_column_projection(column_name, column_alias, True) - ) - if column_alias: - projection_columns.append( - self._escape_and_quote_columns(column_alias, True) - ) - else: - projection_columns.append( - self._escape_and_quote_columns(column_name) - ) - # group by columns + column_has_agg = self._process_column( + column, + select, + summarize_columns, + extend_columns, + projection_columns, + ) + has_aggregates = has_aggregates or column_has_agg + + # Process group by columns by_columns = self._group_by(group_by_cols) - if has_aggregates or bool( - by_columns - ): # Summarize can happen with or without aggregate being created + + # Build statements + if has_aggregates or bool(by_columns): summarize_statement = f"| summarize {', '.join(summarize_columns)} " if by_columns: summarize_statement = ( f"{summarize_statement} by {', '.join(by_columns)}" ) + if extend_columns: extend_statement = f"| extend {', '.join(sorted(extend_columns))}" + project_statement = ( f"| project {', '.join(projection_columns)}" if projection_columns else "" ) + + # Process order by unwrapped_order_by = self._get_order_by(order_by_cols) sort_statement = ( f"| order by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else "" ) + return { "extend": extend_statement, "summarize": summarize_statement, @@ -351,20 +399,27 @@ def replacer(match): def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: if name is None: return "" - if ( + + result = None + + # Special handling for COUNT(*) as a column name + if KustoKqlCompiler._is_count_star_column(name): + if is_alias: + # When it's an alias, we need to quote it + result = '["COUNT(*)"]' + else: + # When it's a function, we need to convert it to count() + result = "count()" + elif ( KustoKqlCompiler._is_kql_function(name) or KustoKqlCompiler._is_number_literal(name) ) and not is_alias: - return name - if name.startswith('"') and name.endswith('"'): - name = name[1:-1] - # First, check if the name is already wrapped in ["ColumnName"] (escaped format) - if name.startswith('["') and name.endswith('"]'): - return name # Return as is if already properly escaped - # Remove surrounding spaces - # Handle mathematical operations (wrap only the column part before operators) - # Find the position of the first operator or space that separates the column name - if not is_alias: + result = name + elif name.startswith('["') and name.endswith('"]'): + # Return as is if already properly escaped + result = name + elif not is_alias: + # Handle mathematical operations (wrap only the column part before operators) for operator in ["/", "+", "-", "*"]: if operator in name: # Split the name at the first operator and wrap the left part @@ -374,10 +429,20 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: if col_part.startswith('"') and col_part.endswith('"'): col_part = col_part[1:-1].strip() col_part = col_part.replace('"', '\\"') - return f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part - # No operators found, just wrap the entire name - name = name.replace('"', '\\"') - return f'["{name}"]' + result = f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part + break + + # If no special case was matched, apply default formatting + if result is None: + # Handle quoted names + if name.startswith('"') and name.endswith('"'): + name = name[1:-1] + + # No operators found, just wrap the entire name + name = name.replace('"', '\\"') + result = f'["{name}"]' + + return result @staticmethod def _sql_to_kql_where(where_clause: str) -> str: @@ -565,12 +630,32 @@ def _extract_let_statements(clause) -> tuple[str, list[str]]: @staticmethod def _extract_column_name_and_alias(column: Column) -> tuple[str, str | None]: if hasattr(column, "element"): + column_name = str(column.element) + column_alias = str(column.name) + + # Special handling for COUNT(*) as a column name + if KustoKqlCompiler._is_count_star_column(column_name): + return "COUNT(*)", column_alias + return KustoKqlCompiler._convert_quoted_columns( - str(column.element) - ), KustoKqlCompiler._convert_quoted_columns(column.name) + column_name + ), KustoKqlCompiler._convert_quoted_columns(column_alias) if hasattr(column, "name"): - return KustoKqlCompiler._convert_quoted_columns(str(column.name)), None - return KustoKqlCompiler._convert_quoted_columns(str(column)), None + column_name = str(column.name) + + # Special handling for COUNT(*) as a column name + if KustoKqlCompiler._is_count_star_column(column_name): + return "COUNT(*)", None + + return KustoKqlCompiler._convert_quoted_columns(column_name), None + + column_str = str(column) + + # Special handling for COUNT(*) as a column name + if KustoKqlCompiler._is_count_star_column(column_str): + return "COUNT(*)", None + + return KustoKqlCompiler._convert_quoted_columns(column_str), None @staticmethod def _build_column_projection( diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index 3ea3736..48d8336 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -372,6 +372,12 @@ def test_escape_and_quote_columns(): KustoKqlCompiler._escape_and_quote_columns("EventInfo_Time / time(1d)") == '["EventInfo_Time"] / time(1d)' ) + # Test COUNT(*) handling + assert KustoKqlCompiler._escape_and_quote_columns("COUNT(*)") == "count()" + assert ( + KustoKqlCompiler._escape_and_quote_columns("COUNT(*)", is_alias=True) + == '["COUNT(*)"]' + ) def test_use_table(): @@ -436,6 +442,37 @@ def test_select_count(): assert query_compiled == query_expected +def test_drill_by_with_count_star(): + """Test that simulates the drill-by functionality with COUNT(*) as a column name.""" + # This test simulates what happens in Superset's drill-by functionality + # where COUNT(*) is used as a column name + kql_query = "logs" + + # In drill-by, Superset might use COUNT(*) directly as a column name + column_count = column("COUNT(*)").label("count_value") + + query = ( + select([column_count]) + .select_from(TextAsFrom(text(kql_query), ["*"]).alias("inner_qry")) + .limit(5) + ) + + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + + # The expected result should properly handle COUNT(*) as a column name + query_expected = ( + 'let inner_qry = (["logs"]);' + "inner_qry" + '| extend ["count_value"] = count()' + '| project ["count_value"]' + "| take 5" + ) + + assert query_compiled == query_expected + + def test_select_with_let(): kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y" query = (