diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index 5be6be3..dac4054 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -107,7 +107,7 @@ def visit_select( from_object = select_stmt.get_final_froms()[0] if hasattr(from_object, "element"): query = self._get_most_inner_element(from_object.element) - (main, lets) = self._extract_let_statements(query.text) + main, lets = self._extract_let_statements(query.text) compiled_query_lines.extend(lets) compiled_query_lines.append( f"let {from_object.name} = ({self._convert_schema_in_statement(main)});" @@ -142,9 +142,15 @@ def visit_select( ) compiled_query_lines.append(f"| where {converted_where_clause}") + # Add summarize first if it exists + if "summarize" in projections_parts_dict: + compiled_query_lines.append(projections_parts_dict.pop("summarize")) + + # Then add extend after summarize if "extend" in projections_parts_dict: compiled_query_lines.append(projections_parts_dict.pop("extend")) + # Add remaining parts (project, sort) for statement_part in projections_parts_dict.values(): if statement_part: compiled_query_lines.append(statement_part) diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index 04b8fc0..8791486 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -175,9 +175,10 @@ def test_group_by_text(): ).replace("\n", "") # raw query text from query query_expected = ( - '["ActiveUsersLastMonth"]| extend ["ActiveUserMetric"] = ["ActiveUsers"], ' - '["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '["ActiveUsersLastMonth"]' '| summarize by ["EventInfo_Time"] / time(1d)' + '| extend ["ActiveUserMetric"] = ["ActiveUsers"], ' + '["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["ActiveUserMetric"]' '| order by ["ActiveUserMetric"] desc' ) @@ -224,20 +225,19 @@ def test_group_by_text_vaccine_dataset(): query.compile(engine, compile_kwargs={"literal_binds": True}) ).replace("\n", "") query_expected = ( - 'database("superset").["CovidVaccineData"]| ' - 'extend ["country_name"] = ["country_name"]| ' - 'summarize by ["country_name"]| ' - 'project ["country_name"]| order by ["country_name"] asc' + 'database("superset").["CovidVaccineData"]' + '| summarize by ["country_name"]' + '| extend ["country_name"] = ["country_name"]' + '| project ["country_name"]' + '| order by ["country_name"] asc' ) assert query_compiled == query_expected def test_is_kql_function(): - assert KustoKqlCompiler._is_kql_function( - """case(Size <= 3, "Small", + assert KustoKqlCompiler._is_kql_function("""case(Size <= 3, "Small", Size <= 10, "Medium", - "Large")""" - ) + "Large")""") assert KustoKqlCompiler._is_kql_function("""bin(time(16d), 7d)""") assert KustoKqlCompiler._is_kql_function( """iff((EventType in ("Heavy Rain", "Flash Flood", "Flood")), "Rain event", "Not rain event")""" @@ -328,8 +328,8 @@ def test_distinct_count_by_text(): # raw query text from query query_expected = ( '["ActiveUsersLastMonth"]' - '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["DistinctUsers"]' '| order by ["ActiveUserMetric"] desc' ) @@ -354,8 +354,8 @@ def test_distinct_count_alt_by_text(): # raw query text from query query_expected = ( '["ActiveUsersLastMonth"]' - '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["DistinctUsers"]' '| order by ["ActiveUserMetric"] desc' ) @@ -549,6 +549,28 @@ def test_match_aggregates(column_name: str, expected_aggregate: str): assert kql_agg is None +def test_calculated_measure_with_adhoc_measure_and_constant(): + """Test calculated measure with an ad hoc measure and a constant. + + Measure 1 = count(*), Measure 2 = "Measure 1" * 2 + Measure 2 should compile to ["Measure 1"] * 2 + The extend clause must come after summarize for this to work. + """ + measure_1 = literal_column("count(*)").label("Measure 1") + measure_2 = literal_column('"Measure 1" * 2').label("Measure 2") + query = select([measure_1, measure_2]).select_from(text("SalesData")) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + query_expected = ( + '["SalesData"]' + '| summarize ["Measure 1"] = count() ' + '| extend ["Measure 2"] = ["Measure 1"] * 2' + '| project ["Measure 1"], ["Measure 2"]' + ) + assert query_compiled == query_expected + + @pytest.mark.parametrize( ("query_table_name", "expected_table_name"), [