diff --git a/application_sdk/activities/metadata_extraction/sql.py b/application_sdk/activities/metadata_extraction/sql.py index dda62d193..9ab39092f 100644 --- a/application_sdk/activities/metadata_extraction/sql.py +++ b/application_sdk/activities/metadata_extraction/sql.py @@ -850,7 +850,7 @@ async def transform_data( dataframe=dataframe, **workflow_args ) await transformed_output.write_daft_dataframe(transform_metadata) - return await transformed_output.get_statistics() + return await transformed_output.get_statistics(typename=typename) @activity.defn @auto_heartbeater diff --git a/application_sdk/outputs/__init__.py b/application_sdk/outputs/__init__.py index 4db007c60..c26256577 100644 --- a/application_sdk/outputs/__init__.py +++ b/application_sdk/outputs/__init__.py @@ -34,6 +34,7 @@ logger = get_logger(__name__) activity.logger = logger + if TYPE_CHECKING: import daft # type: ignore import pandas as pd @@ -330,7 +331,7 @@ async def get_statistics( Exception: If there's an error writing the statistics """ try: - statistics = await self.write_statistics() + statistics = await self.write_statistics(typename) if not statistics: raise ValueError("No statistics data available") statistics = ActivityStatistics.model_validate(statistics) @@ -390,7 +391,9 @@ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int): logger.error(f"Error flushing buffer to files: {str(e)}") raise e - async def write_statistics(self) -> Optional[Dict[str, Any]]: + async def write_statistics( + self, typename: Optional[str] = None + ) -> Optional[Dict[str, Any]]: """Write statistics about the output to a JSON file. This method writes statistics including total record count and chunk count @@ -407,10 +410,28 @@ async def write_statistics(self) -> Optional[Dict[str, Any]]: "partitions": self.partitions, } - # Write the statistics to a json file - output_file_name = f"{self.output_path}/statistics.json.ignore" - with open(output_file_name, "w") as f: - f.write(orjson.dumps(statistics).decode("utf-8")) + # Ensure typename is included in the statistics payload (if provided) + if typename: + statistics["typename"] = typename + + # Write the statistics to a json file inside a dedicated statistics/ folder + statistics_dir = os.path.join(self.output_path, "statistics") + os.makedirs(statistics_dir, exist_ok=True) + output_file_name = os.path.join(statistics_dir, "statistics.json.ignore") + # If chunk_start is provided, include it in the statistics filename + try: + cs = getattr(self, "chunk_start", None) + if cs is not None: + output_file_name = os.path.join( + statistics_dir, f"statistics-chunk-{cs}.json.ignore" + ) + except Exception: + # If accessing chunk_start fails, fallback to default filename + pass + + # Write the statistics dictionary to the JSON file + with open(output_file_name, "wb") as f: + f.write(orjson.dumps(statistics)) destination_file_path = get_object_store_prefix(output_file_name) # Push the file to the object store @@ -418,6 +439,7 @@ async def write_statistics(self) -> Optional[Dict[str, Any]]: source=output_file_name, destination=destination_file_path, ) + return statistics except Exception as e: logger.error(f"Error writing statistics: {str(e)}") diff --git a/tests/unit/outputs/test_output.py b/tests/unit/outputs/test_output.py index 23468e75a..bea3fa4ef 100644 --- a/tests/unit/outputs/test_output.py +++ b/tests/unit/outputs/test_output.py @@ -1,5 +1,6 @@ """Unit tests for output interface.""" +import os from typing import Any from unittest.mock import AsyncMock, mock_open, patch @@ -116,13 +117,15 @@ async def test_write_statistics_success(self): self.output.chunk_count = 5 self.output.partitions = [1, 2, 1, 2, 1] - # Mock the open function, orjson.dumps, and object store upload + # Mock the open function, orjson.dumps, os.makedirs, and object store upload with patch("builtins.open", mock_open()) as mock_file, patch( "orjson.dumps", return_value=b'{"total_record_count": 100, "chunk_count": 5, "partitions": [1,2,1,2,1]}', ) as mock_orjson, patch( + "application_sdk.outputs.os.makedirs", + ) as mock_makedirs, patch( "application_sdk.outputs.get_object_store_prefix", - return_value="path/statistics.json.ignore", + return_value="path/statistics/statistics.json.ignore", ), patch( "application_sdk.services.objectstore.ObjectStore.upload_file", new_callable=AsyncMock, @@ -136,7 +139,12 @@ async def test_write_statistics_success(self): "chunk_count": 5, # This is len(self.partitions) which is 5 "partitions": [1, 2, 1, 2, 1], } - mock_file.assert_called_once_with("/test/path/statistics.json.ignore", "w") + expected_stats_dir = os.path.join("/test/path", "statistics") + mock_makedirs.assert_called_once_with(expected_stats_dir, exist_ok=True) + expected_file_path = os.path.join( + expected_stats_dir, "statistics.json.ignore" + ) + mock_file.assert_called_once_with(expected_file_path, "wb") mock_orjson.assert_called_once_with( { "total_record_count": 100, @@ -147,8 +155,10 @@ async def test_write_statistics_success(self): # Verify the upload call mock_push.assert_awaited_once() upload_kwargs = mock_push.await_args.kwargs # type: ignore[attr-defined] - assert upload_kwargs["source"] == "/test/path/statistics.json.ignore" - assert upload_kwargs["destination"] == "path/statistics.json.ignore" + assert upload_kwargs["source"] == expected_file_path + assert ( + upload_kwargs["destination"] == "path/statistics/statistics.json.ignore" + ) @pytest.mark.asyncio async def test_write_statistics_error(self):