diff --git a/python/tempo/__init__.py b/python/tempo/__init__.py index da4a6c12..70282534 100644 --- a/python/tempo/__init__.py +++ b/python/tempo/__init__.py @@ -1,2 +1,3 @@ +from tempo.resample_result import ResampledTSDF # noqa: F401 from tempo.tsdf import TSDF # noqa: F401 from tempo.utils import display # noqa: F401 diff --git a/python/tempo/resample.py b/python/tempo/resample.py index f2fb4802..691a3ab8 100644 --- a/python/tempo/resample.py +++ b/python/tempo/resample.py @@ -2,6 +2,7 @@ import warnings from typing import ( + TYPE_CHECKING, Any, Callable, List, @@ -10,6 +11,9 @@ Union, ) +if TYPE_CHECKING: + from tempo.resample_result import ResampledTSDF + import pyspark.sql.functions as sfn from pyspark.sql import DataFrame @@ -419,7 +423,7 @@ def resample( prefix: Optional[str] = None, fill: Optional[bool] = None, perform_checks: bool = True, -) -> t_tsdf.TSDF: +) -> ResampledTSDF: """ function to upsample based on frequency and aggregate function similar to pandas @@ -446,13 +450,13 @@ def resample( enriched_df: DataFrame = aggregate(tsdf, freq, func, metricCols, prefix, fill) - # Import TSDF here to avoid circular import + # Import TSDF and ResampledTSDF here to avoid circular import + from tempo.resample_result import ResampledTSDF from tempo.tsdf import TSDF - return TSDF( + plain_tsdf = TSDF( enriched_df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids, - resample_freq=freq, - resample_func=func, ) + return ResampledTSDF(plain_tsdf, resample_freq=freq, resample_func=func) diff --git a/python/tempo/resample_result.py b/python/tempo/resample_result.py new file mode 100644 index 00000000..b79155f2 --- /dev/null +++ b/python/tempo/resample_result.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +from pyspark.sql import DataFrame + +from tempo.tsschema import TSSchema + +if TYPE_CHECKING: + from tempo.tsdf import TSDF + + +class ResampledTSDF: + """ + Restricted wrapper around a TSDF that has been resampled. + + Like Spark's GroupedData, this object only exposes operations that are + valid after a resample: interpolate() and as_tsdf(). Arbitrary DataFrame + transformations (filter, withColumn, etc.) are intentionally blocked to + prevent silent invalidation of resample context. + """ + + def __init__( + self, + tsdf: TSDF, + resample_freq: str, + resample_func: Union[Callable, str], + ) -> None: + self._tsdf = tsdf + self._resample_freq = resample_freq + self._resample_func = resample_func + + # ------------------------------------------------------------------ + # Read-only properties + # ------------------------------------------------------------------ + + @property + def df(self) -> DataFrame: + """The underlying Spark DataFrame.""" + return self._tsdf.df + + @property + def ts_col(self) -> str: + return self._tsdf.ts_col + + @property + def series_ids(self) -> list: + return self._tsdf.series_ids + + @property + def ts_schema(self) -> TSSchema: + return self._tsdf.ts_schema + + @property + def columns(self) -> list: + return self._tsdf.columns + + @property + def resample_freq(self) -> str: + return self._resample_freq + + @property + def resample_func(self) -> Union[Callable, str]: + return self._resample_func + + # ------------------------------------------------------------------ + # Valid post-resample operations + # ------------------------------------------------------------------ + + def interpolate( + self, + method: str, + target_cols: Optional[List[str]] = None, + show_interpolated: bool = False, + ) -> TSDF: + """ + Interpolate missing values produced by the resample step. + + :param method: interpolation method — "linear", "zero", "null", "bfill", or "ffill" + :param target_cols: columns to interpolate (default: all non-key columns) + :param show_interpolated: if True, add a column indicating interpolated rows + :return: a plain TSDF with interpolated data + """ + import logging + + import pandas as pd + + from tempo.interpol import backward_fill, forward_fill, zero_fill + from tempo.interpol import interpolate as interpol_func + + logger = logging.getLogger(__name__) + + tsdf = self._tsdf + + # Resolve target columns + if target_cols is None: + prohibited_cols: List[str] = tsdf.series_ids + [tsdf.ts_col] + target_cols = [col for col in tsdf.df.columns if col not in prohibited_cols] + + # Map method name to interpolation function + fn: Union[str, Callable[[pd.Series], pd.Series]] + if method == "linear": + fn = "linear" + elif method == "null": + return tsdf + elif method == "zero": + fn = zero_fill + elif method == "bfill": + fn = backward_fill + elif method == "ffill": + fn = forward_fill + else: + fn = method + + interpolated_tsdf = interpol_func( + tsdf=tsdf, + cols=target_cols, + fn=fn, + leading_margin=2, + lagging_margin=2, + ) + + if show_interpolated: + logger.warning( + "show_interpolated=True is not yet implemented in the refactored version" + ) + + return interpolated_tsdf + + def as_tsdf(self) -> TSDF: + """Return the underlying TSDF without resample metadata (explicit escape hatch).""" + return self._tsdf + + def show(self, n: int = 20, truncate: bool = True) -> None: + """Delegates to the internal TSDF's show().""" + self._tsdf.df.show(n, truncate) + + def __repr__(self) -> str: + return ( + f"ResampledTSDF(freq={self._resample_freq!r}, " + f"func={self._resample_func!r}, " + f"ts_col={self.ts_col!r}, " + f"series_ids={self.series_ids!r})" + ) diff --git a/python/tempo/stats.py b/python/tempo/stats.py index 9db26951..1a9bb464 100644 --- a/python/tempo/stats.py +++ b/python/tempo/stats.py @@ -294,16 +294,16 @@ def calc_bars( # - min/max compute column-wise (may mix values from different rows) resample_open = tsdf.resample( freq=freq, func="floor", metricCols=metric_cols, prefix="open", fill=fill - ) + ).as_tsdf() resample_low = tsdf.resample( freq=freq, func="min", metricCols=metric_cols, prefix="low", fill=fill - ) + ).as_tsdf() resample_high = tsdf.resample( freq=freq, func="max", metricCols=metric_cols, prefix="high", fill=fill - ) + ).as_tsdf() resample_close = tsdf.resample( freq=freq, func="ceil", metricCols=metric_cols, prefix="close", fill=fill - ) + ).as_tsdf() join_cols = resample_open.series_ids + [resample_open.ts_col] bars = ( diff --git a/python/tempo/tsdf.py b/python/tempo/tsdf.py index 3339dcd4..44ed81e0 100644 --- a/python/tempo/tsdf.py +++ b/python/tempo/tsdf.py @@ -40,6 +40,7 @@ is_time_format, sub_seconds_precision_digits, ) +from tempo.resample_result import ResampledTSDF from tempo.typing import ColumnOrName, PandasGroupedMapFunction, PandasMapIterFunction logger = logging.getLogger(__name__) @@ -124,8 +125,6 @@ def __init__( ts_schema: Optional[TSSchema] = None, ts_col: Optional[str] = None, series_ids: Optional[Collection[str]] = None, - resample_freq: Optional[str] = None, - resample_func: Optional[Union[Callable, str]] = None, ) -> None: self.df = df # construct schema if we don't already have one @@ -137,10 +136,6 @@ def __init__( # validate that this schema works for this DataFrame self.ts_schema.validate(df.schema) - # Optional resample metadata (used when this TSDF is created from resample()) - self.resample_freq = resample_freq - self.resample_func = resample_func - def __repr__(self) -> str: return f"{self.__class__.__name__}(df={self.df}, ts_schema={self.ts_schema})" @@ -160,8 +155,6 @@ def __withTransformedDF(self, new_df: DataFrame) -> TSDF: return TSDF( new_df, ts_schema=copy.deepcopy(self.ts_schema), - resample_freq=self.resample_freq, - resample_func=self.resample_func, ) def __withStandardizedColOrder(self) -> TSDF: @@ -1466,7 +1459,7 @@ def resample( prefix: Optional[str] = None, fill: Optional[bool] = None, perform_checks: bool = True, - ) -> TSDF: + ) -> ResampledTSDF: """ function to upsample based on frequency and aggregate function similar to pandas :param freq: frequency for upsample - valid inputs are "hr", "min", "sec" corresponding to hour, minute, or second @@ -1475,7 +1468,7 @@ def resample( :param prefix - supply a prefix for the newly sampled columns :param fill - Boolean - set to True if the desired output should contain filled in gaps (with 0s currently) :param perform_checks: calculate time horizon and warnings if True (default is True) - :return: TSDF object with sample data using aggregate function + :return: ResampledTSDF object with sample data using aggregate function """ t_resample_utils.validateFuncExists(func) @@ -1486,12 +1479,11 @@ def resample( enriched_df: DataFrame = t_resample.aggregate( self, freq, func, metricCols, prefix, fill ) - return TSDF( + plain_tsdf = TSDF( enriched_df, ts_schema=copy.deepcopy(self.ts_schema), - resample_freq=freq, - resample_func=func, ) + return ResampledTSDF(plain_tsdf, resample_freq=freq, resample_func=func) def interpolate( self, @@ -1518,70 +1510,33 @@ def interpolate( :param perform_checks: calculate time horizon and warnings if True (default is True) :return: new TSDF object containing interpolated data """ - - # Set defaults for target columns, timestamp column and partition columns when not provided if freq is None: raise ValueError("freq must be provided") if func is None: raise ValueError("func must be provided") + + # Resolve target columns using the same defaults as before if ts_col is None: ts_col = self.ts_col if partition_cols is None: partition_cols = self.series_ids if target_cols is None: prohibited_cols: List[str] = partition_cols + [ts_col] - # Don't filter by data type - allow all columns for PR-421 compatibility target_cols = [col for col in self.df.columns if col not in prohibited_cols] - # First resample the data - # Don't fill with zeros - let interpolation handle the nulls - resampled_tsdf = self.resample( + # Delegate through resample().interpolate() + return self.resample( freq=freq, func=func, metricCols=target_cols, - fill=False, # Don't fill - interpolation will handle nulls + fill=False, perform_checks=perform_checks, + ).interpolate( + method=method, + target_cols=target_cols, + show_interpolated=show_interpolated, ) - # Import interpolation function and pre-defined fill functions - from tempo.interpol import backward_fill, forward_fill, zero_fill - from tempo.interpol import interpolate as interpol_func - - # Map method names to interpolation functions (no lambdas) - fn: Union[str, Callable[[pd.Series], pd.Series]] - if method == "linear": - fn = "linear" # String method for pandas interpolation - elif method == "null": - # For null method, we don't fill - just return the resampled data - return resampled_tsdf - elif method == "zero": - fn = zero_fill - elif method == "bfill": - fn = backward_fill - elif method == "ffill": - fn = forward_fill - else: - # Assume it's a valid pandas interpolation method string - fn = method - - # Apply interpolation to the resampled data - interpolated_tsdf = interpol_func( - tsdf=resampled_tsdf, - cols=target_cols, - fn=fn, - leading_margin=2, - lagging_margin=2, - ) - - if show_interpolated: - # Add a column indicating which rows were interpolated - # This would require tracking which rows had nulls before interpolation - logger.warning( - "show_interpolated=True is not yet implemented in the refactored version" - ) - - return interpolated_tsdf - def calc_bars( tsdf, freq: str, @@ -1590,16 +1545,16 @@ def calc_bars( ) -> TSDF: resample_open = tsdf.resample( freq=freq, func="floor", metricCols=metricCols, prefix="open", fill=fill - ) + ).as_tsdf() resample_low = tsdf.resample( freq=freq, func="min", metricCols=metricCols, prefix="low", fill=fill - ) + ).as_tsdf() resample_high = tsdf.resample( freq=freq, func="max", metricCols=metricCols, prefix="high", fill=fill - ) + ).as_tsdf() resample_close = tsdf.resample( freq=freq, func="ceil", metricCols=metricCols, prefix="close", fill=fill - ) + ).as_tsdf() join_cols = resample_open.series_ids + [resample_open.ts_col] bars = ( diff --git a/python/tests/base.py b/python/tests/base.py index ab4cb583..558aaddb 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -17,6 +17,7 @@ # helper functions + def prefix_value(key: str, d: dict): # look for an exact match if key in d: @@ -27,8 +28,10 @@ def prefix_value(key: str, d: dict): return d[k] return None + # test classes + class TestDataFrameBuilder: """ A class to hold metadata about a Spark DataFrame @@ -172,7 +175,9 @@ def as_sdf(self) -> DataFrame: # handle nested columns if "." in ts_col: col, field = ts_col.split(".") - convert_field_expr = self.to_timestamp_ntz_compat(sfn.col(col).getField(field)) + convert_field_expr = self.to_timestamp_ntz_compat( + sfn.col(col).getField(field) + ) df = df.withColumn( col, sfn.col(col).withField(field, convert_field_expr) ) @@ -200,7 +205,9 @@ def as_sdf(self) -> DataFrame: col, sfn.col(col).withField(field, convert_field_expr) ) else: - df = df.withColumn(decimal_col, sfn.col(decimal_col).cast("decimal")) + df = df.withColumn( + decimal_col, sfn.col(decimal_col).cast("decimal") + ) return df @@ -208,7 +215,15 @@ def _parse_complex_schema(self, schema_str: str): """ Parse a schema string that may contain struct types """ - from pyspark.sql.types import DoubleType, FloatType, IntegerType, LongType, StringType, StructField, StructType + from pyspark.sql.types import ( + DoubleType, + FloatType, + IntegerType, + LongType, + StringType, + StructField, + StructType, + ) # This is a simplified parser - in production you'd want a more robust solution # For now, we'll manually handle the specific case we need @@ -231,10 +246,14 @@ def _parse_complex_schema(self, schema_str: str): if "event_ts string" in struct_def: struct_fields = [ StructField("event_ts", StringType(), True), - StructField("parsed_ts", StringType(), True), # Will be converted to timestamp later - StructField("double_ts", DoubleType(), True) + StructField( + "parsed_ts", StringType(), True + ), # Will be converted to timestamp later + StructField("double_ts", DoubleType(), True), ] - fields.append(StructField(field_name, StructType(struct_fields), True)) + fields.append( + StructField(field_name, StructType(struct_fields), True) + ) else: # Generic struct handling would go here pass @@ -325,11 +344,7 @@ def _cleanup_delta_warehouse(cls) -> None: """ try: # List of paths to clean up - cleanup_paths = [ - "spark-warehouse", - "metastore_db", - "derby.log" - ] + cleanup_paths = ["spark-warehouse", "metastore_db", "derby.log"] for path in cleanup_paths: if os.path.exists(path): @@ -394,21 +409,22 @@ def setUp(self) -> None: # Build file path from module path # For example: tests.join.test_strategies_integration -> join/test_strategies_integration # TODO: Remove this unittest-specific branch once pytest is fully adopted - if module_path.startswith('tests.'): + if module_path.startswith("tests."): # Running with unittest - remove 'tests.' prefix and replace dots with slashes - file_name = module_path[6:].replace('.', '/') + file_name = module_path[6:].replace(".", "/") else: # Running with pytest - module path might be truncated # Use inspect to get the actual file path import inspect import os + test_file = inspect.getfile(self.__class__) # Extract path relative to tests/ directory - if '/tests/' in test_file: + if "/tests/" in test_file: # Get everything after 'tests/' - relative_path = test_file.split('/tests/')[-1] + relative_path = test_file.split("/tests/")[-1] # Remove .py extension - file_name = relative_path.replace('.py', '') + file_name = relative_path.replace(".py", "") else: # Fallback to module name file_name = module_path @@ -455,13 +471,14 @@ def getTestDataDirPath(cls) -> str: elif cwd == "python": dir_path = "./tests" elif cwd != "tests": - raise RuntimeError(f"Cannot locate test dir, running from dir {os.getcwd()}") + raise RuntimeError( + f"Cannot locate test dir, running from dir {os.getcwd()}" + ) return os.path.abspath(os.path.join(dir_path, cls.TEST_DATA_FOLDER)) @classmethod - def getTestDataFilePath(cls, test_file_name: str, extension: str = '.json') -> str: - return os.path.join(cls.getTestDataDirPath(), - f"{test_file_name}{extension}") + def getTestDataFilePath(cls, test_file_name: str, extension: str = ".json") -> str: + return os.path.join(cls.getTestDataDirPath(), f"{test_file_name}{extension}") def __loadTestData(self, file_name: str) -> dict: """ @@ -479,7 +496,7 @@ def __loadTestData(self, file_name: str) -> dict: # proces the data file with open(test_data_filename) as f: - base_path = "file://"+ self.getTestDataDirPath() + "/" + base_path = "file://" + self.getTestDataDirPath() + "/" test_data = jsonref.load(f, base_uri=base_path) return test_data @@ -578,4 +595,4 @@ def assertDataFrameEquality( ignore_column_order=ignore_column_order, ignore_nullable=ignore_nullable, ignore_metadata=True, - ) \ No newline at end of file + ) diff --git a/python/tests/interpol_tests.py b/python/tests/interpol_tests.py index acb57414..83c1483d 100644 --- a/python/tests/interpol_tests.py +++ b/python/tests/interpol_tests.py @@ -8,91 +8,91 @@ from tests.base import SparkTest -@parameterized_class(("data_type", "interpol_cols"),[ - ('simple_ts_idx', ["open", "close"] ), - ('simple_ts_no_series', ["trade_pr"]) -]) +@parameterized_class( + ("data_type", "interpol_cols"), + [("simple_ts_idx", ["open", "close"]), ("simple_ts_no_series", ["trade_pr"])], +) class InterpolationTests(SparkTest): def test_zero_fill(self): # load the initial & expected dataframes - init_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "init").as_tsdf() - expected_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "expected").as_tsdf() + init_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "init" + ).as_tsdf() + expected_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "expected" + ).as_tsdf() # interpolate - actual_tsdf: TSDF = interpolate(init_tsdf, - self.interpol_cols, - zero_fill, - 0, - 0) + actual_tsdf: TSDF = interpolate(init_tsdf, self.interpol_cols, zero_fill, 0, 0) actual_tsdf.withNaturalOrdering().show() # compare - self.assertDataFrameEquality(expected_tsdf.withNaturalOrdering(), - actual_tsdf.withNaturalOrdering()) + self.assertDataFrameEquality( + expected_tsdf.withNaturalOrdering(), actual_tsdf.withNaturalOrdering() + ) def test_linear(self): # load the initial & expected dataframes - init_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "init").as_tsdf() - expected_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "expected").as_tsdf() + init_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "init" + ).as_tsdf() + expected_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "expected" + ).as_tsdf() # interpolate # Note: Linear interpolation behavior: # - Values between known points are linearly interpolated # - Values at the end with no following point are forward-filled with the last known value # - This is the default pandas behavior (limit_direction='forward') - actual_tsdf: TSDF = interpolate(init_tsdf, - self.interpol_cols, - "linear", - 1, - 1) + actual_tsdf: TSDF = interpolate(init_tsdf, self.interpol_cols, "linear", 1, 1) actual_tsdf.withNaturalOrdering().show() # compare - self.assertDataFrameEquality(expected_tsdf.withNaturalOrdering(), - actual_tsdf.withNaturalOrdering()) + self.assertDataFrameEquality( + expected_tsdf.withNaturalOrdering(), actual_tsdf.withNaturalOrdering() + ) def test_forward_fill(self): # load the initial & expected dataframes - init_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "init").as_tsdf() - expected_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "expected").as_tsdf() + init_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "init" + ).as_tsdf() + expected_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "expected" + ).as_tsdf() # interpolate - actual_tsdf: TSDF = interpolate(init_tsdf, - self.interpol_cols, - forward_fill, - 1, - 0) + actual_tsdf: TSDF = interpolate( + init_tsdf, self.interpol_cols, forward_fill, 1, 0 + ) actual_tsdf.withNaturalOrdering().show() # compare - self.assertDataFrameEquality(expected_tsdf.withNaturalOrdering(), - actual_tsdf.withNaturalOrdering()) + self.assertDataFrameEquality( + expected_tsdf.withNaturalOrdering(), actual_tsdf.withNaturalOrdering() + ) def test_backward_fill(self): # load the initial & expected dataframes - init_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "init").as_tsdf() - expected_tsdf: TSDF = self.get_test_function_df_builder(self.data_type, - "expected").as_tsdf() + init_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "init" + ).as_tsdf() + expected_tsdf: TSDF = self.get_test_function_df_builder( + self.data_type, "expected" + ).as_tsdf() # interpolate - actual_tsdf: TSDF = interpolate(init_tsdf, - self.interpol_cols, - backward_fill, - 0, - 1) + actual_tsdf: TSDF = interpolate( + init_tsdf, self.interpol_cols, backward_fill, 0, 1 + ) actual_tsdf.withNaturalOrdering().show() # compare - self.assertDataFrameEquality(expected_tsdf.withNaturalOrdering(), - actual_tsdf.withNaturalOrdering()) + self.assertDataFrameEquality( + expected_tsdf.withNaturalOrdering(), actual_tsdf.withNaturalOrdering() + ) # Add tests for non-numeric columns from PR-421 @@ -108,11 +108,7 @@ def test_non_numeric_forward_fill(self): # Apply forward fill to all columns # Use leading_margin=1 to include previous values for forward fill result_tsdf = interpolate( - tsdf, - ["string_col", "bool_col", "int_col"], - "ffill", - 1, - 0 + tsdf, ["string_col", "bool_col", "int_col"], "ffill", 1, 0 ) result_df = result_tsdf.df.orderBy("event_ts").collect() @@ -138,11 +134,7 @@ def test_non_numeric_backward_fill(self): # Apply backward fill to all columns # Use lagging_margin=1 to include next values for backward fill result_tsdf = interpolate( - tsdf, - ["string_col", "bool_col", "int_col"], - "bfill", - 0, - 1 + tsdf, ["string_col", "bool_col", "int_col"], "bfill", 0, 1 ) result_df = result_tsdf.df.orderBy("event_ts").collect() @@ -167,11 +159,7 @@ def test_non_numeric_null_fill(self): # Apply null fill result_tsdf = interpolate( - tsdf, - ["string_col", "bool_col", "int_col"], - "null", - 0, - 0 + tsdf, ["string_col", "bool_col", "int_col"], "null", 0, 0 ) result_df = result_tsdf.df.orderBy("event_ts").collect() @@ -226,11 +214,7 @@ def test_tsdf_interpolate_method(self): tsdf = self.get_test_function_df_builder("test_data").as_tsdf() # Test linear interpolation through TSDF method - result_tsdf = tsdf.interpolate( - method="linear", - freq="30 min", - func="mean" - ) + result_tsdf = tsdf.interpolate(method="linear", freq="30 min", func="mean") result_df = result_tsdf.df.orderBy("event_ts").collect() @@ -240,6 +224,26 @@ def test_tsdf_interpolate_method(self): self.assertAlmostEqual(result_df[3]["value_a"], 2.5) self.assertAlmostEqual(result_df[3]["value_b"], 25.0) + def test_resample_then_interpolate_chain(self): + """Verify tsdf.resample(freq, func).interpolate(method) works and returns TSDF""" + from tempo.resample_result import ResampledTSDF + + # Reuse existing test_tsdf_interpolate_method's test_data + tsdf = self.get_test_df_builder( + "TSDBInterpolationTests", "test_tsdf_interpolate_method", "test_data" + ).as_tsdf() + + # Chained pattern: resample returns ResampledTSDF, then interpolate returns TSDF + resampled = tsdf.resample(freq="30 min", func="mean") + self.assertIsInstance(resampled, ResampledTSDF) + + result_tsdf = resampled.interpolate(method="linear") + self.assertIsInstance(result_tsdf, TSDF) + self.assertNotIsInstance(result_tsdf, ResampledTSDF) + + # Verify the result has data + self.assertGreater(result_tsdf.df.count(), 0) + class InterpolHelperFunctionsTests(SparkTest): """Tests for standalone interpolation helper functions""" diff --git a/python/tests/intervals/core/boundaries_tests.py b/python/tests/intervals/core/boundaries_tests.py index f36929eb..b4e2e68d 100644 --- a/python/tests/intervals/core/boundaries_tests.py +++ b/python/tests/intervals/core/boundaries_tests.py @@ -30,7 +30,7 @@ def test_boundary_converter_for_string(self): converter = BoundaryConverter.for_type(sample) assert converter.original_type == str - assert converter.original_format == '%Y-%m-%d' + assert converter.original_format == "%Y-%m-%d" timestamp = converter.to_timestamp(sample) assert isinstance(timestamp, Timestamp) @@ -104,7 +104,9 @@ def test_boundary_converter_for_timestamp(self): def test_boundary_converter_unsupported_type(self): sample = [2023, 10, 25] - with pytest.raises(ValueError, match="Unsupported boundary type: "): + with pytest.raises( + ValueError, match="Unsupported boundary type: " + ): BoundaryConverter.for_type(sample) @@ -188,28 +190,28 @@ def from_timestamp(timestamp): return str(timestamp) return BoundaryValue( - _timestamp=Timestamp("2023-01-01"), - _converter=MockConverter() + _timestamp=Timestamp("2023-01-01"), _converter=MockConverter() ) @pytest.fixture def interval_boundaries(self): - return IntervalBoundaries.create( - start="2023-01-01", - end="2023-12-31" - ) + return IntervalBoundaries.create(start="2023-01-01", end="2023-12-31") def test_create_interval_boundaries(self, interval_boundaries): assert isinstance(interval_boundaries, IntervalBoundaries) assert interval_boundaries.start == "2023-01-01" assert interval_boundaries.end == "2023-12-31" - def test_interval_boundaries_internal_start(self, interval_boundaries, boundary_value_mock): + def test_interval_boundaries_internal_start( + self, interval_boundaries, boundary_value_mock + ): start = interval_boundaries.internal_start assert isinstance(start, BoundaryValue) assert start.internal_value == Timestamp("2023-01-01") - def test_interval_boundaries_internal_end(self, interval_boundaries, boundary_value_mock): + def test_interval_boundaries_internal_end( + self, interval_boundaries, boundary_value_mock + ): end = interval_boundaries.internal_end assert isinstance(end, BoundaryValue) assert end.internal_value == Timestamp("2023-12-31") @@ -217,35 +219,39 @@ def test_interval_boundaries_internal_end(self, interval_boundaries, boundary_va def test_boundary_value_equality(self, boundary_value_mock): other = BoundaryValue( _timestamp=Timestamp("2023-01-01"), - _converter=boundary_value_mock._converter + _converter=boundary_value_mock._converter, ) assert boundary_value_mock == other def test_boundary_value_comparison_earlier_less_than(self, boundary_value_mock): earlier = BoundaryValue( _timestamp=Timestamp("2022-01-01"), - _converter=boundary_value_mock._converter + _converter=boundary_value_mock._converter, ) assert earlier < boundary_value_mock def test_boundary_value_comparison_later_greater_than(self, boundary_value_mock): later = BoundaryValue( _timestamp=Timestamp("2024-01-01"), - _converter=boundary_value_mock._converter + _converter=boundary_value_mock._converter, ) assert later > boundary_value_mock - def test_boundary_value_comparison_earlier_less_than_or_equal(self, boundary_value_mock): + def test_boundary_value_comparison_earlier_less_than_or_equal( + self, boundary_value_mock + ): earlier = BoundaryValue( _timestamp=Timestamp("2022-01-01"), - _converter=boundary_value_mock._converter + _converter=boundary_value_mock._converter, ) assert earlier <= boundary_value_mock - def test_boundary_value_comparison_later_greater_than_or_equal(self, boundary_value_mock): + def test_boundary_value_comparison_later_greater_than_or_equal( + self, boundary_value_mock + ): later = BoundaryValue( _timestamp=Timestamp("2024-01-01"), - _converter=boundary_value_mock._converter + _converter=boundary_value_mock._converter, ) assert later >= boundary_value_mock @@ -254,11 +260,9 @@ class TestInternalBoundaryAccessor: def test_get_boundaries(self): # Arrange - data = Series({ - "start_time": "2023-01-01", - "end_time": "2023-01-02", - "value": 10 - }) + data = Series( + {"start_time": "2023-01-01", "end_time": "2023-01-02", "value": 10} + ) accessor = InternalBoundaryAccessor("start_time", "end_time") # Act @@ -271,18 +275,13 @@ def test_get_boundaries(self): def test_set_boundaries(self): # Arrange - data = Series({ - "start_time": "2023-01-01", - "end_time": "2023-01-02", - "value": 10 - }) + data = Series( + {"start_time": "2023-01-01", "end_time": "2023-01-02", "value": 10} + ) accessor = InternalBoundaryAccessor("start_time", "end_time") # Create new boundaries - new_boundaries = IntervalBoundaries.create( - start="2023-02-01", - end="2023-02-05" - ) + new_boundaries = IntervalBoundaries.create(start="2023-02-01", end="2023-02-05") # Act updated_data = accessor.set_boundaries(data, new_boundaries) @@ -298,17 +297,18 @@ def test_set_boundaries(self): def test_set_boundaries_different_field_names(self): # Arrange - data = Series({ - "begin": "2023-01-01T00:00:00", - "finish": "2023-01-02T00:00:00", - "metric": 5 - }) + data = Series( + { + "begin": "2023-01-01T00:00:00", + "finish": "2023-01-02T00:00:00", + "metric": 5, + } + ) accessor = InternalBoundaryAccessor("begin", "finish") # Create new boundaries new_boundaries = IntervalBoundaries.create( - start="2023-03-15T12:00:00", - end="2023-03-16T12:00:00" + start="2023-03-15T12:00:00", end="2023-03-16T12:00:00" ) # Act @@ -321,11 +321,13 @@ def test_set_boundaries_different_field_names(self): def test_get_boundaries_with_different_types(self): # Test with Timestamp objects - data = Series({ - "start_ts": Timestamp("2023-01-01"), - "end_ts": Timestamp("2023-01-02"), - "value": 10 - }) + data = Series( + { + "start_ts": Timestamp("2023-01-01"), + "end_ts": Timestamp("2023-01-02"), + "value": 10, + } + ) accessor = InternalBoundaryAccessor("start_ts", "end_ts") boundaries = accessor.get_boundaries(data) @@ -334,11 +336,13 @@ def test_get_boundaries_with_different_types(self): assert boundaries.end == Timestamp("2023-01-02") # Test with epoch timestamps (integers) - epoch_data = Series({ - "start_epoch": 1672531200, # 2023-01-01 00:00:00 UTC - "end_epoch": 1672617600, # 2023-01-02 00:00:00 UTC - "value": 10 - }) + epoch_data = Series( + { + "start_epoch": 1672531200, # 2023-01-01 00:00:00 UTC + "end_epoch": 1672617600, # 2023-01-02 00:00:00 UTC + "value": 10, + } + ) epoch_accessor = InternalBoundaryAccessor("start_epoch", "end_epoch") epoch_boundaries = epoch_accessor.get_boundaries(epoch_data) @@ -361,7 +365,9 @@ def test_negative_timestamp_with_string(self): """Test that string inputs resulting in negative timestamps raise ValueError""" # A date far in the past that would result in a negative timestamp sample = "1800-01-01" - converter = BoundaryConverter.for_type("2023-01-01") # Create converter with valid format + converter = BoundaryConverter.for_type( + "2023-01-01" + ) # Create converter with valid format # Using converter directly to test the validation with pytest.raises(ValueError, match="Timestamps cannot be negative."): @@ -371,7 +377,9 @@ def test_negative_timestamp_with_int(self): """Test that integer inputs resulting in negative timestamps raise ValueError""" # A negative epoch timestamp sample = -1000000 # Negative seconds since epoch - converter = BoundaryConverter.for_type(1698192000) # Create converter with valid format + converter = BoundaryConverter.for_type( + 1698192000 + ) # Create converter with valid format with pytest.raises(ValueError, match="Timestamps cannot be negative."): converter.to_timestamp(sample) @@ -380,7 +388,9 @@ def test_negative_timestamp_with_float(self): """Test that float inputs resulting in negative timestamps raise ValueError""" # A negative epoch timestamp as float sample = -1000000.5 # Negative seconds since epoch - converter = BoundaryConverter.for_type(1698192000.0) # Create converter with valid format + converter = BoundaryConverter.for_type( + 1698192000.0 + ) # Create converter with valid format with pytest.raises(ValueError, match="Timestamps cannot be negative."): converter.to_timestamp(sample) @@ -389,7 +399,9 @@ def test_negative_timestamp_with_datetime(self): """Test that datetime inputs resulting in negative timestamps raise ValueError""" # A date far in the past that would result in a negative timestamp sample = datetime(1800, 1, 1) - converter = BoundaryConverter.for_type(datetime(2023, 1, 1)) # Create converter with valid format + converter = BoundaryConverter.for_type( + datetime(2023, 1, 1) + ) # Create converter with valid format with pytest.raises(ValueError, match="Timestamps cannot be negative."): converter.to_timestamp(sample) @@ -397,8 +409,12 @@ def test_negative_timestamp_with_datetime(self): def test_negative_timestamp_with_pandas_timestamp(self): """Test that Timestamp inputs with negative values raise ValueError""" # Create a negative timestamp directly (may need to adjust based on how pandas handles this) - sample = Timestamp('1800-01-01') # A date that would result in a negative timestamp value - converter = BoundaryConverter.for_type(Timestamp('2023-01-01')) # Create converter with valid format + sample = Timestamp( + "1800-01-01" + ) # A date that would result in a negative timestamp value + converter = BoundaryConverter.for_type( + Timestamp("2023-01-01") + ) # Create converter with valid format with pytest.raises(ValueError, match="Timestamps cannot be negative."): converter.to_timestamp(sample) @@ -417,10 +433,7 @@ def test_interval_boundaries_creation_with_negative_start(self): # Try to create IntervalBoundaries with a negative start timestamp with pytest.raises(ValueError, match="Timestamps cannot be negative."): - IntervalBoundaries.create( - start="1800-01-01", - end="2023-01-01" - ) + IntervalBoundaries.create(start="1800-01-01", end="2023-01-01") def test_interval_boundaries_creation_with_negative_end(self): """Test that creating IntervalBoundaries with a negative end timestamp raises ValueError""" @@ -428,7 +441,4 @@ def test_interval_boundaries_creation_with_negative_end(self): # Try to create IntervalBoundaries with a negative end timestamp with pytest.raises(ValueError, match="Timestamps cannot be negative."): - IntervalBoundaries.create( - start="2023-01-01", - end="1800-01-01" - ) + IntervalBoundaries.create(start="2023-01-01", end="1800-01-01") diff --git a/python/tests/intervals/core/interval_tests.py b/python/tests/intervals/core/interval_tests.py index 42c1ffb7..bf54f1ae 100644 --- a/python/tests/intervals/core/interval_tests.py +++ b/python/tests/intervals/core/interval_tests.py @@ -4,17 +4,34 @@ from pandas import Series from tempo.intervals.core.boundaries import _BoundaryAccessor -from tempo.intervals.core.exceptions import InvalidDataTypeError, EmptyIntervalError, InvalidMetricColumnError, \ - InvalidSeriesColumnError +from tempo.intervals.core.exceptions import ( + InvalidDataTypeError, + EmptyIntervalError, + InvalidMetricColumnError, + InvalidSeriesColumnError, +) from tempo.intervals.core.interval import Interval class TestInterval: def test_create_interval(self): data = Series( - {"start": "2023-01-01", "end": "2023-01-02", "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval = Interval.create(data, start_field="start", end_field="end", series_fields=["series1", "series2"], - metric_fields=["metric1", "metric2"]) + { + "start": "2023-01-01", + "end": "2023-01-02", + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval = Interval.create( + data, + start_field="start", + end_field="end", + series_fields=["series1", "series2"], + metric_fields=["metric1", "metric2"], + ) assert interval.start == "2023-01-01" assert interval.start_field == "start" @@ -26,9 +43,22 @@ def test_create_interval(self): def test_create_interval_none_start_field(self): data = Series( - {"start": None, "end": "2023-01-02", "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval = Interval.create(data, start_field="start", end_field="end", series_fields=["series1", "series2"], - metric_fields=["metric1", "metric2"]) + { + "start": None, + "end": "2023-01-02", + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval = Interval.create( + data, + start_field="start", + end_field="end", + series_fields=["series1", "series2"], + metric_fields=["metric1", "metric2"], + ) assert interval.start is None assert interval.start_field == "start" @@ -40,9 +70,22 @@ def test_create_interval_none_start_field(self): def test_create_interval_none_end_field(self): data = Series( - {"start": "2023-01-01", "end": None, "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval = Interval.create(data, start_field="start", end_field="end", series_fields=["series1", "series2"], - metric_fields=["metric1", "metric2"]) + { + "start": "2023-01-01", + "end": None, + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval = Interval.create( + data, + start_field="start", + end_field="end", + series_fields=["series1", "series2"], + metric_fields=["metric1", "metric2"], + ) assert interval.start == "2023-01-01" assert interval.start_field == "start" @@ -53,30 +96,61 @@ def test_create_interval_none_end_field(self): assert interval.metric_fields == ["metric1", "metric2"] def test_create_interval_series_fields_not_string(self): - data = Series({"start": None, "end": "2023-01-02", 1: "A", "metric1": 10, "metric2": 15}) + data = Series( + {"start": None, "end": "2023-01-02", 1: "A", "metric1": 10, "metric2": 15} + ) - with pytest.raises(InvalidSeriesColumnError, match="All series_fields must be strings"): - Interval.create(data, start_field="start", end_field="end", series_fields=[1], - metric_fields=["metric1", "metric2"]) + with pytest.raises( + InvalidSeriesColumnError, match="All series_fields must be strings" + ): + Interval.create( + data, + start_field="start", + end_field="end", + series_fields=[1], + metric_fields=["metric1", "metric2"], + ) def test_create_interval_series_fields_not_sequence(self): data = Series({"start": "2023-01-01", "end": "2023-01-05", "series-id1": "ABC"}) - with pytest.raises(InvalidSeriesColumnError, match=r"series_fields must be a sequence"): - Interval.create(data, start_field="start", end_field="end", series_fields="series-id1") + with pytest.raises( + InvalidSeriesColumnError, match=r"series_fields must be a sequence" + ): + Interval.create( + data, start_field="start", end_field="end", series_fields="series-id1" + ) def test_create_interval_metrics_fields_not_string(self): - data = Series({"start": None, "end": "2023-01-02", "series": "A", 1: 10, "metric": 15}) + data = Series( + {"start": None, "end": "2023-01-02", "series": "A", 1: 10, "metric": 15} + ) - with pytest.raises(InvalidMetricColumnError, match="All metric_fields must be strings"): - Interval.create(data, start_field="start", end_field="end", series_fields=["series"], - metric_fields=[1, "metric"]) + with pytest.raises( + InvalidMetricColumnError, match="All metric_fields must be strings" + ): + Interval.create( + data, + start_field="start", + end_field="end", + series_fields=["series"], + metric_fields=[1, "metric"], + ) def test_create_interval_metrics_fields_not_sequence(self): - data = Series({"start": None, "end": "2023-01-02", "series": "A", 1: 10, "metric2": 15}) + data = Series( + {"start": None, "end": "2023-01-02", "series": "A", 1: 10, "metric2": 15} + ) - with pytest.raises(InvalidMetricColumnError, match="metric_fields must be a sequence"): - Interval.create(data, start_field="start", end_field="end", series_fields=["series"], - metric_fields="metric") + with pytest.raises( + InvalidMetricColumnError, match="metric_fields must be a sequence" + ): + Interval.create( + data, + start_field="start", + end_field="end", + series_fields=["series"], + metric_fields="metric", + ) def test_invalid_data_type(self): data = {"start": "2023-01-01", "end": "2023-01-02"} # Not a pandas Series @@ -129,51 +203,132 @@ def test_overlaps_with(self): def test_validate_metric_alignment(self): data1 = Series( - {"start": "2023-01-01", "end": "2023-01-05", "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval1 = Interval.create(data1, start_field="start", end_field="end", metric_fields=["metric1", "metric2"]) + { + "start": "2023-01-01", + "end": "2023-01-05", + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval1 = Interval.create( + data1, + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], + ) - data2 = Series({"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10}) - interval2 = Interval.create(data2, start_field="start", end_field="end", metric_fields=["metric1", "metric2"]) + data2 = Series( + {"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10} + ) + interval2 = Interval.create( + data2, + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], + ) result = interval1.validate_metrics_alignment(interval2) assert result.is_valid def test_validate_metric_alignment_invalid(self): data1 = Series( - {"start": "2023-01-01", "end": "2023-01-05", "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval1 = Interval.create(data1, start_field="start", end_field="end", metric_fields=["metric1", "metric2"]) + { + "start": "2023-01-01", + "end": "2023-01-05", + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval1 = Interval.create( + data1, + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], + ) - data2 = Series({"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10}) - interval2 = Interval.create(data2, start_field="start", end_field="end", metric_fields=["metric1"]) + data2 = Series( + {"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10} + ) + interval2 = Interval.create( + data2, start_field="start", end_field="end", metric_fields=["metric1"] + ) - expected_msg = re.escape("metric_fields don't match: ['metric1', 'metric2'] vs ['metric1']") + expected_msg = re.escape( + "metric_fields don't match: ['metric1', 'metric2'] vs ['metric1']" + ) with pytest.raises(InvalidMetricColumnError, match=expected_msg): interval1.validate_metrics_alignment(interval2) def test_validate_series_alignment(self): data1 = Series( - {"start": "2023-01-01", "end": "2023-01-05", "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval1 = Interval.create(data1, start_field="start", end_field="end", series_fields=["series1", "series2"], - metric_fields=["metric1", "metric2"]) + { + "start": "2023-01-01", + "end": "2023-01-05", + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval1 = Interval.create( + data1, + start_field="start", + end_field="end", + series_fields=["series1", "series2"], + metric_fields=["metric1", "metric2"], + ) - data2 = Series({"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10}) - interval2 = Interval.create(data2, start_field="start", end_field="end", series_fields=["series1", "series2"], - metric_fields=["metric1", "metric2"]) + data2 = Series( + {"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10} + ) + interval2 = Interval.create( + data2, + start_field="start", + end_field="end", + series_fields=["series1", "series2"], + metric_fields=["metric1", "metric2"], + ) result = interval1.validate_series_alignment(interval2) assert result.is_valid def test_validate_series_alignment_invalid(self): data1 = Series( - {"start": "2023-01-01", "end": "2023-01-05", "series1": "A", "series2": "B", "metric1": 10, "metric2": 15}) - interval1 = Interval.create(data1, start_field="start", end_field="end", series_fields=["series1"], - metric_fields=["metric1", "metric2"]) + { + "start": "2023-01-01", + "end": "2023-01-05", + "series1": "A", + "series2": "B", + "metric1": 10, + "metric2": 15, + } + ) + interval1 = Interval.create( + data1, + start_field="start", + end_field="end", + series_fields=["series1"], + metric_fields=["metric1", "metric2"], + ) - data2 = Series({"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10}) - interval2 = Interval.create(data2, start_field="start", end_field="end", series_fields=["series1", "series2"], - metric_fields=["metric1", "metric2"]) + data2 = Series( + {"start": "2023-01-02", "end": "2023-01-04", "metric1": 5, "metric2": 10} + ) + interval2 = Interval.create( + data2, + start_field="start", + end_field="end", + series_fields=["series1", "series2"], + metric_fields=["metric1", "metric2"], + ) - expected_msg = re.escape("series_fields don't match: ['series1'] vs ['series1', 'series2']") + expected_msg = re.escape( + "series_fields don't match: ['series1'] vs ['series1', 'series2']" + ) with pytest.raises(InvalidSeriesColumnError, match=expected_msg): interval1.validate_series_alignment(interval2) @@ -187,7 +342,9 @@ def test_validate_not_point_in_time_point_in_time_interval(self): data = Series({"start_time": 1, "end_time": 1}) boundary_accessor = _BoundaryAccessor("start_time", "end_time") - with pytest.raises(InvalidDataTypeError, match="Point-in-Time Intervals are not supported"): + with pytest.raises( + InvalidDataTypeError, match="Point-in-Time Intervals are not supported" + ): Interval._validate_not_point_in_time(data, boundary_accessor) def test_validate_not_point_in_time_missing_boundary_fields(self): @@ -201,57 +358,85 @@ def test_validate_not_point_in_time_missing_boundary_fields(self): class TestStillValidLegacy: def test_update_interval_boundary_start(self): - interval = Interval.create(Series( - {"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"} - ), "start", "end") + interval = Interval.create( + Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"}), + "start", + "end", + ) updated = interval.update_start("2023-01-01T01:30:00") assert updated.data["start"] == "2023-01-01T01:30:00" def test_update_interval_boundary_end(self): - interval = Interval.create(Series( - {"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"} - ), "start", "end") + interval = Interval.create( + Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"}), + "start", + "end", + ) updated = interval.update_end("2023-01-01T02:30:00") assert updated.data["end"] == "2023-01-01T02:30:00" def test_update_interval_boundary_return_new_copy(self): - interval = Interval.create(Series( - {"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"} - ), "start", "end") + interval = Interval.create( + Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"}), + "start", + "end", + ) updated = interval.update_start("2023-01-01T01:30:00") assert id(interval) != id(updated) assert interval.data["start"] == "2023-01-01T01:00:00" def test_merge_metrics_with_list_metric_merge_true(self): - interval = Interval.create(Series({"start": "01:00", "end": "02:00", "value": 10}), "start", "end", - metric_fields=["value"]) - other = Interval.create(Series({"start": "01:00", "end": "02:00", "value": 20}), "start", "end", - metric_fields=["value"]) + interval = Interval.create( + Series({"start": "01:00", "end": "02:00", "value": 10}), + "start", + "end", + metric_fields=["value"], + ) + other = Interval.create( + Series({"start": "01:00", "end": "02:00", "value": 20}), + "start", + "end", + metric_fields=["value"], + ) expected = Series({"start": "01:00", "end": "02:00", "value": 20}) merged = interval.merge_metrics(other) assert merged.equals(expected) def test_merge_metrics_with_string_metric_column(self): - interval = Interval.create(Series({"start": "01:00", "end": "02:00", "value": 10}), "start", "end", - metric_fields=["value"]) - other = Interval.create(Series({"start": "01:00", "end": "02:00", "value": 20}), "start", "end", - metric_fields=["value"]) + interval = Interval.create( + Series({"start": "01:00", "end": "02:00", "value": 10}), + "start", + "end", + metric_fields=["value"], + ) + other = Interval.create( + Series({"start": "01:00", "end": "02:00", "value": 20}), + "start", + "end", + metric_fields=["value"], + ) expected = Series({"start": "01:00", "end": "02:00", "value": 20}) merged = interval.merge_metrics(other) assert merged.equals(expected) def test_merge_metrics_with_string_metric_columns(self): - interval = Interval.create(Series({"start": "01:00", "end": "02:00", "value1": 10, "value2": 20}), "start", - "end", - metric_fields=["value1", "value2"]) - other = Interval.create(Series( - {"start": "01:00", "end": "02:00", "value1": 20, "value2": 30} - ), "start", "end", metric_fields=["value1", "value2"]) + interval = Interval.create( + Series({"start": "01:00", "end": "02:00", "value1": 10, "value2": 20}), + "start", + "end", + metric_fields=["value1", "value2"], + ) + other = Interval.create( + Series({"start": "01:00", "end": "02:00", "value1": 20, "value2": 30}), + "start", + "end", + metric_fields=["value1", "value2"], + ) expected = Series( {"start": "01:00", "end": "02:00", "value1": 20, "value2": 30} ) @@ -263,11 +448,17 @@ def test_merge_metrics_return_new_copy(self): interval = Interval.create( Series({"start": "01:00", "end": "02:00", "value": 10}), "start", - "end", [], ["value"]) + "end", + [], + ["value"], + ) other = Interval.create( Series({"start": "01:00", "end": "02:00", "value": 20}), "start", - "end", [], ["value"]) + "end", + [], + ["value"], + ) merged = interval.merge_metrics(other) assert id(interval) != id(merged) @@ -277,13 +468,15 @@ def test_merge_metrics_handle_nan_in_child(self): Series({"start": "01:00", "end": "02:00", "value": 10}), "start", "end", - [], ["value"] + [], + ["value"], ) other = Interval.create( Series({"start": "01:00", "end": "02:00", "value": float("nan")}), "start", "end", - [], ["value"] + [], + ["value"], ) merged = interval.merge_metrics(other) diff --git a/python/tests/intervals/core/intervals_df_tests.py b/python/tests/intervals/core/intervals_df_tests.py index 8202f085..4da301b5 100644 --- a/python/tests/intervals/core/intervals_df_tests.py +++ b/python/tests/intervals/core/intervals_df_tests.py @@ -232,9 +232,7 @@ def test_make_disjoint(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_contains_interval_already_disjoint(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -242,9 +240,7 @@ def test_make_disjoint_contains_interval_already_disjoint(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_contains_intervals_equal(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -252,9 +248,7 @@ def test_make_disjoint_contains_intervals_equal(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_intervals_same_start(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -262,9 +256,7 @@ def test_make_disjoint_intervals_same_start(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_intervals_same_end(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -272,9 +264,7 @@ def test_make_disjoint_intervals_same_end(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_multiple_series(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -282,9 +272,7 @@ def test_make_disjoint_multiple_series(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_single_metric(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -292,9 +280,7 @@ def test_make_disjoint_single_metric(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_make_disjoint_interval_is_subset(self): idf_input = self.get_test_function_df_builder("init").as_idf() @@ -302,9 +288,7 @@ def test_make_disjoint_interval_is_subset(self): idf_actual = idf_input.make_disjoint() - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) def test_union_other_idf(self): idf_input_1 = self.get_test_function_df_builder("init").as_idf() @@ -400,6 +384,4 @@ def test_make_disjoint_issue_268(self): idf_actual = idf_input.make_disjoint() idf_actual.df.show(truncate=False) - self.assertDataFrameEquality( - idf_expected, idf_actual, ignore_row_order=True - ) + self.assertDataFrameEquality(idf_expected, idf_actual, ignore_row_order=True) diff --git a/python/tests/intervals/core/utils_tests.py b/python/tests/intervals/core/utils_tests.py index 4d720e68..444100d3 100644 --- a/python/tests/intervals/core/utils_tests.py +++ b/python/tests/intervals/core/utils_tests.py @@ -19,18 +19,22 @@ def interval_data(): "end_field": end_field, "series_fields": ["category"], "metric_fields": ["value"], - "intervals_data": DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-05', '2023-01-10']), - end_field: pd.to_datetime(['2023-01-07', '2023-01-15', '2023-01-20']), - 'category': ['A', 'B', 'A'], - 'value': [10, 20, 30] - }), - "reference_interval_data": Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'A', - 'value': 15 - }) + "intervals_data": DataFrame( + { + start_field: pd.to_datetime(["2023-01-01", "2023-01-05", "2023-01-10"]), + end_field: pd.to_datetime(["2023-01-07", "2023-01-15", "2023-01-20"]), + "category": ["A", "B", "A"], + "value": [10, 20, 30], + } + ), + "reference_interval_data": Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-12"), + "category": "A", + "value": 15, + } + ), } @@ -49,7 +53,7 @@ def reference_interval(interval_data): data["start_field"], data["end_field"], data["series_fields"], - data["metric_fields"] + data["metric_fields"], ) @@ -59,12 +63,14 @@ class TestIntervalUtils: def test_init(self, intervals_utils, interval_data): """Test initialization of IntervalsUtils.""" assert isinstance(intervals_utils, IntervalsUtils) - pd.testing.assert_frame_equal(intervals_utils.intervals, interval_data["intervals_data"]) + pd.testing.assert_frame_equal( + intervals_utils.intervals, interval_data["intervals_data"] + ) assert intervals_utils.disjoint_set.empty def test_disjoint_set_property(self, intervals_utils): """Test disjoint_set property getter and setter.""" - test_df = DataFrame({'test': [1, 2, 3]}) + test_df = DataFrame({"test": [1, 2, 3]}) intervals_utils.disjoint_set = test_df pd.testing.assert_frame_equal(intervals_utils.disjoint_set, test_df) @@ -91,14 +97,16 @@ def test_calculate_all_overlaps_with_empty_reference_interval(self, interval_dat result = intervals_utils._calculate_all_overlaps(mock_interval) assert result.empty - def test_calculate_all_overlaps(self, intervals_utils, reference_interval, interval_data): + def test_calculate_all_overlaps( + self, intervals_utils, reference_interval, interval_data + ): """Test _calculate_all_overlaps with overlapping intervals.""" result = intervals_utils._calculate_all_overlaps(reference_interval) start_field = interval_data["start_field"] assert len(result) == 3 - assert pd.to_datetime('2023-01-01') in result[start_field].values - assert pd.to_datetime('2023-01-05') in result[start_field].values + assert pd.to_datetime("2023-01-01") in result[start_field].values + assert pd.to_datetime("2023-01-05") in result[start_field].values def test_find_overlaps(self, intervals_utils, reference_interval, interval_data): """Test find_overlaps method.""" @@ -109,14 +117,18 @@ def test_find_overlaps(self, intervals_utils, reference_interval, interval_data) # Add the reference interval to the intervals and test again intervals_with_reference = interval_data["intervals_data"].copy() - intervals_with_reference.loc[len(intervals_with_reference)] = interval_data["reference_interval_data"] + intervals_with_reference.loc[len(intervals_with_reference)] = interval_data[ + "reference_interval_data" + ] utils_with_reference = IntervalsUtils(intervals_with_reference) result = utils_with_reference.find_overlaps(reference_interval) # Should still be 3 as the reference interval itself should be excluded assert len(result) == 3 - def test_add_as_disjoint_empty_disjoint_set(self, intervals_utils, reference_interval, interval_data): + def test_add_as_disjoint_empty_disjoint_set( + self, intervals_utils, reference_interval, interval_data + ): """Test add_as_disjoint with empty disjoint set.""" result = intervals_utils.add_as_disjoint(reference_interval) @@ -129,75 +141,90 @@ def test_add_as_disjoint_empty_disjoint_set(self, intervals_utils, reference_int assert result.iloc[0][start_field] == reference_data[start_field] assert result.iloc[0][end_field] == reference_data[end_field] - @patch('tempo.intervals.core.utils.IntervalTransformer') - def test_add_as_disjoint_no_overlaps(self, mock_transformer, intervals_utils, reference_interval, interval_data): + @patch("tempo.intervals.core.utils.IntervalTransformer") + def test_add_as_disjoint_no_overlaps( + self, mock_transformer, intervals_utils, reference_interval, interval_data + ): """Test add_as_disjoint with no overlapping intervals.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] # Set up a non-empty disjoint set with no overlaps with reference - non_overlapping_df = DataFrame({ - start_field: pd.to_datetime(['2023-01-25']), - end_field: pd.to_datetime(['2023-01-30']), - 'category': ['C'], - 'value': [40] - }) + non_overlapping_df = DataFrame( + { + start_field: pd.to_datetime(["2023-01-25"]), + end_field: pd.to_datetime(["2023-01-30"]), + "category": ["C"], + "value": [40], + } + ) intervals_utils.disjoint_set = non_overlapping_df # Set up a mock for find_overlaps to return empty DataFrame - with patch.object(IntervalsUtils, 'find_overlaps', return_value=DataFrame()): + with patch.object(IntervalsUtils, "find_overlaps", return_value=DataFrame()): result = intervals_utils.add_as_disjoint(reference_interval) # Should add the reference interval to the disjoint set assert len(result) == 2 - assert pd.to_datetime('2023-01-03') in result[start_field].values - assert pd.to_datetime('2023-01-25') in result[start_field].values + assert pd.to_datetime("2023-01-03") in result[start_field].values + assert pd.to_datetime("2023-01-25") in result[start_field].values - @patch('tempo.intervals.core.utils.IntervalTransformer') - def test_add_as_disjoint_with_duplicate(self, mock_transformer, intervals_utils, reference_interval, interval_data): + @patch("tempo.intervals.core.utils.IntervalTransformer") + def test_add_as_disjoint_with_duplicate( + self, mock_transformer, intervals_utils, reference_interval, interval_data + ): """Test add_as_disjoint with duplicate interval.""" # Set up disjoint set with the reference interval already in it - intervals_utils.disjoint_set = DataFrame([interval_data["reference_interval_data"]]) + intervals_utils.disjoint_set = DataFrame( + [interval_data["reference_interval_data"]] + ) # Set up a mock for find_overlaps to return empty DataFrame - with patch.object(IntervalsUtils, 'find_overlaps', return_value=DataFrame()): + with patch.object(IntervalsUtils, "find_overlaps", return_value=DataFrame()): result = intervals_utils.add_as_disjoint(reference_interval) # Should not add duplicate, so length should still be 1 assert len(result) == 1 - @patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap') - def test_add_as_disjoint_single_overlap(self, mock_resolve_overlap, intervals_utils, reference_interval, - interval_data): + @patch("tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap") + def test_add_as_disjoint_single_overlap( + self, mock_resolve_overlap, intervals_utils, reference_interval, interval_data + ): """Test add_as_disjoint with a single overlapping interval.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] # Set up a disjoint set with one interval that overlaps with reference - overlapping_df = DataFrame({ - start_field: [pd.to_datetime('2023-01-01')], - end_field: [pd.to_datetime('2023-01-05')], - 'category': ['A'], - 'value': [10] - }) + overlapping_df = DataFrame( + { + start_field: [pd.to_datetime("2023-01-01")], + end_field: [pd.to_datetime("2023-01-05")], + "category": ["A"], + "value": [10], + } + ) intervals_utils.disjoint_set = overlapping_df # Mock the find_overlaps method to return the overlapping interval - with patch.object(IntervalsUtils, 'find_overlaps', return_value=overlapping_df): + with patch.object(IntervalsUtils, "find_overlaps", return_value=overlapping_df): # Mock the resolve_overlap method to return resolved intervals mock_resolve_overlap.return_value = [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-03'), - 'category': 'A', - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-05'), - 'category': 'A', - 'value': 15 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-03"), + "category": "A", + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-05"), + "category": "A", + "value": 15, + } + ), ] result = intervals_utils.add_as_disjoint(reference_interval) @@ -206,31 +233,51 @@ def test_add_as_disjoint_single_overlap(self, mock_resolve_overlap, intervals_ut assert len(result) == 2 mock_resolve_overlap.assert_called_once() - @patch('tempo.intervals.core.utils.IntervalsUtils.resolve_all_overlaps') - def test_add_as_disjoint_multiple_overlaps(self, mock_resolve_all_overlaps, intervals_utils, reference_interval, - interval_data): + @patch("tempo.intervals.core.utils.IntervalsUtils.resolve_all_overlaps") + def test_add_as_disjoint_multiple_overlaps( + self, + mock_resolve_all_overlaps, + intervals_utils, + reference_interval, + interval_data, + ): """Test add_as_disjoint with multiple overlapping intervals.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] # Set up a disjoint set with multiple intervals that overlap with reference - overlapping_df = DataFrame({ - start_field: [pd.to_datetime('2023-01-01'), pd.to_datetime('2023-01-05')], - end_field: [pd.to_datetime('2023-01-05'), pd.to_datetime('2023-01-10')], - 'category': ['A', 'B'], - 'value': [10, 20] - }) + overlapping_df = DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-05"), + ], + end_field: [pd.to_datetime("2023-01-05"), pd.to_datetime("2023-01-10")], + "category": ["A", "B"], + "value": [10, 20], + } + ) intervals_utils.disjoint_set = overlapping_df # Mock the find_overlaps method to return all overlapping intervals - with patch.object(IntervalsUtils, 'find_overlaps', return_value=overlapping_df): + with patch.object(IntervalsUtils, "find_overlaps", return_value=overlapping_df): # Mock the resolve_all_overlaps method - mock_resolved_df = DataFrame({ - start_field: [pd.to_datetime('2023-01-01'), pd.to_datetime('2023-01-03'), pd.to_datetime('2023-01-05')], - end_field: [pd.to_datetime('2023-01-03'), pd.to_datetime('2023-01-05'), pd.to_datetime('2023-01-10')], - 'category': ['A', 'A', 'B'], - 'value': [10, 15, 20] - }) + mock_resolved_df = DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + ], + end_field: [ + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-10"), + ], + "category": ["A", "A", "B"], + "value": [10, 15, 20], + } + ) mock_resolve_all_overlaps.return_value = mock_resolved_df result = intervals_utils.add_as_disjoint(reference_interval) @@ -239,45 +286,58 @@ def test_add_as_disjoint_multiple_overlaps(self, mock_resolve_all_overlaps, inte pd.testing.assert_frame_equal(result, mock_resolved_df) mock_resolve_all_overlaps.assert_called_once() - @patch('tempo.intervals.core.utils.IntervalTransformer') - def test_add_as_disjoint_mixed_overlaps(self, mock_transformer, intervals_utils, reference_interval, interval_data): + @patch("tempo.intervals.core.utils.IntervalTransformer") + def test_add_as_disjoint_mixed_overlaps( + self, mock_transformer, intervals_utils, reference_interval, interval_data + ): """Test add_as_disjoint with both overlapping and non-overlapping intervals.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] # Set up a disjoint set with some overlapping and some non-overlapping intervals - mixed_df = DataFrame({ - start_field: [pd.to_datetime('2023-01-01'), pd.to_datetime('2023-01-25')], - end_field: [pd.to_datetime('2023-01-05'), pd.to_datetime('2023-01-30')], - 'category': ['A', 'C'], - 'value': [10, 40] - }) + mixed_df = DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-25"), + ], + end_field: [pd.to_datetime("2023-01-05"), pd.to_datetime("2023-01-30")], + "category": ["A", "C"], + "value": [10, 40], + } + ) intervals_utils.disjoint_set = mixed_df # Mock to return only the overlapping interval - overlapping_df = DataFrame({ - start_field: [pd.to_datetime('2023-01-01')], - end_field: [pd.to_datetime('2023-01-05')], - 'category': ['A'], - 'value': [10] - }) - - with patch.object(IntervalsUtils, 'find_overlaps', return_value=overlapping_df): + overlapping_df = DataFrame( + { + start_field: [pd.to_datetime("2023-01-01")], + end_field: [pd.to_datetime("2023-01-05")], + "category": ["A"], + "value": [10], + } + ) + + with patch.object(IntervalsUtils, "find_overlaps", return_value=overlapping_df): # Mock the transformer's resolve_overlap method mock_resolve_instance = MagicMock() mock_resolve_instance.resolve_overlap.return_value = [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-03'), - 'category': 'A', - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-05'), - 'category': 'A', - 'value': 15 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-03"), + "category": "A", + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-05"), + "category": "A", + "value": 15, + } + ), ] mock_transformer.return_value = mock_resolve_instance @@ -285,25 +345,31 @@ def test_add_as_disjoint_mixed_overlaps(self, mock_transformer, intervals_utils, # Should include both resolved overlaps and original non-overlapping interval assert len(result) == 3 - assert pd.to_datetime('2023-01-25') in result[start_field].values + assert pd.to_datetime("2023-01-25") in result[start_field].values - def test_calculate_all_overlaps_touching_intervals(self, intervals_utils, interval_data): + def test_calculate_all_overlaps_touching_intervals( + self, intervals_utils, interval_data + ): """Test that intervals that touch at endpoints but don't overlap are not considered overlapping.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] # Create a reference interval that touches but doesn't overlap touching_interval = Interval.create( - Series({ - start_field: pd.to_datetime('2023-01-20'), # Starts exactly when another ends - end_field: pd.to_datetime('2023-01-25'), - 'category': 'A', - 'value': 15 - }), + Series( + { + start_field: pd.to_datetime( + "2023-01-20" + ), # Starts exactly when another ends + end_field: pd.to_datetime("2023-01-25"), + "category": "A", + "value": 15, + } + ), start_field, end_field, interval_data["series_fields"], - interval_data["metric_fields"] + interval_data["metric_fields"], ) result = intervals_utils._calculate_all_overlaps(touching_interval) @@ -316,21 +382,29 @@ def test_timezone_aware_timestamps(self, interval_data): end_field = interval_data["end_field"] # Create timezone-aware interval data - tz_intervals = DataFrame({ - start_field: [pd.to_datetime('2023-01-01').tz_localize('UTC'), - pd.to_datetime('2023-01-05').tz_localize('UTC')], - end_field: [pd.to_datetime('2023-01-07').tz_localize('UTC'), - pd.to_datetime('2023-01-15').tz_localize('UTC')], - 'category': ['A', 'B'], - 'value': [10, 20] - }) - - tz_reference = Series({ - start_field: pd.to_datetime('2023-01-03').tz_localize('UTC'), - end_field: pd.to_datetime('2023-01-12').tz_localize('UTC'), - 'category': 'A', - 'value': 15 - }) + tz_intervals = DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01").tz_localize("UTC"), + pd.to_datetime("2023-01-05").tz_localize("UTC"), + ], + end_field: [ + pd.to_datetime("2023-01-07").tz_localize("UTC"), + pd.to_datetime("2023-01-15").tz_localize("UTC"), + ], + "category": ["A", "B"], + "value": [10, 20], + } + ) + + tz_reference = Series( + { + start_field: pd.to_datetime("2023-01-03").tz_localize("UTC"), + end_field: pd.to_datetime("2023-01-12").tz_localize("UTC"), + "category": "A", + "value": 15, + } + ) tz_utils = IntervalsUtils(tz_intervals) @@ -339,7 +413,7 @@ def test_timezone_aware_timestamps(self, interval_data): start_field, end_field, interval_data["series_fields"], - interval_data["metric_fields"] + interval_data["metric_fields"], ) # Test overlaps work correctly with timezone-aware data @@ -350,8 +424,10 @@ def test_timezone_aware_timestamps(self, interval_data): class TestResolveAllOverlaps: """Additional test suite specifically for the resolve_all_overlaps method.""" - @patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap') - def test_resolve_all_overlaps(self, mock_resolve_overlap, intervals_utils, reference_interval, interval_data): + @patch("tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap") + def test_resolve_all_overlaps( + self, mock_resolve_overlap, intervals_utils, reference_interval, interval_data + ): """Test resolve_all_overlaps with multiple intervals.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] @@ -360,41 +436,49 @@ def test_resolve_all_overlaps(self, mock_resolve_overlap, intervals_utils, refer # This ensures that every call to resolve_overlap will return the same value # and we won't run out of side effects mock_resolve_overlap.return_value = [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-03'), - 'category': 'A', - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-07'), - 'category': 'A', - 'value': 15 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-03"), + "category": "A", + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-07"), + "category": "A", + "value": 15, + } + ), ] # Use patch.object to mock find_overlaps to limit which intervals are found - with patch.object(IntervalsUtils, 'find_overlaps') as mock_find_overlaps: + with patch.object(IntervalsUtils, "find_overlaps") as mock_find_overlaps: # Only return a subset of intervals to control testing flow mock_find_overlaps.return_value = interval_data["intervals_data"].iloc[:1] # Mock add_as_disjoint to avoid dependency on that method - with patch.object(IntervalsUtils, 'add_as_disjoint') as mock_add_as_disjoint: - mock_result = DataFrame({ - start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-03'), - pd.to_datetime('2023-01-05') - ], - end_field: [ - pd.to_datetime('2023-01-03'), - pd.to_datetime('2023-01-07'), - pd.to_datetime('2023-01-12') - ], - 'category': ['A', 'A', 'B'], - 'value': [10, 15, 20] - }) + with patch.object( + IntervalsUtils, "add_as_disjoint" + ) as mock_add_as_disjoint: + mock_result = DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + ], + end_field: [ + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-07"), + pd.to_datetime("2023-01-12"), + ], + "category": ["A", "A", "B"], + "value": [10, 15, 20], + } + ) mock_add_as_disjoint.return_value = mock_result result = intervals_utils.resolve_all_overlaps(reference_interval) @@ -416,12 +500,16 @@ def test_resolve_all_overlaps_empty_input(self, reference_interval, interval_dat assert len(result) == 1 assert result.iloc[0][start_field] == reference_data[start_field] - def test_resolve_all_overlaps_with_single_item(self, intervals_utils, reference_interval, interval_data): + def test_resolve_all_overlaps_with_single_item( + self, intervals_utils, reference_interval, interval_data + ): """Test resolve_all_overlaps with a single item in intervals.""" # Create utils with just one interval single_interval_utils = IntervalsUtils(interval_data["intervals_data"].iloc[:1]) - with patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap') as mock_resolve: + with patch( + "tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap" + ) as mock_resolve: # Mock the resolution result mock_resolve.return_value = [interval_data["intervals_data"].iloc[0]] @@ -439,33 +527,33 @@ def test_resolve_all_overlaps_with_complex_overlaps(self, interval_data): metric_fields = interval_data["metric_fields"] # Create a more complex set of overlapping intervals - complex_intervals = DataFrame({ - start_field: pd.to_datetime([ - '2023-01-01', '2023-01-03', '2023-01-05', '2023-01-08' - ]), - end_field: pd.to_datetime([ - '2023-01-06', '2023-01-07', '2023-01-10', '2023-01-15' - ]), - 'category': ['A', 'B', 'A', 'C'], - 'value': [10, 20, 30, 40] - }) + complex_intervals = DataFrame( + { + start_field: pd.to_datetime( + ["2023-01-01", "2023-01-03", "2023-01-05", "2023-01-08"] + ), + end_field: pd.to_datetime( + ["2023-01-06", "2023-01-07", "2023-01-10", "2023-01-15"] + ), + "category": ["A", "B", "A", "C"], + "value": [10, 20, 30, 40], + } + ) utils = IntervalsUtils(complex_intervals) # Create a reference interval that spans across multiple intervals - reference_data = Series({ - start_field: pd.to_datetime('2023-01-02'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'X', - 'value': 50 - }) + reference_data = Series( + { + start_field: pd.to_datetime("2023-01-02"), + end_field: pd.to_datetime("2023-01-12"), + "category": "X", + "value": 50, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) # Store transformer instances to access their properties later @@ -479,102 +567,144 @@ def resolve_overlap_wrapper(self, *args, **kwargs): return original_resolve_overlap(self, *args, **kwargs) # Patch resolve_overlap with our wrapper to track instances - with patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap', - side_effect=resolve_overlap_wrapper) as mock_resolve: + with patch( + "tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap", + side_effect=resolve_overlap_wrapper, + ) as mock_resolve: # Setup the return values for the mocked resolve_overlap mock_resolve.side_effect = [ # First interval (2023-01-01) [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-02'), - 'category': 'A', - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-02'), - end_field: pd.to_datetime('2023-01-06'), - 'category': 'X', - 'value': 50 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-02"), + "category": "A", + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-02"), + end_field: pd.to_datetime("2023-01-06"), + "category": "X", + "value": 50, + } + ), ], # Second interval (2023-01-03) [ - Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-07'), - 'category': 'X', - 'value': 50 - }) + Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-07"), + "category": "X", + "value": 50, + } + ) ], # Third interval (2023-01-05) [ - Series({ - start_field: pd.to_datetime('2023-01-05'), - end_field: pd.to_datetime('2023-01-10'), - 'category': 'X', - 'value': 50 - }) + Series( + { + start_field: pd.to_datetime("2023-01-05"), + end_field: pd.to_datetime("2023-01-10"), + "category": "X", + "value": 50, + } + ) ], # Fourth interval (2023-01-08) [ - Series({ - start_field: pd.to_datetime('2023-01-08'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'X', - 'value': 50 - }), - Series({ - start_field: pd.to_datetime('2023-01-12'), - end_field: pd.to_datetime('2023-01-15'), - 'category': 'C', - 'value': 40 - }) - ] + Series( + { + start_field: pd.to_datetime("2023-01-08"), + end_field: pd.to_datetime("2023-01-12"), + "category": "X", + "value": 50, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-12"), + end_field: pd.to_datetime("2023-01-15"), + "category": "C", + "value": 40, + } + ), + ], ] # Also mock add_as_disjoint to avoid dealing with its complexity - with patch.object(IntervalsUtils, 'add_as_disjoint') as mock_add: + with patch.object(IntervalsUtils, "add_as_disjoint") as mock_add: # After several calls, we expect a specific result - expected_result = DataFrame({ - start_field: pd.to_datetime([ - '2023-01-01', '2023-01-02', '2023-01-03', '2023-01-05', - '2023-01-08', '2023-01-12' - ]), - end_field: pd.to_datetime([ - '2023-01-02', '2023-01-06', '2023-01-07', '2023-01-10', - '2023-01-12', '2023-01-15' - ]), - 'category': ['A', 'X', 'X', 'X', 'X', 'C'], - 'value': [10, 50, 50, 50, 50, 40] - }) + expected_result = DataFrame( + { + start_field: pd.to_datetime( + [ + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-05", + "2023-01-08", + "2023-01-12", + ] + ), + end_field: pd.to_datetime( + [ + "2023-01-02", + "2023-01-06", + "2023-01-07", + "2023-01-10", + "2023-01-12", + "2023-01-15", + ] + ), + "category": ["A", "X", "X", "X", "X", "C"], + "value": [10, 50, 50, 50, 50, 40], + } + ) # Control the behavior of add_as_disjoint as it builds up # This simulates the gradual building of the disjoint set mock_add.side_effect = [ # First call with first resolved interval - DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-02']), - end_field: pd.to_datetime(['2023-01-02', '2023-01-06']), - 'category': ['A', 'X'], - 'value': [10, 50] - }), + DataFrame( + { + start_field: pd.to_datetime(["2023-01-01", "2023-01-02"]), + end_field: pd.to_datetime(["2023-01-02", "2023-01-06"]), + "category": ["A", "X"], + "value": [10, 50], + } + ), # Second call adds the next interval - DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03']), - end_field: pd.to_datetime(['2023-01-02', '2023-01-06', '2023-01-07']), - 'category': ['A', 'X', 'X'], - 'value': [10, 50, 50] - }), + DataFrame( + { + start_field: pd.to_datetime( + ["2023-01-01", "2023-01-02", "2023-01-03"] + ), + end_field: pd.to_datetime( + ["2023-01-02", "2023-01-06", "2023-01-07"] + ), + "category": ["A", "X", "X"], + "value": [10, 50, 50], + } + ), # And so on... - DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03', '2023-01-05']), - end_field: pd.to_datetime(['2023-01-02', '2023-01-06', '2023-01-07', '2023-01-10']), - 'category': ['A', 'X', 'X', 'X'], - 'value': [10, 50, 50, 50] - }), + DataFrame( + { + start_field: pd.to_datetime( + ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-05"] + ), + end_field: pd.to_datetime( + ["2023-01-02", "2023-01-06", "2023-01-07", "2023-01-10"] + ), + "category": ["A", "X", "X", "X"], + "value": [10, 50, 50, 50], + } + ), # Final call returns the expected result - expected_result + expected_result, ] result = utils.resolve_all_overlaps(reference_interval) @@ -586,7 +716,9 @@ def resolve_overlap_wrapper(self, *args, **kwargs): assert mock_resolve.call_count == 4 assert mock_add.call_count == 4 - def test_resolve_all_overlaps_with_partially_overlapping_intervals(self, interval_data): + def test_resolve_all_overlaps_with_partially_overlapping_intervals( + self, interval_data + ): """Test resolve_all_overlaps with intervals that only partially overlap.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] @@ -594,78 +726,102 @@ def test_resolve_all_overlaps_with_partially_overlapping_intervals(self, interva metric_fields = interval_data["metric_fields"] # Create intervals that partially overlap - partial_intervals = DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-10']), - end_field: pd.to_datetime(['2023-01-05', '2023-01-15']), - 'category': ['A', 'B'], - 'value': [10, 20] - }) + partial_intervals = DataFrame( + { + start_field: pd.to_datetime(["2023-01-01", "2023-01-10"]), + end_field: pd.to_datetime(["2023-01-05", "2023-01-15"]), + "category": ["A", "B"], + "value": [10, 20], + } + ) utils = IntervalsUtils(partial_intervals) # Create a reference interval that partially overlaps - reference_data = Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'X', - 'value': 30 - }) + reference_data = Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-12"), + "category": "X", + "value": 30, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) - with patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap') as mock_resolve: + with patch( + "tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap" + ) as mock_resolve: # Define side effect for the first interval mock_resolve.side_effect = [ # First interval split into before overlap and overlap [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-03'), - 'category': 'A', - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-05'), - 'category': 'X', - 'value': 30 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-03"), + "category": "A", + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-05"), + "category": "X", + "value": 30, + } + ), ], # Second interval split into overlap and after overlap [ - Series({ - start_field: pd.to_datetime('2023-01-10'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'X', - 'value': 30 - }), - Series({ - start_field: pd.to_datetime('2023-01-12'), - end_field: pd.to_datetime('2023-01-15'), - 'category': 'B', - 'value': 20 - }) - ] + Series( + { + start_field: pd.to_datetime("2023-01-10"), + end_field: pd.to_datetime("2023-01-12"), + "category": "X", + "value": 30, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-12"), + end_field: pd.to_datetime("2023-01-15"), + "category": "B", + "value": 20, + } + ), + ], ] # Mock disjoint interval handling - with patch.object(IntervalsUtils, 'add_as_disjoint') as mock_add: - result_df = DataFrame({ - start_field: pd.to_datetime([ - '2023-01-01', '2023-01-03', '2023-01-05', '2023-01-10', '2023-01-12' - ]), - end_field: pd.to_datetime([ - '2023-01-03', '2023-01-05', '2023-01-10', '2023-01-12', '2023-01-15' - ]), - 'category': ['A', 'X', 'X', 'X', 'B'], - 'value': [10, 30, 30, 30, 20] - }) + with patch.object(IntervalsUtils, "add_as_disjoint") as mock_add: + result_df = DataFrame( + { + start_field: pd.to_datetime( + [ + "2023-01-01", + "2023-01-03", + "2023-01-05", + "2023-01-10", + "2023-01-12", + ] + ), + end_field: pd.to_datetime( + [ + "2023-01-03", + "2023-01-05", + "2023-01-10", + "2023-01-12", + "2023-01-15", + ] + ), + "category": ["A", "X", "X", "X", "B"], + "value": [10, 30, 30, 30, 20], + } + ) # Configure mock to return our expected final result mock_add.return_value = result_df @@ -679,7 +835,9 @@ def test_resolve_all_overlaps_with_partially_overlapping_intervals(self, interva # Verify the expected number of calls assert mock_resolve.call_count == 2 - def test_resolve_all_overlaps_with_completely_contained_intervals(self, interval_data): + def test_resolve_all_overlaps_with_completely_contained_intervals( + self, interval_data + ): """Test resolve_all_overlaps with intervals completely contained within the reference.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] @@ -687,92 +845,120 @@ def test_resolve_all_overlaps_with_completely_contained_intervals(self, interval metric_fields = interval_data["metric_fields"] # Create intervals completely contained within reference - contained_intervals = DataFrame({ - start_field: pd.to_datetime(['2023-01-05', '2023-01-07']), - end_field: pd.to_datetime(['2023-01-06', '2023-01-09']), - 'category': ['A', 'B'], - 'value': [10, 20] - }) + contained_intervals = DataFrame( + { + start_field: pd.to_datetime(["2023-01-05", "2023-01-07"]), + end_field: pd.to_datetime(["2023-01-06", "2023-01-09"]), + "category": ["A", "B"], + "value": [10, 20], + } + ) utils = IntervalsUtils(contained_intervals) # Create a reference interval that contains others - reference_data = Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-15'), - 'category': 'X', - 'value': 30 - }) + reference_data = Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-15"), + "category": "X", + "value": 30, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) - with patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap') as mock_resolve: + with patch( + "tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap" + ) as mock_resolve: # For contained intervals, we may get specific divisions mock_resolve.side_effect = [ # First interval splits reference into before, during, after [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-05'), - 'category': 'X', - 'value': 30 - }), - Series({ - start_field: pd.to_datetime('2023-01-05'), - end_field: pd.to_datetime('2023-01-06'), - 'category': 'A', # Contained interval takes precedence - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-06'), - end_field: pd.to_datetime('2023-01-15'), - 'category': 'X', - 'value': 30 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-05"), + "category": "X", + "value": 30, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-05"), + end_field: pd.to_datetime("2023-01-06"), + "category": "A", # Contained interval takes precedence + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-06"), + end_field: pd.to_datetime("2023-01-15"), + "category": "X", + "value": 30, + } + ), ], # Second interval further divides [ - Series({ - start_field: pd.to_datetime('2023-01-06'), - end_field: pd.to_datetime('2023-01-07'), - 'category': 'X', - 'value': 30 - }), - Series({ - start_field: pd.to_datetime('2023-01-07'), - end_field: pd.to_datetime('2023-01-09'), - 'category': 'B', # Contained interval takes precedence - 'value': 20 - }), - Series({ - start_field: pd.to_datetime('2023-01-09'), - end_field: pd.to_datetime('2023-01-15'), - 'category': 'X', - 'value': 30 - }) - ] + Series( + { + start_field: pd.to_datetime("2023-01-06"), + end_field: pd.to_datetime("2023-01-07"), + "category": "X", + "value": 30, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-07"), + end_field: pd.to_datetime("2023-01-09"), + "category": "B", # Contained interval takes precedence + "value": 20, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-09"), + end_field: pd.to_datetime("2023-01-15"), + "category": "X", + "value": 30, + } + ), + ], ] # Mock the disjoint interval handling - with patch.object(IntervalsUtils, 'add_as_disjoint') as mock_add: - expected_result = DataFrame({ - start_field: pd.to_datetime([ - '2023-01-01', '2023-01-05', '2023-01-06', '2023-01-07', - '2023-01-09', '2023-01-09' - ]), - end_field: pd.to_datetime([ - '2023-01-05', '2023-01-06', '2023-01-07', '2023-01-09', - '2023-01-09', '2023-01-15' - ]), - 'category': ['X', 'A', 'X', 'B', 'X', 'X'], - 'value': [30, 10, 30, 20, 30, 30] - }) + with patch.object(IntervalsUtils, "add_as_disjoint") as mock_add: + expected_result = DataFrame( + { + start_field: pd.to_datetime( + [ + "2023-01-01", + "2023-01-05", + "2023-01-06", + "2023-01-07", + "2023-01-09", + "2023-01-09", + ] + ), + end_field: pd.to_datetime( + [ + "2023-01-05", + "2023-01-06", + "2023-01-07", + "2023-01-09", + "2023-01-09", + "2023-01-15", + ] + ), + "category": ["X", "A", "X", "B", "X", "X"], + "value": [30, 10, 30, 20, 30, 30], + } + ) # Configure mock to return our expected final result mock_add.return_value = expected_result @@ -792,44 +978,46 @@ def test_resolve_all_overlaps_with_identical_intervals(self, interval_data): metric_fields = interval_data["metric_fields"] # Create an interval identical to reference - identical_data = Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-10'), - 'category': 'A', - 'value': 10 - }) + identical_data = Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-10"), + "category": "A", + "value": 10, + } + ) identical_df = DataFrame([identical_data]) utils = IntervalsUtils(identical_df) # Create the same reference interval reference_interval = Interval.create( - identical_data.copy(), - start_field, - end_field, - series_fields, - metric_fields + identical_data.copy(), start_field, end_field, series_fields, metric_fields ) - with patch('tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap') as mock_resolve: + with patch( + "tempo.intervals.overlap.transformer.IntervalTransformer.resolve_overlap" + ) as mock_resolve: # When intervals are identical, could choose either one mock_resolve.return_value = [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-10'), - 'category': 'A', # Keep original since it's identical - 'value': 10 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-10"), + "category": "A", # Keep original since it's identical + "value": 10, + } + ) ] result = utils.resolve_all_overlaps(reference_interval) # Should have exactly one interval assert len(result) == 1 - assert result.iloc[0][start_field] == pd.to_datetime('2023-01-01') - assert result.iloc[0][end_field] == pd.to_datetime('2023-01-10') - assert result.iloc[0]['category'] == 'A' - assert result.iloc[0]['value'] == 10 + assert result.iloc[0][start_field] == pd.to_datetime("2023-01-01") + assert result.iloc[0][end_field] == pd.to_datetime("2023-01-10") + assert result.iloc[0]["category"] == "A" + assert result.iloc[0]["value"] == 10 def test_resolve_all_overlaps_recursive_call_structure(self, interval_data): """Test that resolve_all_overlaps correctly builds up disjoint intervals through recursion.""" @@ -839,33 +1027,35 @@ def test_resolve_all_overlaps_recursive_call_structure(self, interval_data): metric_fields = interval_data["metric_fields"] # Create a set of three intervals that will need recursive resolution - intervals_data = DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-05', '2023-01-08']), - end_field: pd.to_datetime(['2023-01-06', '2023-01-10', '2023-01-12']), - 'category': ['A', 'B', 'C'], - 'value': [10, 20, 30] - }) + intervals_data = DataFrame( + { + start_field: pd.to_datetime(["2023-01-01", "2023-01-05", "2023-01-08"]), + end_field: pd.to_datetime(["2023-01-06", "2023-01-10", "2023-01-12"]), + "category": ["A", "B", "C"], + "value": [10, 20, 30], + } + ) utils = IntervalsUtils(intervals_data) # Reference interval that overlaps with all three - reference_data = Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-11'), - 'category': 'X', - 'value': 40 - }) + reference_data = Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-11"), + "category": "X", + "value": 40, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) # Create a mock for IntervalTransformer to track instantiation - with patch('tempo.intervals.overlap.transformer.IntervalTransformer') as mock_transformer_cls: + with patch( + "tempo.intervals.overlap.transformer.IntervalTransformer" + ) as mock_transformer_cls: # SIMPLER APPROACH: Just call the mock directly 3 times # This verifies that mock_transformer_cls.call_count can reach 3 # without relying on resolve_all_overlaps to do it @@ -874,7 +1064,9 @@ def test_resolve_all_overlaps_recursive_call_structure(self, interval_data): mock_transformer_cls.return_value.resolve_overlap.return_value = [] # Directly call the mock - this will increment call_count - mock_transformer_cls(interval=reference_interval, other=reference_interval) + mock_transformer_cls( + interval=reference_interval, other=reference_interval + ) # Verify the mock was called 3 times assert mock_transformer_cls.call_count >= 3 @@ -891,37 +1083,41 @@ def test_resolve_all_overlaps_with_mock_and_functionality(self, interval_data): metric_fields = interval_data["metric_fields"] # Create a set of three intervals that will need recursive resolution - intervals_data = DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-05', '2023-01-08']), - end_field: pd.to_datetime(['2023-01-06', '2023-01-10', '2023-01-12']), - 'category': ['A', 'B', 'C'], - 'value': [10, 20, 30] - }) + intervals_data = DataFrame( + { + start_field: pd.to_datetime(["2023-01-01", "2023-01-05", "2023-01-08"]), + end_field: pd.to_datetime(["2023-01-06", "2023-01-10", "2023-01-12"]), + "category": ["A", "B", "C"], + "value": [10, 20, 30], + } + ) utils = IntervalsUtils(intervals_data) # Reference interval that overlaps with all three - reference_data = Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-11'), - 'category': 'X', - 'value': 40 - }) + reference_data = Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-11"), + "category": "X", + "value": 40, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) # Create a mock for IntervalTransformer, patching where it's actually imported in utils.py - with patch('tempo.intervals.core.utils.IntervalTransformer') as mock_transformer_cls: + with patch( + "tempo.intervals.core.utils.IntervalTransformer" + ) as mock_transformer_cls: # First part: Directly call the mock to ensure call_count works for i in range(3): mock_transformer_cls.return_value.resolve_overlap.return_value = [] - mock_transformer_cls(interval=reference_interval, other=reference_interval) + mock_transformer_cls( + interval=reference_interval, other=reference_interval + ) # Reset the mock to prepare for the real test mock_transformer_cls.reset_mock() @@ -936,44 +1132,54 @@ def custom_resolve(*args, **kwargs): other_data = mock_transformer.other.data other_start = other_data[start_field] - if other_start == pd.to_datetime('2023-01-01'): + if other_start == pd.to_datetime("2023-01-01"): return [ - Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-03'), - 'category': 'A', - 'value': 10 - }), - Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-06'), - 'category': 'X', - 'value': 40 - }) + Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-03"), + "category": "A", + "value": 10, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-06"), + "category": "X", + "value": 40, + } + ), ] - elif other_start == pd.to_datetime('2023-01-05'): + elif other_start == pd.to_datetime("2023-01-05"): return [ - Series({ - start_field: pd.to_datetime('2023-01-05'), - end_field: pd.to_datetime('2023-01-10'), - 'category': 'X', - 'value': 40 - }) + Series( + { + start_field: pd.to_datetime("2023-01-05"), + end_field: pd.to_datetime("2023-01-10"), + "category": "X", + "value": 40, + } + ) ] - elif other_start == pd.to_datetime('2023-01-08'): + elif other_start == pd.to_datetime("2023-01-08"): return [ - Series({ - start_field: pd.to_datetime('2023-01-08'), - end_field: pd.to_datetime('2023-01-11'), - 'category': 'X', - 'value': 40 - }), - Series({ - start_field: pd.to_datetime('2023-01-11'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'C', - 'value': 30 - }) + Series( + { + start_field: pd.to_datetime("2023-01-08"), + end_field: pd.to_datetime("2023-01-11"), + "category": "X", + "value": 40, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-11"), + end_field: pd.to_datetime("2023-01-12"), + "category": "C", + "value": 30, + } + ), ] return [] @@ -981,7 +1187,9 @@ def custom_resolve(*args, **kwargs): # Second part: Setup mocks to make resolve_all_overlaps work with our mock # We'll patch _calculate_all_overlaps to return the intervals that should overlap - with patch.object(IntervalsUtils, '_calculate_all_overlaps') as mock_calc_overlaps: + with patch.object( + IntervalsUtils, "_calculate_all_overlaps" + ) as mock_calc_overlaps: # Return the first interval as an overlap first_interval = intervals_data.iloc[[0]] mock_calc_overlaps.return_value = first_interval @@ -1026,39 +1234,41 @@ def test_resolve_all_overlaps_with_no_overlaps(self, interval_data): metric_fields = interval_data["metric_fields"] # Create intervals that don't overlap with reference - non_overlapping = DataFrame({ - start_field: pd.to_datetime(['2023-01-01', '2023-01-05']), - end_field: pd.to_datetime(['2023-01-03', '2023-01-07']), - 'category': ['A', 'B'], - 'value': [10, 20] - }) + non_overlapping = DataFrame( + { + start_field: pd.to_datetime(["2023-01-01", "2023-01-05"]), + end_field: pd.to_datetime(["2023-01-03", "2023-01-07"]), + "category": ["A", "B"], + "value": [10, 20], + } + ) utils = IntervalsUtils(non_overlapping) # Reference interval that doesn't overlap - reference_data = Series({ - start_field: pd.to_datetime('2023-01-10'), - end_field: pd.to_datetime('2023-01-15'), - 'category': 'X', - 'value': 30 - }) + reference_data = Series( + { + start_field: pd.to_datetime("2023-01-10"), + end_field: pd.to_datetime("2023-01-15"), + "category": "X", + "value": 30, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) # Override _calculate_all_overlaps to return empty (simulating no overlaps) - with patch.object(IntervalsUtils, '_calculate_all_overlaps', return_value=DataFrame()): + with patch.object( + IntervalsUtils, "_calculate_all_overlaps", return_value=DataFrame() + ): result = utils.resolve_all_overlaps(reference_interval) # Should just return the reference interval assert len(result) == 1 - assert result.iloc[0][start_field] == pd.to_datetime('2023-01-10') - assert result.iloc[0][end_field] == pd.to_datetime('2023-01-15') + assert result.iloc[0][start_field] == pd.to_datetime("2023-01-10") + assert result.iloc[0][end_field] == pd.to_datetime("2023-01-15") class TestAddAsDisjoint: @@ -1073,23 +1283,21 @@ def test_add_as_disjoint_no_overlap(self, interval_data): utils = IntervalsUtils(intervals_df) # Create a new interval with the same columns structure as the intervals_df - new_data = Series({ - start_field: pd.to_datetime('2023-01-25'), - end_field: pd.to_datetime('2023-01-30'), - "category": "D", # Match the column structure in intervals_df - "value": 25 # Match the column structure in intervals_df - }) + new_data = Series( + { + start_field: pd.to_datetime("2023-01-25"), + end_field: pd.to_datetime("2023-01-30"), + "category": "D", # Match the column structure in intervals_df + "value": 25, # Match the column structure in intervals_df + } + ) new_interval = Interval.create( - new_data, - start_field, - end_field, - series_fields, - metric_fields + new_data, start_field, end_field, series_fields, metric_fields ) # Mock find_overlaps to return empty DataFrame (no overlaps) - with patch.object(IntervalsUtils, 'find_overlaps') as mock_find_overlaps: + with patch.object(IntervalsUtils, "find_overlaps") as mock_find_overlaps: mock_find_overlaps.return_value = DataFrame() # Set up the disjoint_set property @@ -1102,8 +1310,9 @@ def test_add_as_disjoint_no_overlap(self, interval_data): # Expected result: original intervals + new interval expected_result = pd.concat([intervals_df, DataFrame([new_data])]) - pd.testing.assert_frame_equal(result.reset_index(drop=True), - expected_result.reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.reset_index(drop=True), expected_result.reset_index(drop=True) + ) def test_add_as_disjoint_with_overlap(self, interval_data): """Test add_as_disjoint when there's overlap with existing intervals.""" @@ -1122,49 +1331,55 @@ def test_add_as_disjoint_with_overlap(self, interval_data): new_data = Series(index=overlapping_interval.index) # Fill the new interval with the correct data - new_data[start_field] = pd.to_datetime('2023-01-12') # Use datetime objects to match - new_data[end_field] = pd.to_datetime('2023-01-18') + new_data[start_field] = pd.to_datetime( + "2023-01-12" + ) # Use datetime objects to match + new_data[end_field] = pd.to_datetime("2023-01-18") new_data["category"] = "C" new_data["value"] = 30 # Create the interval object new_interval = Interval.create( - new_data, - start_field, - end_field, - series_fields, - metric_fields + new_data, start_field, end_field, series_fields, metric_fields ) # Mock find_overlaps to return the overlapping interval - with patch.object(IntervalsUtils, 'find_overlaps') as mock_find_overlaps: + with patch.object(IntervalsUtils, "find_overlaps") as mock_find_overlaps: mock_find_overlaps.return_value = DataFrame([overlapping_interval]) # Mock the IntervalTransformer - with patch('tempo.intervals.core.utils.IntervalTransformer') as mock_transformer_class: + with patch( + "tempo.intervals.core.utils.IntervalTransformer" + ) as mock_transformer_class: # Set up the mock for the resolve_overlap method mock_transformer_instance = mock_transformer_class.return_value # Use datetime objects for the timestamps to match what's in the DataFrame mock_transformer_instance.resolve_overlap.return_value = [ - Series({ - start_field: pd.to_datetime('2023-01-10'), - end_field: pd.to_datetime('2023-01-12'), - "category": overlapping_interval["category"], - "value": 20 - }), - Series({ - start_field: pd.to_datetime('2023-01-12'), - end_field: pd.to_datetime('2023-01-18'), - "category": new_data["category"], - "value": 30 - }), - Series({ - start_field: pd.to_datetime('2023-01-18'), - end_field: pd.to_datetime('2023-01-20'), - "category": overlapping_interval["category"], - "value": 30 - }) + Series( + { + start_field: pd.to_datetime("2023-01-10"), + end_field: pd.to_datetime("2023-01-12"), + "category": overlapping_interval["category"], + "value": 20, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-12"), + end_field: pd.to_datetime("2023-01-18"), + "category": new_data["category"], + "value": 30, + } + ), + Series( + { + start_field: pd.to_datetime("2023-01-18"), + end_field: pd.to_datetime("2023-01-20"), + "category": overlapping_interval["category"], + "value": 30, + } + ), ] # Set up the disjoint_set property with intervals except the overlapping one @@ -1179,7 +1394,9 @@ def test_add_as_disjoint_with_overlap(self, interval_data): expected_data = [] # Add the resolved intervals from the transformer - for series_data in mock_transformer_instance.resolve_overlap.return_value: + for ( + series_data + ) in mock_transformer_instance.resolve_overlap.return_value: expected_data.append(series_data) # Add the non-overlapping intervals from disjoint_set @@ -1187,10 +1404,16 @@ def test_add_as_disjoint_with_overlap(self, interval_data): expected_data.append(intervals_df.iloc[i]) # We need to sort both DataFrames by start time to ensure consistent order - expected_df = DataFrame(expected_data).sort_values(by=start_field).reset_index(drop=True) + expected_df = ( + DataFrame(expected_data) + .sort_values(by=start_field) + .reset_index(drop=True) + ) # Sort the result by start time as well - sorted_result = result.sort_values(by=start_field).reset_index(drop=True) + sorted_result = result.sort_values(by=start_field).reset_index( + drop=True + ) pd.testing.assert_frame_equal(sorted_result, expected_df) @@ -1205,28 +1428,26 @@ def test_add_as_disjoint_duplicate(self, interval_data): utils = IntervalsUtils(intervals_df) # Create a duplicate of an existing interval - duplicate_data = Series({ - start_field: pd.to_datetime('2023-01-01'), - end_field: pd.to_datetime('2023-01-07'), - 'category': 'A', - 'value': 10 - }) + duplicate_data = Series( + { + start_field: pd.to_datetime("2023-01-01"), + end_field: pd.to_datetime("2023-01-07"), + "category": "A", + "value": 10, + } + ) duplicate_interval = Interval.create( - duplicate_data, - start_field, - end_field, - series_fields, - metric_fields + duplicate_data, start_field, end_field, series_fields, metric_fields ) # Set up the disjoint_set property utils.disjoint_set = intervals_df.copy() # Mock find_overlaps to return empty DataFrame (simulating no overlaps) - with patch.object(IntervalsUtils, 'find_overlaps', return_value=DataFrame()): + with patch.object(IntervalsUtils, "find_overlaps", return_value=DataFrame()): # Create a comparison result where at least one row matches (any returns True) - with patch('pandas.Series.any', return_value=True): + with patch("pandas.Series.any", return_value=True): result = utils.add_as_disjoint(duplicate_interval) # Expected result: original intervals unchanged @@ -1243,18 +1464,10 @@ def test_add_as_disjoint_empty_disjoint_set(self, interval_data): utils = IntervalsUtils(DataFrame()) # Create a new interval - new_data = Series({ - start_field: 4, - end_field: 8, - "metric": 30 - }) + new_data = Series({start_field: 4, end_field: 8, "metric": 30}) new_interval = Interval.create( - new_data, - start_field, - end_field, - series_fields, - metric_fields + new_data, start_field, end_field, series_fields, metric_fields ) # Set the disjoint_set to be empty @@ -1266,71 +1479,79 @@ def test_add_as_disjoint_empty_disjoint_set(self, interval_data): expected_result = DataFrame([new_data]) pd.testing.assert_frame_equal(result, expected_result) - def test_add_as_disjoint_multiple_to_resolve_not_only_overlaps(self, intervals_utils, reference_interval, - interval_data): + def test_add_as_disjoint_multiple_to_resolve_not_only_overlaps( + self, intervals_utils, reference_interval, interval_data + ): """Test add_as_disjoint where multiple_to_resolve=True and only_overlaps_present=False.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] # Create disjoint set with multiple intervals, some overlapping and some not - mixed_df = pd.DataFrame({ - start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-25') # Non-overlapping - ], - end_field: [ - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10'), - pd.to_datetime('2023-01-30') # Non-overlapping - ], - 'category': ['A', 'B', 'C'], - 'value': [10, 20, 40] - }) + mixed_df = pd.DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-25"), # Non-overlapping + ], + end_field: [ + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-10"), + pd.to_datetime("2023-01-30"), # Non-overlapping + ], + "category": ["A", "B", "C"], + "value": [10, 20, 40], + } + ) intervals_utils.disjoint_set = mixed_df # Set up a reference interval that overlaps with first two intervals - overlapping_df = pd.DataFrame({ - start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-05') - ], - end_field: [ - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10') - ], - 'category': ['A', 'B'], - 'value': [10, 20] - }) + overlapping_df = pd.DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-05"), + ], + end_field: [pd.to_datetime("2023-01-05"), pd.to_datetime("2023-01-10")], + "category": ["A", "B"], + "value": [10, 20], + } + ) # Non-overlapping subset should be the third interval - non_overlapping_df = pd.DataFrame({ - start_field: [pd.to_datetime('2023-01-25')], - end_field: [pd.to_datetime('2023-01-30')], - 'category': ['C'], - 'value': [40] - }) + non_overlapping_df = pd.DataFrame( + { + start_field: [pd.to_datetime("2023-01-25")], + end_field: [pd.to_datetime("2023-01-30")], + "category": ["C"], + "value": [40], + } + ) # Mock to return the overlapping intervals - with patch.object(IntervalsUtils, 'find_overlaps', return_value=overlapping_df): + with patch.object(IntervalsUtils, "find_overlaps", return_value=overlapping_df): # Mock IntervalsUtils.resolve_all_overlaps to return resolved intervals - resolved_df = pd.DataFrame({ - start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-03'), - pd.to_datetime('2023-01-05') - ], - end_field: [ - pd.to_datetime('2023-01-03'), - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10') - ], - 'category': ['A', 'A', 'B'], - 'value': [10, 15, 20] - }) - - with patch('tempo.intervals.core.utils.IntervalsUtils.resolve_all_overlaps', - return_value=resolved_df) as mock_resolve_all: + resolved_df = pd.DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + ], + end_field: [ + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-10"), + ], + "category": ["A", "A", "B"], + "value": [10, 15, 20], + } + ) + + with patch( + "tempo.intervals.core.utils.IntervalsUtils.resolve_all_overlaps", + return_value=resolved_df, + ) as mock_resolve_all: # Execute the method result = intervals_utils.add_as_disjoint(reference_interval) @@ -1339,15 +1560,18 @@ def test_add_as_disjoint_multiple_to_resolve_not_only_overlaps(self, intervals_u # Expected result should have both resolved overlaps and non-overlapping intervals expected_cols = list(result.columns) # Use actual column ordering - expected_df = pd.concat([resolved_df, non_overlapping_df])[expected_cols] + expected_df = pd.concat([resolved_df, non_overlapping_df])[ + expected_cols + ] # Compare results (ignore index values) pd.testing.assert_frame_equal( - result.reset_index(drop=True), - expected_df.reset_index(drop=True) + result.reset_index(drop=True), expected_df.reset_index(drop=True) ) - def test_add_as_disjoint_multiple_resolve_not_only_overlaps_corner_case(self, interval_data): + def test_add_as_disjoint_multiple_resolve_not_only_overlaps_corner_case( + self, interval_data + ): """Test for a corner case in add_as_disjoint that could lead to the NotImplementedError.""" start_field = interval_data["start_field"] end_field = interval_data["end_field"] @@ -1359,74 +1583,79 @@ def test_add_as_disjoint_multiple_resolve_not_only_overlaps_corner_case(self, in utils = IntervalsUtils(intervals_df) # Create a reference interval - reference_data = pd.Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'X', - 'value': 15 - }) + reference_data = pd.Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-12"), + "category": "X", + "value": 15, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) # Set up a complex disjoint set - disjoint_df = pd.DataFrame({ - start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-15') # Non-overlapping - ], - end_field: [ - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10'), - pd.to_datetime('2023-01-20') # Non-overlapping - ], - 'category': ['A', 'B', 'C'], - 'value': [10, 20, 30] - }) + disjoint_df = pd.DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-15"), # Non-overlapping + ], + end_field: [ + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-10"), + pd.to_datetime("2023-01-20"), # Non-overlapping + ], + "category": ["A", "B", "C"], + "value": [10, 20, 30], + } + ) utils.disjoint_set = disjoint_df # Set up conditions for multiple_to_resolve=True and only_overlaps_present=False - overlapping_subset_df = pd.DataFrame({ - start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-05') - ], - end_field: [ - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10') - ], - 'category': ['A', 'B'], - 'value': [10, 20] - }) - - # Mock find_overlaps to return our overlapping subset - with patch.object(IntervalsUtils, 'find_overlaps', return_value=overlapping_subset_df): - # Create a mock for the resolve_all_overlaps method to return complex resolution - resolved_df = pd.DataFrame({ + overlapping_subset_df = pd.DataFrame( + { start_field: [ - pd.to_datetime('2023-01-01'), - pd.to_datetime('2023-01-03'), - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10') + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-05"), ], - end_field: [ - pd.to_datetime('2023-01-03'), - pd.to_datetime('2023-01-05'), - pd.to_datetime('2023-01-10'), - pd.to_datetime('2023-01-12') - ], - 'category': ['A', 'X', 'X', 'X'], - 'value': [10, 15, 15, 15] - }) + end_field: [pd.to_datetime("2023-01-05"), pd.to_datetime("2023-01-10")], + "category": ["A", "B"], + "value": [10, 20], + } + ) - with patch('tempo.intervals.core.utils.IntervalsUtils.resolve_all_overlaps', - return_value=resolved_df) as mock_resolve_all: + # Mock find_overlaps to return our overlapping subset + with patch.object( + IntervalsUtils, "find_overlaps", return_value=overlapping_subset_df + ): + # Create a mock for the resolve_all_overlaps method to return complex resolution + resolved_df = pd.DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-10"), + ], + end_field: [ + pd.to_datetime("2023-01-03"), + pd.to_datetime("2023-01-05"), + pd.to_datetime("2023-01-10"), + pd.to_datetime("2023-01-12"), + ], + "category": ["A", "X", "X", "X"], + "value": [10, 15, 15, 15], + } + ) + + with patch( + "tempo.intervals.core.utils.IntervalsUtils.resolve_all_overlaps", + return_value=resolved_df, + ) as mock_resolve_all: # Execute the method and check the result includes both parts result = utils.add_as_disjoint(reference_interval) @@ -1437,14 +1666,14 @@ def test_add_as_disjoint_multiple_resolve_not_only_overlaps_corner_case(self, in assert len(result) == 5 # Verify the non-overlapping interval is present - assert pd.to_datetime('2023-01-15') in result[start_field].values - assert pd.to_datetime('2023-01-20') in result[end_field].values + assert pd.to_datetime("2023-01-15") in result[start_field].values + assert pd.to_datetime("2023-01-20") in result[end_field].values # Verify the resolved intervals are present - assert pd.to_datetime('2023-01-01') in result[start_field].values - assert pd.to_datetime('2023-01-03') in result[start_field].values - assert pd.to_datetime('2023-01-05') in result[start_field].values - assert pd.to_datetime('2023-01-10') in result[start_field].values + assert pd.to_datetime("2023-01-01") in result[start_field].values + assert pd.to_datetime("2023-01-03") in result[start_field].values + assert pd.to_datetime("2023-01-05") in result[start_field].values + assert pd.to_datetime("2023-01-10") in result[start_field].values def test_add_as_disjoint_unexpected_conditions_raises_error(self, interval_data): """Test that add_as_disjoint raises NotImplementedError when conditions don't match expected cases.""" @@ -1457,28 +1686,31 @@ def test_add_as_disjoint_unexpected_conditions_raises_error(self, interval_data) utils = IntervalsUtils(interval_data["intervals_data"]) # Create a reference interval - reference_data = pd.Series({ - start_field: pd.to_datetime('2023-01-03'), - end_field: pd.to_datetime('2023-01-12'), - 'category': 'X', - 'value': 15 - }) + reference_data = pd.Series( + { + start_field: pd.to_datetime("2023-01-03"), + end_field: pd.to_datetime("2023-01-12"), + "category": "X", + "value": 15, + } + ) reference_interval = Interval.create( - reference_data, - start_field, - end_field, - series_fields, - metric_fields + reference_data, start_field, end_field, series_fields, metric_fields ) # Create a disjoint set with two rows - disjoint_df = pd.DataFrame({ - start_field: [pd.to_datetime('2023-01-01'), pd.to_datetime('2023-01-05')], - end_field: [pd.to_datetime('2023-01-05'), pd.to_datetime('2023-01-10')], - 'category': ['A', 'B'], - 'value': [10, 20] - }) + disjoint_df = pd.DataFrame( + { + start_field: [ + pd.to_datetime("2023-01-01"), + pd.to_datetime("2023-01-05"), + ], + end_field: [pd.to_datetime("2023-01-05"), pd.to_datetime("2023-01-10")], + "category": ["A", "B"], + "value": [10, 20], + } + ) utils.disjoint_set = disjoint_df # Create a mock implementation of add_as_disjoint that always raises NotImplementedError @@ -1493,7 +1725,9 @@ def mock_add_as_disjoint(self, interval): IntervalsUtils.add_as_disjoint = mock_add_as_disjoint # Now call the method - this should raise NotImplementedError - with pytest.raises(NotImplementedError, match="Interval resolution not implemented"): + with pytest.raises( + NotImplementedError, match="Interval resolution not implemented" + ): utils.add_as_disjoint(reference_interval) finally: @@ -1505,7 +1739,9 @@ class TestStillValidLegacy: def test_identify_interval_overlaps_df_empty(self): df = pd.DataFrame() row = Interval.create( - pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:05"}), "start", "end" + pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:05"}), + "start", + "end", ) result = IntervalsUtils(df).find_overlaps(row) @@ -1527,7 +1763,9 @@ def test_identify_interval_overlaps_overlapping_intervals(self): } ) row = Interval.create( - pd.Series({"start": "2023-01-01T00:00:03", "end": "2023-01-01T00:00:06"}), "start", "end" + pd.Series({"start": "2023-01-01T00:00:03", "end": "2023-01-01T00:00:06"}), + "start", + "end", ) result = IntervalsUtils(df).find_overlaps(row) @@ -1555,8 +1793,11 @@ def test_identify_interval_overlaps_no_overlapping_intervals(self): ], } ) - row = Interval.create(pd.Series({"start": "2023-01-01T00:00:04.1", "end": "2023-01-01T00:00:05"}), "start", - "end") + row = Interval.create( + pd.Series({"start": "2023-01-01T00:00:04.1", "end": "2023-01-01T00:00:05"}), + "start", + "end", + ) result = IntervalsUtils(df).find_overlaps(row) assert result.empty @@ -1576,7 +1817,11 @@ def test_identify_interval_overlaps_interval_subset(self): ], } ) - row = Interval.create(pd.Series({"start": "2023-01-01T00:00:02", "end": "2023-01-01T00:00:04"}), "start", "end") + row = Interval.create( + pd.Series({"start": "2023-01-01T00:00:02", "end": "2023-01-01T00:00:04"}), + "start", + "end", + ) result = IntervalsUtils(df).find_overlaps(row) expected = pd.DataFrame( @@ -1589,15 +1834,28 @@ def test_identify_interval_overlaps_identical_start_end(self): df = pd.DataFrame( {"start": ["2023-01-01T00:00:02"], "end": ["2023-01-01T00:00:05"]} ) - row = Interval.create(pd.Series({"start": "2023-01-01T00:00:02", "end": "2023-01-01T00:00:05"}), "start", "end") + row = Interval.create( + pd.Series({"start": "2023-01-01T00:00:02", "end": "2023-01-01T00:00:05"}), + "start", + "end", + ) result = IntervalsUtils(df).find_overlaps(row) assert result.empty def test_resolve_all_overlaps_basic(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 00:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end", metric_fields=["value"]) + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 00:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + metric_fields=["value"], + ) overlaps = pd.DataFrame( { "start": [ @@ -1618,9 +1876,18 @@ def test_resolve_all_overlaps_basic(self): assert len(result) == 5 def test_add_as_disjoint_where_basic_overlap(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 00:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end", metric_fields=["value"]) + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 00:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + metric_fields=["value"], + ) disjoint_set = pd.DataFrame( { "start": ["2023-01-01 03:00:00"], @@ -1650,13 +1917,24 @@ def test_add_as_disjoint_where_basic_overlap(self): } ) - pd.testing.assert_frame_equal(result.sort_values("start").reset_index(drop=True), - expected.sort_values("start").reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.sort_values("start").reset_index(drop=True), + expected.sort_values("start").reset_index(drop=True), + ) def test_add_as_disjoint_where_no_overlap(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 00:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end", metric_fields=["value"]) + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 00:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + metric_fields=["value"], + ) disjoint_set = pd.DataFrame( { "start": ["2023-01-01 06:00:00"], @@ -1672,13 +1950,23 @@ def test_add_as_disjoint_where_no_overlap(self): expected = pd.concat([disjoint_set, pd.DataFrame([interval.data])]) - pd.testing.assert_frame_equal(result.reset_index(drop=True), - expected.reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.reset_index(drop=True), expected.reset_index(drop=True) + ) def test_add_as_disjoint_where_empty_disjoint_set(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 00:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end", metric_fields=["value"]) + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 00:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + metric_fields=["value"], + ) disjoint_set = pd.DataFrame() interval_utils = IntervalsUtils(disjoint_set) @@ -1688,13 +1976,22 @@ def test_add_as_disjoint_where_empty_disjoint_set(self): expected = pd.DataFrame([interval.data]) - pd.testing.assert_frame_equal(result.reset_index(drop=True), - expected.reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.reset_index(drop=True), expected.reset_index(drop=True) + ) def test_add_as_disjoint_where_duplicate_interval(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 00:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end") + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 00:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + ) disjoint_set = pd.DataFrame( { "start": ["2023-01-01 00:00:00"], @@ -1708,13 +2005,23 @@ def test_add_as_disjoint_where_duplicate_interval(self): result = interval_utils.add_as_disjoint(interval) - pd.testing.assert_frame_equal(result.reset_index(drop=True), - disjoint_set.reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.reset_index(drop=True), disjoint_set.reset_index(drop=True) + ) def test_add_as_disjoint_where_multiple_overlaps(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 01:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end", metric_fields=["value"]) + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 01:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + metric_fields=["value"], + ) disjoint_set = pd.DataFrame( { "start": [ @@ -1759,13 +2066,24 @@ def test_add_as_disjoint_where_multiple_overlaps(self): } ) - pd.testing.assert_frame_equal(result.sort_values("start").reset_index(drop=True), - expected.sort_values("start").reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.sort_values("start").reset_index(drop=True), + expected.sort_values("start").reset_index(drop=True), + ) def test_add_as_disjoint_where_all_records_overlap(self): - interval = Interval.create(pd.Series( - {"start": "2023-01-01 01:00:00", "end": "2023-01-01 05:00:00", "value": 10} - ), "start", "end", metric_fields=["value"]) + interval = Interval.create( + pd.Series( + { + "start": "2023-01-01 01:00:00", + "end": "2023-01-01 05:00:00", + "value": 10, + } + ), + "start", + "end", + metric_fields=["value"], + ) disjoint_set = pd.DataFrame( { "start": ["2023-01-01 01:30:00", "2023-01-01 02:30:00"], @@ -1787,5 +2105,7 @@ def test_add_as_disjoint_where_all_records_overlap(self): } ) - pd.testing.assert_frame_equal(result.sort_values("start").reset_index(drop=True), - expected.sort_values("start").reset_index(drop=True)) + pd.testing.assert_frame_equal( + result.sort_values("start").reset_index(drop=True), + expected.sort_values("start").reset_index(drop=True), + ) diff --git a/python/tests/intervals/core/validation_tests.py b/python/tests/intervals/core/validation_tests.py index 40f9cec9..5bcc4f04 100644 --- a/python/tests/intervals/core/validation_tests.py +++ b/python/tests/intervals/core/validation_tests.py @@ -1,7 +1,11 @@ import pandas as pd import pytest -from tempo.intervals.core.exceptions import EmptyIntervalError, InvalidDataTypeError, InvalidMetricColumnError +from tempo.intervals.core.exceptions import ( + EmptyIntervalError, + InvalidDataTypeError, + InvalidMetricColumnError, +) from tempo.intervals.core.validation import IntervalValidator, ValidationResult @@ -131,7 +135,9 @@ def test_validate_series_id_columns_with_none(self): assert result.is_valid is True assert result.message is None - def test_validate_series_id_columns_with_non_string_elements(self, non_string_column_list): + def test_validate_series_id_columns_with_non_string_elements( + self, non_string_column_list + ): """Test validation fails with non-string elements in series ID columns""" with pytest.raises(InvalidMetricColumnError) as excinfo: IntervalValidator.validate_series_id_columns(non_string_column_list) @@ -157,7 +163,9 @@ def test_validate_metric_columns_with_none(self): assert result.is_valid is True assert result.message is None - def test_validate_metric_columns_with_non_string_elements(self, non_string_column_list): + def test_validate_metric_columns_with_non_string_elements( + self, non_string_column_list + ): """Test validation fails with non-string elements in metric columns""" with pytest.raises(InvalidMetricColumnError) as excinfo: IntervalValidator.validate_metric_columns(non_string_column_list) diff --git a/python/tests/intervals/datetime/utils_tests.py b/python/tests/intervals/datetime/utils_tests.py index 7915a930..ea32a757 100644 --- a/python/tests/intervals/datetime/utils_tests.py +++ b/python/tests/intervals/datetime/utils_tests.py @@ -4,67 +4,61 @@ from tempo.intervals.datetime.utils import infer_datetime_format -@pytest.mark.parametrize("date_string", [ - # ISO 8601 formats with timezone - "2023-01-01T12:34:56.789012+0000", - "2023-01-01T12:34:56.789012Z", - "2023-01-01T12:34:56+0000", - "2023-01-01T12:34:56Z", - "2023-01-01T12:34+0000", - "2023-01-01T12:34Z", - - # ISO 8601 formats without timezone - "2023-01-01T12:34:56.789012", - "2023-01-01T12:34:56", - "2023-01-01T12:34", - - # Standard datetime formats with microseconds - "2023-01-01 12:34:56.789012+0000", - "2023-01-01 12:34:56.789012", - "2023-01-01 12:34:56.789", - - # Standard datetime formats - "2023-01-01 12:34:56+0000", - "2023-01-01 12:34:56", - "2023-01-01 12:34", - - # Date only formats - "2023-01-01", - "2023/01/01", - - # US date formats - "01/01/2023 12:34:56.789012", - "01/01/2023 12:34:56", - "01/01/2023 12:34", - "01/01/2023", - "1/1/2023", - "1/1/23", - - # UK/European date formats - "01-01-2023 12:34:56", - "01-01-2023", - "01.01.2023", - "01.01.2023 12:34:56", - - # Month name formats - "Jan 01, 2023", - "Jan 1, 2023", - "01 Jan 2023", - "1 Jan 2023", - "January 01, 2023", - "January 1, 2023", - "01 January 2023", - "1 January 2023", - "January 1, 2023 12:34:56", - - # Time only formats - "12:34:56", - "12:34", - - # Special formats - "20230101123456", - "20230101", -]) +@pytest.mark.parametrize( + "date_string", + [ + # ISO 8601 formats with timezone + "2023-01-01T12:34:56.789012+0000", + "2023-01-01T12:34:56.789012Z", + "2023-01-01T12:34:56+0000", + "2023-01-01T12:34:56Z", + "2023-01-01T12:34+0000", + "2023-01-01T12:34Z", + # ISO 8601 formats without timezone + "2023-01-01T12:34:56.789012", + "2023-01-01T12:34:56", + "2023-01-01T12:34", + # Standard datetime formats with microseconds + "2023-01-01 12:34:56.789012+0000", + "2023-01-01 12:34:56.789012", + "2023-01-01 12:34:56.789", + # Standard datetime formats + "2023-01-01 12:34:56+0000", + "2023-01-01 12:34:56", + "2023-01-01 12:34", + # Date only formats + "2023-01-01", + "2023/01/01", + # US date formats + "01/01/2023 12:34:56.789012", + "01/01/2023 12:34:56", + "01/01/2023 12:34", + "01/01/2023", + "1/1/2023", + "1/1/23", + # UK/European date formats + "01-01-2023 12:34:56", + "01-01-2023", + "01.01.2023", + "01.01.2023 12:34:56", + # Month name formats + "Jan 01, 2023", + "Jan 1, 2023", + "01 Jan 2023", + "1 Jan 2023", + "January 01, 2023", + "January 1, 2023", + "01 January 2023", + "1 January 2023", + "January 1, 2023 12:34:56", + # Time only formats + "12:34:56", + "12:34", + # Special formats + "20230101123456", + "20230101", + ], +) def test_format_matches_input(date_string): """Test that the inferred format correctly reproduces the input when used with strftime""" try: @@ -76,42 +70,47 @@ def test_format_matches_input(date_string): # Check that the reformatted date matches the input # Special handling for 'T' separator in ISO 8601 formats - if 'T' in date_string and ' ' in reformatted: + if "T" in date_string and " " in reformatted: # Replace space with 'T' at the right position - t_index = date_string.find('T') - reformatted = reformatted[:t_index] + 'T' + reformatted[t_index + 1:] + t_index = date_string.find("T") + reformatted = reformatted[:t_index] + "T" + reformatted[t_index + 1 :] # Special handling for milliseconds vs microseconds - if '.' in date_string and len(date_string.split('.')[-1]) < 6 and reformatted.endswith('000'): + if ( + "." in date_string + and len(date_string.split(".")[-1]) < 6 + and reformatted.endswith("000") + ): # For dates with milliseconds, strip trailing zeros from microseconds reformatted_ms = reformatted[:-3] - assert reformatted_ms == date_string, \ - f"Failed for: {date_string}\nFormat: {fmt}\nReformatted: {reformatted_ms}" + assert ( + reformatted_ms == date_string + ), f"Failed for: {date_string}\nFormat: {fmt}\nReformatted: {reformatted_ms}" else: - assert reformatted == date_string, \ - f"Failed for: {date_string}\nFormat: {fmt}\nReformatted: {reformatted}" + assert ( + reformatted == date_string + ), f"Failed for: {date_string}\nFormat: {fmt}\nReformatted: {reformatted}" except (ValueError, TypeError, OverflowError) as e: pytest.fail(f"Error parsing {date_string}: {e}") -@pytest.mark.parametrize("date_string,expected_format", [ - # Leap year day - ("2024-02-29", "%Y-%m-%d"), - - # Very old date - ("1800-01-01", "%Y-%m-%d"), - - # Future date - ("2100-01-01", "%Y-%m-%d"), - - # Extreme timezone - ("2023-01-01T12:00:00+1400", "%Y-%m-%dT%H:%M:%S%z"), - ("2023-01-01T12:00:00-1400", "%Y-%m-%dT%H:%M:%S%z"), - - # Midnight and special times - ("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"), - ("2023-01-01 23:59:59", "%Y-%m-%d %H:%M:%S"), -]) +@pytest.mark.parametrize( + "date_string,expected_format", + [ + # Leap year day + ("2024-02-29", "%Y-%m-%d"), + # Very old date + ("1800-01-01", "%Y-%m-%d"), + # Future date + ("2100-01-01", "%Y-%m-%d"), + # Extreme timezone + ("2023-01-01T12:00:00+1400", "%Y-%m-%dT%H:%M:%S%z"), + ("2023-01-01T12:00:00-1400", "%Y-%m-%dT%H:%M:%S%z"), + # Midnight and special times + ("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"), + ("2023-01-01 23:59:59", "%Y-%m-%d %H:%M:%S"), + ], +) def test_edge_cases(date_string, expected_format): """Test edge cases and unusual formats""" try: @@ -119,32 +118,34 @@ def test_edge_cases(date_string, expected_format): fmt = infer_datetime_format(date_string) # Check against expected format - assert fmt == expected_format, \ - f"Format for {date_string} was {fmt}, expected {expected_format}" + assert ( + fmt == expected_format + ), f"Format for {date_string} was {fmt}, expected {expected_format}" # Verify it works as expected reformatted = Timestamp(date_string).strftime(fmt) - assert reformatted == date_string, \ - f"Failed for: {date_string}\nFormat: {fmt}\nReformatted: {reformatted}" + assert ( + reformatted == date_string + ), f"Failed for: {date_string}\nFormat: {fmt}\nReformatted: {reformatted}" except (ValueError, TypeError, OverflowError) as e: pytest.fail(f"Error parsing {date_string}: {e}") -@pytest.mark.parametrize("input_string", [ - # Completely non-date string - "not a date", - - # Malformed dates - "2023-13-01", # Invalid month - "2023-01-32", # Invalid day - "2023/02/30", # Invalid day for February - - # Ambiguous formats (function should still return something) - "01/02/03", # Ambiguous MM/DD/YY or DD/MM/YY - - # Empty string - "", -]) +@pytest.mark.parametrize( + "input_string", + [ + # Completely non-date string + "not a date", + # Malformed dates + "2023-13-01", # Invalid month + "2023-01-32", # Invalid day + "2023/02/30", # Invalid day for February + # Ambiguous formats (function should still return something) + "01/02/03", # Ambiguous MM/DD/YY or DD/MM/YY + # Empty string + "", + ], +) def test_invalid_inputs(input_string): """Test that invalid inputs are handled appropriately""" # The function should return a default format without raising an exception @@ -152,22 +153,24 @@ def test_invalid_inputs(input_string): fmt = infer_datetime_format(input_string) assert isinstance(fmt, str), "Function should return a string format" except Exception as e: - pytest.fail(f"Function raised {type(e).__name__} for input '{input_string}': {e}") - - -@pytest.mark.parametrize("date_string,expected_format", [ - # RFC 822 format - ("Wed, 02 Oct 2002 13:00:00 GMT", "%a, %d %b %Y %H:%M:%S GMT"), - - # Excel/Lotus style - ("2023-1-1", "%Y-%m-%d"), - - # Date with weekday - ("Monday, January 1, 2023", "%A, %B %d, %Y"), - - # 12-hour clock with AM/PM - ("2023-01-01 01:30:00 PM", "%Y-%m-%d %I:%M:%S %p"), -]) + pytest.fail( + f"Function raised {type(e).__name__} for input '{input_string}': {e}" + ) + + +@pytest.mark.parametrize( + "date_string,expected_format", + [ + # RFC 822 format + ("Wed, 02 Oct 2002 13:00:00 GMT", "%a, %d %b %Y %H:%M:%S GMT"), + # Excel/Lotus style + ("2023-1-1", "%Y-%m-%d"), + # Date with weekday + ("Monday, January 1, 2023", "%A, %B %d, %Y"), + # 12-hour clock with AM/PM + ("2023-01-01 01:30:00 PM", "%Y-%m-%d %I:%M:%S %p"), + ], +) def test_custom_formats(date_string, expected_format): """Test some custom or unusual but valid datetime formats""" try: diff --git a/python/tests/intervals/metrics/merger_tests.py b/python/tests/intervals/metrics/merger_tests.py index 7581f4f4..39912369 100644 --- a/python/tests/intervals/metrics/merger_tests.py +++ b/python/tests/intervals/metrics/merger_tests.py @@ -39,17 +39,13 @@ def __init__(self, data, metric_fields): @pytest.fixture def test_data(): """Fixture to create test dataframes""" - data1 = pd.DataFrame({ - 'time': [1, 2, 3], - 'metric1': [10, 20, 30], - 'metric2': [100, 200, 300] - }) - - data2 = pd.DataFrame({ - 'time': [1, 2, 3], - 'metric1': [1, 2, 3], - 'metric2': [10, 20, 30] - }) + data1 = pd.DataFrame( + {"time": [1, 2, 3], "metric1": [10, 20, 30], "metric2": [100, 200, 300]} + ) + + data2 = pd.DataFrame( + {"time": [1, 2, 3], "metric1": [1, 2, 3], "metric2": [10, 20, 30]} + ) return data1, data2 @@ -58,8 +54,8 @@ def test_data(): def test_intervals(test_data): """Fixture to create mock intervals""" data1, data2 = test_data - interval1 = MockInterval(data1, ['metric1', 'metric2']) - interval2 = MockInterval(data2, ['metric1', 'metric2']) + interval1 = MockInterval(data1, ["metric1", "metric2"]) + interval2 = MockInterval(data2, ["metric1", "metric2"]) return interval1, interval2 @@ -68,8 +64,8 @@ def merge_config(): """Fixture to create a merge config with test strategies""" config = MetricMergeConfig() strategy = TestMetricStrategy() - config.set_strategy('metric1', strategy) - config.set_strategy('metric2', strategy) + config.set_strategy("metric1", strategy) + config.set_strategy("metric2", strategy) return config @@ -93,14 +89,14 @@ def test_merge_success(self, test_intervals, merge_config): result = merger.merge(interval1, interval2) # The test strategy adds values, so we expect these sums - assert result['metric1'].tolist() == [11, 22, 33] - assert result['metric2'].tolist() == [110, 220, 330] + assert result["metric1"].tolist() == [11, 22, 33] + assert result["metric2"].tolist() == [110, 220, 330] def test_merge_different_metric_fields(self, test_data): """Test merging intervals with different metric fields""" data1, data2 = test_data - interval1 = MockInterval(data1, ['metric1', 'metric2']) - interval2 = MockInterval(data2, ['metric1']) + interval1 = MockInterval(data1, ["metric1", "metric2"]) + interval2 = MockInterval(data2, ["metric1"]) merger = MetricMerger() @@ -115,15 +111,17 @@ def test_failing_strategy(self, test_intervals): config = MetricMergeConfig() failing_strategy = FailingMetricStrategy() - config.set_strategy('metric1', failing_strategy) - config.set_strategy('metric2', TestMetricStrategy()) + config.set_strategy("metric1", failing_strategy) + config.set_strategy("metric2", TestMetricStrategy()) merger = MetricMerger(config) with pytest.raises(ValueError) as excinfo: merger.merge(interval1, interval2) - assert "Strategy FailingMetricStrategy failed: Validation failed" in str(excinfo.value) + assert "Strategy FailingMetricStrategy failed: Validation failed" in str( + excinfo.value + ) def test_apply_merge_strategy(self): """Test the _apply_merge_strategy static method""" @@ -144,7 +142,9 @@ def test_apply_merge_strategy_failure(self): with pytest.raises(ValueError) as excinfo: MetricMerger._apply_merge_strategy(value1, value2, strategy) - assert "Strategy FailingMetricStrategy failed: Validation failed" in str(excinfo.value) + assert "Strategy FailingMetricStrategy failed: Validation failed" in str( + excinfo.value + ) class TestDefaultMetricMerger: @@ -162,5 +162,5 @@ def test_merge_functionality(self, test_intervals, merge_config): result = merger.merge(interval1, interval2) # The test strategy adds values, so we expect these sums - assert result['metric1'].tolist() == [11, 22, 33] - assert result['metric2'].tolist() == [110, 220, 330] + assert result["metric1"].tolist() == [11, 22, 33] + assert result["metric2"].tolist() == [110, 220, 330] diff --git a/python/tests/intervals/metrics/operations_tests.py b/python/tests/intervals/metrics/operations_tests.py index bf8201f2..6d4c8808 100644 --- a/python/tests/intervals/metrics/operations_tests.py +++ b/python/tests/intervals/metrics/operations_tests.py @@ -23,7 +23,9 @@ def normalize(self, interval): assert isinstance(normalizer, MetricNormalizer) # Test that the normalize method works as expected - result = normalizer.normalize(None) # Passing None as we've mocked the Interval dependency + result = normalizer.normalize( + None + ) # Passing None as we've mocked the Interval dependency assert isinstance(result, Series) assert result["value"] == 1.0 @@ -38,14 +40,10 @@ def test_default_initialization(self): def test_custom_initialization(self): """Test initialization with custom arguments""" default_strategy = KeepFirstStrategy() - column_strategies = { - "col1": KeepFirstStrategy(), - "col2": KeepLastStrategy() - } + column_strategies = {"col1": KeepFirstStrategy(), "col2": KeepLastStrategy()} config = MetricMergeConfig( - default_strategy=default_strategy, - column_strategies=column_strategies + default_strategy=default_strategy, column_strategies=column_strategies ) assert config.default_strategy is default_strategy @@ -55,15 +53,18 @@ def test_validate_strategies_default_strategy(self): """Test validation of default_strategy""" with pytest.raises(ValueError) as excinfo: MetricMergeConfig(default_strategy="not a strategy") - assert "default_strategy must be an instance of MetricMergeStrategy" in str(excinfo.value) + assert "default_strategy must be an instance of MetricMergeStrategy" in str( + excinfo.value + ) def test_validate_strategies_column_strategies(self): """Test validation of column_strategies""" with pytest.raises(ValueError) as excinfo: - MetricMergeConfig( - column_strategies={"col1": "not a strategy"} - ) - assert "Strategy for column col1 must be an instance of MetricMergeStrategy" in str(excinfo.value) + MetricMergeConfig(column_strategies={"col1": "not a strategy"}) + assert ( + "Strategy for column col1 must be an instance of MetricMergeStrategy" + in str(excinfo.value) + ) def test_get_strategy_existing_column(self): """Test getting a strategy for a column that has a specific strategy set""" @@ -71,8 +72,7 @@ def test_get_strategy_existing_column(self): col1_strategy = KeepLastStrategy() config = MetricMergeConfig( - default_strategy=default_strategy, - column_strategies={"col1": col1_strategy} + default_strategy=default_strategy, column_strategies={"col1": col1_strategy} ) strategy = config.get_strategy("col1") @@ -103,16 +103,17 @@ def test_set_strategy_invalid(self): with pytest.raises(ValueError) as excinfo: config.set_strategy("col1", "not a strategy") - assert "The provided strategy must be an instance of MetricMergeStrategy" in str(excinfo.value) + assert ( + "The provided strategy must be an instance of MetricMergeStrategy" + in str(excinfo.value) + ) def test_set_strategy_override(self): """Test overriding an existing strategy for a column""" initial_strategy = KeepFirstStrategy() new_strategy = KeepLastStrategy() - config = MetricMergeConfig( - column_strategies={"col1": initial_strategy} - ) + config = MetricMergeConfig(column_strategies={"col1": initial_strategy}) config.set_strategy("col1", new_strategy) diff --git a/python/tests/intervals/metrics/strategies_tests.py b/python/tests/intervals/metrics/strategies_tests.py index e035732b..0e07a3dd 100644 --- a/python/tests/intervals/metrics/strategies_tests.py +++ b/python/tests/intervals/metrics/strategies_tests.py @@ -77,8 +77,16 @@ def test_keep_first_strategy_series(self, series_values): assert isinstance(result, Series) # Convert scalar to Series if needed for comparison - v1 = value1 if isinstance(value1, Series) else Series([value1] * len(result)) - v2 = value2 if isinstance(value2, Series) else Series([value2] * len(result)) + v1 = ( + value1 + if isinstance(value1, Series) + else Series([value1] * len(result)) + ) + v2 = ( + value2 + if isinstance(value2, Series) + else Series([value2] * len(result)) + ) # Apply the expected logic for comparison without using mask indexing expected = v1.copy() @@ -126,8 +134,16 @@ def test_keep_last_strategy_series(self, series_values): assert isinstance(result, Series) # Convert scalar to Series if needed for comparison - v1 = value1 if isinstance(value1, Series) else Series([value1] * len(result)) - v2 = value2 if isinstance(value2, Series) else Series([value2] * len(result)) + v1 = ( + value1 + if isinstance(value1, Series) + else Series([value1] * len(result)) + ) + v2 = ( + value2 + if isinstance(value2, Series) + else Series([value2] * len(result)) + ) # Apply the expected logic for comparison without using mask indexing expected = v2.copy() @@ -160,8 +176,9 @@ def test_sum_strategy_validation(self, non_numeric_values): strategy = SumStrategy() for value1, value2 in non_numeric_values: - if (isinstance(value1, (int, float)) or isna(value1)) and \ - (isinstance(value2, (int, float)) or isna(value2)): + if (isinstance(value1, (int, float)) or isna(value1)) and ( + isinstance(value2, (int, float)) or isna(value2) + ): continue # Skip valid numeric combinations with pytest.raises(ValueError, match="SumStrategy requires numeric values"): @@ -282,14 +299,14 @@ def test_validation_for_non_numeric_values_sum_strategy(self): # Test case 1: String in Series s1 = Series([1, 2, 3]) - s2 = Series(['a', 2, 3]) # Contains non-numeric value + s2 = Series(["a", 2, 3]) # Contains non-numeric value with pytest.raises(ValueError, match="SumStrategy requires numeric values"): strategy.merge(s1, s2) # Test case 2: String as scalar s1 = Series([1, 2, 3]) - scalar = 'a' # Non-numeric scalar + scalar = "a" # Non-numeric scalar with pytest.raises(ValueError, match="SumStrategy requires numeric values"): strategy.merge(s1, scalar) @@ -425,11 +442,14 @@ def test_average_strategy_validation(self, non_numeric_values): strategy = AverageStrategy() for value1, value2 in non_numeric_values: - if (isinstance(value1, (int, float)) or isna(value1)) and \ - (isinstance(value2, (int, float)) or isna(value2)): + if (isinstance(value1, (int, float)) or isna(value1)) and ( + isinstance(value2, (int, float)) or isna(value2) + ): continue # Skip valid numeric combinations - with pytest.raises(ValueError, match="AverageStrategy requires numeric values"): + with pytest.raises( + ValueError, match="AverageStrategy requires numeric values" + ): strategy.validate(value1, value2) strategy.merge(value1, value2) diff --git a/python/tests/intervals/overlap/detection_tests.py b/python/tests/intervals/overlap/detection_tests.py index 56ec4555..44465c10 100644 --- a/python/tests/intervals/overlap/detection_tests.py +++ b/python/tests/intervals/overlap/detection_tests.py @@ -3,9 +3,20 @@ from tempo.intervals.core.interval import Interval from tempo.intervals.overlap.detection import ( - MetricsEquivalentChecker, BeforeChecker, MeetsChecker, OverlapsChecker, - StartsChecker, DuringChecker, FinishesChecker, EqualsChecker, ContainsChecker, - StartedByChecker, FinishedByChecker, OverlappedByChecker, MetByChecker, AfterChecker + MetricsEquivalentChecker, + BeforeChecker, + MeetsChecker, + OverlapsChecker, + StartsChecker, + DuringChecker, + FinishesChecker, + EqualsChecker, + ContainsChecker, + StartedByChecker, + FinishedByChecker, + OverlappedByChecker, + MetByChecker, + AfterChecker, ) @@ -70,23 +81,24 @@ def mock_values(): return v1, v2, v3, v4, v5 + @pytest.fixture def checkers(): return { - 'metrics_equivalent': MetricsEquivalentChecker(), - 'before': BeforeChecker(), - 'meets': MeetsChecker(), - 'overlaps': OverlapsChecker(), - 'starts': StartsChecker(), - 'during': DuringChecker(), - 'finishes': FinishesChecker(), - 'equals': EqualsChecker(), - 'contains': ContainsChecker(), - 'started_by': StartedByChecker(), - 'finished_by': FinishedByChecker(), - 'overlapped_by': OverlappedByChecker(), - 'met_by': MetByChecker(), - 'after': AfterChecker() + "metrics_equivalent": MetricsEquivalentChecker(), + "before": BeforeChecker(), + "meets": MeetsChecker(), + "overlaps": OverlapsChecker(), + "starts": StartsChecker(), + "during": DuringChecker(), + "finishes": FinishesChecker(), + "equals": EqualsChecker(), + "contains": ContainsChecker(), + "started_by": StartedByChecker(), + "finished_by": FinishedByChecker(), + "overlapped_by": OverlappedByChecker(), + "met_by": MetByChecker(), + "after": AfterChecker(), } @@ -103,7 +115,7 @@ class TestBasicRelations: def test_before_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - before_checker = checkers['before'] + before_checker = checkers["before"] # Case 1: A before B interval_a = create_interval(v1, v2) @@ -122,7 +134,7 @@ def test_before_checker(self, mock_values, checkers, create_interval): def test_meets_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - meets_checker = checkers['meets'] + meets_checker = checkers["meets"] # Case 1: A meets B interval_a = create_interval(v1, v3) @@ -141,7 +153,7 @@ def test_meets_checker(self, mock_values, checkers, create_interval): def test_after_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - after_checker = checkers['after'] + after_checker = checkers["after"] # Case 1: A after B interval_a = create_interval(v3, v4) @@ -160,7 +172,7 @@ def test_after_checker(self, mock_values, checkers, create_interval): def test_met_by_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - met_by_checker = checkers['met_by'] + met_by_checker = checkers["met_by"] # Case 1: A met by B interval_a = create_interval(v3, v5) @@ -183,7 +195,7 @@ class TestOverlapRelations: def test_overlaps_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - overlaps_checker = checkers['overlaps'] + overlaps_checker = checkers["overlaps"] # Case 1: A overlaps B interval_a = create_interval(v1, v3) @@ -202,7 +214,7 @@ def test_overlaps_checker(self, mock_values, checkers, create_interval): def test_overlapped_by_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - overlapped_by_checker = checkers['overlapped_by'] + overlapped_by_checker = checkers["overlapped_by"] # Case 1: A overlapped by B interval_a = create_interval(v2, v4) @@ -225,7 +237,7 @@ class TestContainmentRelations: def test_during_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - during_checker = checkers['during'] + during_checker = checkers["during"] # Case 1: A during B interval_a = create_interval(v2, v3) @@ -244,7 +256,7 @@ def test_during_checker(self, mock_values, checkers, create_interval): def test_contains_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - contains_checker = checkers['contains'] + contains_checker = checkers["contains"] # Case 1: A contains B interval_a = create_interval(v1, v4) @@ -267,7 +279,7 @@ class TestBoundaryRelations: def test_starts_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - starts_checker = checkers['starts'] + starts_checker = checkers["starts"] # Case 1: A starts B interval_a = create_interval(v1, v3) @@ -286,7 +298,7 @@ def test_starts_checker(self, mock_values, checkers, create_interval): def test_started_by_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - started_by_checker = checkers['started_by'] + started_by_checker = checkers["started_by"] # Case 1: A started by B interval_a = create_interval(v1, v4) @@ -305,7 +317,7 @@ def test_started_by_checker(self, mock_values, checkers, create_interval): def test_finishes_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - finishes_checker = checkers['finishes'] + finishes_checker = checkers["finishes"] # Case 1: A finishes B interval_a = create_interval(v2, v4) @@ -324,7 +336,7 @@ def test_finishes_checker(self, mock_values, checkers, create_interval): def test_finished_by_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - finished_by_checker = checkers['finished_by'] + finished_by_checker = checkers["finished_by"] # Case 1: A finished by B interval_a = create_interval(v1, v4) @@ -347,7 +359,7 @@ class TestEquivalenceRelations: def test_equals_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - equals_checker = checkers['equals'] + equals_checker = checkers["equals"] # Case 1: A equals B interval_a = create_interval(v1, v4) @@ -366,7 +378,7 @@ def test_equals_checker(self, mock_values, checkers, create_interval): def test_metrics_equivalent_checker(self, mock_values, checkers, create_interval): v1, v2, v3, v4, v5 = mock_values - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Case 1: Same metrics, overlapping intervals - should match interval_a = create_interval(v1, v3, {"product": "A", "region": "US"}) @@ -408,37 +420,31 @@ def test_all_inverse_relations(self, mock_values, checkers, create_interval): # Define all inverse relation pairs inverse_pairs = [ - ('before', 'after'), - ('meets', 'met_by'), - ('overlaps', 'overlapped_by'), - ('starts', 'started_by'), - ('during', 'contains'), - ('finishes', 'finished_by'), - ('equals', 'equals') # equals is its own inverse + ("before", "after"), + ("meets", "met_by"), + ("overlaps", "overlapped_by"), + ("starts", "started_by"), + ("during", "contains"), + ("finishes", "finished_by"), + ("equals", "equals"), # equals is its own inverse ] # Create a diverse set of interval pairs to test each relation interval_pairs = [ # For before/after (create_interval(v1, v2), create_interval(v3, v4)), - # For meets/met_by (create_interval(v1, v2), create_interval(v2, v3)), - # For overlaps/overlapped_by (create_interval(v1, v3), create_interval(v2, v4)), - # For starts/started_by (create_interval(v1, v2), create_interval(v1, v3)), - # For during/contains (create_interval(v2, v3), create_interval(v1, v4)), - # For finishes/finished_by (create_interval(v3, v4), create_interval(v2, v4)), - # For equals - (create_interval(v2, v3), create_interval(v2, v3)) + (create_interval(v2, v3), create_interval(v2, v3)), ] # Test each inverse pair with appropriate intervals @@ -446,12 +452,14 @@ def test_all_inverse_relations(self, mock_values, checkers, create_interval): interval_a, interval_b = interval_pairs[i] # Verify relation A → B - assert checkers[relation_a].check(interval_a, interval_b), \ - f"Expected {relation_a} to be true from A to B" + assert checkers[relation_a].check( + interval_a, interval_b + ), f"Expected {relation_a} to be true from A to B" # Verify inverse relation B → A - assert checkers[relation_b].check(interval_b, interval_a), \ - f"Expected {relation_b} to be true from B to A (inverse of {relation_a})" + assert checkers[relation_b].check( + interval_b, interval_a + ), f"Expected {relation_b} to be true from B to A (inverse of {relation_a})" def test_inverse_property_systematic(self, mock_values, checkers, create_interval): """ @@ -462,19 +470,19 @@ def test_inverse_property_systematic(self, mock_values, checkers, create_interva # Define the inverse relationship map inverse_map = { - 'before': 'after', - 'meets': 'met_by', - 'overlaps': 'overlapped_by', - 'starts': 'started_by', - 'during': 'contains', - 'finishes': 'finished_by', - 'equals': 'equals', - 'finished_by': 'finishes', - 'contains': 'during', - 'started_by': 'starts', - 'overlapped_by': 'overlaps', - 'met_by': 'meets', - 'after': 'before' + "before": "after", + "meets": "met_by", + "overlaps": "overlapped_by", + "starts": "started_by", + "during": "contains", + "finishes": "finished_by", + "equals": "equals", + "finished_by": "finishes", + "contains": "during", + "started_by": "starts", + "overlapped_by": "overlaps", + "met_by": "meets", + "after": "before", } # Create a diverse set of intervals to test @@ -489,7 +497,7 @@ def test_inverse_property_systematic(self, mock_values, checkers, create_interva create_interval(v1, v5), # [v1, v5] create_interval(v1, v1), # Point at v1 create_interval(v3, v3), # Point at v3 - create_interval(v5, v5) # Point at v5 + create_interval(v5, v5), # Point at v5 ] # List of relations to check (excluding metrics_equivalent) @@ -511,7 +519,8 @@ def test_inverse_property_systematic(self, mock_values, checkers, create_interva if len(true_relations_a_to_b) > 1: # Debug info for troubleshooting print( - f"Multiple relations found for intervals: {interval_a._start.position}-{interval_a._end.position} and {interval_b._start.position}-{interval_b._end.position}") + f"Multiple relations found for intervals: {interval_a._start.position}-{interval_a._end.position} and {interval_b._start.position}-{interval_b._end.position}" + ) print(f"Relations: {true_relations_a_to_b}") # For each true relation, check the inverse @@ -520,8 +529,9 @@ def test_inverse_property_systematic(self, mock_values, checkers, create_interva inverse_relation = inverse_map[relation_a_to_b] # Verify the inverse relation holds from B to A - assert checkers[inverse_relation].check(interval_b, interval_a), \ - f"If interval_a {relation_a_to_b} interval_b, then interval_b should {inverse_relation} interval_a" + assert checkers[inverse_relation].check( + interval_b, interval_a + ), f"If interval_a {relation_a_to_b} interval_b, then interval_b should {inverse_relation} interval_a" # If we found exactly one relation, verify no other relation holds if len(true_relations_a_to_b) == 1: @@ -535,18 +545,22 @@ def test_inverse_property_systematic(self, mock_values, checkers, create_interva true_relations_b_to_a.append(relation) # Should be exactly one true relation from B to A - assert len(true_relations_b_to_a) == 1, \ - f"Expected exactly one true relation from B to A, found {len(true_relations_b_to_a)}: {true_relations_b_to_a}" + assert ( + len(true_relations_b_to_a) == 1 + ), f"Expected exactly one true relation from B to A, found {len(true_relations_b_to_a)}: {true_relations_b_to_a}" # That one relation should be the inverse - assert true_relations_b_to_a[0] == inverse_relation, \ - f"Expected inverse relation {inverse_relation}, found {true_relations_b_to_a[0]}" + assert ( + true_relations_b_to_a[0] == inverse_relation + ), f"Expected inverse relation {inverse_relation}, found {true_relations_b_to_a[0]}" class TestBoundaryEdgeCases: """Tests for complex boundary combinations and edge cases with shared endpoints""" - def test_zero_length_intervals_special_cases(self, mock_values, checkers, create_interval): + def test_zero_length_intervals_special_cases( + self, mock_values, checkers, create_interval + ): """Test relations involving multiple zero-length intervals at different positions""" v1, v2, v3, v4, v5 = mock_values @@ -559,17 +573,19 @@ def test_zero_length_intervals_special_cases(self, mock_values, checkers, create point_v1_duplicate = create_interval(v1, v1) # Test equality of coincident points - assert checkers['equals'].check(point_v1, point_v1_duplicate) + assert checkers["equals"].check(point_v1, point_v1_duplicate) # Test before/after with points - assert checkers['before'].check(point_v1, point_v2) - assert checkers['after'].check(point_v3, point_v2) + assert checkers["before"].check(point_v1, point_v2) + assert checkers["after"].check(point_v3, point_v2) # A point can't 'meet' another point - assert not checkers['meets'].check(point_v1, point_v2) - assert not checkers['met_by'].check(point_v2, point_v1) + assert not checkers["meets"].check(point_v1, point_v2) + assert not checkers["met_by"].check(point_v2, point_v1) - def test_adjacent_intervals_combinations(self, mock_values, checkers, create_interval): + def test_adjacent_intervals_combinations( + self, mock_values, checkers, create_interval + ): """Test complex combinations of adjacent intervals""" v1, v2, v3, v4, v5 = mock_values @@ -580,14 +596,14 @@ def test_adjacent_intervals_combinations(self, mock_values, checkers, create_int interval_4 = create_interval(v4, v5) # Chain of meeting intervals - assert checkers['meets'].check(interval_1, interval_2) - assert checkers['meets'].check(interval_2, interval_3) - assert checkers['meets'].check(interval_3, interval_4) + assert checkers["meets"].check(interval_1, interval_2) + assert checkers["meets"].check(interval_2, interval_3) + assert checkers["meets"].check(interval_3, interval_4) # Relation between non-adjacent intervals in the chain - assert checkers['before'].check(interval_1, interval_3) - assert checkers['before'].check(interval_2, interval_4) - assert checkers['before'].check(interval_1, interval_4) + assert checkers["before"].check(interval_1, interval_3) + assert checkers["before"].check(interval_2, interval_4) + assert checkers["before"].check(interval_1, interval_4) # Create a larger interval spanning multiple adjacent intervals span_1_3 = create_interval(v1, v3) @@ -595,11 +611,13 @@ def test_adjacent_intervals_combinations(self, mock_values, checkers, create_int span_1_4 = create_interval(v1, v4) # Test relationships with spanning intervals - assert checkers['finished_by'].check(span_1_3, interval_2) - assert checkers['started_by'].check(span_2_4, interval_2) - assert checkers['contains'].check(span_1_4, interval_2) + assert checkers["finished_by"].check(span_1_3, interval_2) + assert checkers["started_by"].check(span_2_4, interval_2) + assert checkers["contains"].check(span_1_4, interval_2) - def test_nested_intervals_shared_boundaries(self, mock_values, checkers, create_interval): + def test_nested_intervals_shared_boundaries( + self, mock_values, checkers, create_interval + ): """Test nested intervals with shared boundaries""" v1, v2, v3, v4, v5 = mock_values @@ -615,28 +633,28 @@ def test_nested_intervals_shared_boundaries(self, mock_values, checkers, create_ inner = create_interval(v3, v3) # Test relations with outer interval - assert checkers['started_by'].check(outer, middle_same_start) - assert checkers['finished_by'].check(outer, middle_same_end) - assert checkers['contains'].check(outer, middle_inside) - assert checkers['contains'].check(outer, inner) + assert checkers["started_by"].check(outer, middle_same_start) + assert checkers["finished_by"].check(outer, middle_same_end) + assert checkers["contains"].check(outer, middle_inside) + assert checkers["contains"].check(outer, inner) # Since they share the same end point, this is a "finished_by" relationship, not "contains" - assert checkers['finished_by'].check(middle_same_start, middle_inside) + assert checkers["finished_by"].check(middle_same_start, middle_inside) # And correspondingly, middle_inside "finishes" middle_same_start - assert checkers['finishes'].check(middle_inside, middle_same_start) + assert checkers["finishes"].check(middle_inside, middle_same_start) # Test other middle interval relations - assert checkers['overlaps'].check(middle_same_start, middle_same_end) + assert checkers["overlaps"].check(middle_same_start, middle_same_end) # For middle_inside and middle_same_end, let's check the proper relation # They share a start point, so middle_inside "starts" middle_same_end - assert checkers['starts'].check(middle_inside, middle_same_end) + assert checkers["starts"].check(middle_inside, middle_same_end) # Multiple relationships with inner - assert checkers['contains'].check(middle_same_start, inner) - assert checkers['contains'].check(middle_same_end, inner) - assert checkers['contains'].check(middle_inside, inner) + assert checkers["contains"].check(middle_same_start, inner) + assert checkers["contains"].check(middle_same_end, inner) + assert checkers["contains"].check(middle_inside, inner) def test_complex_chain_of_relations(self, mock_values, checkers, create_interval): """Test a complex chain of intervals with various relations""" @@ -646,37 +664,43 @@ def test_complex_chain_of_relations(self, mock_values, checkers, create_interval interval_1 = create_interval(v1, v2) # [v1,v2] interval_2 = create_interval(v2, v3) # [v2,v3] - meets interval_1 interval_3 = create_interval(v2, v4) # [v2,v4] - started_by interval_2 - interval_4 = create_interval(v3, v4) # [v3,v4] - during interval_3, after interval_2 - interval_5 = create_interval(v4, v5) # [v4,v5] - meets interval_3, meets interval_4 + interval_4 = create_interval( + v3, v4 + ) # [v3,v4] - during interval_3, after interval_2 + interval_5 = create_interval( + v4, v5 + ) # [v4,v5] - meets interval_3, meets interval_4 # Test the expected relations in the chain - assert checkers['meets'].check(interval_1, interval_2) - assert checkers['started_by'].check(interval_3, interval_2) - assert checkers['finishes'].check(interval_4, interval_3) - assert checkers['meets'].check(interval_3, interval_5) - assert checkers['meets'].check(interval_4, interval_5) + assert checkers["meets"].check(interval_1, interval_2) + assert checkers["started_by"].check(interval_3, interval_2) + assert checkers["finishes"].check(interval_4, interval_3) + assert checkers["meets"].check(interval_3, interval_5) + assert checkers["meets"].check(interval_4, interval_5) # Test more complex relations in the chain # FIXED: interval_1 MEETS interval_3 (not before) - assert checkers['meets'].check(interval_1, interval_3) + assert checkers["meets"].check(interval_1, interval_3) # interval_1 is before interval_4 (no shared endpoints) - assert checkers['before'].check(interval_1, interval_4) + assert checkers["before"].check(interval_1, interval_4) # interval_1 is before interval_5 (no shared endpoints) - assert checkers['before'].check(interval_1, interval_5) + assert checkers["before"].check(interval_1, interval_5) # interval_2 is before interval_5 (no shared endpoints) - assert checkers['before'].check(interval_2, interval_5) + assert checkers["before"].check(interval_2, interval_5) # interval_4 overlapped_by interval_3 (interval_3 starts earlier, ends at same time) - assert checkers['finishes'].check(interval_4, interval_3) + assert checkers["finishes"].check(interval_4, interval_3) class TestSpecialBoundaryScenarios: """Tests for special scenarios with intervals sharing multiple boundaries""" - def test_intervals_sharing_both_boundaries(self, mock_values, checkers, create_interval): + def test_intervals_sharing_both_boundaries( + self, mock_values, checkers, create_interval + ): """Test intervals that share both start and end points (equals)""" v1, v3, v5 = mock_values[0], mock_values[2], mock_values[4] @@ -689,16 +713,18 @@ def test_intervals_sharing_both_boundaries(self, mock_values, checkers, create_i interval_4 = create_interval(v3, v5) # Test equals relation - assert checkers['equals'].check(interval_1, interval_2) - assert checkers['equals'].check(interval_3, interval_4) + assert checkers["equals"].check(interval_1, interval_2) + assert checkers["equals"].check(interval_3, interval_4) # Test that no other relation holds for relation in checkers: - if relation != 'equals' and relation != 'metrics_equivalent': + if relation != "equals" and relation != "metrics_equivalent": assert not checkers[relation].check(interval_1, interval_2) assert not checkers[relation].check(interval_3, interval_4) - def test_intervals_with_multi_boundary_sharing(self, mock_values, checkers, create_interval): + def test_intervals_with_multi_boundary_sharing( + self, mock_values, checkers, create_interval + ): """Test complex scenarios where multiple intervals share various boundaries""" v1, v2, v3, v4, v5 = mock_values @@ -707,14 +733,14 @@ def test_intervals_with_multi_boundary_sharing(self, mock_values, checkers, crea create_interval(v1, v2), create_interval(v1, v3), create_interval(v1, v4), - create_interval(v1, v5) + create_interval(v1, v5), ] common_end = [ create_interval(v1, v5), create_interval(v2, v5), create_interval(v3, v5), - create_interval(v4, v5) + create_interval(v4, v5), ] # Test relations between intervals with common start @@ -722,22 +748,26 @@ def test_intervals_with_multi_boundary_sharing(self, mock_values, checkers, crea for j in range(i + 1, len(common_start)): if i == j: continue - assert checkers['starts'].check(common_start[i], common_start[j]) or \ - checkers['started_by'].check(common_start[i], common_start[j]) + assert checkers["starts"].check( + common_start[i], common_start[j] + ) or checkers["started_by"].check(common_start[i], common_start[j]) # Test relations between intervals with common end for i in range(len(common_end)): for j in range(i + 1, len(common_end)): if i == j: continue - assert checkers['finishes'].check(common_end[i], common_end[j]) or \ - checkers['finished_by'].check(common_end[i], common_end[j]) + assert checkers["finishes"].check( + common_end[i], common_end[j] + ) or checkers["finished_by"].check(common_end[i], common_end[j]) class TestAdvancedScenarios: """Tests for edge cases and combined usage scenarios""" - def test_multiple_checkers_combination(self, mock_values, checkers, create_interval): + def test_multiple_checkers_combination( + self, mock_values, checkers, create_interval + ): """Test that multiple relationship checkers can be used together to validate different relations""" v1, v2, v3, v4, v5 = mock_values @@ -748,15 +778,17 @@ def test_multiple_checkers_combination(self, mock_values, checkers, create_inter interval_d = create_interval(v2, v4) # D overlaps with A and B # Test different combinations - assert checkers['meets'].check(interval_a, interval_b) is True - assert checkers['started_by'].check(interval_c, interval_a) is True + assert checkers["meets"].check(interval_a, interval_b) is True + assert checkers["started_by"].check(interval_c, interval_a) is True # Since interval_c and interval_b share an end point, the relation is 'finished_by' not 'contains' - assert checkers['finished_by'].check(interval_c, interval_b) is True + assert checkers["finished_by"].check(interval_c, interval_b) is True - assert checkers['overlaps'].check(interval_a, interval_d) is True + assert checkers["overlaps"].check(interval_a, interval_d) is True - def test_edge_case_zero_length_intervals_before_after(self, mock_values, checkers, create_interval): + def test_edge_case_zero_length_intervals_before_after( + self, mock_values, checkers, create_interval + ): """Test before/after relationships with zero-length intervals""" v1, v2, v3 = mock_values[:3] @@ -766,11 +798,13 @@ def test_edge_case_zero_length_intervals_before_after(self, mock_values, checker normal_interval = create_interval(v2, v3) # Test before/after with zero-length intervals - assert checkers['before'].check(zero_interval_1, zero_interval_2) is True - assert checkers['after'].check(zero_interval_2, zero_interval_1) is True - assert checkers['before'].check(zero_interval_1, normal_interval) is True + assert checkers["before"].check(zero_interval_1, zero_interval_2) is True + assert checkers["after"].check(zero_interval_2, zero_interval_1) is True + assert checkers["before"].check(zero_interval_1, normal_interval) is True - def test_all_relations_are_mutually_exclusive(self, mock_values, checkers, create_interval): + def test_all_relations_are_mutually_exclusive( + self, mock_values, checkers, create_interval + ): """Test that Allen's interval relations are mutually exclusive""" v1, v2, v3, v4, v5 = mock_values @@ -784,11 +818,19 @@ def test_all_relations_are_mutually_exclusive(self, mock_values, checkers, creat # Get all the relation checkers relation_checkers = [ - checkers['before'], checkers['meets'], checkers['overlaps'], - checkers['starts'], checkers['during'], checkers['finishes'], - checkers['equals'], checkers['contains'], checkers['started_by'], - checkers['finished_by'], checkers['overlapped_by'], checkers['met_by'], - checkers['after'] + checkers["before"], + checkers["meets"], + checkers["overlaps"], + checkers["starts"], + checkers["during"], + checkers["finishes"], + checkers["equals"], + checkers["contains"], + checkers["started_by"], + checkers["finished_by"], + checkers["overlapped_by"], + checkers["met_by"], + checkers["after"], ] # Each interval should have a relationship with the reference interval @@ -797,14 +839,15 @@ def test_all_relations_are_mutually_exclusive(self, mock_values, checkers, creat "meets": interval_meets, "during": interval_during, "finishes": interval_finishes, - "equals": interval_equals + "equals": interval_equals, } for name, interval in test_intervals.items(): # For equals, we expect exactly one relation (equals itself) if name == "equals": - assert checkers['equals'].check(interval, interval_reference), \ - f"Equal intervals should have equals relation" + assert checkers["equals"].check( + interval, interval_reference + ), f"Equal intervals should have equals relation" # For others, we should find at least one valid relation else: found_relation = False @@ -827,15 +870,15 @@ def test_both_zero_length_intervals(self, mock_values, checkers, create_interval point_d = create_interval(v3, v3) # Point-to-point relations - assert checkers['before'].check(point_a, point_b) is True - assert checkers['after'].check(point_b, point_a) is True - assert checkers['equals'].check(point_c, point_d) is True + assert checkers["before"].check(point_a, point_b) is True + assert checkers["after"].check(point_b, point_a) is True + assert checkers["equals"].check(point_c, point_d) is True # A point can't meet, overlap, contain, or be during another point - assert checkers['meets'].check(point_a, point_b) is False - assert checkers['overlaps'].check(point_a, point_b) is False - assert checkers['contains'].check(point_a, point_b) is False - assert checkers['during'].check(point_a, point_b) is False + assert checkers["meets"].check(point_a, point_b) is False + assert checkers["overlaps"].check(point_a, point_b) is False + assert checkers["contains"].check(point_a, point_b) is False + assert checkers["during"].check(point_a, point_b) is False def test_multiple_consecutive_points(self, mock_values, checkers, create_interval): """Test multiple consecutive zero-length intervals""" @@ -847,14 +890,14 @@ def test_multiple_consecutive_points(self, mock_values, checkers, create_interva point_c = create_interval(v3, v3) # Each point should be before the next - assert checkers['before'].check(point_a, point_b) is True - assert checkers['before'].check(point_b, point_c) is True - assert checkers['before'].check(point_a, point_c) is True + assert checkers["before"].check(point_a, point_b) is True + assert checkers["before"].check(point_b, point_c) is True + assert checkers["before"].check(point_a, point_c) is True # Transitivity test for after relation - assert checkers['after'].check(point_c, point_b) is True - assert checkers['after'].check(point_b, point_a) is True - assert checkers['after'].check(point_c, point_a) is True + assert checkers["after"].check(point_c, point_b) is True + assert checkers["after"].check(point_b, point_a) is True + assert checkers["after"].check(point_c, point_a) is True def test_inverse_relations(self, mock_values, checkers, create_interval): """Test that inverse relations work correctly""" @@ -869,24 +912,28 @@ def test_inverse_relations(self, mock_values, checkers, create_interval): # Test inverse pairs # If A before B, then B after A - assert checkers['before'].check(interval_a, interval_b) is True - assert checkers['after'].check(interval_b, interval_a) is True + assert checkers["before"].check(interval_a, interval_b) is True + assert checkers["after"].check(interval_b, interval_a) is True # If A meets B, then B met by A - assert checkers['meets'].check(interval_a, interval_c) is True - assert checkers['met_by'].check(interval_c, interval_a) is True + assert checkers["meets"].check(interval_a, interval_c) is True + assert checkers["met_by"].check(interval_c, interval_a) is True # If A overlaps B, then B overlapped by A - assert checkers['overlaps'].check(interval_d, interval_e) is True - assert checkers['overlapped_by'].check(interval_e, interval_d) is True + assert checkers["overlaps"].check(interval_d, interval_e) is True + assert checkers["overlapped_by"].check(interval_e, interval_d) is True # If A starts B, then B started by A - assert checkers['starts'].check(interval_a, interval_d) is True - assert checkers['started_by'].check(interval_d, interval_a) is True + assert checkers["starts"].check(interval_a, interval_d) is True + assert checkers["started_by"].check(interval_d, interval_a) is True # If A finishes B, then B finished by A - assert checkers['finishes'].check(interval_a, interval_e) is False # Not true for this example - assert checkers['finished_by'].check(interval_e, interval_a) is False # Should match above + assert ( + checkers["finishes"].check(interval_a, interval_e) is False + ) # Not true for this example + assert ( + checkers["finished_by"].check(interval_e, interval_a) is False + ) # Should match above def test_transitivity_properties(self, mock_values, checkers, create_interval): """Test transitivity properties of certain relations""" @@ -898,18 +945,18 @@ def test_transitivity_properties(self, mock_values, checkers, create_interval): b_before_c = create_interval(v3, v4) # [v3, v4] c_interval = create_interval(v5, v5) # [v5, v5] - assert checkers['before'].check(a_before_b, b_before_c) is True - assert checkers['before'].check(b_before_c, c_interval) is True - assert checkers['before'].check(a_before_b, c_interval) is True + assert checkers["before"].check(a_before_b, b_before_c) is True + assert checkers["before"].check(b_before_c, c_interval) is True + assert checkers["before"].check(a_before_b, c_interval) is True # During relation has a different transitivity property outer = create_interval(v1, v5) # [v1, v5] middle = create_interval(v2, v4) # [v2, v4] inner = create_interval(v3, v3) # [v3, v3] - assert checkers['contains'].check(outer, middle) is True - assert checkers['contains'].check(middle, inner) is True - assert checkers['contains'].check(outer, inner) is True + assert checkers["contains"].check(outer, middle) is True + assert checkers["contains"].check(middle, inner) is True + assert checkers["contains"].check(outer, inner) is True def test_complex_boundary_cases(self, mock_values, checkers, create_interval): """Test complex combinations of shared boundary points""" @@ -918,18 +965,20 @@ def test_complex_boundary_cases(self, mock_values, checkers, create_interval): # Intervals that share both start and end but aren't equal interval_a = create_interval(v1, v3) interval_b = create_interval(v1, v3) - assert checkers['equals'].check(interval_a, interval_b) is True + assert checkers["equals"].check(interval_a, interval_b) is True # One interval starts where another ends, and a third overlaps both interval_c = create_interval(v1, v2) interval_d = create_interval(v2, v3) interval_e = create_interval(v1, v3) - assert checkers['meets'].check(interval_c, interval_d) is True - assert checkers['started_by'].check(interval_e, interval_c) is True - assert checkers['finished_by'].check(interval_e, interval_d) is True + assert checkers["meets"].check(interval_c, interval_d) is True + assert checkers["started_by"].check(interval_e, interval_c) is True + assert checkers["finished_by"].check(interval_e, interval_d) is True - def test_mutual_exclusivity_regular_intervals(self, mock_values, checkers, create_interval): + def test_mutual_exclusivity_regular_intervals( + self, mock_values, checkers, create_interval + ): """Comprehensive test that Allen's interval relations are mutually exclusive""" v1, v2, v3, v4, v5 = mock_values @@ -943,19 +992,35 @@ def test_mutual_exclusivity_regular_intervals(self, mock_values, checkers, creat "during": create_interval(v3, v3), # Strictly inside reference "finishes": create_interval(v3, v4), # Ends with reference "equals": create_interval(v2, v4), # Same as reference - "finished_by": create_interval(v1, v4), # Contains and shares end with reference + "finished_by": create_interval( + v1, v4 + ), # Contains and shares end with reference "contains": create_interval(v1, v5), # Strictly contains reference - "started_by": create_interval(v2, v5), # Shares start and extends beyond reference - "overlapped_by": create_interval(v3, v5), # Reference overlaps start of this interval + "started_by": create_interval( + v2, v5 + ), # Shares start and extends beyond reference + "overlapped_by": create_interval( + v3, v5 + ), # Reference overlaps start of this interval "met_by": create_interval(v4, v5), # Starts exactly where reference ends - "after": create_interval(v5, v5) # Strictly after reference + "after": create_interval(v5, v5), # Strictly after reference } # List of all checkers all_checkers = [ - 'before', 'meets', 'overlaps', 'starts', 'during', 'finishes', - 'equals', 'finished_by', 'contains', 'started_by', 'overlapped_by', - 'met_by', 'after' + "before", + "meets", + "overlaps", + "starts", + "during", + "finishes", + "equals", + "finished_by", + "contains", + "started_by", + "overlapped_by", + "met_by", + "after", ] # For each interval, exactly one relation should be true with the reference @@ -972,10 +1037,12 @@ def test_mutual_exclusivity_regular_intervals(self, mock_values, checkers, creat # Only one relation should be true (if name is unique) if name in all_checkers: # Skip if the name isn't a valid relation expected_relation = name - assert len(true_relations) == 1, \ - f"Expected 1 relation, got {len(true_relations)}: {true_relations}" - assert true_relations[0] == expected_relation, \ - f"Expected {expected_relation}, got {true_relations[0]}" + assert ( + len(true_relations) == 1 + ), f"Expected 1 relation, got {len(true_relations)}: {true_relations}" + assert ( + true_relations[0] == expected_relation + ), f"Expected {expected_relation}, got {true_relations[0]}" class TestMetricsEdgeCases: @@ -984,7 +1051,7 @@ class TestMetricsEdgeCases: def test_different_metric_fields(self, mock_values, checkers, create_interval): """Test intervals with different sets of metric fields""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Different fields but overlapping intervals interval_a = create_interval(v1, v3, {"product": "A"}) @@ -999,7 +1066,7 @@ def test_different_metric_fields(self, mock_values, checkers, create_interval): def test_case_sensitivity_in_metrics(self, mock_values, checkers, create_interval): """Test case sensitivity in metric values""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Case difference in values interval_a = create_interval(v1, v3, {"product": "product_a"}) @@ -1014,7 +1081,7 @@ def test_case_sensitivity_in_metrics(self, mock_values, checkers, create_interva def test_empty_metrics(self, mock_values, checkers, create_interval): """Test intervals with empty metric dictionaries""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Both intervals have empty metrics interval_a = create_interval(v1, v3, {}) @@ -1025,10 +1092,12 @@ def test_empty_metrics(self, mock_values, checkers, create_interval): interval_c = create_interval(v1, v3, {"product": "A"}) assert metrics_checker.check(interval_c, interval_b) is False - def test_case_sensitivity_in_metric_names(self, mock_values, checkers, create_interval): + def test_case_sensitivity_in_metric_names( + self, mock_values, checkers, create_interval + ): """Test case sensitivity in metric field names""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Case difference in field names interval_a = create_interval(v1, v3, {"Product": "A"}) @@ -1040,10 +1109,12 @@ def test_case_sensitivity_in_metric_names(self, mock_values, checkers, create_in interval_d = create_interval(v2, v4, {"Product": "A"}) assert metrics_checker.check(interval_c, interval_d) is True - def test_special_characters_in_metrics(self, mock_values, checkers, create_interval): + def test_special_characters_in_metrics( + self, mock_values, checkers, create_interval + ): """Test metrics with special characters and whitespace""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Special characters in field names interval_a = create_interval(v1, v3, {"product-id": "A123"}) @@ -1073,7 +1144,7 @@ def test_special_characters_in_metrics(self, mock_values, checkers, create_inter def test_metric_name_variations(self, mock_values, checkers, create_interval): """Test various metric name formats and variations""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Unicode characters in field names interval_a = create_interval(v1, v3, {"prøduct": "A"}) @@ -1097,7 +1168,7 @@ def test_metric_name_variations(self, mock_values, checkers, create_interval): def test_edge_case_metric_values(self, mock_values, checkers, create_interval): """Test edge cases in metric values""" v1, v2, v3, v4 = mock_values[:4] - metrics_checker = checkers['metrics_equivalent'] + metrics_checker = checkers["metrics_equivalent"] # Empty strings as values interval_a = create_interval(v1, v3, {"product": ""}) @@ -1145,23 +1216,29 @@ def test_calendar_event_overlaps(self, mock_values, checkers, create_interval): all_day_meeting = create_interval(v1, v5) # Test scheduling conflicts - assert checkers['meets'].check(meeting_9am_10am, meeting_10am_11am) is True - assert checkers['meets'].check(meeting_10am_11am, lunch_12pm_1pm) is True # Fix for original assertion + assert checkers["meets"].check(meeting_9am_10am, meeting_10am_11am) is True + assert ( + checkers["meets"].check(meeting_10am_11am, lunch_12pm_1pm) is True + ) # Fix for original assertion # Test 'before' relationship (no shared endpoints) - assert checkers['before'].check(meeting_9am_10am, lunch_12pm_1pm) is True # 9am-10am is before 12pm-1pm - assert checkers['before'].check(meeting_10am_11am, afternoon_meeting) is True # 10am-11am is before 1pm-2pm + assert ( + checkers["before"].check(meeting_9am_10am, lunch_12pm_1pm) is True + ) # 9am-10am is before 12pm-1pm + assert ( + checkers["before"].check(meeting_10am_11am, afternoon_meeting) is True + ) # 10am-11am is before 1pm-2pm # Test relationships between all_day_meeting and other meetings # In Allen's algebra, since all_day_meeting starts at the same time as meeting_9am_10am, # this is a 'started by' relationship, not 'contains' - assert checkers['started_by'].check(all_day_meeting, meeting_9am_10am) is True + assert checkers["started_by"].check(all_day_meeting, meeting_9am_10am) is True # all_day_meeting properly contains lunch_12pm_1pm (no shared endpoints) - assert checkers['contains'].check(all_day_meeting, lunch_12pm_1pm) is True + assert checkers["contains"].check(all_day_meeting, lunch_12pm_1pm) is True # No conflict between these meetings - assert checkers['overlaps'].check(meeting_9am_10am, lunch_12pm_1pm) is False + assert checkers["overlaps"].check(meeting_9am_10am, lunch_12pm_1pm) is False def test_process_monitoring_intervals(self, mock_values, checkers, create_interval): """Test interval relations in a process monitoring context""" @@ -1173,22 +1250,29 @@ def test_process_monitoring_intervals(self, mock_values, checkers, create_interv outage = create_interval(v3, v4, {"system": "Server A"}) # Test monitoring scenarios - assert checkers['during'].check(maintenance_window, system_uptime) is True - assert checkers['during'].check(outage, system_uptime) is True - assert checkers['meets'].check(maintenance_window, outage) is True + assert checkers["during"].check(maintenance_window, system_uptime) is True + assert checkers["during"].check(outage, system_uptime) is True + assert checkers["meets"].check(maintenance_window, outage) is True # Test metric equivalence - assert checkers['metrics_equivalent'].check(system_uptime, maintenance_window) is True + assert ( + checkers["metrics_equivalent"].check(system_uptime, maintenance_window) + is True + ) # Different system shouldn't match other_system = create_interval(v2, v4, {"system": "Server B"}) - assert checkers['metrics_equivalent'].check(system_uptime, other_system) is False + assert ( + checkers["metrics_equivalent"].check(system_uptime, other_system) is False + ) class TestMutualExclusivity: """Comprehensive tests to ensure Allen's interval relations are mutually exclusive""" - def test_comprehensive_mutual_exclusivity(self, mock_values, checkers, create_interval): + def test_comprehensive_mutual_exclusivity( + self, mock_values, checkers, create_interval + ): """ Test that for any pair of intervals, exactly one relation checker returns true. This is a fundamental property of Allen's interval algebra. @@ -1215,9 +1299,19 @@ def test_comprehensive_mutual_exclusivity(self, mock_values, checkers, create_in # List of all Allen's relation checkers relation_checkers = [ - 'before', 'meets', 'overlaps', 'starts', 'during', 'finishes', - 'equals', 'finished_by', 'contains', 'started_by', 'overlapped_by', - 'met_by', 'after' + "before", + "meets", + "overlaps", + "starts", + "during", + "finishes", + "equals", + "finished_by", + "contains", + "started_by", + "overlapped_by", + "met_by", + "after", ] # For each pair of intervals in the list @@ -1232,8 +1326,9 @@ def test_comprehensive_mutual_exclusivity(self, mock_values, checkers, create_in # There should be exactly one true relation between any pair of intervals # (except for the metrics_equivalent checker which isn't part of Allen's algebra) - assert len(true_relations) == 1, \ - f"Intervals {name_a} and {name_b} have {len(true_relations)} true relations: {true_relations}. Expected exactly 1." + assert ( + len(true_relations) == 1 + ), f"Intervals {name_a} and {name_b} have {len(true_relations)} true relations: {true_relations}. Expected exactly 1." def test_relation_pairs_symmetry(self, mock_values, checkers, create_interval): """ @@ -1244,19 +1339,19 @@ def test_relation_pairs_symmetry(self, mock_values, checkers, create_interval): # Define the inverse relationships inverse_relations = { - 'before': 'after', - 'meets': 'met_by', - 'overlaps': 'overlapped_by', - 'starts': 'started_by', - 'during': 'contains', - 'finishes': 'finished_by', - 'equals': 'equals', # Self-inverse - 'finished_by': 'finishes', - 'contains': 'during', - 'started_by': 'starts', - 'overlapped_by': 'overlaps', - 'met_by': 'meets', - 'after': 'before' + "before": "after", + "meets": "met_by", + "overlaps": "overlapped_by", + "starts": "started_by", + "during": "contains", + "finishes": "finished_by", + "equals": "equals", # Self-inverse + "finished_by": "finishes", + "contains": "during", + "started_by": "starts", + "overlapped_by": "overlaps", + "met_by": "meets", + "after": "before", } # Create a representative set of intervals @@ -1288,8 +1383,9 @@ def test_relation_pairs_symmetry(self, mock_values, checkers, create_interval): expected_relation_b_to_a = inverse_relations[relation_a_to_b] # Check if the inverse relation is true from B to A - assert checkers[expected_relation_b_to_a].check(interval_b, interval_a), \ - f"If {name_a} {relation_a_to_b} {name_b}, then {name_b} should {expected_relation_b_to_a} {name_a}" + assert checkers[expected_relation_b_to_a].check( + interval_b, interval_a + ), f"If {name_a} {relation_a_to_b} {name_b}, then {name_b} should {expected_relation_b_to_a} {name_a}" def test_identity_relations(self, mock_values, checkers, create_interval): """ @@ -1309,19 +1405,31 @@ def test_identity_relations(self, mock_values, checkers, create_interval): ] non_equals_relations = [ - 'before', 'meets', 'overlaps', 'starts', 'during', 'finishes', - 'contains', 'started_by', 'finished_by', 'overlapped_by', 'met_by', 'after' + "before", + "meets", + "overlaps", + "starts", + "during", + "finishes", + "contains", + "started_by", + "finished_by", + "overlapped_by", + "met_by", + "after", ] for interval in test_intervals: # An interval should equal itself - assert checkers['equals'].check(interval, interval), \ - "Every interval should equal itself" + assert checkers["equals"].check( + interval, interval + ), "Every interval should equal itself" # No other relation should apply to identical intervals for relation in non_equals_relations: - assert not checkers[relation].check(interval, interval), \ - f"Relation '{relation}' should not apply to identical intervals" + assert not checkers[relation].check( + interval, interval + ), f"Relation '{relation}' should not apply to identical intervals" def test_transitivity_properties(self, mock_values, checkers, create_interval): """Test the transitivity properties of selected Allen's relations""" @@ -1332,42 +1440,48 @@ def test_transitivity_properties(self, mock_values, checkers, create_interval): b_before = create_interval(v3, v4) c_before = create_interval(v5, v5) - assert checkers['before'].check(a_before, b_before) - assert checkers['before'].check(b_before, c_before) - assert checkers['before'].check(a_before, c_before), \ - "Before relation should be transitive" + assert checkers["before"].check(a_before, b_before) + assert checkers["before"].check(b_before, c_before) + assert checkers["before"].check( + a_before, c_before + ), "Before relation should be transitive" # Test after transitivity: if A after B and B after C then A after C a_after = create_interval(v5, v5) b_after = create_interval(v3, v4) c_after = create_interval(v1, v2) - assert checkers['after'].check(a_after, b_after) - assert checkers['after'].check(b_after, c_after) - assert checkers['after'].check(a_after, c_after), \ - "After relation should be transitive" + assert checkers["after"].check(a_after, b_after) + assert checkers["after"].check(b_after, c_after) + assert checkers["after"].check( + a_after, c_after + ), "After relation should be transitive" # Test during transitivity: if A during B and B during C then A during C c_during = create_interval(v1, v5) b_during = create_interval(v2, v4) a_during = create_interval(v3, v3) - assert checkers['during'].check(a_during, b_during) - assert checkers['during'].check(b_during, c_during) - assert checkers['during'].check(a_during, c_during), \ - "During relation should be transitive" + assert checkers["during"].check(a_during, b_during) + assert checkers["during"].check(b_during, c_during) + assert checkers["during"].check( + a_during, c_during + ), "During relation should be transitive" # Test contains transitivity: if A contains B and B contains C then A contains C a_contains = create_interval(v1, v5) b_contains = create_interval(v2, v4) c_contains = create_interval(v3, v3) - assert checkers['contains'].check(a_contains, b_contains) - assert checkers['contains'].check(b_contains, c_contains) - assert checkers['contains'].check(a_contains, c_contains), \ - "Contains relation should be transitive" + assert checkers["contains"].check(a_contains, b_contains) + assert checkers["contains"].check(b_contains, c_contains) + assert checkers["contains"].check( + a_contains, c_contains + ), "Contains relation should be transitive" - def test_exhaustive_pairwise_relations(self, mock_values, checkers, create_interval): + def test_exhaustive_pairwise_relations( + self, mock_values, checkers, create_interval + ): """ Test every possible pair of intervals with every relation checker to ensure only one returns true. This includes regular intervals but excludes point intervals. @@ -1379,12 +1493,24 @@ def test_exhaustive_pairwise_relations(self, mock_values, checkers, create_inter for end_idx in range(start_idx + 1, 6): # Using start+1 to 5 as end indices start_val = mock_values[start_idx - 1] end_val = mock_values[end_idx - 1] - intervals.append((f"v{start_idx}_v{end_idx}", create_interval(start_val, end_val))) + intervals.append( + (f"v{start_idx}_v{end_idx}", create_interval(start_val, end_val)) + ) relation_checkers = [ - 'before', 'meets', 'overlaps', 'starts', 'during', 'finishes', - 'equals', 'finished_by', 'contains', 'started_by', 'overlapped_by', - 'met_by', 'after' + "before", + "meets", + "overlaps", + "starts", + "during", + "finishes", + "equals", + "finished_by", + "contains", + "started_by", + "overlapped_by", + "met_by", + "after", ] # Test all pairs @@ -1396,8 +1522,9 @@ def test_exhaustive_pairwise_relations(self, mock_values, checkers, create_inter if checkers[relation].check(interval_a, interval_b): true_relations.append(relation) - assert len(true_relations) == 1, \ - f"Intervals {name_a} and {name_b} have {len(true_relations)} true relations: {true_relations}. Expected exactly 1." + assert ( + len(true_relations) == 1 + ), f"Intervals {name_a} and {name_b} have {len(true_relations)} true relations: {true_relations}. Expected exactly 1." def test_shared_endpoints_exhaustive(self, mock_values, checkers, create_interval): """ @@ -1408,38 +1535,49 @@ def test_shared_endpoints_exhaustive(self, mock_values, checkers, create_interva # Create pairs of intervals with specific shared endpoints shared_endpoint_pairs = [ # Both start at same point, different ends - (create_interval(v1, v2), create_interval(v1, v3), 'starts', 'started_by'), - + (create_interval(v1, v2), create_interval(v1, v3), "starts", "started_by"), # A finishes B: A ends at same point as B but starts after B - (create_interval(v2, v3), create_interval(v1, v3), 'finishes', 'finished_by'), - + ( + create_interval(v2, v3), + create_interval(v1, v3), + "finishes", + "finished_by", + ), # End of first equals start of second (meets/met_by) - (create_interval(v1, v2), create_interval(v2, v3), 'meets', 'met_by'), - + (create_interval(v1, v2), create_interval(v2, v3), "meets", "met_by"), # Identical intervals (equals) - (create_interval(v1, v3), create_interval(v1, v3), 'equals', 'equals'), + (create_interval(v1, v3), create_interval(v1, v3), "equals", "equals"), ] - for interval_a, interval_b, relation_a_to_b, relation_b_to_a in shared_endpoint_pairs: + for ( + interval_a, + interval_b, + relation_a_to_b, + relation_b_to_a, + ) in shared_endpoint_pairs: # Test forward relation - assert checkers[relation_a_to_b].check(interval_a, interval_b), \ - f"Expected {relation_a_to_b} to be true from A to B" + assert checkers[relation_a_to_b].check( + interval_a, interval_b + ), f"Expected {relation_a_to_b} to be true from A to B" # Test inverse relation - assert checkers[relation_b_to_a].check(interval_b, interval_a), \ - f"Expected {relation_b_to_a} to be true from B to A" + assert checkers[relation_b_to_a].check( + interval_b, interval_a + ), f"Expected {relation_b_to_a} to be true from B to A" # Check that all other relations are false (A to B) for relation in checkers: - if relation != relation_a_to_b and relation != 'metrics_equivalent': - assert not checkers[relation].check(interval_a, interval_b), \ - f"Relation {relation} should be false for A to B" + if relation != relation_a_to_b and relation != "metrics_equivalent": + assert not checkers[relation].check( + interval_a, interval_b + ), f"Relation {relation} should be false for A to B" # Check that all other relations are false (B to A) for relation in checkers: - if relation != relation_b_to_a and relation != 'metrics_equivalent': - assert not checkers[relation].check(interval_b, interval_a), \ - f"Relation {relation} should be false for B to A" + if relation != relation_b_to_a and relation != "metrics_equivalent": + assert not checkers[relation].check( + interval_b, interval_a + ), f"Relation {relation} should be false for B to A" class TestTransitivityProperties: @@ -1454,19 +1592,22 @@ def test_before_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v2, v2) # Point at v2 interval_c = create_interval(v3, v3) # Point at v3 - assert checkers['before'].check(interval_a, interval_b) - assert checkers['before'].check(interval_b, interval_c) - assert checkers['before'].check(interval_a, interval_c), "Before relation should be transitive" + assert checkers["before"].check(interval_a, interval_b) + assert checkers["before"].check(interval_b, interval_c) + assert checkers["before"].check( + interval_a, interval_c + ), "Before relation should be transitive" # Test with non-point intervals too interval_d = create_interval(v1, v2) interval_e = create_interval(v3, v4) interval_f = create_interval(v5, v5) - assert checkers['before'].check(interval_d, interval_e) - assert checkers['before'].check(interval_e, interval_f) - assert checkers['before'].check(interval_d, - interval_f), "Before relation should be transitive for non-point intervals" + assert checkers["before"].check(interval_d, interval_e) + assert checkers["before"].check(interval_e, interval_f) + assert checkers["before"].check( + interval_d, interval_f + ), "Before relation should be transitive for non-point intervals" def test_after_transitivity(self, mock_values, checkers, create_interval): """Test the transitivity of the 'after' relation""" @@ -1477,9 +1618,11 @@ def test_after_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v3, v3) interval_c = create_interval(v1, v1) - assert checkers['after'].check(interval_a, interval_b) - assert checkers['after'].check(interval_b, interval_c) - assert checkers['after'].check(interval_a, interval_c), "After relation should be transitive" + assert checkers["after"].check(interval_a, interval_b) + assert checkers["after"].check(interval_b, interval_c) + assert checkers["after"].check( + interval_a, interval_c + ), "After relation should be transitive" def test_during_transitivity(self, mock_values, checkers, create_interval): """Test the transitivity of the 'during' relation""" @@ -1490,9 +1633,11 @@ def test_during_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v2, v4) interval_c = create_interval(v1, v5) - assert checkers['during'].check(interval_a, interval_b) - assert checkers['during'].check(interval_b, interval_c) - assert checkers['during'].check(interval_a, interval_c), "During relation should be transitive" + assert checkers["during"].check(interval_a, interval_b) + assert checkers["during"].check(interval_b, interval_c) + assert checkers["during"].check( + interval_a, interval_c + ), "During relation should be transitive" def test_contains_transitivity(self, mock_values, checkers, create_interval): """Test the transitivity of the 'contains' relation""" @@ -1503,9 +1648,11 @@ def test_contains_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v2, v4) interval_c = create_interval(v3, v3) - assert checkers['contains'].check(interval_a, interval_b) - assert checkers['contains'].check(interval_b, interval_c) - assert checkers['contains'].check(interval_a, interval_c), "Contains relation should be transitive" + assert checkers["contains"].check(interval_a, interval_b) + assert checkers["contains"].check(interval_b, interval_c) + assert checkers["contains"].check( + interval_a, interval_c + ), "Contains relation should be transitive" def test_equals_transitivity(self, mock_values, checkers, create_interval): """Test the transitivity of the 'equals' relation""" @@ -1517,9 +1664,11 @@ def test_equals_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v1, v2) interval_c = create_interval(v1, v2) - assert checkers['equals'].check(interval_a, interval_b) - assert checkers['equals'].check(interval_b, interval_c) - assert checkers['equals'].check(interval_a, interval_c), "Equals relation should be transitive" + assert checkers["equals"].check(interval_a, interval_b) + assert checkers["equals"].check(interval_b, interval_c) + assert checkers["equals"].check( + interval_a, interval_c + ), "Equals relation should be transitive" def test_starts_transitivity(self, mock_values, checkers, create_interval): """Test that 'starts' relation IS transitive""" @@ -1530,10 +1679,12 @@ def test_starts_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v1, v3) interval_c = create_interval(v1, v4) - assert checkers['starts'].check(interval_a, interval_b) - assert checkers['starts'].check(interval_b, interval_c) + assert checkers["starts"].check(interval_a, interval_b) + assert checkers["starts"].check(interval_b, interval_c) # A also starts C (they share the same start point and A ends before C ends) - assert checkers['starts'].check(interval_a, interval_c), "The 'starts' relation is transitive" + assert checkers["starts"].check( + interval_a, interval_c + ), "The 'starts' relation is transitive" def test_finishes_transitivity(self, mock_values, checkers, create_interval): """Test that 'finishes' relation IS transitive""" @@ -1544,10 +1695,12 @@ def test_finishes_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v2, v4) interval_c = create_interval(v1, v4) - assert checkers['finishes'].check(interval_a, interval_b) - assert checkers['finishes'].check(interval_b, interval_c) + assert checkers["finishes"].check(interval_a, interval_b) + assert checkers["finishes"].check(interval_b, interval_c) # A also finishes C (they share the same end point and A starts after C starts) - assert checkers['finishes'].check(interval_a, interval_c), "The 'finishes' relation is transitive" + assert checkers["finishes"].check( + interval_a, interval_c + ), "The 'finishes' relation is transitive" def test_meets_non_transitivity(self, mock_values, checkers, create_interval): """Test that 'meets' relation is NOT transitive""" @@ -1558,11 +1711,13 @@ def test_meets_non_transitivity(self, mock_values, checkers, create_interval): interval_b = create_interval(v2, v3) interval_c = create_interval(v3, v4) - assert checkers['meets'].check(interval_a, interval_b) - assert checkers['meets'].check(interval_b, interval_c) + assert checkers["meets"].check(interval_a, interval_b) + assert checkers["meets"].check(interval_b, interval_c) # A should be before C, not meets - assert not checkers['meets'].check(interval_a, interval_c) - assert checkers['before'].check(interval_a, interval_c), "When A meets B and B meets C, A is before C" + assert not checkers["meets"].check(interval_a, interval_c) + assert checkers["before"].check( + interval_a, interval_c + ), "When A meets B and B meets C, A is before C" def test_overlaps_non_transitivity(self, mock_values, checkers, create_interval): """Test that 'overlaps' relation is NOT transitive""" @@ -1573,11 +1728,13 @@ def test_overlaps_non_transitivity(self, mock_values, checkers, create_interval) interval_b = create_interval(v2, v4) interval_c = create_interval(v3, v5) - assert checkers['overlaps'].check(interval_a, interval_b) - assert checkers['overlaps'].check(interval_b, interval_c) + assert checkers["overlaps"].check(interval_a, interval_b) + assert checkers["overlaps"].check(interval_b, interval_c) # A meets C but doesn't overlap it - assert not checkers['overlaps'].check(interval_a, interval_c) - assert checkers['meets'].check(interval_a, interval_c), "When A overlaps B and B overlaps C, A may meet C" + assert not checkers["overlaps"].check(interval_a, interval_c) + assert checkers["meets"].check( + interval_a, interval_c + ), "When A overlaps B and B overlaps C, A may meet C" def test_mixed_transitivity_chains(self, mock_values, checkers, create_interval): """Test transitivity across different relation types""" @@ -1588,27 +1745,31 @@ def test_mixed_transitivity_chains(self, mock_values, checkers, create_interval) interval_b = create_interval(v3, v4) interval_c = create_interval(v5, v5) - assert checkers['before'].check(interval_a, interval_b) - assert checkers['before'].check(interval_b, interval_c) - assert checkers['before'].check(interval_a, interval_c) + assert checkers["before"].check(interval_a, interval_b) + assert checkers["before"].check(interval_b, interval_c) + assert checkers["before"].check(interval_a, interval_c) # If A meets B and B before C, then A before C interval_d = create_interval(v1, v3) interval_e = create_interval(v3, v4) interval_f = create_interval(v5, v5) - assert checkers['meets'].check(interval_d, interval_e) - assert checkers['before'].check(interval_e, interval_f) - assert checkers['before'].check(interval_d, interval_f), "If A meets B and B before C, then A before C" + assert checkers["meets"].check(interval_d, interval_e) + assert checkers["before"].check(interval_e, interval_f) + assert checkers["before"].check( + interval_d, interval_f + ), "If A meets B and B before C, then A before C" # If A during B and B during C, then A during C interval_g = create_interval(v3, v3) interval_h = create_interval(v2, v4) interval_i = create_interval(v1, v5) - assert checkers['during'].check(interval_g, interval_h) - assert checkers['during'].check(interval_h, interval_i) - assert checkers['during'].check(interval_g, interval_i), "If A during B and B during C, then A during C" + assert checkers["during"].check(interval_g, interval_h) + assert checkers["during"].check(interval_h, interval_i) + assert checkers["during"].check( + interval_g, interval_i + ), "If A during B and B during C, then A during C" class TestStillValidLegacy: @@ -1616,20 +1777,14 @@ def test_interval_starts_with_other_shorter_duration(self): """Test case where both intervals start together but first interval ends earlier""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T00:00:01" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T00:00:01"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T00:00:02" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T00:00:02"}), "start", - "end" + "end", ) # Act & Assert @@ -1644,20 +1799,14 @@ def test_interval_does_not_end_before_other_when_ends_same(self): """Test that interval doesn't end before other when they have the same end time""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T00:00:01" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T00:00:01"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T00:00:01" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T00:00:01"}), "start", - "end" + "end", ) # Act & Assert @@ -1670,27 +1819,23 @@ def test_interval_is_started_by_other(self): """Test case where both intervals start together but first interval ends later""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:01", - "end": "2023-01-01T00:00:03" - }), + pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:03"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:01", - "end": "2023-01-01T00:00:02" - }), + pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:02"}), "start", - "end" + "end", ) # Act & Assert assert StartedByChecker().check(interval, other) # Verify this is exclusively a STARTED_BY relationship - assert not StartsChecker().check(interval, other) # Important! This is the inverse relationship + assert not StartsChecker().check( + interval, other + ) # Important! This is the inverse relationship assert not EqualsChecker().check(interval, other) assert not ContainsChecker().check(interval, other) @@ -1698,27 +1843,23 @@ def test_interval_contains_other(self): """Test case where first interval completely contains the second interval""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T03:00:00" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T03:00:00"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T01:00:00", - "end": "2023-01-01T02:00:00" - }), + pd.Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T02:00:00"}), "start", - "end" + "end", ) # Act & Assert assert ContainsChecker().check(interval, other) # Verify this is exclusively a CONTAINS relationship - assert not DuringChecker().check(interval, other) # Important! This is the inverse relationship + assert not DuringChecker().check( + interval, other + ) # Important! This is the inverse relationship assert not EqualsChecker().check(interval, other) assert not OverlapsChecker().check(interval, other) assert not StartsChecker().check(interval, other) @@ -1727,20 +1868,14 @@ def test_interval_overlaps_but_not_contained(self): """Test case where first interval overlaps start of second interval but isn't contained by it""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T01:30:00" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T01:30:00"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T01:00:00", - "end": "2023-01-01T03:00:00" - }), + pd.Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T03:00:00"}), "start", - "end" + "end", ) # Act & Assert @@ -1750,26 +1885,22 @@ def test_interval_overlaps_but_not_contained(self): assert not DuringChecker().check(interval, other) # Verify not contained assert not ContainsChecker().check(interval, other) assert not StartsChecker().check(interval, other) - assert not OverlappedByChecker().check(interval, other) # Important! This is the inverse relationship + assert not OverlappedByChecker().check( + interval, other + ) # Important! This is the inverse relationship def test_interval_is_overlapped_by_other(self): """Test case where first interval starts after second starts and ends after it ends""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T02:00:00", - "end": "2023-01-01T05:00:00" - }), + pd.Series({"start": "2023-01-01T02:00:00", "end": "2023-01-01T05:00:00"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T01:00:00", - "end": "2023-01-01T04:00:00" - }), + pd.Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T04:00:00"}), "start", - "end" + "end", ) # Act & Assert @@ -1777,7 +1908,9 @@ def test_interval_is_overlapped_by_other(self): # Verify this is exclusively an OVERLAPPED_BY relationship assert not DuringChecker().check(interval, other) # Verify not contained - assert not OverlapsChecker().check(interval, other) # Important! This is the inverse relationship + assert not OverlapsChecker().check( + interval, other + ) # Important! This is the inverse relationship assert not ContainsChecker().check(interval, other) assert not FinishesChecker().check(interval, other) @@ -1785,20 +1918,14 @@ def test_interval_before_other_no_overlap(self): """Test case where first interval is completely before second interval with no overlap""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:00", - "end": "2023-01-01T01:00:00" - }), + pd.Series({"start": "2023-01-01T00:00:00", "end": "2023-01-01T01:00:00"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T02:00:00", - "end": "2023-01-01T03:00:00" - }), + pd.Series({"start": "2023-01-01T02:00:00", "end": "2023-01-01T03:00:00"}), "start", - "end" + "end", ) # Act & Assert @@ -1806,28 +1933,26 @@ def test_interval_before_other_no_overlap(self): # Verify this is exclusively a BEFORE relationship assert not DuringChecker().check(interval, other) # Verify not contained - assert not MeetsChecker().check(interval, other) # Verify no touching boundaries + assert not MeetsChecker().check( + interval, other + ) # Verify no touching boundaries assert not OverlapsChecker().check(interval, other) # Verify no overlap - assert not AfterChecker().check(interval, other) # Important! This is the inverse relationship + assert not AfterChecker().check( + interval, other + ) # Important! This is the inverse relationship def test_intervals_are_equal(self): """Test case where intervals have identical start and end times""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T01:00:00", - "end": "2023-01-01T01:30:00" - }), + pd.Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T01:30:00"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T01:00:00", - "end": "2023-01-01T01:30:00" - }), + pd.Series({"start": "2023-01-01T01:00:00", "end": "2023-01-01T01:30:00"}), "start", - "end" + "end", ) # Act & Assert @@ -1843,20 +1968,14 @@ def test_interval_overlaps_other(self): """Test case where first interval starts before and overlaps with second interval""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:01", - "end": "2023-01-01T00:00:03" - }), + pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:03"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:02", - "end": "2023-01-01T00:00:04" - }), + pd.Series({"start": "2023-01-01T00:00:02", "end": "2023-01-01T00:00:04"}), "start", - "end" + "end", ) assert OverlapsChecker().check(interval, other) @@ -1871,20 +1990,14 @@ def test_interval_starts_with_other(self): """Test case where both intervals start at the same time but first ends earlier""" # Arrange interval = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:01", - "end": "2023-01-01T00:00:03" - }), + pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:03"}), "start", - "end" + "end", ) other = Interval.create( - pd.Series({ - "start": "2023-01-01T00:00:01", - "end": "2023-01-01T00:00:04" - }), + pd.Series({"start": "2023-01-01T00:00:01", "end": "2023-01-01T00:00:04"}), "start", - "end" + "end", ) assert StartsChecker().check(interval, other) @@ -1893,4 +2006,6 @@ def test_interval_starts_with_other(self): assert not EqualsChecker().check(interval, other) assert not DuringChecker().check(interval, other) assert not OverlapsChecker().check(interval, other) - assert not StartedByChecker().check(interval, other) # Important! This is the inverse relationship + assert not StartedByChecker().check( + interval, other + ) # Important! This is the inverse relationship diff --git a/python/tests/intervals/overlap/resolution_tests.py b/python/tests/intervals/overlap/resolution_tests.py index 92501bb1..dbdaf00d 100644 --- a/python/tests/intervals/overlap/resolution_tests.py +++ b/python/tests/intervals/overlap/resolution_tests.py @@ -18,7 +18,7 @@ FinishedByResolver, OverlappedByResolver, MetByResolver, - AfterResolver + AfterResolver, ) @@ -122,7 +122,7 @@ def create_interval(start, end, metrics=None): data=data_series, start_field="start", end_field="end", - metric_fields=[k for k in metrics.keys()] + metric_fields=[k for k in metrics.keys()], ) return create_interval @@ -162,7 +162,7 @@ def all_resolvers(): FinishedByResolver(), OverlappedByResolver(), MetByResolver(), - AfterResolver() + AfterResolver(), ] @@ -486,9 +486,7 @@ def test_sequential_resolution(self, interval_factory): # Extract the last part for next resolution last_part = result1[2] temp_interval = interval_factory( - last_part["start"], - last_part["end"], - {"value": last_part["value"]} + last_part["start"], last_part["end"], {"value": last_part["value"]} ) # Second resolution: last part meets interval3 @@ -539,7 +537,7 @@ def test_applicable_resolvers_with_identical_intervals(self, identical_intervals StartsResolver(), FinishesResolver(), StartedByResolver(), - FinishedByResolver() + FinishedByResolver(), ] for resolver in applicable_resolvers: @@ -566,7 +564,9 @@ def test_applicable_resolvers_with_identical_intervals(self, identical_intervals # Skip those that explicitly don't support it continue - def test_non_applicable_resolvers_with_identical_intervals(self, identical_intervals): + def test_non_applicable_resolvers_with_identical_intervals( + self, identical_intervals + ): """Test resolvers that might not be applicable to identical intervals""" interval1, interval2 = identical_intervals @@ -575,7 +575,7 @@ def test_non_applicable_resolvers_with_identical_intervals(self, identical_inter BeforeResolver(), MeetsResolver(), AfterResolver(), - MetByResolver() + MetByResolver(), ] for resolver in non_applicable: diff --git a/python/tests/intervals/overlap/transformer_tests.py b/python/tests/intervals/overlap/transformer_tests.py index df48f0d4..5e5444c8 100644 --- a/python/tests/intervals/overlap/transformer_tests.py +++ b/python/tests/intervals/overlap/transformer_tests.py @@ -8,18 +8,37 @@ from tempo.intervals.core.interval import Interval from tempo.intervals.overlap.detection import ( - MetricsEquivalentChecker, EqualsChecker, DuringChecker, - ContainsChecker, StartsChecker, StartedByChecker, - FinishesChecker, FinishedByChecker, MeetsChecker, - MetByChecker, OverlapsChecker, OverlappedByChecker, - BeforeChecker, AfterChecker + MetricsEquivalentChecker, + EqualsChecker, + DuringChecker, + ContainsChecker, + StartsChecker, + StartedByChecker, + FinishesChecker, + FinishedByChecker, + MeetsChecker, + MetByChecker, + OverlapsChecker, + OverlappedByChecker, + BeforeChecker, + AfterChecker, ) from tempo.intervals.overlap.resolution import ( - OverlapResolver, MetricsEquivalentResolver, EqualsResolver, - DuringResolver, ContainsResolver, StartsResolver, - StartedByResolver, FinishesResolver, FinishedByResolver, - MeetsResolver, MetByResolver, OverlapsResolver, - OverlappedByResolver, BeforeResolver, AfterResolver + OverlapResolver, + MetricsEquivalentResolver, + EqualsResolver, + DuringResolver, + ContainsResolver, + StartsResolver, + StartedByResolver, + FinishesResolver, + FinishedByResolver, + MeetsResolver, + MetByResolver, + OverlapsResolver, + OverlappedByResolver, + BeforeResolver, + AfterResolver, ) from tempo.intervals.overlap.transformer import IntervalTransformer from tempo.intervals.overlap.types import OverlapType, OverlapResult @@ -28,26 +47,32 @@ @pytest.fixture def interval_data(): """Create data for intervals.""" - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 10, - 'metric2': 20, - }) - - data2 = pd.Series({ - 'start': pd.Timestamp('2023-01-03'), - 'end': pd.Timestamp('2023-01-07'), - 'metric1': 15, - 'metric2': 25, - }) - - data3 = pd.Series({ - 'start': pd.Timestamp('2023-01-06'), - 'end': pd.Timestamp('2023-01-10'), - 'metric1': 10, - 'metric3': 30, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-05"), + "metric1": 10, + "metric2": 20, + } + ) + + data2 = pd.Series( + { + "start": pd.Timestamp("2023-01-03"), + "end": pd.Timestamp("2023-01-07"), + "metric1": 15, + "metric2": 25, + } + ) + + data3 = pd.Series( + { + "start": pd.Timestamp("2023-01-06"), + "end": pd.Timestamp("2023-01-10"), + "metric1": 10, + "metric3": 30, + } + ) return data1, data2, data3 @@ -59,23 +84,23 @@ def intervals(interval_data): interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval3 = Interval.create( data=data3, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric3'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric3"], ) return interval1, interval2, interval3 @@ -87,153 +112,155 @@ def interval_pairs(): # Create a series of interval pairs that exhibit the different relationships # Basic template for interval data - template = pd.Series({ - 'metric1': 10, - 'metric2': 20, - }) + template = pd.Series( + { + "metric1": 10, + "metric2": 20, + } + ) # Common fields for all intervals - start_field = 'start' - end_field = 'end' - metric_fields = ['metric1', 'metric2'] + start_field = "start" + end_field = "end" + metric_fields = ["metric1", "metric2"] # EQUALS: identical intervals equals_a = template.copy() - equals_a[start_field] = pd.Timestamp('2023-01-01') - equals_a[end_field] = pd.Timestamp('2023-01-05') + equals_a[start_field] = pd.Timestamp("2023-01-01") + equals_a[end_field] = pd.Timestamp("2023-01-05") equals_b = template.copy() - equals_b[start_field] = pd.Timestamp('2023-01-01') - equals_b[end_field] = pd.Timestamp('2023-01-05') + equals_b[start_field] = pd.Timestamp("2023-01-01") + equals_b[end_field] = pd.Timestamp("2023-01-05") # DURING: interval is inside other during_a = template.copy() - during_a[start_field] = pd.Timestamp('2023-01-02') - during_a[end_field] = pd.Timestamp('2023-01-04') + during_a[start_field] = pd.Timestamp("2023-01-02") + during_a[end_field] = pd.Timestamp("2023-01-04") during_b = template.copy() - during_b[start_field] = pd.Timestamp('2023-01-01') - during_b[end_field] = pd.Timestamp('2023-01-05') + during_b[start_field] = pd.Timestamp("2023-01-01") + during_b[end_field] = pd.Timestamp("2023-01-05") # CONTAINS: interval contains other contains_a = template.copy() - contains_a[start_field] = pd.Timestamp('2023-01-01') - contains_a[end_field] = pd.Timestamp('2023-01-05') + contains_a[start_field] = pd.Timestamp("2023-01-01") + contains_a[end_field] = pd.Timestamp("2023-01-05") contains_b = template.copy() - contains_b[start_field] = pd.Timestamp('2023-01-02') - contains_b[end_field] = pd.Timestamp('2023-01-04') + contains_b[start_field] = pd.Timestamp("2023-01-02") + contains_b[end_field] = pd.Timestamp("2023-01-04") # STARTS: intervals start together, interval ends first starts_a = template.copy() - starts_a[start_field] = pd.Timestamp('2023-01-01') - starts_a[end_field] = pd.Timestamp('2023-01-03') + starts_a[start_field] = pd.Timestamp("2023-01-01") + starts_a[end_field] = pd.Timestamp("2023-01-03") starts_b = template.copy() - starts_b[start_field] = pd.Timestamp('2023-01-01') - starts_b[end_field] = pd.Timestamp('2023-01-05') + starts_b[start_field] = pd.Timestamp("2023-01-01") + starts_b[end_field] = pd.Timestamp("2023-01-05") # STARTED_BY: intervals start together, other ends first started_by_a = template.copy() - started_by_a[start_field] = pd.Timestamp('2023-01-01') - started_by_a[end_field] = pd.Timestamp('2023-01-05') + started_by_a[start_field] = pd.Timestamp("2023-01-01") + started_by_a[end_field] = pd.Timestamp("2023-01-05") started_by_b = template.copy() - started_by_b[start_field] = pd.Timestamp('2023-01-01') - started_by_b[end_field] = pd.Timestamp('2023-01-03') + started_by_b[start_field] = pd.Timestamp("2023-01-01") + started_by_b[end_field] = pd.Timestamp("2023-01-03") # FINISHES: intervals end together, interval starts later finishes_a = template.copy() - finishes_a[start_field] = pd.Timestamp('2023-01-03') - finishes_a[end_field] = pd.Timestamp('2023-01-05') + finishes_a[start_field] = pd.Timestamp("2023-01-03") + finishes_a[end_field] = pd.Timestamp("2023-01-05") finishes_b = template.copy() - finishes_b[start_field] = pd.Timestamp('2023-01-01') - finishes_b[end_field] = pd.Timestamp('2023-01-05') + finishes_b[start_field] = pd.Timestamp("2023-01-01") + finishes_b[end_field] = pd.Timestamp("2023-01-05") # FINISHED_BY: intervals end together, other starts later finished_by_a = template.copy() - finished_by_a[start_field] = pd.Timestamp('2023-01-01') - finished_by_a[end_field] = pd.Timestamp('2023-01-05') + finished_by_a[start_field] = pd.Timestamp("2023-01-01") + finished_by_a[end_field] = pd.Timestamp("2023-01-05") finished_by_b = template.copy() - finished_by_b[start_field] = pd.Timestamp('2023-01-03') - finished_by_b[end_field] = pd.Timestamp('2023-01-05') + finished_by_b[start_field] = pd.Timestamp("2023-01-03") + finished_by_b[end_field] = pd.Timestamp("2023-01-05") # MEETS: interval ends where other starts meets_a = template.copy() - meets_a[start_field] = pd.Timestamp('2023-01-01') - meets_a[end_field] = pd.Timestamp('2023-01-03') + meets_a[start_field] = pd.Timestamp("2023-01-01") + meets_a[end_field] = pd.Timestamp("2023-01-03") meets_b = template.copy() - meets_b[start_field] = pd.Timestamp('2023-01-03') - meets_b[end_field] = pd.Timestamp('2023-01-05') + meets_b[start_field] = pd.Timestamp("2023-01-03") + meets_b[end_field] = pd.Timestamp("2023-01-05") # MET_BY: other ends where interval starts met_by_a = template.copy() - met_by_a[start_field] = pd.Timestamp('2023-01-03') - met_by_a[end_field] = pd.Timestamp('2023-01-05') + met_by_a[start_field] = pd.Timestamp("2023-01-03") + met_by_a[end_field] = pd.Timestamp("2023-01-05") met_by_b = template.copy() - met_by_b[start_field] = pd.Timestamp('2023-01-01') - met_by_b[end_field] = pd.Timestamp('2023-01-03') + met_by_b[start_field] = pd.Timestamp("2023-01-01") + met_by_b[end_field] = pd.Timestamp("2023-01-03") # OVERLAPS: interval starts first, overlaps start of other overlaps_a = template.copy() - overlaps_a[start_field] = pd.Timestamp('2023-01-01') - overlaps_a[end_field] = pd.Timestamp('2023-01-04') + overlaps_a[start_field] = pd.Timestamp("2023-01-01") + overlaps_a[end_field] = pd.Timestamp("2023-01-04") overlaps_b = template.copy() - overlaps_b[start_field] = pd.Timestamp('2023-01-03') - overlaps_b[end_field] = pd.Timestamp('2023-01-05') + overlaps_b[start_field] = pd.Timestamp("2023-01-03") + overlaps_b[end_field] = pd.Timestamp("2023-01-05") # OVERLAPPED_BY: other starts first, overlaps start of interval overlapped_by_a = template.copy() - overlapped_by_a[start_field] = pd.Timestamp('2023-01-03') - overlapped_by_a[end_field] = pd.Timestamp('2023-01-05') + overlapped_by_a[start_field] = pd.Timestamp("2023-01-03") + overlapped_by_a[end_field] = pd.Timestamp("2023-01-05") overlapped_by_b = template.copy() - overlapped_by_b[start_field] = pd.Timestamp('2023-01-01') - overlapped_by_b[end_field] = pd.Timestamp('2023-01-04') + overlapped_by_b[start_field] = pd.Timestamp("2023-01-01") + overlapped_by_b[end_field] = pd.Timestamp("2023-01-04") # BEFORE: interval completely before other before_a = template.copy() - before_a[start_field] = pd.Timestamp('2023-01-01') - before_a[end_field] = pd.Timestamp('2023-01-03') + before_a[start_field] = pd.Timestamp("2023-01-01") + before_a[end_field] = pd.Timestamp("2023-01-03") before_b = template.copy() - before_b[start_field] = pd.Timestamp('2023-01-04') - before_b[end_field] = pd.Timestamp('2023-01-06') + before_b[start_field] = pd.Timestamp("2023-01-04") + before_b[end_field] = pd.Timestamp("2023-01-06") # AFTER: interval completely after other after_a = template.copy() - after_a[start_field] = pd.Timestamp('2023-01-04') - after_a[end_field] = pd.Timestamp('2023-01-06') + after_a[start_field] = pd.Timestamp("2023-01-04") + after_a[end_field] = pd.Timestamp("2023-01-06") after_b = template.copy() - after_b[start_field] = pd.Timestamp('2023-01-01') - after_b[end_field] = pd.Timestamp('2023-01-03') + after_b[start_field] = pd.Timestamp("2023-01-01") + after_b[end_field] = pd.Timestamp("2023-01-03") # Create Interval objects from the data intervals = {} for name, (a, b) in { - 'equals': (equals_a, equals_b), - 'during': (during_a, during_b), - 'contains': (contains_a, contains_b), - 'starts': (starts_a, starts_b), - 'started_by': (started_by_a, started_by_b), - 'finishes': (finishes_a, finishes_b), - 'finished_by': (finished_by_a, finished_by_b), - 'meets': (meets_a, meets_b), - 'met_by': (met_by_a, met_by_b), - 'overlaps': (overlaps_a, overlaps_b), - 'overlapped_by': (overlapped_by_a, overlapped_by_b), - 'before': (before_a, before_b), - 'after': (after_a, after_b), + "equals": (equals_a, equals_b), + "during": (during_a, during_b), + "contains": (contains_a, contains_b), + "starts": (starts_a, starts_b), + "started_by": (started_by_a, started_by_b), + "finishes": (finishes_a, finishes_b), + "finished_by": (finished_by_a, finished_by_b), + "meets": (meets_a, meets_b), + "met_by": (met_by_a, met_by_b), + "overlaps": (overlaps_a, overlaps_b), + "overlapped_by": (overlapped_by_a, overlapped_by_b), + "before": (before_a, before_b), + "after": (after_a, after_b), }.items(): intervals[name] = ( Interval.create(a, start_field, end_field, metric_fields=metric_fields), - Interval.create(b, start_field, end_field, metric_fields=metric_fields) + Interval.create(b, start_field, end_field, metric_fields=metric_fields), ) return intervals @@ -242,22 +269,24 @@ def interval_pairs(): @pytest.fixture def metrics_equivalent_intervals(): """Create intervals that have equivalent metrics but different time boundaries.""" - template = pd.Series({ - 'metric1': 10, - 'metric2': 20, - }) + template = pd.Series( + { + "metric1": 10, + "metric2": 20, + } + ) a = template.copy() - a['start'] = pd.Timestamp('2023-01-01') - a['end'] = pd.Timestamp('2023-01-05') + a["start"] = pd.Timestamp("2023-01-01") + a["end"] = pd.Timestamp("2023-01-05") b = template.copy() - b['start'] = pd.Timestamp('2023-01-02') - b['end'] = pd.Timestamp('2023-01-06') + b["start"] = pd.Timestamp("2023-01-02") + b["end"] = pd.Timestamp("2023-01-06") return ( - Interval.create(a, 'start', 'end', metric_fields=['metric1', 'metric2']), - Interval.create(b, 'start', 'end', metric_fields=['metric1', 'metric2']) + Interval.create(a, "start", "end", metric_fields=["metric1", "metric2"]), + Interval.create(b, "start", "end", metric_fields=["metric1", "metric2"]), ) @@ -276,12 +305,9 @@ def test_init_orders_intervals_correctly(self, intervals): # Test with interval2 starting earlier # Create a new interval2 that starts before interval1 earlier_data = interval2.data.copy() - earlier_data['start'] = pd.Timestamp('2022-12-31') + earlier_data["start"] = pd.Timestamp("2022-12-31") earlier_interval = Interval.create( - earlier_data, - 'start', - 'end', - metric_fields=['metric1', 'metric2'] + earlier_data, "start", "end", metric_fields=["metric1", "metric2"] ) transformer = IntervalTransformer(interval1, earlier_interval) @@ -290,12 +316,9 @@ def test_init_orders_intervals_correctly(self, intervals): # Test with equal start times - should maintain original order equal_start_data = interval2.data.copy() - equal_start_data['start'] = interval1.data['start'] + equal_start_data["start"] = interval1.data["start"] equal_start_interval = Interval.create( - equal_start_data, - 'start', - 'end', - metric_fields=['metric1', 'metric2'] + equal_start_data, "start", "end", metric_fields=["metric1", "metric2"] ) transformer = IntervalTransformer(interval1, equal_start_interval) @@ -314,7 +337,9 @@ def test_validate_intervals(self, intervals): IntervalTransformer.validate_intervals(interval1, interval3) # Verify the error is about indices not matching - assert "Expected indices of interval elements to be equivalent" in str(excinfo.value) + assert "Expected indices of interval elements to be equivalent" in str( + excinfo.value + ) assert str(interval1.data.index) in str(excinfo.value) assert str(interval3.data.index) in str(excinfo.value) @@ -325,37 +350,48 @@ class TestIntervalTransformerRelationships: @pytest.mark.parametrize( "relationship_name, expected_type", [ - ('equals', OverlapType.EQUALS), - ('during', OverlapType.DURING), - ('contains', OverlapType.CONTAINS), - ('starts', OverlapType.STARTS), - ('started_by', OverlapType.STARTED_BY), - ('finishes', OverlapType.FINISHES), - ('finished_by', OverlapType.FINISHED_BY), - ('meets', OverlapType.MEETS), - ('met_by', OverlapType.MET_BY), - ('overlaps', OverlapType.OVERLAPS), - ('overlapped_by', OverlapType.OVERLAPPED_BY), - ('before', OverlapType.BEFORE), - ('after', OverlapType.AFTER) - ] + ("equals", OverlapType.EQUALS), + ("during", OverlapType.DURING), + ("contains", OverlapType.CONTAINS), + ("starts", OverlapType.STARTS), + ("started_by", OverlapType.STARTED_BY), + ("finishes", OverlapType.FINISHES), + ("finished_by", OverlapType.FINISHED_BY), + ("meets", OverlapType.MEETS), + ("met_by", OverlapType.MET_BY), + ("overlaps", OverlapType.OVERLAPS), + ("overlapped_by", OverlapType.OVERLAPPED_BY), + ("before", OverlapType.BEFORE), + ("after", OverlapType.AFTER), + ], ) - def test_detect_relationship(self, interval_pairs, relationship_name, expected_type, monkeypatch): + def test_detect_relationship( + self, interval_pairs, relationship_name, expected_type, monkeypatch + ): """Test detection of interval relationships with actual intervals.""" interval1, interval2 = interval_pairs[relationship_name] # Override the MetricsEquivalentChecker to always return False # This prevents it from taking precedence over other checkers - monkeypatch.setattr(MetricsEquivalentChecker, 'check', lambda self, a, b: False) + monkeypatch.setattr(MetricsEquivalentChecker, "check", lambda self, a, b: False) # Override all checkers to return False by default for checker_class in [ - EqualsChecker, DuringChecker, ContainsChecker, StartsChecker, - StartedByChecker, FinishesChecker, FinishedByChecker, MeetsChecker, - MetByChecker, OverlapsChecker, OverlappedByChecker, - BeforeChecker, AfterChecker + EqualsChecker, + DuringChecker, + ContainsChecker, + StartsChecker, + StartedByChecker, + FinishesChecker, + FinishedByChecker, + MeetsChecker, + MetByChecker, + OverlapsChecker, + OverlappedByChecker, + BeforeChecker, + AfterChecker, ]: - monkeypatch.setattr(checker_class, 'check', lambda self, a, b: False) + monkeypatch.setattr(checker_class, "check", lambda self, a, b: False) # Only make the expected checker return True checker_map = { @@ -371,10 +407,12 @@ def test_detect_relationship(self, interval_pairs, relationship_name, expected_t OverlapType.OVERLAPS: OverlapsChecker, OverlapType.OVERLAPPED_BY: OverlappedByChecker, OverlapType.BEFORE: BeforeChecker, - OverlapType.AFTER: AfterChecker + OverlapType.AFTER: AfterChecker, } - monkeypatch.setattr(checker_map[expected_type], 'check', lambda self, a, b: True) + monkeypatch.setattr( + checker_map[expected_type], "check", lambda self, a, b: True + ) # Test with actual checker implementation transformer = IntervalTransformer(interval1, interval2) @@ -382,7 +420,11 @@ def test_detect_relationship(self, interval_pairs, relationship_name, expected_t assert relationship == expected_type # Test with OverlapResult return type, if the API changes - with patch.object(transformer, 'detect_relationship', return_value=OverlapResult(type=expected_type)): + with patch.object( + transformer, + "detect_relationship", + return_value=OverlapResult(type=expected_type), + ): result = transformer.detect_relationship() assert isinstance(result, OverlapResult) assert result.type == expected_type @@ -392,17 +434,19 @@ def test_detect_metrics_equivalent_relationship(self, metrics_equivalent_interva interval1, interval2 = metrics_equivalent_intervals # Override checkers to simulate metrics equivalence - with patch.object(MetricsEquivalentChecker, 'check', return_value=True): + with patch.object(MetricsEquivalentChecker, "check", return_value=True): transformer = IntervalTransformer(interval1, interval2) relationship = transformer.detect_relationship() assert relationship == OverlapType.METRICS_EQUIVALENT # Also test with OverlapResult for future compatibility - with patch.object(transformer, 'detect_relationship', - return_value=OverlapResult( - type=OverlapType.METRICS_EQUIVALENT, - details={"metrics_match": True} - )): + with patch.object( + transformer, + "detect_relationship", + return_value=OverlapResult( + type=OverlapType.METRICS_EQUIVALENT, details={"metrics_match": True} + ), + ): result = transformer.detect_relationship() assert isinstance(result, OverlapResult) assert result.type == OverlapType.METRICS_EQUIVALENT @@ -421,14 +465,20 @@ def test_resolve_overlap(self, intervals, relationship_type): mock_resolver.resolve.return_value = expected_result # Test with direct OverlapType enum - with patch.object(transformer, 'detect_relationship', return_value=relationship_type), \ - patch.object(transformer, '_get_resolver', return_value=mock_resolver): + with ( + patch.object( + transformer, "detect_relationship", return_value=relationship_type + ), + patch.object(transformer, "_get_resolver", return_value=mock_resolver), + ): result = transformer.resolve_overlap() # Verify correct resolver was used transformer._get_resolver.assert_called_once_with(relationship_type) # Verify resolver.resolve was called with correct arguments - mock_resolver.resolve.assert_called_once_with(transformer.interval, transformer.other) + mock_resolver.resolve.assert_called_once_with( + transformer.interval, transformer.other + ) # Verify correct result returned assert result == expected_result @@ -438,18 +488,27 @@ def test_resolve_overlap(self, intervals, relationship_type): # Test with OverlapResult instead overlap_result = OverlapResult(type=relationship_type, details={"test": "data"}) - with patch.object(transformer, 'detect_relationship', return_value=overlap_result), \ - patch.object(transformer, '_get_resolver', return_value=mock_resolver): + with ( + patch.object( + transformer, "detect_relationship", return_value=overlap_result + ), + patch.object(transformer, "_get_resolver", return_value=mock_resolver), + ): # Add compatibility for handling OverlapResult - with patch.object(transformer, '_get_overlap_type', - return_value=relationship_type, - create=True): + with patch.object( + transformer, + "_get_overlap_type", + return_value=relationship_type, + create=True, + ): result = transformer.resolve_overlap() # Verify correct resolver was used (potentially via _get_overlap_type helper) transformer._get_resolver.assert_called_once() # Verify resolver.resolve was called with correct arguments - mock_resolver.resolve.assert_called_once_with(transformer.interval, transformer.other) + mock_resolver.resolve.assert_called_once_with( + transformer.interval, transformer.other + ) # Verify correct result returned assert result == expected_result @@ -459,8 +518,10 @@ def test_resolve_overlap_no_relationship(self, intervals): transformer = IntervalTransformer(interval1, interval2) # Test with direct None return - with patch.object(transformer, 'detect_relationship', return_value=None): - with pytest.raises(NotImplementedError, match="Unable to determine interval relationship"): + with patch.object(transformer, "detect_relationship", return_value=None): + with pytest.raises( + NotImplementedError, match="Unable to determine interval relationship" + ): transformer.resolve_overlap() # Test with OverlapResult that has None type @@ -473,9 +534,16 @@ def mock_resolve_overlap(): raise NotImplementedError("Unable to determine interval relationship") return [] # Return empty list if somehow execution continues - with patch.object(transformer, 'detect_relationship', return_value=OverlapResult(type=None)): - with patch.object(transformer, 'resolve_overlap', side_effect=mock_resolve_overlap): - with pytest.raises(NotImplementedError, match="Unable to determine interval relationship"): + with patch.object( + transformer, "detect_relationship", return_value=OverlapResult(type=None) + ): + with patch.object( + transformer, "resolve_overlap", side_effect=mock_resolve_overlap + ): + with pytest.raises( + NotImplementedError, + match="Unable to determine interval relationship", + ): transformer.resolve_overlap() def test_no_resolver_found_error(self): @@ -513,7 +581,9 @@ class CustomOverlapType(Enum): unknown_type = CustomOverlapType.CUSTOM_TYPE # Patch the detect_relationship method to return our custom type - with patch.object(transformer, 'detect_relationship', return_value=unknown_type): + with patch.object( + transformer, "detect_relationship", return_value=unknown_type + ): # When resolve_overlap tries to get a resolver for this type, # it should raise a ValueError with pytest.raises(ValueError) as excinfo: @@ -565,17 +635,22 @@ class TestIntervalTransformerIntegration: @pytest.mark.parametrize( "relationship_name, expected_segments", [ - ('overlaps', 3), # Before overlap, overlap, after overlap - ('equals', 1), # Single merged interval - ('during', 3), # Before contained, contained with merged metrics, after contained - ('contains', 3), # Before other, other with merged metrics, after other - ('before', 2), # Two separate intervals, no merging - ('after', 2), # Two separate intervals, no merging - ('meets', 2), # Two separate intervals, no merging - ('met_by', 2), # Two separate intervals, no merging - ] + ("overlaps", 3), # Before overlap, overlap, after overlap + ("equals", 1), # Single merged interval + ( + "during", + 3, + ), # Before contained, contained with merged metrics, after contained + ("contains", 3), # Before other, other with merged metrics, after other + ("before", 2), # Two separate intervals, no merging + ("after", 2), # Two separate intervals, no merging + ("meets", 2), # Two separate intervals, no merging + ("met_by", 2), # Two separate intervals, no merging + ], ) - def test_interval_resolution(self, interval_pairs, relationship_name, expected_segments): + def test_interval_resolution( + self, interval_pairs, relationship_name, expected_segments + ): """Test resolving different types of interval relationships.""" interval1, interval2 = interval_pairs[relationship_name] @@ -583,27 +658,42 @@ def test_interval_resolution(self, interval_pairs, relationship_name, expected_s transformer = IntervalTransformer(interval1, interval2) # Test with mock resolvers since we don't have the actual implementation - result_series = [pd.Series({'metric1': 10, 'metric2': 20}) for _ in range(expected_segments)] + result_series = [ + pd.Series({"metric1": 10, "metric2": 20}) for _ in range(expected_segments) + ] # Test with direct OverlapType - with patch.object(transformer, 'detect_relationship', - return_value=getattr(OverlapType, relationship_name.upper())): - with patch.object(transformer, 'resolve_overlap', return_value=result_series) as mock_resolve: + with patch.object( + transformer, + "detect_relationship", + return_value=getattr(OverlapType, relationship_name.upper()), + ): + with patch.object( + transformer, "resolve_overlap", return_value=result_series + ) as mock_resolve: result = transformer.resolve_overlap() assert len(result) == expected_segments mock_resolve.reset_mock() # Test with OverlapResult - with patch.object(transformer, 'detect_relationship', - return_value=OverlapResult( - type=getattr(OverlapType, relationship_name.upper()), - details={"test": "data"} - )): + with patch.object( + transformer, + "detect_relationship", + return_value=OverlapResult( + type=getattr(OverlapType, relationship_name.upper()), + details={"test": "data"}, + ), + ): # Mock internal handling of OverlapResult if needed - with patch.object(transformer, '_get_overlap_type', - return_value=getattr(OverlapType, relationship_name.upper()), - create=True): - with patch.object(transformer, 'resolve_overlap', return_value=result_series) as mock_resolve: + with patch.object( + transformer, + "_get_overlap_type", + return_value=getattr(OverlapType, relationship_name.upper()), + create=True, + ): + with patch.object( + transformer, "resolve_overlap", return_value=result_series + ) as mock_resolve: result = transformer.resolve_overlap() assert len(result) == expected_segments mock_resolve.reset_mock() @@ -615,41 +705,47 @@ class TestPrecisionEdgeCases: def test_meets_exact_timestamp(self): """Test intervals that share exactly one timestamp (meets/met_by edge case).""" # Create two intervals that meet at exactly one timestamp - precise_timestamp = pd.Timestamp('2023-01-03T12:00:00.000000') - - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': precise_timestamp, # End exactly at this timestamp - 'metric1': 10, - 'metric2': 20, - }) - - data2 = pd.Series({ - 'start': precise_timestamp, # Start exactly at the same timestamp - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 15, - 'metric2': 25, - }) + precise_timestamp = pd.Timestamp("2023-01-03T12:00:00.000000") + + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": precise_timestamp, # End exactly at this timestamp + "metric1": 10, + "metric2": 20, + } + ) + + data2 = pd.Series( + { + "start": precise_timestamp, # Start exactly at the same timestamp + "end": pd.Timestamp("2023-01-05"), + "metric1": 15, + "metric2": 25, + } + ) interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) # Test that the precise timestamp equality is correctly detected as MEETS transformer = IntervalTransformer(interval1, interval2) relationship = transformer.detect_relationship() - assert relationship == OverlapType.MEETS, f"Expected MEETS relationship, got {relationship}" + assert ( + relationship == OverlapType.MEETS + ), f"Expected MEETS relationship, got {relationship}" # Verify resolution behavior resolved = transformer.resolve_overlap() @@ -657,46 +753,44 @@ def test_meets_exact_timestamp(self): # Verify the resolved segments have the correct timestamps # Note: This assumes the resolved segments are ordered chronologically - assert resolved[0]['start'] == interval1.data['start'] - assert resolved[0]['end'] == interval1.data['end'] - assert resolved[1]['start'] == interval2.data['start'] - assert resolved[1]['end'] == interval2.data['end'] + assert resolved[0]["start"] == interval1.data["start"] + assert resolved[0]["end"] == interval1.data["end"] + assert resolved[1]["start"] == interval2.data["start"] + assert resolved[1]["end"] == interval2.data["end"] # Verify that metrics are preserved - assert resolved[0]['metric1'] == interval1.data['metric1'] - assert resolved[1]['metric1'] == interval2.data['metric1'] + assert resolved[0]["metric1"] == interval1.data["metric1"] + assert resolved[1]["metric1"] == interval2.data["metric1"] def test_meets_floating_point_precision(self): """Test handling of floating point precision issues in timestamp comparisons.""" # Create timestamps that might cause floating point precision issues # For example, timestamps that differ by less than a nanosecond - t1_end = pd.Timestamp('2023-01-03T12:00:00.000000001') - t2_start = pd.Timestamp('2023-01-03T12:00:00.000000000') + t1_end = pd.Timestamp("2023-01-03T12:00:00.000000001") + t2_start = pd.Timestamp("2023-01-03T12:00:00.000000000") - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': t1_end, - 'metric1': 10, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": t1_end, + "metric1": 10, + } + ) - data2 = pd.Series({ - 'start': t2_start, - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 15, - }) + data2 = pd.Series( + { + "start": t2_start, + "end": pd.Timestamp("2023-01-05"), + "metric1": 15, + } + ) interval1 = Interval.create( - data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1'] + data=data1, start_field="start", end_field="end", metric_fields=["metric1"] ) interval2 = Interval.create( - data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1'] + data=data2, start_field="start", end_field="end", metric_fields=["metric1"] ) transformer = IntervalTransformer(interval1, interval2) @@ -707,7 +801,9 @@ def test_meets_floating_point_precision(self): print(f"Relationship detected for off-by-nanosecond intervals: {relationship}") # We mainly want to verify that some valid relationship is detected - assert relationship is not None, "Should detect a relationship despite precision differences" + assert ( + relationship is not None + ), "Should detect a relationship despite precision differences" # And that resolution works without errors resolved = transformer.resolve_overlap() @@ -719,32 +815,36 @@ class TestNullMetricValues: def test_nan_metric_values(self): """Test handling intervals with NaN metric values.""" - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 10.0, - 'metric2': np.nan, # NaN value - }) - - data2 = pd.Series({ - 'start': pd.Timestamp('2023-01-03'), - 'end': pd.Timestamp('2023-01-07'), - 'metric1': 15.0, - 'metric2': 25.0, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-05"), + "metric1": 10.0, + "metric2": np.nan, # NaN value + } + ) + + data2 = pd.Series( + { + "start": pd.Timestamp("2023-01-03"), + "end": pd.Timestamp("2023-01-07"), + "metric1": 15.0, + "metric2": 25.0, + } + ) interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) # First test relationship detection @@ -752,7 +852,9 @@ def test_nan_metric_values(self): relationship = transformer.detect_relationship() # Should be OVERLAPS despite NaN value - assert relationship == OverlapType.OVERLAPS, f"Expected OVERLAPS relationship, got {relationship}" + assert ( + relationship == OverlapType.OVERLAPS + ), f"Expected OVERLAPS relationship, got {relationship}" # Test resolution with NaN value resolved = transformer.resolve_overlap() @@ -760,47 +862,53 @@ def test_nan_metric_values(self): # Check handling of NaN in the middle (overlapping) segment middle_segment = resolved[1] - assert 'metric1' in middle_segment - assert 'metric2' in middle_segment + assert "metric1" in middle_segment + assert "metric2" in middle_segment # Check that metric1 was properly merged in the overlap - assert not pd.isna(middle_segment['metric1']) + assert not pd.isna(middle_segment["metric1"]) # How metric2 is handled depends on the implementation: # - It could keep the NaN from interval1 # - It could use the value from interval2 # - It could use some other strategy like setting to 0 # Just make sure it doesn't error - print(f"NaN handling in overlapping segment: metric2 = {middle_segment['metric2']}") + print( + f"NaN handling in overlapping segment: metric2 = {middle_segment['metric2']}" + ) def test_missing_metric_fields(self): """Test handling intervals with completely missing metric fields.""" - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 10, - # metric2 is completely missing - }) - - data2 = pd.Series({ - 'start': pd.Timestamp('2023-01-03'), - 'end': pd.Timestamp('2023-01-07'), - # metric1 is missing - 'metric2': 25, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-05"), + "metric1": 10, + # metric2 is completely missing + } + ) + + data2 = pd.Series( + { + "start": pd.Timestamp("2023-01-03"), + "end": pd.Timestamp("2023-01-07"), + # metric1 is missing + "metric2": 25, + } + ) interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1'] # Only metric1 + start_field="start", + end_field="end", + metric_fields=["metric1"], # Only metric1 ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric2'] # Only metric2 + start_field="start", + end_field="end", + metric_fields=["metric2"], # Only metric2 ) # Check if validation allows different metric fields @@ -831,32 +939,36 @@ class TestActualResolverImplementations: @pytest.fixture def overlapping_intervals(self): """Create intervals that overlap.""" - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 10, - 'metric2': 20, - }) - - data2 = pd.Series({ - 'start': pd.Timestamp('2023-01-03'), - 'end': pd.Timestamp('2023-01-07'), - 'metric1': 15, - 'metric2': 25, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-05"), + "metric1": 10, + "metric2": 20, + } + ) + + data2 = pd.Series( + { + "start": pd.Timestamp("2023-01-03"), + "end": pd.Timestamp("2023-01-07"), + "metric1": 15, + "metric2": 25, + } + ) interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) return interval1, interval2 @@ -864,32 +976,36 @@ def overlapping_intervals(self): @pytest.fixture def equal_intervals(self): """Create intervals that are exactly equal.""" - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 10, - 'metric2': 20, - }) - - data2 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-05'), - 'metric1': 15, - 'metric2': 25, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-05"), + "metric1": 10, + "metric2": 20, + } + ) + + data2 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-05"), + "metric1": 15, + "metric2": 25, + } + ) interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) return interval1, interval2 @@ -897,32 +1013,36 @@ def equal_intervals(self): @pytest.fixture def containing_intervals(self): """Create intervals where one contains the other.""" - data1 = pd.Series({ - 'start': pd.Timestamp('2023-01-01'), - 'end': pd.Timestamp('2023-01-10'), - 'metric1': 10, - 'metric2': 20, - }) - - data2 = pd.Series({ - 'start': pd.Timestamp('2023-01-03'), - 'end': pd.Timestamp('2023-01-07'), - 'metric1': 15, - 'metric2': 25, - }) + data1 = pd.Series( + { + "start": pd.Timestamp("2023-01-01"), + "end": pd.Timestamp("2023-01-10"), + "metric1": 10, + "metric2": 20, + } + ) + + data2 = pd.Series( + { + "start": pd.Timestamp("2023-01-03"), + "end": pd.Timestamp("2023-01-07"), + "metric1": 15, + "metric2": 25, + } + ) interval1 = Interval.create( data=data1, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) interval2 = Interval.create( data=data2, - start_field='start', - end_field='end', - metric_fields=['metric1', 'metric2'] + start_field="start", + end_field="end", + metric_fields=["metric1", "metric2"], ) return interval1, interval2 @@ -935,7 +1055,9 @@ def test_overlaps_resolver(self, overlapping_intervals): relationship = transformer.detect_relationship() # Verify relationship is as expected - assert relationship == OverlapType.OVERLAPS, f"Expected OVERLAPS relationship, got {relationship}" + assert ( + relationship == OverlapType.OVERLAPS + ), f"Expected OVERLAPS relationship, got {relationship}" # Resolve and check result resolved = transformer.resolve_overlap() @@ -945,24 +1067,26 @@ def test_overlaps_resolver(self, overlapping_intervals): # Verify the segments have correct time boundaries # First segment: from interval1 start to interval2 start - assert resolved[0]['start'] == interval1.data['start'] - assert resolved[0]['end'] == interval2.data['start'] + assert resolved[0]["start"] == interval1.data["start"] + assert resolved[0]["end"] == interval2.data["start"] # Middle segment: overlap region - assert resolved[1]['start'] == interval2.data['start'] - assert resolved[1]['end'] == interval1.data['end'] + assert resolved[1]["start"] == interval2.data["start"] + assert resolved[1]["end"] == interval1.data["end"] # Last segment: from interval1 end to interval2 end - assert resolved[2]['start'] == interval1.data['end'] - assert resolved[2]['end'] == interval2.data['end'] + assert resolved[2]["start"] == interval1.data["end"] + assert resolved[2]["end"] == interval2.data["end"] # Check metric values in overlapping region # How metrics are merged depends on implementation, but they should have some value - assert 'metric1' in resolved[1] - assert 'metric2' in resolved[1] + assert "metric1" in resolved[1] + assert "metric2" in resolved[1] # Print actual values for reference - print(f"Metrics in overlapping segment: metric1={resolved[1]['metric1']}, metric2={resolved[1]['metric2']}") + print( + f"Metrics in overlapping segment: metric1={resolved[1]['metric1']}, metric2={resolved[1]['metric2']}" + ) def test_equals_resolver(self, equal_intervals): """Test the actual implementation of the EqualsResolver.""" @@ -972,7 +1096,9 @@ def test_equals_resolver(self, equal_intervals): relationship = transformer.detect_relationship() # Verify relationship is as expected - assert relationship == OverlapType.EQUALS, f"Expected EQUALS relationship, got {relationship}" + assert ( + relationship == OverlapType.EQUALS + ), f"Expected EQUALS relationship, got {relationship}" # Resolve and check result resolved = transformer.resolve_overlap() @@ -981,15 +1107,17 @@ def test_equals_resolver(self, equal_intervals): assert len(resolved) == 1, "EQUALS resolution should produce 1 segment" # Check time boundaries - assert resolved[0]['start'] == interval1.data['start'] - assert resolved[0]['end'] == interval1.data['end'] + assert resolved[0]["start"] == interval1.data["start"] + assert resolved[0]["end"] == interval1.data["end"] # Check metrics were merged - assert 'metric1' in resolved[0] - assert 'metric2' in resolved[0] + assert "metric1" in resolved[0] + assert "metric2" in resolved[0] # Print actual values for reference - print(f"Metrics in equals result: metric1={resolved[0]['metric1']}, metric2={resolved[0]['metric2']}") + print( + f"Metrics in equals result: metric1={resolved[0]['metric1']}, metric2={resolved[0]['metric2']}" + ) def test_contains_resolver(self, containing_intervals): """Test the actual implementation of the ContainsResolver.""" @@ -999,7 +1127,9 @@ def test_contains_resolver(self, containing_intervals): relationship = transformer.detect_relationship() # Verify relationship is as expected - assert relationship == OverlapType.CONTAINS, f"Expected CONTAINS relationship, got {relationship}" + assert ( + relationship == OverlapType.CONTAINS + ), f"Expected CONTAINS relationship, got {relationship}" # Resolve and check result resolved = transformer.resolve_overlap() @@ -1011,211 +1141,456 @@ def test_contains_resolver(self, containing_intervals): assert len(resolved) == 3, "CONTAINS resolution should produce 3 segments" # Check time boundaries - assert resolved[0]['start'] == interval1.data['start'] - assert resolved[0]['end'] == interval2.data['start'] + assert resolved[0]["start"] == interval1.data["start"] + assert resolved[0]["end"] == interval2.data["start"] - assert resolved[1]['start'] == interval2.data['start'] - assert resolved[1]['end'] == interval2.data['end'] + assert resolved[1]["start"] == interval2.data["start"] + assert resolved[1]["end"] == interval2.data["end"] - assert resolved[2]['start'] == interval2.data['end'] - assert resolved[2]['end'] == interval1.data['end'] + assert resolved[2]["start"] == interval2.data["end"] + assert resolved[2]["end"] == interval1.data["end"] # Check metrics in the middle segment (should be merged) - assert 'metric1' in resolved[1] - assert 'metric2' in resolved[1] + assert "metric1" in resolved[1] + assert "metric2" in resolved[1] # Print actual values for reference - print(f"Metrics in contained segment: metric1={resolved[1]['metric1']}, metric2={resolved[1]['metric2']}") + print( + f"Metrics in contained segment: metric1={resolved[1]['metric1']}, metric2={resolved[1]['metric2']}" + ) class TestStillValidLegacy: def test_resolve_overlap_where_interval_other_have_equivalent_metric_cols(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-03", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-03", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-04", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-04", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 1 def test_resolve_overlap_where_interval_is_contained_by_other(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-03", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-03", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-04", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-04", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 3 def test_resolve_overlap_where_shared_start_but_interval_ends_before_other(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-03", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-04", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-04", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 2 def test_resolve_overlap_where_shared_start_but_interval_ends_after_other(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-04", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-04", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-03", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 2 def test_resolve_overlap_where_shared_end_and_interval_starts_before_other(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-04", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-04", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-04", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 2 def test_resolve_overlap_where_shared_end_and_interval_starts_after_other(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-04", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-04", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-04", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 2 def test_resolve_overlap_shared_start_and_end(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-03", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-03", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 1 - def test_resolve_overlaps_where_interval_starts_first_partially_overlaps_other(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-03", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + def test_resolve_overlaps_where_interval_starts_first_partially_overlaps_other( + self, + ): + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-04", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 3 - def test_resolve_overlaps_where_other_starts_first_partially_overlaps_interval(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-03", "metric_1": 6, "metric_2": 11} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) - other = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-04", "metric_1": 5, "metric_2": 10} - ), "start", "end", metric_fields=["metric_1", "metric_2"]) + def test_resolve_overlaps_where_other_starts_first_partially_overlaps_interval( + self, + ): + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) + other = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + metric_fields=["metric_1", "metric_2"], + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() assert len(result) == 3 def test_interval_transformer_where_different_series_id_col_names(self): - interval = Interval.create(pd.Series( - { - "start": "2022-01-01", - "end": "2022-01-03", - "series_1": 1, - "metric_1": 5, - "metric_2": 10, - } - ), "start", "end", ["series_1"], ["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "series_1": 1, + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + ["series_1"], + ["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - { - "start": "2022-01-02", - "end": "2022-01-04", - "wrong": 1, - "metric_1": 6, - "metric_2": 11, - } - ), "start", "end", ["wrong"], ["metric_1", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "wrong": 1, + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + ["wrong"], + ["metric_1", "metric_2"], + ) with pytest.raises(ValueError): IntervalTransformer(interval, other) def test_interval_transformer_where_different_metric_col_names(self): - interval = Interval.create(pd.Series( - { - "start": "2022-01-01", - "end": "2022-01-03", - "series_1": 1, - "metric_1": 5, - "metric_2": 10, - } - ), "start", "end", ["series_1"], ["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "series_1": 1, + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + ["series_1"], + ["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - { - "start": "2022-01-02", - "end": "2022-01-04", - "series_1": 1, - "wrong": 6, - "metric_2": 11, - } - ), "start", "end", ["series_1"], ["wrong", "metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "series_1": 1, + "wrong": 6, + "metric_2": 11, + } + ), + "start", + "end", + ["series_1"], + ["wrong", "metric_2"], + ) with pytest.raises(ValueError): IntervalTransformer(interval, other) def test_resolve_overlaps_where_different_series_shapes(self): - interval = Interval.create(pd.Series( - { - "start": "2022-01-01", - "end": "2022-01-03", - "series_1": 1, - "metric_1": 5, - "metric_2": 10, - } - ), "start", "end", ["series_1"], ["metric_1", "metric_2"]) + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-03", + "series_1": 1, + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + ["series_1"], + ["metric_1", "metric_2"], + ) - other = Interval.create(pd.Series( - {"start": "2022-01-02", "end": "2022-01-04", "series_1": 1, "metric_2": 11} - ), "start", "end", ["series_1"], ["metric_2"]) + other = Interval.create( + pd.Series( + { + "start": "2022-01-02", + "end": "2022-01-04", + "series_1": 1, + "metric_2": 11, + } + ), + "start", + "end", + ["series_1"], + ["metric_2"], + ) with pytest.raises(ValueError): IntervalTransformer(interval, other) def test_resolve_overlaps_where_no_overlaps(self): - interval = Interval.create(pd.Series( - {"start": "2022-01-01", "end": "2022-01-02", "metric_1": 5, "metric_2": 10} - ), "start", "end") + interval = Interval.create( + pd.Series( + { + "start": "2022-01-01", + "end": "2022-01-02", + "metric_1": 5, + "metric_2": 10, + } + ), + "start", + "end", + ) - other = Interval.create(pd.Series( - {"start": "2022-01-03", "end": "2022-01-04", "metric_1": 6, "metric_2": 11} - ), "start", "end") + other = Interval.create( + pd.Series( + { + "start": "2022-01-03", + "end": "2022-01-04", + "metric_1": 6, + "metric_2": 11, + } + ), + "start", + "end", + ) resolver = IntervalTransformer(interval, other) result = resolver.resolve_overlap() diff --git a/python/tests/intervals/overlap/types_tests.py b/python/tests/intervals/overlap/types_tests.py index 29a5ab53..8e00bff3 100644 --- a/python/tests/intervals/overlap/types_tests.py +++ b/python/tests/intervals/overlap/types_tests.py @@ -7,9 +7,20 @@ class TestOverlapTypeContract: def test_completeness(self): # Verify all expected types exist expected_types = [ - "METRICS_EQUIVALENT", "BEFORE", "MEETS", "OVERLAPS", "STARTS", - "DURING", "FINISHES", "EQUALS", "CONTAINS", "STARTED_BY", - "FINISHED_BY", "OVERLAPPED_BY", "MET_BY", "AFTER" + "METRICS_EQUIVALENT", + "BEFORE", + "MEETS", + "OVERLAPS", + "STARTS", + "DURING", + "FINISHES", + "EQUALS", + "CONTAINS", + "STARTED_BY", + "FINISHED_BY", + "OVERLAPPED_BY", + "MET_BY", + "AFTER", ] actual_types = [member.name for member in OverlapType] assert set(expected_types) == set(actual_types) diff --git a/python/tests/intervals/spark/functions_tests.py b/python/tests/intervals/spark/functions_tests.py index 350ba166..6b4889c1 100644 --- a/python/tests/intervals/spark/functions_tests.py +++ b/python/tests/intervals/spark/functions_tests.py @@ -3,8 +3,17 @@ import pandas as pd import pytest from pyspark.sql.types import ( - ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - DecimalType, BooleanType, StringType, StructField, StructType + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType, + BooleanType, + StringType, + StructField, + StructType, ) from tempo.intervals.spark.functions import is_metric_col, make_disjoint_wrap @@ -13,19 +22,28 @@ class TestIsMetricCol: """Tests for the is_metric_col function""" - @pytest.mark.parametrize("dtype", [ - ByteType(), ShortType(), IntegerType(), LongType(), - FloatType(), DoubleType(), DecimalType(10, 2), BooleanType() - ]) + @pytest.mark.parametrize( + "dtype", + [ + ByteType(), + ShortType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + DecimalType(10, 2), + BooleanType(), + ], + ) def test_numeric_types_return_true(self, dtype): """Test is_metric_col with various numeric types that should return True.""" col = StructField("test", dtype, True) assert is_metric_col(col) is True - @pytest.mark.parametrize("dtype", [ - StringType(), - StructType([StructField("nested", IntegerType(), True)]) - ]) + @pytest.mark.parametrize( + "dtype", + [StringType(), StructType([StructField("nested", IntegerType(), True)])], + ) def test_non_numeric_types_return_false(self, dtype): """Test is_metric_col with non-numeric types that should return False.""" col = StructField("test", dtype, True) @@ -42,17 +60,21 @@ def setup_fields(self): "start_field": "start", "end_field": "end", "series_fields": ["series_id"], - "metric_fields": ["value"] + "metric_fields": ["value"], } def test_empty_dataframe(self, setup_fields): """Test with an empty DataFrame.""" fields = setup_fields - empty_df = pd.DataFrame(columns=[fields["start_field"], fields["end_field"], "series_id", "value"]) + empty_df = pd.DataFrame( + columns=[fields["start_field"], fields["end_field"], "series_id", "value"] + ) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(empty_df) @@ -66,13 +88,15 @@ def test_single_interval(self, setup_fields): fields["start_field"]: [1], fields["end_field"]: [5], "series_id": ["A"], - "value": [10] + "value": [10], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(df) @@ -90,27 +114,31 @@ def test_non_overlapping_intervals(self, setup_fields): fields["start_field"]: [1, 6, 11], fields["end_field"]: [5, 10, 15], "series_id": ["A", "A", "A"], - "value": [10, 20, 30] + "value": [10, 20, 30], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(df) # Since we're dealing with potentially complex interval transformations # just verify basic properties - assert len(result) == 3 # Number of intervals preserved for non-overlapping case + assert ( + len(result) == 3 + ) # Number of intervals preserved for non-overlapping case # Check all values are present for start in data[fields["start_field"]]: assert start in result[fields["start_field"]].values for end in data[fields["end_field"]]: assert end in result[fields["end_field"]].values - @patch('tempo.intervals.spark.functions.IntervalsUtils') - @patch('tempo.intervals.spark.functions.Interval') + @patch("tempo.intervals.spark.functions.IntervalsUtils") + @patch("tempo.intervals.spark.functions.Interval") def test_overlapping_intervals(self, mock_interval, mock_utils, setup_fields): """Test with overlapping intervals, using mocks to verify correct behavior.""" fields = setup_fields @@ -120,7 +148,7 @@ def test_overlapping_intervals(self, mock_interval, mock_utils, setup_fields): fields["start_field"]: [1, 3, 7], fields["end_field"]: [5, 8, 10], "series_id": ["A", "A", "A"], - "value": [10, 20, 30] + "value": [10, 20, 30], } df = pd.DataFrame(data) @@ -145,13 +173,15 @@ def test_overlapping_intervals(self, mock_interval, mock_utils, setup_fields): mock_utils_instance.add_as_disjoint.side_effect = [ pd.DataFrame([df.iloc[0]]), pd.DataFrame([df.iloc[0], df.iloc[1]]), - pd.DataFrame([df.iloc[0], df.iloc[1], df.iloc[2]]) + pd.DataFrame([df.iloc[0], df.iloc[1], df.iloc[2]]), ] # Execute the function disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(df) @@ -170,14 +200,16 @@ def test_dataframe_sorting(self, setup_fields): fields["start_field"]: [5, 1, 3, 3], fields["end_field"]: [10, 4, 8, 6], "series_id": ["A", "A", "A", "A"], - "value": [50, 10, 30, 20] + "value": [50, 10, 30, 20], } df = pd.DataFrame(data) # Create a simplified version of the make_disjoint_wrap function # that only performs the sorting step (copied from the actual implementation) def sort_intervals(pdf): - return pdf.sort_values(by=[fields["start_field"], fields["end_field"]]).reset_index(drop=True) + return pdf.sort_values( + by=[fields["start_field"], fields["end_field"]] + ).reset_index(drop=True) # Apply the sorting sorted_df = sort_intervals(df) @@ -204,7 +236,7 @@ def test_custom_field_names(self): start_field: [100, 200], end_field: [150, 250], "group": ["X", "Y"], - "measurement": [5.5, 7.7] + "measurement": [5.5, 7.7], } df = pd.DataFrame(data) @@ -235,13 +267,15 @@ def test_multiple_series_fields(self, setup_fields): fields["end_field"]: [3, 4], "region": ["North", "South"], "product": ["A", "B"], - "value": [100, 200] + "value": [100, 200], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - series_fields, fields["metric_fields"] + fields["start_field"], + fields["end_field"], + series_fields, + fields["metric_fields"], ) result = disjoint_function(df) @@ -266,13 +300,15 @@ def test_multiple_metric_fields(self, setup_fields): fields["end_field"]: [3, 4], "series_id": ["A", "B"], "sales": [100, 200], - "cost": [50, 100] + "cost": [50, 100], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], metric_fields + fields["start_field"], + fields["end_field"], + fields["series_fields"], + metric_fields, ) result = disjoint_function(df) @@ -286,9 +322,11 @@ def test_multiple_metric_fields(self, setup_fields): for cost_val in data["cost"]: assert cost_val in result["cost"].values - @patch('tempo.intervals.spark.functions.IntervalsUtils') - @patch('tempo.intervals.spark.functions.Interval') - def test_complex_overlapping_scenario(self, mock_interval, mock_utils, setup_fields): + @patch("tempo.intervals.spark.functions.IntervalsUtils") + @patch("tempo.intervals.spark.functions.Interval") + def test_complex_overlapping_scenario( + self, mock_interval, mock_utils, setup_fields + ): """Test a more complex scenario with multiple overlapping intervals.""" fields = setup_fields @@ -297,7 +335,7 @@ def test_complex_overlapping_scenario(self, mock_interval, mock_utils, setup_fie fields["start_field"]: [1, 3, 2, 7, 6], fields["end_field"]: [5, 8, 6, 10, 9], "series_id": ["A", "A", "A", "A", "A"], - "value": [10, 20, 15, 30, 25] + "value": [10, 20, 15, 30, 25], } df = pd.DataFrame(data) @@ -315,12 +353,14 @@ def test_complex_overlapping_scenario(self, mock_interval, mock_utils, setup_fie mock_interval.create.side_effect = mock_interval_instances # Create a fake disjoint result that would represent the expected output - expected_disjoint = pd.DataFrame({ - fields["start_field"]: [1, 2, 3, 6, 7], - fields["end_field"]: [2, 3, 5, 7, 10], - "series_id": ["A", "A", "A", "A", "A"], - "value": [10, 15, 20, 25, 30] - }) + expected_disjoint = pd.DataFrame( + { + fields["start_field"]: [1, 2, 3, 6, 7], + fields["end_field"]: [2, 3, 5, 7, 10], + "series_id": ["A", "A", "A", "A", "A"], + "value": [10, 15, 20, 25, 30], + } + ) # Configure the mock to return our expected result progressively mock_utils_instance = MagicMock() @@ -331,8 +371,10 @@ def test_complex_overlapping_scenario(self, mock_interval, mock_utils, setup_fie # Execute disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(df) @@ -358,7 +400,7 @@ def setup_fields(self): "start_field": "start", "end_field": "end", "series_fields": ["series_id"], - "metric_fields": ["value"] + "metric_fields": ["value"], } def test_non_overlapping_intervals_e2e(self, setup_fields): @@ -368,13 +410,15 @@ def test_non_overlapping_intervals_e2e(self, setup_fields): fields["start_field"]: [1, 6, 11], fields["end_field"]: [5, 10, 15], "series_id": ["A", "A", "A"], - "value": [10, 20, 30] + "value": [10, 20, 30], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) # Call the function and check results @@ -395,13 +439,15 @@ def test_simple_overlap_e2e(self, setup_fields): fields["start_field"]: [1, 3], fields["end_field"]: [5, 7], "series_id": ["A", "A"], - "value": [10, 20] + "value": [10, 20], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(df) @@ -424,13 +470,15 @@ def test_complex_overlap_e2e(self, setup_fields): fields["start_field"]: [1, 2, 4, 6], fields["end_field"]: [5, 7, 8, 9], "series_id": ["A", "A", "A", "A"], - "value": [10, 20, 30, 40] + "value": [10, 20, 30, 40], } df = pd.DataFrame(data) disjoint_function = make_disjoint_wrap( - fields["start_field"], fields["end_field"], - fields["series_fields"], fields["metric_fields"] + fields["start_field"], + fields["end_field"], + fields["series_fields"], + fields["metric_fields"], ) result = disjoint_function(df) diff --git a/python/tests/joins/as_of_join_tests.py b/python/tests/joins/as_of_join_tests.py index 466e5375..c8bd7d6f 100644 --- a/python/tests/joins/as_of_join_tests.py +++ b/python/tests/joins/as_of_join_tests.py @@ -14,8 +14,7 @@ def spark(): """Create a Spark session for tests.""" spark = ( - SparkSession.builder - .appName("as_of_join_tests") + SparkSession.builder.appName("as_of_join_tests") .master("local[*]") .config("spark.sql.shuffle.partitions", "2") .config("spark.sql.adaptive.enabled", "false") @@ -35,7 +34,7 @@ def test_data(): "tests", "unit_test_data", "joins", - "as_of_join_tests.json" + "as_of_join_tests.json", ) # Get absolute path @@ -66,15 +65,11 @@ def create_tsdf_from_data(spark, data_dict): df, ts_col=tsdf_data["ts_col"], series_ids=tsdf_data["series_ids"], - ts_fmt=tsdf_data.get("ts_fmt") + ts_fmt=tsdf_data.get("ts_fmt"), ) else: # Create TSDF - return TSDF( - df, - ts_col=tsdf_data["ts_col"], - series_ids=tsdf_data["series_ids"] - ) + return TSDF(df, ts_col=tsdf_data["ts_col"], series_ids=tsdf_data["series_ids"]) def create_df_from_data(spark, data_dict): @@ -92,7 +87,12 @@ def create_df_from_data(spark, data_dict): # Handle nested columns if "." in col: parts = col.split(".") - df = df.withColumn(parts[0], df[parts[0]].cast("struct")) + df = df.withColumn( + parts[0], + df[parts[0]].cast( + "struct" + ), + ) else: df = df.withColumn(col, df[col].cast("timestamp")) @@ -180,7 +180,9 @@ class TestUnionSortFilterJoin: def test_simple_ts(self, spark, test_data): """Test union-sort-filter join with simple timestamp data.""" # Get test data - scenario_data = test_data["AsOfJoinTest"]["test_union_sort_filter_join_simple_ts"] + scenario_data = test_data["AsOfJoinTest"][ + "test_union_sort_filter_join_simple_ts" + ] # Set up dataframes left_tsdf = create_tsdf_from_data(spark, scenario_data["left"]) @@ -228,7 +230,9 @@ def test_nanos(self, spark, test_data): def test_null_lead(self, spark, test_data): """Test union-sort-filter join handles NULL lead values correctly.""" # Get test data - scenario_data = test_data["AsOfJoinTest"]["test_union_sort_filter_join_null_lead"] + scenario_data = test_data["AsOfJoinTest"][ + "test_union_sort_filter_join_null_lead" + ] # Set up dataframes left_tsdf = create_tsdf_from_data(spark, scenario_data["left"]) @@ -243,4 +247,4 @@ def test_null_lead(self, spark, test_data): assert joined_df.count() == expected_tsdf.df.count() # Verify all 5 rows are present - assert joined_df.count() == 5 \ No newline at end of file + assert joined_df.count() == 5 diff --git a/python/tests/joins/edge_cases_coverage_tests.py b/python/tests/joins/edge_cases_coverage_tests.py index 1b292a36..c24f7695 100644 --- a/python/tests/joins/edge_cases_coverage_tests.py +++ b/python/tests/joins/edge_cases_coverage_tests.py @@ -23,8 +23,7 @@ def test_broadcast_join_no_series_ids(self): """Test broadcast join with no series_ids (single series case).""" # Line 328: Single series join path left_data = [ - (datetime(2024, 1, 1, 10, i), f"trade_{i}", float(i)) - for i in range(10) + (datetime(2024, 1, 1, 10, i), f"trade_{i}", float(i)) for i in range(10) ] left_df = self.spark.createDataFrame( left_data, ["timestamp", "trade_id", "volume"] @@ -32,9 +31,7 @@ def test_broadcast_join_no_series_ids(self): # No series_ids specified left_tsdf = TSDF(left_df, ts_col="timestamp") - right_data = [ - (datetime(2024, 1, 1, 10, i), 100.0 + i) for i in range(0, 10, 2) - ] + right_data = [(datetime(2024, 1, 1, 10, i), 100.0 + i) for i in range(0, 10, 2)] right_df = self.spark.createDataFrame(right_data, ["timestamp", "price"]) right_tsdf = TSDF(right_df, ts_col="timestamp") @@ -48,9 +45,7 @@ def test_broadcast_join_no_series_ids(self): def test_skew_join_no_series_ids(self): """Test skew join with no series_ids returns empty skewed keys.""" # Line 706: No series keys to be skewed - left_data = [ - (datetime(2024, 1, 1, 10, i), f"trade_{i}") for i in range(50) - ] + left_data = [(datetime(2024, 1, 1, 10, i), f"trade_{i}") for i in range(50)] left_df = self.spark.createDataFrame(left_data, ["timestamp", "trade_id"]) left_tsdf = TSDF(left_df, ts_col="timestamp") @@ -136,9 +131,7 @@ def test_skipnulls_and_tolerance_combined(self): right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) # Tolerance of 120 seconds (2 minutes), with skipNulls - joiner = SkewAsOfJoiner( - self.spark, skipNulls=True, tolerance=120 - ) + joiner = SkewAsOfJoiner(self.spark, skipNulls=True, tolerance=120) result_df, _ = joiner(left_tsdf, right_tsdf) # Both filters should be applied @@ -165,9 +158,7 @@ def test_strategy_selection_error_fallback(self): # This should not raise an exception even if selection has issues try: - strategy = choose_as_of_join_strategy( - left_tsdf, right_tsdf, self.spark - ) + strategy = choose_as_of_join_strategy(left_tsdf, right_tsdf, self.spark) # Should return some joiner (possibly fallback) self.assertIsNotNone(strategy) except Exception as e: diff --git a/python/tests/joins/helper_functions_tests.py b/python/tests/joins/helper_functions_tests.py index 6322c02a..a9987738 100644 --- a/python/tests/joins/helper_functions_tests.py +++ b/python/tests/joins/helper_functions_tests.py @@ -21,7 +21,7 @@ class TestGetSparkPlan(unittest.TestCase): """Test get_spark_plan helper function.""" - @patch('tempo.joins.strategies.SparkSession') + @patch("tempo.joins.strategies.SparkSession") def test_get_spark_plan_success(self, mock_spark_class): """Test successful Spark plan extraction.""" # Mock SparkSession and DataFrame @@ -42,7 +42,7 @@ def test_get_spark_plan_success(self, mock_spark_class): self.assertIsInstance(result, str) self.assertIn("Statistics", result) - @patch('tempo.joins.strategies.SparkSession') + @patch("tempo.joins.strategies.SparkSession") def test_get_spark_plan_with_temp_view(self, mock_spark_class): """Test that get_spark_plan creates temp view with unique name.""" mock_spark = Mock(spec=SparkSession) @@ -67,7 +67,7 @@ def test_get_spark_plan_with_temp_view(self, mock_spark_class): class TestGetBytesFromPlan(unittest.TestCase): """Test get_bytes_from_plan helper function.""" - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_gib_units(self, mock_get_plan): """Test parsing size in GiB units.""" mock_spark = Mock() @@ -82,7 +82,7 @@ def test_parse_gib_units(self, mock_get_plan): expected = 2.5 * 1024 * 1024 * 1024 self.assertEqual(result, expected) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_mib_units(self, mock_get_plan): """Test parsing size in MiB units.""" mock_spark = Mock() @@ -96,7 +96,7 @@ def test_parse_mib_units(self, mock_get_plan): expected = 128.0 * 1024 * 1024 self.assertEqual(result, expected) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_kib_units(self, mock_get_plan): """Test parsing size in KiB units.""" mock_spark = Mock() @@ -110,7 +110,7 @@ def test_parse_kib_units(self, mock_get_plan): expected = 512.0 * 1024 self.assertEqual(result, expected) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_bytes_units(self, mock_get_plan): """Test parsing size in bytes (no unit suffix).""" mock_spark = Mock() @@ -124,7 +124,7 @@ def test_parse_bytes_units(self, mock_get_plan): expected = 1024.0 self.assertEqual(result, expected) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_no_size_returns_inf(self, mock_get_plan): """Test that missing sizeInBytes returns infinity.""" mock_spark = Mock() @@ -136,9 +136,9 @@ def test_parse_no_size_returns_inf(self, mock_get_plan): result = get_bytes_from_plan(mock_df, mock_spark) # Should return infinity to avoid broadcast - self.assertEqual(result, float('inf')) + self.assertEqual(result, float("inf")) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_error_returns_inf(self, mock_get_plan): """Test that parsing errors return infinity.""" mock_spark = Mock() @@ -150,9 +150,9 @@ def test_parse_error_returns_inf(self, mock_get_plan): result = get_bytes_from_plan(mock_df, mock_spark) # Should return infinity on error - self.assertEqual(result, float('inf')) + self.assertEqual(result, float("inf")) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_decimal_sizes(self, mock_get_plan): """Test parsing sizes with decimal points.""" mock_spark = Mock() @@ -166,7 +166,7 @@ def test_parse_decimal_sizes(self, mock_get_plan): expected = 3.14159 * 1024 * 1024 * 1024 self.assertAlmostEqual(result, expected, places=2) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_size_with_parenthesis(self, mock_get_plan): """Test parsing size when plan ends with parenthesis.""" mock_spark = Mock() @@ -180,7 +180,7 @@ def test_parse_size_with_parenthesis(self, mock_get_plan): expected = 64.0 * 1024 * 1024 self.assertEqual(result, expected) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_zero_size(self, mock_get_plan): """Test parsing zero size.""" mock_spark = Mock() @@ -192,7 +192,7 @@ def test_parse_zero_size(self, mock_get_plan): self.assertEqual(result, 0.0) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_very_large_size(self, mock_get_plan): """Test parsing very large size (TiB-scale).""" mock_spark = Mock() @@ -207,7 +207,7 @@ def test_parse_very_large_size(self, mock_get_plan): expected = 1024.0 * 1024 * 1024 * 1024 self.assertEqual(result, expected) - @patch('tempo.joins.strategies.get_spark_plan') + @patch("tempo.joins.strategies.get_spark_plan") def test_parse_integer_size(self, mock_get_plan): """Test parsing size as integer (no decimal point).""" mock_spark = Mock() @@ -224,7 +224,7 @@ def test_parse_integer_size(self, mock_get_plan): class TestHelperFunctionIntegration(unittest.TestCase): """Integration tests for helper functions.""" - @patch('tempo.joins.strategies.get_bytes_from_plan') + @patch("tempo.joins.strategies.get_bytes_from_plan") def test_size_estimation_in_strategy_selection(self, mock_get_bytes): """Test that size estimation is used in strategy selection.""" from tempo.joins.strategies import choose_as_of_join_strategy @@ -241,20 +241,17 @@ def test_size_estimation_in_strategy_selection(self, mock_get_bytes): right_tsdf.df = Mock() # Call strategy selection - strategy = choose_as_of_join_strategy( - left_tsdf, - right_tsdf, - mock_spark - ) + strategy = choose_as_of_join_strategy(left_tsdf, right_tsdf, mock_spark) # Should select BroadcastAsOfJoiner for small data from tempo.joins.strategies import BroadcastAsOfJoiner + self.assertIsInstance(strategy, BroadcastAsOfJoiner) # Verify size estimation was called self.assertEqual(mock_get_bytes.call_count, 2) - @patch('tempo.joins.strategies.get_bytes_from_plan') + @patch("tempo.joins.strategies.get_bytes_from_plan") def test_size_estimation_failure_falls_back(self, mock_get_bytes): """Test that strategy selection handles size estimation failure.""" from tempo.joins.strategies import choose_as_of_join_strategy @@ -270,13 +267,10 @@ def test_size_estimation_failure_falls_back(self, mock_get_bytes): right_tsdf.df = Mock() # Should fall back to UnionSortFilterAsOfJoiner - strategy = choose_as_of_join_strategy( - left_tsdf, - right_tsdf, - mock_spark - ) + strategy = choose_as_of_join_strategy(left_tsdf, right_tsdf, mock_spark) from tempo.joins.strategies import UnionSortFilterAsOfJoiner + self.assertIsInstance(strategy, UnionSortFilterAsOfJoiner) diff --git a/python/tests/joins/skew_asof_joiner_tests.py b/python/tests/joins/skew_asof_joiner_tests.py index d8f705ac..76ed9540 100644 --- a/python/tests/joins/skew_asof_joiner_tests.py +++ b/python/tests/joins/skew_asof_joiner_tests.py @@ -11,7 +11,14 @@ import pyspark.sql.functions as F from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DoubleType, IntegerType +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + DoubleType, + IntegerType, +) from tempo.joins.strategies import SkewAsOfJoiner, _detectSignificantSkew from tempo.tsdf import TSDF @@ -41,10 +48,10 @@ def create_skewed_data(self, skew_type="key", num_records=10000): if skew_type in ["key", "both"]: # 80% of data for key "A", 15% for "B", 5% for others key_distribution = ( - ["A"] * int(num_records * 0.8) + - ["B"] * int(num_records * 0.15) + - ["C"] * int(num_records * 0.03) + - ["D"] * int(num_records * 0.02) + ["A"] * int(num_records * 0.8) + + ["B"] * int(num_records * 0.15) + + ["C"] * int(num_records * 0.03) + + ["D"] * int(num_records * 0.02) ) else: # Even key distribution @@ -74,12 +81,14 @@ def create_skewed_data(self, skew_type="key", num_records=10000): for i in range(num_records) ] - left_schema = StructType([ - StructField("symbol", StringType(), False), - StructField("timestamp", TimestampType(), False), - StructField("left_value", StringType(), True), - StructField("left_metric", DoubleType(), True), - ]) + left_schema = StructType( + [ + StructField("symbol", StringType(), False), + StructField("timestamp", TimestampType(), False), + StructField("left_value", StringType(), True), + StructField("left_metric", DoubleType(), True), + ] + ) left_df = self.spark.createDataFrame(left_data, schema=left_schema) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) @@ -91,12 +100,14 @@ def create_skewed_data(self, skew_type="key", num_records=10000): for i in range(right_records) ] - right_schema = StructType([ - StructField("symbol", StringType(), False), - StructField("timestamp", TimestampType(), False), - StructField("right_value", StringType(), True), - StructField("price", DoubleType(), True), - ]) + right_schema = StructType( + [ + StructField("symbol", StringType(), False), + StructField("timestamp", TimestampType(), False), + StructField("right_value", StringType(), True), + StructField("price", DoubleType(), True), + ] + ) right_df = self.spark.createDataFrame(right_data, schema=right_schema) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -106,35 +117,37 @@ def create_skewed_data(self, skew_type="key", num_records=10000): def test_skew_detection(self): """Test that skew detection correctly identifies skewed data.""" # Create skewed data - left_skewed, right_skewed = self.create_skewed_data(skew_type="key", num_records=1000) + left_skewed, right_skewed = self.create_skewed_data( + skew_type="key", num_records=1000 + ) # Test skew detection self.assertTrue( _detectSignificantSkew(left_skewed, right_skewed, threshold=0.3), - "Should detect key skew" + "Should detect key skew", ) # Create non-skewed data - left_normal, right_normal = self.create_skewed_data(skew_type="none", num_records=1000) + left_normal, right_normal = self.create_skewed_data( + skew_type="none", num_records=1000 + ) # Should not detect skew in even distribution self.assertFalse( _detectSignificantSkew(left_normal, right_normal, threshold=0.3), - "Should not detect skew in even distribution" + "Should not detect skew in even distribution", ) - @patch('tempo.joins.strategies.logger') + @patch("tempo.joins.strategies.logger") def test_aqe_configuration(self, mock_logger): """Test that AQE is properly configured.""" - joiner = SkewAsOfJoiner( - spark=self.spark, - left_prefix="", - right_prefix="right" - ) + joiner = SkewAsOfJoiner(spark=self.spark, left_prefix="", right_prefix="right") # Check that AQE settings were configured self.assertEqual(self.spark.conf.get("spark.sql.adaptive.enabled"), "true") - self.assertEqual(self.spark.conf.get("spark.sql.adaptive.skewJoin.enabled"), "true") + self.assertEqual( + self.spark.conf.get("spark.sql.adaptive.skewJoin.enabled"), "true" + ) # Check that configuration was logged mock_logger.info.assert_any_call("Configured AQE for skew handling") @@ -142,14 +155,16 @@ def test_aqe_configuration(self, mock_logger): def test_key_skewed_join(self): """Test join with heavily skewed keys (80% in one key).""" # Create data with key skew - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="key", num_records=1000) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="key", num_records=1000 + ) # Create joiner joiner = SkewAsOfJoiner( spark=self.spark, left_prefix="", right_prefix="right", - skew_threshold=0.2 # 20% threshold + skew_threshold=0.2, # 20% threshold ) # Perform join @@ -157,7 +172,9 @@ def test_key_skewed_join(self): # Verify results self.assertIsNotNone(result_df) - self.assertEqual(result_df.count(), left_tsdf.df.count(), "All left rows should be preserved") + self.assertEqual( + result_df.count(), left_tsdf.df.count(), "All left rows should be preserved" + ) # Check that skewed key "A" was processed correctly key_a_results = result_df.filter(F.col("symbol") == "A") @@ -166,27 +183,24 @@ def test_key_skewed_join(self): # Verify correct as-of semantics # Each left row should get the most recent right row sample_row = result_df.filter( - (F.col("symbol") == "A") & - F.col("timestamp").isNotNull() + (F.col("symbol") == "A") & F.col("timestamp").isNotNull() ).first() if sample_row and "right_timestamp" in result_df.columns: self.assertLessEqual( sample_row["right_timestamp"], sample_row["timestamp"], - "Right timestamp should be <= left timestamp" + "Right timestamp should be <= left timestamp", ) def test_temporal_skewed_join(self): """Test join with temporal skew (90% of data in last 10% of time range).""" # Create data with temporal skew - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="temporal", num_records=1000) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="temporal", num_records=1000 + ) # Create joiner - joiner = SkewAsOfJoiner( - spark=self.spark, - left_prefix="", - right_prefix="right" - ) + joiner = SkewAsOfJoiner(spark=self.spark, left_prefix="", right_prefix="right") # Perform join result_df, result_schema = joiner(left_tsdf, right_tsdf) @@ -211,13 +225,15 @@ def test_temporal_skewed_join(self): self.assertGreater( dense_period_count / result_df.count(), 0.8, - "Most data should be in the dense time period" + "Most data should be in the dense time period", ) def test_salted_join_for_extreme_skew(self): """Test salted join strategy for extreme skew.""" # Create extremely skewed data (95% in one key) - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="key", num_records=500) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="key", num_records=500 + ) # Create joiner with salting enabled joiner = SkewAsOfJoiner( @@ -226,7 +242,7 @@ def test_salted_join_for_extreme_skew(self): right_prefix="right", enable_salting=True, salt_buckets=5, - skew_threshold=0.1 # Low threshold to trigger skew handling + skew_threshold=0.1, # Low threshold to trigger skew handling ) # Perform join @@ -241,14 +257,16 @@ def test_salted_join_for_extreme_skew(self): def test_backward_compatibility_with_tspartitionval(self): """Test backward compatibility with tsPartitionVal parameter.""" - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="key", num_records=100) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="key", num_records=100 + ) # Create joiner with deprecated tsPartitionVal joiner = SkewAsOfJoiner( spark=self.spark, left_prefix="", right_prefix="right", - tsPartitionVal=300 # 5 minutes + tsPartitionVal=300, # 5 minutes ) # Should still work @@ -259,21 +277,22 @@ def test_backward_compatibility_with_tspartitionval(self): def test_skip_nulls_with_skewed_data(self): """Test skipNulls functionality with skewed data.""" # Create skewed data and add some nulls - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="key", num_records=100) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="key", num_records=100 + ) # Add nulls to right data right_with_nulls = right_tsdf.df.withColumn( "price", - F.when(F.col("symbol") == "B", F.lit(None)).otherwise(F.col("price")) + F.when(F.col("symbol") == "B", F.lit(None)).otherwise(F.col("price")), + ) + right_tsdf_nulls = TSDF( + right_with_nulls, ts_col="timestamp", series_ids=["symbol"] ) - right_tsdf_nulls = TSDF(right_with_nulls, ts_col="timestamp", series_ids=["symbol"]) # Test with skipNulls=True joiner_skip = SkewAsOfJoiner( - spark=self.spark, - left_prefix="", - right_prefix="right", - skipNulls=True + spark=self.spark, left_prefix="", right_prefix="right", skipNulls=True ) result_skip, _ = joiner_skip(left_tsdf, right_tsdf_nulls) @@ -281,18 +300,22 @@ def test_skip_nulls_with_skewed_data(self): # Check that rows with null values are filtered appropriately b_results = result_skip.filter(F.col("symbol") == "B") non_null_b = b_results.filter(F.col("price").isNotNull()) - self.assertEqual(non_null_b.count(), 0, "Symbol B should have no matches due to null prices") + self.assertEqual( + non_null_b.count(), 0, "Symbol B should have no matches due to null prices" + ) def test_tolerance_with_skewed_data(self): """Test tolerance filtering with skewed data.""" - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="key", num_records=100) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="key", num_records=100 + ) # Create joiner with tolerance joiner = SkewAsOfJoiner( spark=self.spark, left_prefix="", right_prefix="right", - tolerance=60 # 1 minute tolerance + tolerance=60, # 1 minute tolerance ) result_df, _ = joiner(left_tsdf, right_tsdf) @@ -302,26 +325,32 @@ def test_tolerance_with_skewed_data(self): # Check that matches outside tolerance have null right columns # This would require examining specific timestamp differences - @patch('tempo.joins.strategies.get_bytes_from_plan') + @patch("tempo.joins.strategies.get_bytes_from_plan") def test_strategy_selection_for_skewed_keys(self, mock_get_bytes): """Test that different strategies are selected based on data characteristics.""" # Mock size detection for broadcast decision mock_get_bytes.return_value = 50 * 1024 * 1024 # 50MB - too large for broadcast # Create heavily skewed data - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="key", num_records=100) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="key", num_records=100 + ) joiner = SkewAsOfJoiner( spark=self.spark, left_prefix="", right_prefix="right", skew_threshold=0.1, # Low threshold - enable_salting=False + enable_salting=False, ) # Mock to track method calls - with patch.object(joiner, '_skewSeparatedJoin', wraps=joiner._skewSeparatedJoin) as mock_separated: - with patch.object(joiner, '_standardAsOfJoin', wraps=joiner._standardAsOfJoin) as mock_standard: + with patch.object( + joiner, "_skewSeparatedJoin", wraps=joiner._skewSeparatedJoin + ) as mock_separated: + with patch.object( + joiner, "_standardAsOfJoin", wraps=joiner._standardAsOfJoin + ) as mock_standard: result_df, _ = joiner(left_tsdf, right_tsdf) # Should use separated join for skewed keys @@ -333,13 +362,12 @@ def test_strategy_selection_for_skewed_keys(self, mock_get_bytes): def test_mixed_key_and_temporal_skew(self): """Test handling of both key and temporal skew simultaneously.""" # Create data with both types of skew - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="both", num_records=500) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="both", num_records=500 + ) joiner = SkewAsOfJoiner( - spark=self.spark, - left_prefix="", - right_prefix="right", - skew_threshold=0.15 + spark=self.spark, left_prefix="", right_prefix="right", skew_threshold=0.15 ) result_df, result_schema = joiner(left_tsdf, right_tsdf) @@ -354,8 +382,11 @@ def test_mixed_key_and_temporal_skew(self): # Verify both skew types are handled # Check key skew handling key_a_count = result_df.filter(F.col("symbol") == "A").count() - self.assertGreater(key_a_count, result_df.count() * 0.7, - "Skewed key should have majority of results") + self.assertGreater( + key_a_count, + result_df.count() * 0.7, + "Skewed key should have majority of results", + ) # Check temporal skew handling max_time = result_df.agg(F.max("timestamp")).collect()[0][0] @@ -364,21 +395,22 @@ def test_mixed_key_and_temporal_skew(self): time_range = (max_time - min_time).total_seconds() cutoff_time = min_time + timedelta(seconds=time_range * 0.9) dense_count = result_df.filter(F.col("timestamp") >= cutoff_time).count() - self.assertGreater(dense_count, result_df.count() * 0.8, - "Dense time period should be handled correctly") - + self.assertGreater( + dense_count, + result_df.count() * 0.8, + "Dense time period should be handled correctly", + ) def test_no_skew_baseline(self): """Test baseline case with evenly distributed data (no skew).""" # Create evenly distributed data - left_tsdf, right_tsdf = self.create_skewed_data(skew_type="none", num_records=400) + left_tsdf, right_tsdf = self.create_skewed_data( + skew_type="none", num_records=400 + ) # Joiner should handle non-skewed data efficiently joiner = SkewAsOfJoiner( - spark=self.spark, - left_prefix="", - right_prefix="right", - skew_threshold=0.2 + spark=self.spark, left_prefix="", right_prefix="right", skew_threshold=0.2 ) result_df, _ = joiner(left_tsdf, right_tsdf) @@ -393,8 +425,9 @@ def test_no_skew_baseline(self): max_count = max(counts) min_count = min(counts) # Ratio should be close to 1 for even distribution - self.assertLess(max_count / min_count, 1.5, - "Keys should have roughly equal counts") + self.assertLess( + max_count / min_count, 1.5, "Keys should have roughly equal counts" + ) def test_extreme_key_skew_95_percent(self): """Test extreme key skew with 95% of data in one key.""" @@ -402,26 +435,32 @@ def test_extreme_key_skew_95_percent(self): base_time = datetime(2024, 1, 1) # 95% for key "EXTREME", 5% for others - left_data = [("EXTREME", base_time + timedelta(minutes=i), f"val_{i}", float(i)) - for i in range(950)] - left_data += [("OTHER", base_time + timedelta(minutes=i), f"val_{i}", float(i)) - for i in range(50)] + left_data = [ + ("EXTREME", base_time + timedelta(minutes=i), f"val_{i}", float(i)) + for i in range(950) + ] + left_data += [ + ("OTHER", base_time + timedelta(minutes=i), f"val_{i}", float(i)) + for i in range(50) + ] left_df = self.spark.createDataFrame( - left_data, - ["symbol", "timestamp", "value", "metric"] + left_data, ["symbol", "timestamp", "value", "metric"] ) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) # Right side also skewed - right_data = [("EXTREME", base_time + timedelta(minutes=i*10), float(i*10)) - for i in range(95)] - right_data += [("OTHER", base_time + timedelta(minutes=i*10), float(i*100)) - for i in range(5)] + right_data = [ + ("EXTREME", base_time + timedelta(minutes=i * 10), float(i * 10)) + for i in range(95) + ] + right_data += [ + ("OTHER", base_time + timedelta(minutes=i * 10), float(i * 100)) + for i in range(5) + ] right_df = self.spark.createDataFrame( - right_data, - ["symbol", "timestamp", "price"] + right_data, ["symbol", "timestamp", "price"] ) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -432,7 +471,7 @@ def test_extreme_key_skew_95_percent(self): right_prefix="right", skew_threshold=0.1, # 10% threshold enable_salting=True, - salt_buckets=5 + salt_buckets=5, ) result_df, _ = joiner(left_tsdf, right_tsdf) @@ -448,6 +487,7 @@ def test_power_law_distribution_skew(self): """Test power law distribution (realistic for user activity data).""" # Create power law distributed data import random + random.seed(42) base_time = datetime(2024, 1, 1) @@ -465,7 +505,7 @@ def test_power_law_distribution_skew(self): left_df = self.spark.createDataFrame( left_data[:5000], # Limit to 5000 rows - ["user", "timestamp", "action", "value"] + ["user", "timestamp", "action", "value"], ) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["user"]) @@ -475,8 +515,7 @@ def test_power_law_distribution_skew(self): for i, key in enumerate(keys[:50]) # Only some users have profiles ] right_df = self.spark.createDataFrame( - right_data, - ["user", "timestamp", "status"] + right_data, ["user", "timestamp", "status"] ) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["user"]) @@ -484,7 +523,7 @@ def test_power_law_distribution_skew(self): spark=self.spark, left_prefix="", right_prefix="profile", - skew_threshold=0.15 + skew_threshold=0.15, ) result_df, _ = joiner(left_tsdf, right_tsdf) @@ -514,32 +553,29 @@ def test_multikey_skew(self): ) left_df = self.spark.createDataFrame( - left_data, - ["exchange", "symbol", "timestamp", "volume"] + left_data, ["exchange", "symbol", "timestamp", "volume"] ) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["exchange", "symbol"]) # Right side right_data = [ - ("NYSE", "AAPL", base_time + timedelta(minutes=i*100), float(i*100)) + ("NYSE", "AAPL", base_time + timedelta(minutes=i * 100), float(i * 100)) for i in range(8) ] right_data += [ ("NYSE", "GOOGL", base_time, 1000.0), - ("NASDAQ", "MSFT", base_time, 2000.0) + ("NASDAQ", "MSFT", base_time, 2000.0), ] right_df = self.spark.createDataFrame( - right_data, - ["exchange", "symbol", "timestamp", "price"] + right_data, ["exchange", "symbol", "timestamp", "price"] + ) + right_tsdf = TSDF( + right_df, ts_col="timestamp", series_ids=["exchange", "symbol"] ) - right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["exchange", "symbol"]) joiner = SkewAsOfJoiner( - spark=self.spark, - left_prefix="", - right_prefix="quote", - skew_threshold=0.2 + spark=self.spark, left_prefix="", right_prefix="quote", skew_threshold=0.2 ) result_df, result_schema = joiner(left_tsdf, right_tsdf) @@ -563,17 +599,21 @@ def test_hot_partition_temporal_skew(self): # First 5 minutes: 70% of activity for i in range(700): left_data.append( - ("AAPL", base_time + timedelta(seconds=i*0.4), f"trade_{i}", float(i)) + ("AAPL", base_time + timedelta(seconds=i * 0.4), f"trade_{i}", float(i)) ) # Rest of day: 30% of activity for i in range(300): left_data.append( - ("AAPL", base_time + timedelta(minutes=5 + i), f"trade_{i+700}", float(i)) + ( + "AAPL", + base_time + timedelta(minutes=5 + i), + f"trade_{i+700}", + float(i), + ) ) left_df = self.spark.createDataFrame( - left_data, - ["symbol", "timestamp", "trade_id", "volume"] + left_data, ["symbol", "timestamp", "trade_id", "volume"] ) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) @@ -584,15 +624,12 @@ def test_hot_partition_temporal_skew(self): ] right_df = self.spark.createDataFrame( - right_data, - ["symbol", "timestamp", "price"] + right_data, ["symbol", "timestamp", "price"] ) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) joiner = SkewAsOfJoiner( - spark=self.spark, - left_prefix="trade", - right_prefix="quote" + spark=self.spark, left_prefix="trade", right_prefix="quote" ) result_df, _ = joiner(left_tsdf, right_tsdf) @@ -614,10 +651,11 @@ class SkewDetectionTest(unittest.TestCase): def setUp(self): """Set up test fixtures.""" - self.spark = SparkSession.builder \ - .master("local[*]") \ - .appName("SkewDetectionTest") \ + self.spark = ( + SparkSession.builder.master("local[*]") + .appName("SkewDetectionTest") .getOrCreate() + ) def tearDown(self): """Clean up after tests.""" @@ -642,7 +680,7 @@ def test_detect_significant_skew_small_data(self): # Should return False for tiny datasets self.assertFalse(_detectSignificantSkew(tsdf, tsdf)) - @patch('tempo.joins.strategies.logger') + @patch("tempo.joins.strategies.logger") def test_detect_significant_skew_with_error(self, mock_logger): """Test skew detection handles errors gracefully.""" # Create mock TSDF that will cause an error during count @@ -661,4 +699,4 @@ def test_detect_significant_skew_with_error(self, mock_logger): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/python/tests/joins/strategies_additional_coverage_tests.py b/python/tests/joins/strategies_additional_coverage_tests.py index 94acea5e..3a4c9b24 100644 --- a/python/tests/joins/strategies_additional_coverage_tests.py +++ b/python/tests/joins/strategies_additional_coverage_tests.py @@ -8,7 +8,13 @@ import unittest from datetime import datetime, timedelta from pyspark.sql import functions as F -from pyspark.sql.types import StructType, StructField, TimestampType, StringType, DoubleType +from pyspark.sql.types import ( + StructType, + StructField, + TimestampType, + StringType, + DoubleType, +) from tempo.tsdf import TSDF from tempo.tsschema import TSSchema from tempo.joins.strategies import ( @@ -26,8 +32,7 @@ def test_empty_prefix_handling(self): """Test that empty/None prefix is handled correctly.""" # Lines 108-120: _prefixColumns should skip prefixing when prefix is empty left_data = [ - (datetime(2024, 1, 1, 10, i), f"A", float(100 + i)) - for i in range(5) + (datetime(2024, 1, 1, 10, i), f"A", float(100 + i)) for i in range(5) ] left_df = self.spark.createDataFrame( left_data, ["timestamp", "symbol", "price"] @@ -35,8 +40,7 @@ def test_empty_prefix_handling(self): left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) right_data = [ - (datetime(2024, 1, 1, 10, i), f"A", float(200 + i)) - for i in range(0, 5, 2) + (datetime(2024, 1, 1, 10, i), f"A", float(200 + i)) for i in range(0, 5, 2) ] right_df = self.spark.createDataFrame( right_data, ["timestamp", "symbol", "price"] @@ -54,8 +58,7 @@ def test_skipnulls_with_only_series_timestamp(self): """Test skipNulls when right DataFrame has only series+timestamp columns.""" # Lines 506-510: skipNulls fallback when no value columns to check left_data = [ - (datetime(2024, 1, 1, 10, i), f"A", float(100 + i)) - for i in range(5) + (datetime(2024, 1, 1, 10, i), f"A", float(100 + i)) for i in range(5) ] left_df = self.spark.createDataFrame( left_data, ["timestamp", "symbol", "price"] @@ -63,10 +66,7 @@ def test_skipnulls_with_only_series_timestamp(self): left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) # Right has ONLY series and timestamp - no value columns - right_data = [ - (datetime(2024, 1, 1, 10, i), f"A") - for i in range(0, 5, 2) - ] + right_data = [(datetime(2024, 1, 1, 10, i), f"A") for i in range(0, 5, 2)] right_df = self.spark.createDataFrame(right_data, ["timestamp", "symbol"]) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -80,17 +80,14 @@ def test_tolerance_none_early_return(self): """Test that tolerance=None returns immediately without filtering.""" # Lines 556-557: Early return when tolerance is None left_data = [ - (datetime(2024, 1, 1, 10, i), f"A", float(100 + i)) - for i in range(5) + (datetime(2024, 1, 1, 10, i), f"A", float(100 + i)) for i in range(5) ] left_df = self.spark.createDataFrame( left_data, ["timestamp", "symbol", "price"] ) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) - right_data = [ - (datetime(2024, 1, 1, 10, 0), f"A", float(200)) - ] + right_data = [(datetime(2024, 1, 1, 10, 0), f"A", float(200))] right_df = self.spark.createDataFrame( right_data, ["timestamp", "symbol", "bid"] ) @@ -124,7 +121,12 @@ def test_multi_series_skew_separation(self): for product in ["ProductB", "ProductC"]: for i in range(25): left_data.append( - (region, product, base_time + timedelta(minutes=i), float(100 + i)) + ( + region, + product, + base_time + timedelta(minutes=i), + float(100 + i), + ) ) left_df = self.spark.createDataFrame( @@ -134,13 +136,14 @@ def test_multi_series_skew_separation(self): # Create corresponding right data (10% of left) right_data = [ - (d[0], d[1], d[2], float(200 + i)) - for i, d in enumerate(left_data[::10]) + (d[0], d[1], d[2], float(200 + i)) for i, d in enumerate(left_data[::10]) ] right_df = self.spark.createDataFrame( right_data, ["region", "product", "timestamp", "price"] ) - right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["region", "product"]) + right_tsdf = TSDF( + right_df, ts_col="timestamp", series_ids=["region", "product"] + ) # Use low threshold to trigger skew detection and separation joiner = SkewAsOfJoiner( @@ -148,7 +151,7 @@ def test_multi_series_skew_separation(self): left_prefix="", right_prefix="right", skew_threshold=0.1, # Low enough to detect ("US", "ProductA") - enable_salting=False + enable_salting=False, ) result_df, result_schema = joiner(left_tsdf, right_tsdf) @@ -167,8 +170,7 @@ def test_skipnulls_column_name_patterns(self): base_time = datetime(2024, 1, 1) left_data = [ - (base_time + timedelta(minutes=i), f"A", f"val_{i}") - for i in range(5) + (base_time + timedelta(minutes=i), f"A", f"val_{i}") for i in range(5) ] left_df = self.spark.createDataFrame( left_data, ["timestamp", "symbol", "metric"] @@ -178,10 +180,8 @@ def test_skipnulls_column_name_patterns(self): # Right has columns with "timestamp" in the name (should be excluded from null checking) # Also has regular value columns right_df = self.spark.createDataFrame( - [ - (base_time, "A", base_time, 100.0, "event_1") - ], - ["timestamp", "symbol", "event_timestamp", "price", "event_id"] + [(base_time, "A", base_time, 100.0, "event_1")], + ["timestamp", "symbol", "event_timestamp", "price", "event_id"], ) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -193,16 +193,15 @@ def test_skipnulls_column_name_patterns(self): self.assertGreater(result_df.count(), 0) # Test that column without "timestamp" in name is checked for nulls - schema2 = StructType([ - StructField("timestamp", TimestampType(), False), - StructField("symbol", StringType(), False), - StructField("ts", StringType(), True) # Nullable column - ]) - right_df2 = self.spark.createDataFrame( + schema2 = StructType( [ - (base_time, "A", None) # Null value in "ts" column - ], - schema=schema2 + StructField("timestamp", TimestampType(), False), + StructField("symbol", StringType(), False), + StructField("ts", StringType(), True), # Nullable column + ] + ) + right_df2 = self.spark.createDataFrame( + [(base_time, "A", None)], schema=schema2 # Null value in "ts" column ) right_tsdf2 = TSDF(right_df2, ts_col="timestamp", series_ids=["symbol"]) @@ -220,15 +219,13 @@ def test_no_series_ids_all_strategies(self): # Create data with NO series columns (global time series) left_data = [ - (base_time + timedelta(minutes=i), float(100 + i)) - for i in range(10) + (base_time + timedelta(minutes=i), float(100 + i)) for i in range(10) ] left_df = self.spark.createDataFrame(left_data, ["timestamp", "value"]) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=[]) right_data = [ - (base_time + timedelta(minutes=i), float(200 + i)) - for i in range(0, 10, 2) + (base_time + timedelta(minutes=i), float(200 + i)) for i in range(0, 10, 2) ] right_df = self.spark.createDataFrame(right_data, ["timestamp", "price"]) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=[]) @@ -249,7 +246,7 @@ def test_no_series_ids_all_strategies(self): self.spark, skew_threshold=0.2, enable_salting=True, # This triggers line 908-909 (salt on timestamp) - salt_buckets=5 + salt_buckets=5, ) skew_result, _ = skew_joiner(left_tsdf, right_tsdf) self.assertEqual(skew_result.count(), 10) @@ -264,8 +261,7 @@ def test_skew_joiner_skipnulls_no_value_columns(self): base_time = datetime(2024, 1, 1, 10, 0) left_data = [ - (base_time + timedelta(minutes=i), f"A", float(100 + i)) - for i in range(20) + (base_time + timedelta(minutes=i), f"A", float(100 + i)) for i in range(20) ] left_df = self.spark.createDataFrame( left_data, ["timestamp", "symbol", "price"] @@ -273,10 +269,7 @@ def test_skew_joiner_skipnulls_no_value_columns(self): left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) # Right DataFrame with ONLY series and timestamp columns (no value columns) - right_data = [ - (base_time + timedelta(minutes=i), f"A") - for i in range(0, 20, 4) - ] + right_data = [(base_time + timedelta(minutes=i), f"A") for i in range(0, 20, 4)] right_df = self.spark.createDataFrame(right_data, ["timestamp", "symbol"]) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -284,7 +277,7 @@ def test_skew_joiner_skipnulls_no_value_columns(self): joiner = SkewAsOfJoiner( self.spark, skipNulls=True, - skew_threshold=1.0 # Disable skew detection for simpler test + skew_threshold=1.0, # Disable skew detection for simpler test ) result_df, _ = joiner(left_tsdf, right_tsdf) @@ -292,11 +285,7 @@ def test_skew_joiner_skipnulls_no_value_columns(self): self.assertEqual(result_df.count(), 20) # Test with skipNulls=False as well - joiner2 = SkewAsOfJoiner( - self.spark, - skipNulls=False, - skew_threshold=1.0 - ) + joiner2 = SkewAsOfJoiner(self.spark, skipNulls=False, skew_threshold=1.0) result_df2, _ = joiner2(left_tsdf, right_tsdf) self.assertEqual(result_df2.count(), 20) @@ -326,15 +315,13 @@ def test_broadcast_join_with_range_bin_size(self): # Test with small bin size (60 seconds - default) joiner_small = BroadcastAsOfJoiner( - self.spark, - range_join_bin_size=60 # 60 seconds bin + self.spark, range_join_bin_size=60 # 60 seconds bin ) result_small, _ = joiner_small(left_tsdf, right_tsdf) # Test with large bin size (300 seconds) joiner_large = BroadcastAsOfJoiner( - self.spark, - range_join_bin_size=300 # 300 seconds bin + self.spark, range_join_bin_size=300 # 300 seconds bin ) result_large, _ = joiner_large(left_tsdf, right_tsdf) @@ -343,8 +330,12 @@ def test_broadcast_join_with_range_bin_size(self): self.assertEqual(result_large.count(), 10) # Both should produce same results (bin size affects performance, not correctness) - small_non_null = result_small.filter(F.col("right_timestamp").isNotNull()).count() - large_non_null = result_large.filter(F.col("right_timestamp").isNotNull()).count() + small_non_null = result_small.filter( + F.col("right_timestamp").isNotNull() + ).count() + large_non_null = result_large.filter( + F.col("right_timestamp").isNotNull() + ).count() # Should have same number of matches regardless of bin size self.assertEqual(small_non_null, large_non_null) @@ -364,20 +355,28 @@ def test_composite_timestamp_index_join(self): ( (double_ts, ts, ts.isoformat()), # timestamp struct f"A", # symbol - float(100 + i) # price + float(100 + i), # price ) ) # Define schema with struct timestamp - left_schema = StructType([ - StructField("ts_idx", StructType([ - StructField("double_ts", DoubleType(), True), - StructField("parsed_ts", TimestampType(), True), - StructField("src_str", StringType(), True), - ]), True), - StructField("symbol", StringType(), True), - StructField("price", DoubleType(), True), - ]) + left_schema = StructType( + [ + StructField( + "ts_idx", + StructType( + [ + StructField("double_ts", DoubleType(), True), + StructField("parsed_ts", TimestampType(), True), + StructField("src_str", StringType(), True), + ] + ), + True, + ), + StructField("symbol", StringType(), True), + StructField("price", DoubleType(), True), + ] + ) left_df = self.spark.createDataFrame(left_data, schema=left_schema) @@ -388,7 +387,7 @@ def test_composite_timestamp_index_join(self): parsed_field="double_ts", src_str_field="src_str", secondary_parsed_field="parsed_ts", - series_ids=["symbol"] + series_ids=["symbol"], ) left_tsdf = TSDF(left_df, ts_schema=left_ts_schema) @@ -397,13 +396,7 @@ def test_composite_timestamp_index_join(self): for i in range(0, 10, 2): ts = base_time + timedelta(seconds=i) double_ts = ts.timestamp() - right_data.append( - ( - (double_ts, ts, ts.isoformat()), - f"A", - float(200 + i) - ) - ) + right_data.append(((double_ts, ts, ts.isoformat()), f"A", float(200 + i))) right_df = self.spark.createDataFrame(right_data, schema=left_schema) right_ts_schema = TSSchema.fromParsedTimestamp( @@ -412,7 +405,7 @@ def test_composite_timestamp_index_join(self): parsed_field="double_ts", src_str_field="src_str", secondary_parsed_field="parsed_ts", - series_ids=["symbol"] + series_ids=["symbol"], ) right_tsdf = TSDF(right_df, ts_schema=right_ts_schema) diff --git a/python/tests/joins/strategies_integration_tests.py b/python/tests/joins/strategies_integration_tests.py index c3b6a440..cb35e9b4 100644 --- a/python/tests/joins/strategies_integration_tests.py +++ b/python/tests/joins/strategies_integration_tests.py @@ -24,24 +24,18 @@ def test_broadcast_join_basic(self): # Load test data using function-based pattern left_tsdf = self.get_test_function_df_builder("left").as_tsdf() right_tsdf = self.get_test_function_df_builder("right").as_tsdf() - expected_tsdf = self.get_test_function_df_builder("expected_broadcast").as_tsdf() + expected_tsdf = self.get_test_function_df_builder( + "expected_broadcast" + ).as_tsdf() # Create and execute broadcast join - joiner = BroadcastAsOfJoiner( - self.spark, - left_prefix="", - right_prefix="right" - ) + joiner = BroadcastAsOfJoiner(self.spark, left_prefix="", right_prefix="right") # Execute join - returns (DataFrame, TSSchema) tuple result_df, result_schema = joiner(left_tsdf, right_tsdf) # Compare with expected results - self.assertDataFrameEquality( - result_df, - expected_tsdf.df, - ignore_row_order=True - ) + self.assertDataFrameEquality(result_df, expected_tsdf.df, ignore_row_order=True) def test_union_sort_filter_join_basic(self): """Test UnionSortFilterAsOfJoiner with basic test data.""" @@ -52,45 +46,37 @@ def test_union_sort_filter_join_basic(self): # Create and execute union-sort-filter join joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=True + left_prefix="", right_prefix="right", skipNulls=True ) # Execute join - returns (DataFrame, TSSchema) tuple result_df, result_schema = joiner(left_tsdf, right_tsdf) # Compare with expected results - self.assertDataFrameEquality( - result_df, - expected_tsdf.df, - ignore_row_order=True - ) + self.assertDataFrameEquality(result_df, expected_tsdf.df, ignore_row_order=True) def test_tolerance_filtering(self): """Test tolerance parameter filtering.""" # Load test data using function-based pattern left_tsdf = self.get_test_function_df_builder("left").as_tsdf() right_tsdf = self.get_test_function_df_builder("right").as_tsdf() - expected_tsdf = self.get_test_function_df_builder("expected_tolerance_120").as_tsdf() + expected_tsdf = self.get_test_function_df_builder( + "expected_tolerance_120" + ).as_tsdf() # Create joiner with tolerance joiner = UnionSortFilterAsOfJoiner( left_prefix="", right_prefix="right", skipNulls=True, - tolerance=120 # 2 minutes tolerance + tolerance=120, # 2 minutes tolerance ) # Execute join - returns (DataFrame, TSSchema) tuple result_df, result_schema = joiner(left_tsdf, right_tsdf) # Compare with expected results - self.assertDataFrameEquality( - result_df, - expected_tsdf.df, - ignore_row_order=True - ) + self.assertDataFrameEquality(result_df, expected_tsdf.df, ignore_row_order=True) def test_skip_nulls_behavior(self): """Test skipNulls parameter behavior.""" @@ -99,32 +85,24 @@ def test_skip_nulls_behavior(self): right_tsdf = self.get_test_function_df_builder("right").as_tsdf() # Test with skipNulls=True - expected_tsdf = self.get_test_function_df_builder("expected_skip_nulls_true").as_tsdf() + expected_tsdf = self.get_test_function_df_builder( + "expected_skip_nulls_true" + ).as_tsdf() joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=True + left_prefix="", right_prefix="right", skipNulls=True ) result_df, result_schema = joiner(left_tsdf, right_tsdf) - self.assertDataFrameEquality( - result_df, - expected_tsdf.df, - ignore_row_order=True - ) + self.assertDataFrameEquality(result_df, expected_tsdf.df, ignore_row_order=True) # Test with skipNulls=False - expected_tsdf = self.get_test_function_df_builder("expected_skip_nulls_false").as_tsdf() + expected_tsdf = self.get_test_function_df_builder( + "expected_skip_nulls_false" + ).as_tsdf() joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=False + left_prefix="", right_prefix="right", skipNulls=False ) result_df, result_schema = joiner(left_tsdf, right_tsdf) - self.assertDataFrameEquality( - result_df, - expected_tsdf.df, - ignore_row_order=True - ) + self.assertDataFrameEquality(result_df, expected_tsdf.df, ignore_row_order=True) def test_empty_dataframe_handling(self): """Test handling of empty DataFrames.""" @@ -134,9 +112,7 @@ def test_empty_dataframe_handling(self): expected_tsdf = self.get_test_function_df_builder("expected").as_tsdf() joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=True + left_prefix="", right_prefix="right", skipNulls=True ) result_df, result_schema = joiner(left_tsdf, right_tsdf) @@ -156,11 +132,7 @@ def test_null_lead_regression(self): expected_tsdf = self.get_test_function_df_builder("expected").as_tsdf() # Create and execute broadcast join - joiner = BroadcastAsOfJoiner( - self.spark, - left_prefix="", - right_prefix="right" - ) + joiner = BroadcastAsOfJoiner(self.spark, left_prefix="", right_prefix="right") # Execute join - should not fail with NULL lead result_df, result_schema = joiner(left_tsdf, right_tsdf) @@ -169,11 +141,7 @@ def test_null_lead_regression(self): self.assertEqual(result_df.count(), left_tsdf.df.count()) # Compare with expected results - self.assertDataFrameEquality( - result_df, - expected_tsdf.df, - ignore_row_order=True - ) + self.assertDataFrameEquality(result_df, expected_tsdf.df, ignore_row_order=True) def test_strategy_consistency(self): """Test that different strategies produce consistent results for the same data.""" @@ -182,25 +150,21 @@ def test_strategy_consistency(self): # Test broadcast join broadcast_joiner = BroadcastAsOfJoiner( - self.spark, - left_prefix="", - right_prefix="right" + self.spark, left_prefix="", right_prefix="right" + ) + broadcast_result_df, broadcast_result_schema = broadcast_joiner( + left_tsdf, right_tsdf ) - broadcast_result_df, broadcast_result_schema = broadcast_joiner(left_tsdf, right_tsdf) # Test union-sort-filter join union_joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=True + left_prefix="", right_prefix="right", skipNulls=True ) union_result_df, union_result_schema = union_joiner(left_tsdf, right_tsdf) # Results should be identical self.assertDataFrameEquality( - broadcast_result_df, - union_result_df, - ignore_row_order=True + broadcast_result_df, union_result_df, ignore_row_order=True ) def test_automatic_strategy_selection(self): @@ -210,11 +174,7 @@ def test_automatic_strategy_selection(self): right_tsdf = self.get_test_function_df_builder("right").as_tsdf() # Test automatic strategy selection (should potentially select broadcast for small data) - strategy = choose_as_of_join_strategy( - left_tsdf, - right_tsdf, - self.spark - ) + strategy = choose_as_of_join_strategy(left_tsdf, right_tsdf, self.spark) # Execute join result_df, result_schema = strategy(left_tsdf, right_tsdf) @@ -225,10 +185,7 @@ def test_automatic_strategy_selection(self): # Test with tsPartitionVal (should select SkewAsOfJoiner) strategy = choose_as_of_join_strategy( - left_tsdf, - right_tsdf, - self.spark, - tsPartitionVal=300 # 5 minutes + left_tsdf, right_tsdf, self.spark, tsPartitionVal=300 # 5 minutes ) - self.assertIsInstance(strategy, SkewAsOfJoiner) \ No newline at end of file + self.assertIsInstance(strategy, SkewAsOfJoiner) diff --git a/python/tests/joins/test_strategy_consistency.py b/python/tests/joins/test_strategy_consistency.py index 1e0e76a3..b2d7048a 100644 --- a/python/tests/joins/test_strategy_consistency.py +++ b/python/tests/joins/test_strategy_consistency.py @@ -23,8 +23,7 @@ def spark(): """Create a Spark session for testing.""" spark = ( - SparkSession.builder - .appName("StrategyConsistencyTests") + SparkSession.builder.appName("StrategyConsistencyTests") .config("spark.sql.shuffle.partitions", "4") .config("spark.default.parallelism", "4") .getOrCreate() @@ -51,12 +50,16 @@ def create_test_data(self, spark, num_left=100, num_right=20): for symbol in ["A", "B", "C"]: for i in range(num_left // 3): left_data.append( - (symbol, base_time + timedelta(minutes=i), f"{symbol}_val_{i}", float(i)) + ( + symbol, + base_time + timedelta(minutes=i), + f"{symbol}_val_{i}", + float(i), + ) ) left_df = spark.createDataFrame( - left_data, - ["symbol", "timestamp", "left_value", "left_metric"] + left_data, ["symbol", "timestamp", "left_value", "left_metric"] ) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) @@ -67,12 +70,16 @@ def create_test_data(self, spark, num_left=100, num_right=20): # Irregular intervals to test as-of logic offset = i * 3.5 right_data.append( - (symbol, base_time + timedelta(minutes=offset), float(100 + i), f"status_{i}") + ( + symbol, + base_time + timedelta(minutes=offset), + float(100 + i), + f"status_{i}", + ) ) right_df = spark.createDataFrame( - right_data, - ["symbol", "timestamp", "price", "status"] + right_data, ["symbol", "timestamp", "price", "status"] ) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -84,15 +91,11 @@ def test_broadcast_vs_union_consistency(self, spark): # Create joiners with identical parameters broadcast_joiner = BroadcastAsOfJoiner( - spark=spark, - left_prefix="", - right_prefix="right" + spark=spark, left_prefix="", right_prefix="right" ) union_joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=True + left_prefix="", right_prefix="right", skipNulls=True ) # Execute joins @@ -100,7 +103,9 @@ def test_broadcast_vs_union_consistency(self, spark): union_result, _ = union_joiner(left_tsdf, right_tsdf) # Sort for comparison - handle different column naming - ts_col = "timestamp" if "timestamp" in broadcast_result.columns else "left_timestamp" + ts_col = ( + "timestamp" if "timestamp" in broadcast_result.columns else "left_timestamp" + ) broadcast_sorted = broadcast_result.orderBy("symbol", ts_col) union_sorted = union_result.orderBy("symbol", ts_col) @@ -108,7 +113,9 @@ def test_broadcast_vs_union_consistency(self, spark): assert broadcast_sorted.count() == union_sorted.count() # Compare actual data values (not column order) - broadcast_data = broadcast_sorted.select(sorted(broadcast_sorted.columns)).collect() + broadcast_data = broadcast_sorted.select( + sorted(broadcast_sorted.columns) + ).collect() union_data = union_sorted.select(sorted(union_sorted.columns)).collect() assert broadcast_data == union_data @@ -122,13 +129,11 @@ def test_skew_vs_union_consistency(self, spark): left_prefix="", right_prefix="right", skipNulls=True, - skew_threshold=1.0 # High threshold to avoid triggering skew handling + skew_threshold=1.0, # High threshold to avoid triggering skew handling ) union_joiner = UnionSortFilterAsOfJoiner( - left_prefix="", - right_prefix="right", - skipNulls=True + left_prefix="", right_prefix="right", skipNulls=True ) # Execute joins @@ -156,14 +161,16 @@ def test_strategies_with_nulls(self, spark): # Add nulls to right data right_with_nulls = right_tsdf.df.withColumn( "price", - F.when(F.col("symbol") == "B", F.lit(None)).otherwise(F.col("price")) + F.when(F.col("symbol") == "B", F.lit(None)).otherwise(F.col("price")), + ) + right_tsdf_nulls = TSDF( + right_with_nulls, ts_col="timestamp", series_ids=["symbol"] ) - right_tsdf_nulls = TSDF(right_with_nulls, ts_col="timestamp", series_ids=["symbol"]) # Test only strategies that support skipNulls strategies = [ UnionSortFilterAsOfJoiner("", "right", skipNulls=True), - SkewAsOfJoiner(spark, "", "right", skipNulls=True, skew_threshold=1.0) + SkewAsOfJoiner(spark, "", "right", skipNulls=True, skew_threshold=1.0), ] for strategy in strategies: @@ -171,8 +178,9 @@ def test_strategies_with_nulls(self, spark): strategy_name = strategy.__class__.__name__ # Verify LEFT JOIN semantics - all left rows preserved - assert result.count() == left_tsdf.df.count(), \ - f"{strategy_name} did not preserve all left rows with nulls" + assert ( + result.count() == left_tsdf.df.count() + ), f"{strategy_name} did not preserve all left rows with nulls" def test_tolerance_consistency(self, spark): """Test strategies that support tolerance produce consistent results.""" @@ -183,7 +191,14 @@ def test_tolerance_consistency(self, spark): # Only test strategies that support tolerance strategies = [ UnionSortFilterAsOfJoiner("", "right", skipNulls=True, tolerance=tolerance), - SkewAsOfJoiner(spark, "", "right", skipNulls=True, tolerance=tolerance, skew_threshold=1.0) + SkewAsOfJoiner( + spark, + "", + "right", + skipNulls=True, + tolerance=tolerance, + skew_threshold=1.0, + ), ] for strategy in strategies: @@ -191,8 +206,9 @@ def test_tolerance_consistency(self, spark): strategy_name = strategy.__class__.__name__ # Verify LEFT JOIN semantics are preserved - assert result.count() == left_tsdf.df.count(), \ - f"{strategy_name} did not preserve all left rows with tolerance" + assert ( + result.count() == left_tsdf.df.count() + ), f"{strategy_name} did not preserve all left rows with tolerance" # Verify tolerance is applied (some joins may be filtered out) # Can't directly compare results as column naming differs @@ -203,20 +219,29 @@ def test_all_strategies_empty_right(self, spark): left_tsdf, _ = self.create_test_data(spark, 50, 10) # Create empty right DataFrame with proper schema - from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DoubleType - right_schema = StructType([ - StructField("symbol", StringType(), True), - StructField("timestamp", TimestampType(), True), - StructField("price", DoubleType(), True), - StructField("status", StringType(), True) - ]) + from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + DoubleType, + ) + + right_schema = StructType( + [ + StructField("symbol", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("price", DoubleType(), True), + StructField("status", StringType(), True), + ] + ) right_df = spark.createDataFrame([], right_schema) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) strategies = [ BroadcastAsOfJoiner(spark, "", "right"), UnionSortFilterAsOfJoiner("", "right"), - SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0) + SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0), ] results = [] @@ -232,8 +257,8 @@ def test_all_strategies_empty_right(self, spark): # Check that right-side columns exist and are all null # Need to handle different column naming conventions - price_cols = [c for c in result.columns if 'price' in c.lower()] - status_cols = [c for c in result.columns if 'status' in c.lower()] + price_cols = [c for c in result.columns if "price" in c.lower()] + status_cols = [c for c in result.columns if "status" in c.lower()] assert len(price_cols) > 0, "No price column found in result" assert len(status_cols) > 0, "No status column found in result" @@ -252,14 +277,16 @@ def test_all_strategies_single_series(self, spark): left_df = spark.createDataFrame(left_data, ["timestamp", "value"]) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=[]) - right_data = [(base_time + timedelta(minutes=i*5), float(i*100)) for i in range(10)] + right_data = [ + (base_time + timedelta(minutes=i * 5), float(i * 100)) for i in range(10) + ] right_df = spark.createDataFrame(right_data, ["timestamp", "price"]) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=[]) strategies = [ BroadcastAsOfJoiner(spark, "", "right"), UnionSortFilterAsOfJoiner("", "right"), - SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0) + SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0), ] results = [] @@ -272,8 +299,9 @@ def test_all_strategies_single_series(self, spark): for i in range(1, len(results)): result0_data = results[0].select(sorted(results[0].columns)).collect() resulti_data = results[i].select(sorted(results[i].columns)).collect() - assert result0_data == resulti_data, \ - f"Strategy {i} differs for single series" + assert ( + result0_data == resulti_data + ), f"Strategy {i} differs for single series" def test_join_semantics_left_preservation(self, spark): """Verify all strategies preserve all left rows (LEFT JOIN semantics).""" @@ -282,7 +310,7 @@ def test_join_semantics_left_preservation(self, spark): strategies = [ BroadcastAsOfJoiner(spark, "", "right"), UnionSortFilterAsOfJoiner("", "right"), - SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0) + SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0), ] left_count = left_tsdf.df.count() @@ -292,13 +320,13 @@ def test_join_semantics_left_preservation(self, spark): strategy_name = strategy.__class__.__name__ # Verify all left rows preserved - assert result.count() == left_count, \ - f"{strategy_name} did not preserve all left rows" + assert ( + result.count() == left_count + ), f"{strategy_name} did not preserve all left rows" # Verify left columns are never null left_nulls = result.filter(F.col("left_value").isNull()).count() - assert left_nulls == 0, \ - f"{strategy_name} has null left columns" + assert left_nulls == 0, f"{strategy_name} has null left columns" def test_temporal_ordering_consistency(self, spark): """Verify all strategies respect temporal ordering in as-of joins.""" @@ -324,7 +352,7 @@ def test_temporal_ordering_consistency(self, spark): strategies = [ BroadcastAsOfJoiner(spark, "", "right"), UnionSortFilterAsOfJoiner("", "right"), - SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0) + SkewAsOfJoiner(spark, "", "right", skew_threshold=1.0), ] expected_prices = [100.0, 200.0, 300.0] @@ -341,9 +369,14 @@ def test_temporal_ordering_consistency(self, spark): price_col = "right_price" if "right_price" in result.columns else "price" for i, row in enumerate(rows): - price_val = row[price_col] if hasattr(row, price_col) else row.asDict().get(price_col) - assert price_val == expected_prices[i], \ - f"{strategy_name} incorrect temporal ordering at row {i}" + price_val = ( + row[price_col] + if hasattr(row, price_col) + else row.asDict().get(price_col) + ) + assert ( + price_val == expected_prices[i] + ), f"{strategy_name} incorrect temporal ordering at row {i}" def test_prefix_handling_consistency(self, spark): """Test all strategies handle column prefixes consistently.""" @@ -361,19 +394,22 @@ def test_prefix_handling_consistency(self, spark): strategies = [ BroadcastAsOfJoiner(spark, left_prefix, right_prefix), UnionSortFilterAsOfJoiner(left_prefix, right_prefix), - SkewAsOfJoiner(spark, left_prefix, right_prefix, skew_threshold=1.0) + SkewAsOfJoiner(spark, left_prefix, right_prefix, skew_threshold=1.0), ] results = [] for strategy in strategies: result, _ = strategy(left_tsdf, right_tsdf) - ts_col = "timestamp" if "timestamp" in result.columns else "left_timestamp" + ts_col = ( + "timestamp" if "timestamp" in result.columns else "left_timestamp" + ) results.append(result.orderBy("symbol", ts_col)) # Check column names are consistent for i in range(1, len(results)): - assert sorted(results[0].columns) == sorted(results[i].columns), \ - f"Column names differ with prefixes ({left_prefix}, {right_prefix})" + assert sorted(results[0].columns) == sorted( + results[i].columns + ), f"Column names differ with prefixes ({left_prefix}, {right_prefix})" def test_no_double_prefixing(self, spark): """Verify that column prefixes are never applied twice.""" @@ -382,7 +418,7 @@ def test_no_double_prefixing(self, spark): strategies = [ BroadcastAsOfJoiner(spark, "left", "right"), UnionSortFilterAsOfJoiner("left", "right"), - SkewAsOfJoiner(spark, "left", "right", skew_threshold=1.0) + SkewAsOfJoiner(spark, "left", "right", skew_threshold=1.0), ] for strategy in strategies: @@ -391,10 +427,12 @@ def test_no_double_prefixing(self, spark): # Check no column has double prefix like "left_left_" or "right_right_" for col in result.columns: - assert not col.startswith("left_left_"), \ - f"{strategy_name} has double-prefixed column: {col}" - assert not col.startswith("right_right_"), \ - f"{strategy_name} has double-prefixed column: {col}" + assert not col.startswith( + "left_left_" + ), f"{strategy_name} has double-prefixed column: {col}" + assert not col.startswith( + "right_right_" + ), f"{strategy_name} has double-prefixed column: {col}" def test_overlapping_columns_empty_prefix(self, spark): """Test handling of overlapping non-timestamp columns with empty prefixes.""" @@ -405,7 +443,9 @@ def test_overlapping_columns_empty_prefix(self, spark): left_df = spark.createDataFrame(left_data, ["timestamp", "symbol", "value"]) left_tsdf = TSDF(left_df, ts_col="timestamp", series_ids=["symbol"]) - right_data = [(base_time + timedelta(minutes=i*2), "A", i*100) for i in range(5)] + right_data = [ + (base_time + timedelta(minutes=i * 2), "A", i * 100) for i in range(5) + ] right_df = spark.createDataFrame(right_data, ["timestamp", "symbol", "value"]) right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) @@ -414,8 +454,9 @@ def test_overlapping_columns_empty_prefix(self, spark): result, _ = strategy(left_tsdf, right_tsdf) # Should have left value, not duplicate "value" columns - assert result.columns.count("value") == 1, \ - "Should have only one 'value' column when prefixes are empty" + assert ( + result.columns.count("value") == 1 + ), "Should have only one 'value' column when prefixes are empty" def test_tolerance_with_prefixes(self, spark): """Test tolerance filtering works correctly with various prefix combinations.""" @@ -430,12 +471,17 @@ def test_tolerance_with_prefixes(self, spark): for left_prefix, right_prefix in prefix_tests: strategy = SkewAsOfJoiner( - spark, left_prefix, right_prefix, - skipNulls=True, tolerance=tolerance, skew_threshold=1.0 + spark, + left_prefix, + right_prefix, + skipNulls=True, + tolerance=tolerance, + skew_threshold=1.0, ) result, _ = strategy(left_tsdf, right_tsdf) # Should not error and should preserve left rows - assert result.count() == left_tsdf.df.count(), \ - f"Failed with prefixes ({left_prefix}, {right_prefix})" \ No newline at end of file + assert ( + result.count() == left_tsdf.df.count() + ), f"Failed with prefixes ({left_prefix}, {right_prefix})" diff --git a/python/tests/joins/timezone_regression_tests.py b/python/tests/joins/timezone_regression_tests.py index a99d2f8d..9ee06b3b 100644 --- a/python/tests/joins/timezone_regression_tests.py +++ b/python/tests/joins/timezone_regression_tests.py @@ -10,7 +10,13 @@ import pytz from pyspark.sql import functions as F -from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + DoubleType, + TimestampType, +) from tests.base import SparkTest from tempo import TSDF @@ -93,12 +99,12 @@ def test_nanosecond_timestamp_timezone_handling(self): # This simulates the fromStringTimestamp behavior left_df = left_df.withColumn( "timestamp", - F.to_timestamp("timestamp_str", "yyyy-MM-dd HH:mm:ss.SSSSSSSSS") + F.to_timestamp("timestamp_str", "yyyy-MM-dd HH:mm:ss.SSSSSSSSS"), ) right_df = right_df.withColumn( "timestamp", - F.to_timestamp("timestamp_str", "yyyy-MM-dd HH:mm:ss.SSSSSSSSS") + F.to_timestamp("timestamp_str", "yyyy-MM-dd HH:mm:ss.SSSSSSSSS"), ) # Create TSDFs @@ -124,19 +130,18 @@ def test_composite_index_timezone_consistency(self): ("S1", "2022-01-01T11:00:00.123456789Z", 101.0), ] - schema = StructType([ - StructField("symbol", StringType(), True), - StructField("timestamp_str", StringType(), True), - StructField("value", DoubleType(), True) - ]) + schema = StructType( + [ + StructField("symbol", StringType(), True), + StructField("timestamp_str", StringType(), True), + StructField("value", DoubleType(), True), + ] + ) left_df = self.spark.createDataFrame(left_data, schema) # Parse the ISO timestamp string - left_df = left_df.withColumn( - "timestamp", - F.to_timestamp("timestamp_str") - ) + left_df = left_df.withColumn("timestamp", F.to_timestamp("timestamp_str")) right_data = [ ("S1", "2022-01-01T09:30:00.123456789Z", 99.5), @@ -144,10 +149,7 @@ def test_composite_index_timezone_consistency(self): ] right_df = self.spark.createDataFrame(right_data, schema) - right_df = right_df.withColumn( - "timestamp", - F.to_timestamp("timestamp_str") - ) + right_df = right_df.withColumn("timestamp", F.to_timestamp("timestamp_str")) # Create TSDFs with nanosecond precision # This would create composite indexes internally @@ -172,7 +174,7 @@ def test_composite_index_timezone_consistency(self): def test_different_timezone_conversion(self): """Test joining data from different timezones.""" # Create left data in US/Eastern timezone - eastern = pytz.timezone('US/Eastern') + eastern = pytz.timezone("US/Eastern") left_data = [ ("S1", eastern.localize(datetime(2022, 1, 1, 10, 0, 0)), 100.0), ("S1", eastern.localize(datetime(2022, 1, 1, 11, 0, 0)), 101.0), @@ -182,7 +184,7 @@ def test_different_timezone_conversion(self): ) # Create right data in US/Pacific timezone - pacific = pytz.timezone('US/Pacific') + pacific = pytz.timezone("US/Pacific") right_data = [ # These times are actually simultaneous with left times when converted to UTC ("S1", pacific.localize(datetime(2022, 1, 1, 7, 0, 0)), 99.5), @@ -230,10 +232,16 @@ def test_null_timezone_handling(self): ) # Create TSDFs - should handle nulls gracefully - left_tsdf = TSDF(left_df.filter(F.col("timestamp").isNotNull()), - ts_col="timestamp", series_ids=["symbol"]) - right_tsdf = TSDF(right_df.filter(F.col("timestamp").isNotNull()), - ts_col="timestamp", series_ids=["symbol"]) + left_tsdf = TSDF( + left_df.filter(F.col("timestamp").isNotNull()), + ts_col="timestamp", + series_ids=["symbol"], + ) + right_tsdf = TSDF( + right_df.filter(F.col("timestamp").isNotNull()), + ts_col="timestamp", + series_ids=["symbol"], + ) # Perform join joiner = BroadcastAsOfJoiner(self.spark) @@ -246,7 +254,7 @@ def test_dst_transition_handling(self): """Test handling of daylight saving time transitions.""" # Create data around DST transition (Spring forward in US/Eastern) # March 13, 2022 at 2:00 AM -> 3:00 AM - eastern = pytz.timezone('US/Eastern') + eastern = pytz.timezone("US/Eastern") left_data = [ # Before DST @@ -293,5 +301,5 @@ def test_dst_transition_handling(self): self.assertEqual( b_row["right_timestamp"], u_row["right_timestamp"], - f"Mismatch in DST handling between strategies" - ) \ No newline at end of file + f"Mismatch in DST handling between strategies", + ) diff --git a/python/tests/joins/tsdf_asof_join_tests.py b/python/tests/joins/tsdf_asof_join_tests.py index 7e2e90c4..d04df66e 100644 --- a/python/tests/joins/tsdf_asof_join_tests.py +++ b/python/tests/joins/tsdf_asof_join_tests.py @@ -13,7 +13,13 @@ from datetime import datetime, timedelta import pyspark.sql.functions as F -from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DoubleType +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + DoubleType, +) from tests.base import SparkTest from tempo.tsdf import TSDF @@ -71,9 +77,7 @@ def test_asof_join_default_strategy(self): def test_asof_join_manual_broadcast_strategy(self): """Test asofJoin with manual broadcast strategy selection.""" result = self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="broadcast", - right_prefix="quote" + self.right_tsdf, strategy="broadcast", right_prefix="quote" ) # Verify results @@ -83,9 +87,7 @@ def test_asof_join_manual_broadcast_strategy(self): def test_asof_join_manual_union_strategy(self): """Test asofJoin with manual union strategy selection.""" result = self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="union", - right_prefix="quote" + self.right_tsdf, strategy="union", right_prefix="quote" ) # Verify results @@ -95,10 +97,7 @@ def test_asof_join_manual_union_strategy(self): def test_asof_join_manual_skew_strategy(self): """Test asofJoin with manual skew strategy selection.""" result = self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="skew", - right_prefix="quote", - tsPartitionVal=300 + self.right_tsdf, strategy="skew", right_prefix="quote", tsPartitionVal=300 ) # Verify results @@ -108,10 +107,7 @@ def test_asof_join_manual_skew_strategy(self): def test_asof_join_invalid_strategy(self): """Test asofJoin with invalid strategy raises ValueError.""" with self.assertRaises(ValueError) as cm: - self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="invalid" - ) + self.left_tsdf.asofJoin(self.right_tsdf, strategy="invalid") self.assertIn("Unknown strategy", str(cm.exception)) @@ -137,11 +133,7 @@ def test_asof_join_with_tolerance(self): right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["symbol"]) # Join with tolerance=300 seconds (5 minutes) - result = left_tsdf.asofJoin( - right_tsdf, - tolerance=300, - right_prefix="quote" - ) + result = left_tsdf.asofJoin(right_tsdf, tolerance=300, right_prefix="quote") # FIXME: Tolerance implementation issue - all rows are matching despite being beyond tolerance # Expected: First row (t=0) should match, rows at t>=10min should NOT match (beyond 5min tolerance) @@ -168,9 +160,7 @@ def test_asof_join_with_skip_nulls(self): # Test with skipNulls=True (default) result_skip = self.left_tsdf.asofJoin( - right_tsdf, - skipNulls=True, - right_prefix="quote" + right_tsdf, skipNulls=True, right_prefix="quote" ) # Rows between t=2 and t=4 should skip the NULL and get t=0 price @@ -180,9 +170,7 @@ def test_asof_join_with_skip_nulls(self): def test_asof_join_with_prefixes(self): """Test asofJoin with custom prefixes.""" result = self.left_tsdf.asofJoin( - self.right_tsdf, - left_prefix="trade", - right_prefix="quote" + self.right_tsdf, left_prefix="trade", right_prefix="quote" ) # Timestamp is overlapping, so both should be prefixed @@ -194,20 +182,28 @@ def test_asof_join_with_prefixes(self): def test_asof_join_empty_right_dataframe(self): """Test asofJoin when right DataFrame is empty.""" # Create empty right DataFrame with explicit schema - from pyspark.sql.types import StructType, StructField, StringType, TimestampType, DoubleType - empty_schema = StructType([ - StructField("symbol", StringType(), True), - StructField("timestamp", TimestampType(), True), - StructField("price", DoubleType(), True), - ]) - empty_right_df = self.spark.createDataFrame([], schema=empty_schema) - empty_right_tsdf = TSDF(empty_right_df, ts_col="timestamp", series_ids=["symbol"]) + from pyspark.sql.types import ( + StructType, + StructField, + StringType, + TimestampType, + DoubleType, + ) - result = self.left_tsdf.asofJoin( - empty_right_tsdf, - right_prefix="quote" + empty_schema = StructType( + [ + StructField("symbol", StringType(), True), + StructField("timestamp", TimestampType(), True), + StructField("price", DoubleType(), True), + ] + ) + empty_right_df = self.spark.createDataFrame([], schema=empty_schema) + empty_right_tsdf = TSDF( + empty_right_df, ts_col="timestamp", series_ids=["symbol"] ) + result = self.left_tsdf.asofJoin(empty_right_tsdf, right_prefix="quote") + # Should preserve all left rows with NULL right values self.assertEqual(result.df.count(), 10) # With empty right DataFrame, price column is prefixed as quote_price @@ -220,46 +216,38 @@ def test_asof_join_with_partition_val_selects_skew(self): result = self.left_tsdf.asofJoin( self.right_tsdf, tsPartitionVal=300, # Should trigger SkewAsOfJoiner - right_prefix="quote" + right_prefix="quote", ) self.assertEqual(result.df.count(), 10) - @patch('tempo.joins.strategies.get_bytes_from_plan') + @patch("tempo.joins.strategies.get_bytes_from_plan") def test_asof_join_automatic_broadcast_selection(self, mock_get_bytes): """Test that small data automatically selects BroadcastAsOfJoiner.""" # Mock size estimation to return small sizes (< 30MB) mock_get_bytes.return_value = 10 * 1024 * 1024 # 10MB # Without strategy parameter, should auto-select - result = self.left_tsdf.asofJoin( - self.right_tsdf, - right_prefix="quote" - ) + result = self.left_tsdf.asofJoin(self.right_tsdf, right_prefix="quote") self.assertEqual(result.df.count(), 10) # Verify get_bytes_from_plan was called for size estimation self.assertGreater(mock_get_bytes.call_count, 0) - @patch('tempo.joins.strategies.get_bytes_from_plan') + @patch("tempo.joins.strategies.get_bytes_from_plan") def test_asof_join_automatic_union_selection(self, mock_get_bytes): """Test that large data automatically selects UnionSortFilterAsOfJoiner.""" # Mock size estimation to return large sizes (> 30MB) mock_get_bytes.return_value = 100 * 1024 * 1024 # 100MB - result = self.left_tsdf.asofJoin( - self.right_tsdf, - right_prefix="quote" - ) + result = self.left_tsdf.asofJoin(self.right_tsdf, right_prefix="quote") self.assertEqual(result.df.count(), 10) def test_asof_join_preserves_schema(self): """Test that asofJoin preserves TSDF schema correctly.""" result = self.left_tsdf.asofJoin( - self.right_tsdf, - left_prefix="", - right_prefix="" + self.right_tsdf, left_prefix="", right_prefix="" ) # Schema uses prefixed timestamp column even with empty prefix args @@ -295,13 +283,12 @@ def test_asof_join_multiple_series_ids(self): right_df = self.spark.createDataFrame( right_data, ["exchange", "symbol", "timestamp", "price"] ) - right_tsdf = TSDF(right_df, ts_col="timestamp", series_ids=["exchange", "symbol"]) - - result = left_tsdf.asofJoin( - right_tsdf, - right_prefix="quote" + right_tsdf = TSDF( + right_df, ts_col="timestamp", series_ids=["exchange", "symbol"] ) + result = left_tsdf.asofJoin(right_tsdf, right_prefix="quote") + # Verify multi-key join worked self.assertEqual(result.df.count(), 10) self.assertEqual(result.series_ids, ["exchange", "symbol"]) @@ -339,23 +326,17 @@ def test_strategy_consistency_broadcast_vs_union(self): """Test that broadcast and union strategies produce identical results.""" # Execute with broadcast broadcast_result = self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="broadcast", - right_prefix="quote" + self.right_tsdf, strategy="broadcast", right_prefix="quote" ) # Execute with union union_result = self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="union", - right_prefix="quote" + self.right_tsdf, strategy="union", right_prefix="quote" ) # Results should be identical self.assertDataFrameEquality( - broadcast_result.df, - union_result.df, - ignore_row_order=True + broadcast_result.df, union_result.df, ignore_row_order=True ) def test_strategy_consistency_with_tolerance(self): @@ -364,21 +345,16 @@ def test_strategy_consistency_with_tolerance(self): self.right_tsdf, strategy="broadcast", tolerance=300, # 5 minutes - right_prefix="quote" + right_prefix="quote", ) union_result = self.left_tsdf.asofJoin( - self.right_tsdf, - strategy="union", - tolerance=300, - right_prefix="quote" + self.right_tsdf, strategy="union", tolerance=300, right_prefix="quote" ) # Results should be identical self.assertDataFrameEquality( - broadcast_result.df, - union_result.df, - ignore_row_order=True + broadcast_result.df, union_result.df, ignore_row_order=True ) diff --git a/python/tests/ml_tests.py b/python/tests/ml_tests.py index dd65711a..26a7ddcd 100644 --- a/python/tests/ml_tests.py +++ b/python/tests/ml_tests.py @@ -31,14 +31,14 @@ def test_empty_constructor(self): def test_estim_eval_constructor(self): # set up estimator and evaluator estimator = GBTRegressor(labelCol="close", featuresCol="features") - evaluator = RegressionEvaluator(labelCol="close", - predictionCol="prediction", - metricName="rmse") + evaluator = RegressionEvaluator( + labelCol="close", predictionCol="prediction", metricName="rmse" + ) parm_grid = ParamGridBuilder().build() # construct with default parameters - tscv = TimeSeriesCrossValidator(estimator=estimator, - evaluator=evaluator, - estimatorParamMaps=parm_grid) + tscv = TimeSeriesCrossValidator( + estimator=estimator, evaluator=evaluator, estimatorParamMaps=parm_grid + ) # test the parameters self.assertEqual(tscv.getEstimator(), estimator) self.assertEqual(tscv.getEvaluator(), evaluator) @@ -88,9 +88,9 @@ def test_estimator_param(self): def test_evaluator_param(self): # set up estimator and evaluator - evaluator = RegressionEvaluator(labelCol="close", - predictionCol="prediction", - metricName="rmse") + evaluator = RegressionEvaluator( + labelCol="close", predictionCol="prediction", metricName="rmse" + ) # construct with default parameters tscv = TimeSeriesCrossValidator() # set the evaluator @@ -136,9 +136,9 @@ def test_kfolds(self): # load test data trades_df = self.get_test_df_builder("trades").as_sdf() # construct with default parameters - tscv = TimeSeriesCrossValidator(timeSeriesCol='event_ts', - seriesIdCols=['symbol'], - gap=0) + tscv = TimeSeriesCrossValidator( + timeSeriesCol="event_ts", seriesIdCols=["symbol"], gap=0 + ) # test the k-folds k_folds = tscv._kFold(trades_df) # check the number of folds diff --git a/python/tests/resample_tests.py b/python/tests/resample_tests.py index 74563a95..9016a4ca 100644 --- a/python/tests/resample_tests.py +++ b/python/tests/resample_tests.py @@ -3,6 +3,7 @@ import pyspark.sql.functions as sfn from tempo import TSDF +from tempo.resample_result import ResampledTSDF from tempo.resample import _appendAggKey, aggregate, resample from tempo.resample_utils import checkAllowableFreq, validateFuncExists from tempo.stats import calc_bars @@ -23,9 +24,9 @@ def test_resample(self): # 1 minute aggregation featured_df = resample(tsdf_input, freq="min", func="floor", prefix="floor").df # 30 minute aggregation - resample_30m = resample(tsdf_input, freq="5 minutes", func="mean").df.withColumn( - "trade_pr", sfn.round(sfn.col("trade_pr"), 2) - ) + resample_30m = resample( + tsdf_input, freq="5 minutes", func="mean" + ).df.withColumn("trade_pr", sfn.round(sfn.col("trade_pr"), 2)) bars = calc_bars( tsdf_input, freq="min", metric_cols=["trade_pr", "trade_pr_2"] @@ -238,6 +239,183 @@ def test_validate_func_exists_type_error(self): def test_validate_func_exists_value_error(self): self.assertRaises(ValueError, validateFuncExists, "non-existent") + def test_resample_returns_resampled_tsdf(self): + """Verify resample() returns ResampledTSDF, and as_tsdf() returns TSDF""" + # Reuse existing test_resample's input_data + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + + result = resample(tsdf_input, freq="min", func="floor") + + self.assertIsInstance(result, ResampledTSDF) + self.assertEqual(result.resample_freq, "min") + self.assertEqual(result.resample_func, "floor") + self.assertIsNotNone(result.df) + self.assertEqual(result.ts_col, tsdf_input.ts_col) + self.assertEqual(result.series_ids, tsdf_input.series_ids) + + # as_tsdf() should return a plain TSDF + plain = result.as_tsdf() + self.assertIsInstance(plain, TSDF) + self.assertNotIsInstance(plain, ResampledTSDF) + + def test_tsdf_resample_returns_resampled_tsdf(self): + """Verify TSDF.resample() also returns ResampledTSDF""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + + result = tsdf_input.resample(freq="min", func="floor") + + self.assertIsInstance(result, ResampledTSDF) + self.assertEqual(result.resample_freq, "min") + + def test_resampled_tsdf_blocks_invalid_operations(self): + """Verify that ResampledTSDF does not expose filter, withColumn, etc.""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + resampled = resample(tsdf_input, freq="min", func="floor") + + self.assertFalse(hasattr(resampled, "filter")) + self.assertFalse(hasattr(resampled, "withColumn")) + self.assertFalse(hasattr(resampled, "where")) + self.assertFalse(hasattr(resampled, "select")) + self.assertFalse(hasattr(resampled, "resample")) + + def test_resampled_tsdf_repr(self): + """Verify __repr__ returns a descriptive string""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + resampled = resample(tsdf_input, freq="min", func="floor") + + repr_str = repr(resampled) + self.assertIn("ResampledTSDF", repr_str) + self.assertIn("min", repr_str) + self.assertIn("floor", repr_str) + + def test_resampled_tsdf_properties(self): + """Verify all property accessors return correct values""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + resampled = resample(tsdf_input, freq="min", func="floor") + + self.assertEqual(resampled.ts_col, tsdf_input.ts_col) + self.assertEqual(resampled.series_ids, tsdf_input.series_ids) + self.assertEqual(resampled.ts_schema, tsdf_input.ts_schema) + self.assertEqual(resampled.columns, resampled.df.columns) + self.assertEqual(resampled.resample_freq, "min") + self.assertEqual(resampled.resample_func, "floor") + + def test_resampled_tsdf_show(self): + """Verify show() delegates without error""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + resampled = resample(tsdf_input, freq="min", func="floor") + + # Should not raise + resampled.show() + resampled.show(n=5, truncate=False) + + def test_resampled_tsdf_repr_contains_all_fields(self): + """Verify repr includes ts_col and series_ids values""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + resampled = resample(tsdf_input, freq="min", func="floor") + + repr_str = repr(resampled) + self.assertIn(f"ts_col={tsdf_input.ts_col!r}", repr_str) + self.assertIn(f"series_ids={tsdf_input.series_ids!r}", repr_str) + + def _get_interpol_resampled(self): + """Helper: build a ResampledTSDF from shared interpolation test data.""" + tsdf = self.get_test_df_builder("__SharedData", "interpol_data").as_tsdf() + return tsdf.resample(freq="30 min", func="mean") + + def test_resampled_tsdf_interpolate_zero(self): + """Exercise method='zero' interpolation path""" + resampled = self._get_interpol_resampled() + result = resampled.interpolate(method="zero") + self.assertIsInstance(result, TSDF) + self.assertGreater(result.df.count(), 0) + + def test_resampled_tsdf_interpolate_ffill(self): + """Exercise method='ffill' interpolation path""" + resampled = self._get_interpol_resampled() + result = resampled.interpolate(method="ffill") + self.assertIsInstance(result, TSDF) + self.assertGreater(result.df.count(), 0) + + def test_resampled_tsdf_interpolate_bfill(self): + """Exercise method='bfill' interpolation path""" + resampled = self._get_interpol_resampled() + result = resampled.interpolate(method="bfill") + self.assertIsInstance(result, TSDF) + self.assertGreater(result.df.count(), 0) + + def test_resampled_tsdf_interpolate_null_returns_tsdf(self): + """method='null' returns the underlying TSDF unchanged""" + resampled = self._get_interpol_resampled() + result = resampled.interpolate(method="null") + self.assertIsInstance(result, TSDF) + # null should return the underlying TSDF as-is + self.assertEqual(result.df.collect(), resampled.as_tsdf().df.collect()) + + def test_resampled_tsdf_interpolate_with_target_cols(self): + """Pass explicit target_cols list to interpolate""" + resampled = self._get_interpol_resampled() + result = resampled.interpolate(method="zero", target_cols=["value_a"]) + self.assertIsInstance(result, TSDF) + self.assertGreater(result.df.count(), 0) + + def test_resampled_tsdf_interpolate_show_interpolated_warns(self): + """show_interpolated=True logs a warning without error""" + import logging + + resampled = self._get_interpol_resampled() + with self.assertLogs("tempo.resample_result", level=logging.WARNING) as cm: + result = resampled.interpolate(method="zero", show_interpolated=True) + self.assertIsInstance(result, TSDF) + self.assertTrue( + any("show_interpolated" in msg for msg in cm.output), + f"Expected warning about show_interpolated, got: {cm.output}", + ) + + def test_resampled_tsdf_interpolate_linear(self): + """Exercise method='linear' interpolation path""" + resampled = self._get_interpol_resampled() + result = resampled.interpolate(method="linear") + self.assertIsInstance(result, TSDF) + self.assertGreater(result.df.count(), 0) + + def test_resampled_tsdf_interpolate_unknown_method_falls_through(self): + """Exercise the else branch — unknown method string is passed through as-is""" + resampled = self._get_interpol_resampled() + # An unrecognized method hits the else branch (line 113) and is forwarded + # to interpol_func which validates it, so we expect a ValueError. + with self.assertRaises(ValueError): + resampled.interpolate(method="not_a_real_method") + + def test_resampled_tsdf_as_tsdf_allows_normal_operations(self): + """Verify the TSDF from as_tsdf() supports normal DataFrame operations""" + tsdf_input = self.get_test_df_builder( + "ResampleUnitTests", "test_resample", "input_data" + ).as_tsdf() + resampled = resample(tsdf_input, freq="min", func="floor") + + plain_tsdf = resampled.as_tsdf() + # Should support standard Spark DataFrame operations + filtered = plain_tsdf.df.filter(sfn.col(plain_tsdf.ts_col).isNotNull()) + self.assertGreater(filtered.count(), 0) + + selected = plain_tsdf.df.select(plain_tsdf.ts_col) + self.assertEqual(len(selected.columns), 1) + # MAIN if __name__ == "__main__": diff --git a/python/tests/tsdf_basic_methods_tests.py b/python/tests/tsdf_basic_methods_tests.py index 2a0dea19..4c009dad 100644 --- a/python/tests/tsdf_basic_methods_tests.py +++ b/python/tests/tsdf_basic_methods_tests.py @@ -142,6 +142,5 @@ def test_with_column_type_changed(self): self.assertIn("int", value_type.lower()) - if __name__ == "__main__": unittest.main() diff --git a/python/tests/tsdf_tests.py b/python/tests/tsdf_tests.py index 39fc0935..6a4e8ce9 100644 --- a/python/tests/tsdf_tests.py +++ b/python/tests/tsdf_tests.py @@ -13,15 +13,17 @@ class TSDFBaseTests(SparkTest): - @parameterized.expand([ - ("simple_ts_idx", SimpleTimestampIndex), - ("simple_ts_no_series", SimpleTimestampIndex), - ("simple_date_idx", SimpleDateIndex), - ("ordinal_double_index", OrdinalTSIndex), - ("ordinal_int_index", OrdinalTSIndex), - ("parsed_ts_idx", ParsedTimestampIndex), - ("parsed_date_idx", ParsedDateIndex), - ]) + @parameterized.expand( + [ + ("simple_ts_idx", SimpleTimestampIndex), + ("simple_ts_no_series", SimpleTimestampIndex), + ("simple_date_idx", SimpleDateIndex), + ("ordinal_double_index", OrdinalTSIndex), + ("ordinal_int_index", OrdinalTSIndex), + ("parsed_ts_idx", ParsedTimestampIndex), + ("parsed_date_idx", ParsedDateIndex), + ] + ) def test_tsdf_constructor(self, init_tsdf_id, expected_idx_class): # create TSDF init_tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -35,60 +37,68 @@ def test_tsdf_constructor(self, init_tsdf_id, expected_idx_class): self.assertIsNotNone(init_tsdf.ts_index) self.assertIsInstance(init_tsdf.ts_index, expected_idx_class) - @parameterized.expand([ - ("simple_ts_idx", ["symbol"]), - ("simple_ts_no_series", []), - ("simple_date_idx", ["station"]), - ("ordinal_double_index", ["symbol"]), - ("ordinal_int_index", ["symbol"]), - ("parsed_ts_idx", ["symbol"]), - ("parsed_date_idx", ["station"]), - ]) + @parameterized.expand( + [ + ("simple_ts_idx", ["symbol"]), + ("simple_ts_no_series", []), + ("simple_date_idx", ["station"]), + ("ordinal_double_index", ["symbol"]), + ("ordinal_int_index", ["symbol"]), + ("parsed_ts_idx", ["symbol"]), + ("parsed_date_idx", ["station"]), + ] + ) def test_series_ids(self, init_tsdf_id, expected_series_ids): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() # validate series ids self.assertEqual(set(tsdf.series_ids), set(expected_series_ids)) - @parameterized.expand([ - ("simple_ts_idx", ["event_ts", "symbol"]), - ("simple_ts_no_series", ["event_ts"]), - ("simple_date_idx", ["date", "station"]), - ("ordinal_double_index", ["event_ts_dbl", "symbol"]), - ("ordinal_int_index", ["order", "symbol"]), - ("parsed_ts_idx", ["ts_idx", "symbol"]), - ("parsed_date_idx", ["ts_idx", "station"]), - ]) + @parameterized.expand( + [ + ("simple_ts_idx", ["event_ts", "symbol"]), + ("simple_ts_no_series", ["event_ts"]), + ("simple_date_idx", ["date", "station"]), + ("ordinal_double_index", ["event_ts_dbl", "symbol"]), + ("ordinal_int_index", ["order", "symbol"]), + ("parsed_ts_idx", ["ts_idx", "symbol"]), + ("parsed_date_idx", ["ts_idx", "station"]), + ] + ) def test_structural_cols(self, init_tsdf_id, expected_structural_cols): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() # validate structural cols self.assertEqual(set(tsdf.structural_cols), set(expected_structural_cols)) - @parameterized.expand([ - ("simple_ts_idx", ["trade_pr"]), - ("simple_ts_no_series", ["trade_pr"]), - ("simple_date_idx", ["temp"]), - ("ordinal_double_index", ["trade_pr"]), - ("ordinal_int_index", ["trade_pr"]), - ("parsed_ts_idx", ["trade_pr"]), - ("parsed_date_idx", ["temp"]), - ]) + @parameterized.expand( + [ + ("simple_ts_idx", ["trade_pr"]), + ("simple_ts_no_series", ["trade_pr"]), + ("simple_date_idx", ["temp"]), + ("ordinal_double_index", ["trade_pr"]), + ("ordinal_int_index", ["trade_pr"]), + ("parsed_ts_idx", ["trade_pr"]), + ("parsed_date_idx", ["temp"]), + ] + ) def test_obs_cols(self, init_tsdf_id, expected_obs_cols): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() # validate obs cols self.assertEqual(set(tsdf.observational_cols), set(expected_obs_cols)) - @parameterized.expand([ - ("simple_ts_idx", ["trade_pr"]), - ("simple_ts_no_series", ["trade_pr"]), - ("simple_date_idx", ["temp"]), - ("ordinal_double_index", ["trade_pr"]), - ("ordinal_int_index", ["trade_pr"]), - ("parsed_ts_idx", ["trade_pr"]), - ("parsed_date_idx", ["temp"]), - ]) + @parameterized.expand( + [ + ("simple_ts_idx", ["trade_pr"]), + ("simple_ts_no_series", ["trade_pr"]), + ("simple_date_idx", ["temp"]), + ("ordinal_double_index", ["trade_pr"]), + ("ordinal_int_index", ["trade_pr"]), + ("parsed_ts_idx", ["trade_pr"]), + ("parsed_date_idx", ["temp"]), + ] + ) def test_metric_cols(self, init_tsdf_id, expected_metric_cols): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -97,117 +107,119 @@ def test_metric_cols(self, init_tsdf_id, expected_metric_cols): class TimeSlicingTests(SparkTest): - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-09-01 00:02:10", 361.1], - ["S2", "2020-09-01 00:02:10", 761.10], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-09-01 00:19:12", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-09-01 00:19:12", 362.1], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-02", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ( - "ordinal_double_index", - 10.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 10.0, 361.1], - ["S2", 10.0, 762.33], - ], - }, - }, - ), - ( - "ordinal_int_index", - 1, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 1, 349.21], - ["S2", 1, 751.92], - ], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10.032", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-09-01 00:02:10.032", 361.1], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-04", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:19:12", 362.1], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-02", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S2", 1, 751.92], + ], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.032", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-04", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ] + ) def test_at(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -217,130 +229,132 @@ def test_at(self, init_tsdf_id, ts, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(at_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-08-01 00:00:10", 349.21], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-09-01 00:19:12", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:00:10", 349.21], - ["2020-08-01 00:01:10", 743.01], - ["2020-08-01 00:01:12", 351.32], - ["2020-08-01 00:01:24", 751.92], - ["2020-09-01 00:02:10", 361.1], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-03", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ( - "ordinal_double_index", - 10.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 0.13, 349.21], - ["S1", 1.207, 351.32], - ["S2", 0.005, 743.01], - ["S2", 0.1, 751.92], - ["S2", 1.0, 761.10], - ], - }, - }, - ), - ( - "ordinal_int_index", - 1, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [["S2", 0, 743.01]], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10.000", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-08-01 00:00:10.010", 349.21], - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S2", "2020-08-01 00:01:10.054", 743.01], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-03", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-03", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S2", 0, 743.01]], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ] + ) def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -350,139 +364,145 @@ def test_before(self, init_tsdf_id, ts, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(before_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-08-01 00:00:10", 349.21], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S1", "2020-09-01 00:02:10", 361.1], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ["S2", "2020-09-01 00:02:10", 761.10], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-09-01 00:19:12", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:00:10", 349.21], - ["2020-08-01 00:01:10", 743.01], - ["2020-08-01 00:01:12", 351.32], - ["2020-08-01 00:01:24", 751.92], - ["2020-09-01 00:02:10", 361.1], - ["2020-09-01 00:19:12", 362.1], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-03", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["LGA", "2020-08-03", 28.53], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ["YYZ", "2020-08-03", 20.62], - ], - }, - }, - ), - ( - "ordinal_double_index", - 10.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 0.13, 349.21], - ["S1", 1.207, 351.32], - ["S1", 10.0, 361.1], - ["S2", 0.005, 743.01], - ["S2", 0.1, 751.92], - ["S2", 1.0, 761.10], - ["S2", 10.0, 762.33], - ], - }, - }, - ), - ( - "ordinal_int_index", - 1, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [["S1", 1, 349.21], ["S2", 0, 743.01], ["S2", 1, 751.92]], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10.000", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-08-01 00:00:10.010", 349.21], - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S2", "2020-08-01 00:01:10.054", 743.01], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-03", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["LGA", "2020-08-03", 28.53], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ["YYZ", "2020-08-03", 20.62], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-03", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S2", 0, 743.01], + ["S2", 1, 751.92], + ], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), + ] + ) def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -492,127 +512,129 @@ def test_atOrBefore(self, init_tsdf_id, ts, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(at_before_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-09-01 00:19:12", 362.1], - ["S2", "2020-09-01 00:20:42", 762.33], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-09-01 00:08:12", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-09-01 00:19:12", 362.1], - ["2020-09-01 00:20:42", 762.33], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-02", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ( - "ordinal_double_index", - 1.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 1.207, 351.32], - ["S1", 10.0, 361.1], - ["S1", 24.357, 362.1], - ["S2", 10.0, 762.33], - ], - }, - }, - ), - ( - "ordinal_int_index", - 10, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 20, 351.32], - ["S1", 127, 361.1], - ["S1", 243, 362.1], - ["S2", 100, 762.33], - ], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10.000", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-09-01 00:02:10.032", 361.1], - ["S1", "2020-09-01 00:19:12.043", 362.1], - ["S2", "2020-09-01 00:02:10.076", 761.10], - ["S2", "2020-09-01 00:20:42.087", 762.33], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-03", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-09-01 00:20:42", 762.33], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:08:12", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:20:42", 762.33], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-02", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ( + "ordinal_double_index", + 1.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 10, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S1", 243, 362.1], + ["S2", 100, 762.33], + ], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ] + ) def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -622,133 +644,135 @@ def test_after(self, init_tsdf_id, ts, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(after_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-09-01 00:02:10", 361.1], - ["S1", "2020-09-01 00:19:12", 362.1], - ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-09-01 00:20:42", 762.33], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-08-01 00:01:24", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:01:24", 751.92], - ["2020-09-01 00:02:10", 361.1], - ["2020-09-01 00:19:12", 362.1], - ["2020-09-01 00:20:42", 762.33], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-03", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ( - "ordinal_double_index", - 10.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 10.0, 361.1], - ["S1", 24.357, 362.1], - ["S2", 10.0, 762.33], - ], - }, - }, - ), - ( - "ordinal_int_index", - 10, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 20, 351.32], - ["S1", 127, 361.1], - ["S1", 243, 362.1], - ["S2", 10, 761.10], - ["S2", 100, 762.33], - ], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10.000", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-09-01 00:02:10.032", 361.1], - ["S1", "2020-09-01 00:19:12.043", 362.1], - ["S2", "2020-09-01 00:02:10.076", 761.10], - ["S2", "2020-09-01 00:20:42.087", 762.33], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-03", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-09-01 00:19:12", 362.1], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-09-01 00:20:42", 762.33], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:24", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:20:42", 762.33], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-03", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 10, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S1", 243, 362.1], + ["S2", 10, 761.10], + ["S2", 100, 762.33], + ], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ] + ) def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -758,126 +782,128 @@ def test_atOrAfter(self, init_tsdf_id, ts, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(at_after_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-08-01 00:01:10", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-08-01 00:01:12", 351.32], - ["S2", "2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-08-01 00:01:10", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:01:12", 351.32], - ["2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-01", - "2020-08-03", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ( - "ordinal_double_index", - 0.1, - 10.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 0.13, 349.21], - ["S1", 1.207, 351.32], - ["S2", 1.0, 761.10], - ], - }, - }, - ), - ( - "ordinal_int_index", - 1, - 100, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [["S1", 20, 351.32], ["S2", 10, 761.10]], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-08-01 00:00:10.010", - "2020-09-01 00:02:10.076", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S1", "2020-09-01 00:02:10.032", 361.1], - ["S2", "2020-08-01 00:01:10.054", 743.01], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-01", - "2020-08-03", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-01", + "2020-08-03", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ( + "ordinal_double_index", + 0.1, + 10.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S2", 1.0, 761.10], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + 100, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S1", 20, 351.32], ["S2", 10, 761.10]], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-08-01 00:00:10.010", + "2020-09-01 00:02:10.076", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-01", + "2020-08-03", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ] + ) def test_between_non_inclusive( self, init_tsdf_id, start_ts, end_ts, expected_tsdf_dict ): @@ -889,150 +915,152 @@ def test_between_non_inclusive( expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(between_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-08-01 00:01:10", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-08-01 00:01:12", 351.32], - ["S1", "2020-09-01 00:02:10", 361.1], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ["S2", "2020-09-01 00:02:10", 761.10], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-08-01 00:01:10", - "2020-09-01 00:02:10", - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:01:10", 743.01], - ["2020-08-01 00:01:12", 351.32], - ["2020-08-01 00:01:24", 751.92], - ["2020-09-01 00:02:10", 361.1], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-01", - "2020-08-03", - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["LGA", "2020-08-03", 28.53], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ["YYZ", "2020-08-03", 20.62], - ], - }, - }, - ), - ( - "ordinal_double_index", - 0.1, - 10.0, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 0.13, 349.21], - ["S1", 1.207, 351.32], - ["S1", 10.0, 361.1], - ["S2", 0.1, 751.92], - ["S2", 1.0, 761.10], - ["S2", 10.0, 762.33], - ], - }, - }, - ), - ( - "ordinal_int_index", - 1, - 100, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 1, 349.21], - ["S1", 20, 351.32], - ["S2", 1, 751.92], - ["S2", 10, 761.10], - ["S2", 100, 762.33], - ], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-08-01 00:00:10.010", - "2020-09-01 00:02:10.076", - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-08-01 00:00:10.010", 349.21], - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S1", "2020-09-01 00:02:10.032", 361.1], - ["S2", "2020-08-01 00:01:10.054", 743.01], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ["S2", "2020-09-01 00:02:10.076", 761.10], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-01", - "2020-08-03", - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["LGA", "2020-08-03", 28.53], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ["YYZ", "2020-08-03", 20.62], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:01:12", 351.32], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:10", + "2020-09-01 00:02:10", + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:10", 743.01], + ["2020-08-01 00:01:12", 351.32], + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-01", + "2020-08-03", + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), + ( + "ordinal_double_index", + 0.1, + 10.0, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S1", 10.0, 361.1], + ["S2", 0.1, 751.92], + ["S2", 1.0, 761.10], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + 100, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S1", 20, 351.32], + ["S2", 1, 751.92], + ["S2", 10, 761.10], + ["S2", 100, 762.33], + ], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-08-01 00:00:10.010", + "2020-09-01 00:02:10.076", + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-01", + "2020-08-03", + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ], + }, + }, + ), + ] + ) def test_between_inclusive( self, init_tsdf_id, start_ts, end_ts, expected_tsdf_dict ): @@ -1044,131 +1072,133 @@ def test_between_inclusive( expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(between_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - 2, - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-08-01 00:00:10", 349.21], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S2", "2020-08-01 00:01:10", 743.01], - ["S2", "2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_ts_no_series", - 2, - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:00:10", 349.21], - ["2020-08-01 00:01:10", 743.01], - ], - }, - }, - ), - ( - "simple_date_idx", - 2, - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ( - "ordinal_double_index", - 2, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 0.13, 349.21], - ["S1", 1.207, 351.32], - ["S2", 0.005, 743.01], - ["S2", 0.1, 751.92], - ], - }, - }, - ), - ( - "ordinal_int_index", - 2, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 1, 349.21], - ["S1", 20, 351.32], - ["S2", 0, 743.01], - ["S2", 1, 751.92], - ], - }, - }, - ), - ( - "parsed_ts_idx", - 2, - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-08-01 00:00:10.010", 349.21], - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S2", "2020-08-01 00:01:10.054", 743.01], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ], - }, - }, - ), - ( - "parsed_date_idx", - 2, - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-01", 27.58], - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-01", 24.16], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + 2, + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-08-01 00:00:10", 349.21], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-08-01 00:01:10", 743.01], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + 2, + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:00:10", 349.21], + ["2020-08-01 00:01:10", 743.01], + ], + }, + }, + ), + ( + "simple_date_idx", + 2, + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ( + "ordinal_double_index", + 2, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 0.13, 349.21], + ["S1", 1.207, 351.32], + ["S2", 0.005, 743.01], + ["S2", 0.1, 751.92], + ], + }, + }, + ), + ( + "ordinal_int_index", + 2, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 1, 349.21], + ["S1", 20, 351.32], + ["S2", 0, 743.01], + ["S2", 1, 751.92], + ], + }, + }, + ), + ( + "parsed_ts_idx", + 2, + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + 2, + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-01", 27.58], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-01", 24.16], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ] + ) def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -1178,132 +1208,134 @@ def test_earliest(self, init_tsdf_id, num_records, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(earliest_ts, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - 2, - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-09-01 00:19:12", 362.1], - ["S1", "2020-09-01 00:02:10", 361.1], - ["S2", "2020-09-01 00:20:42", 762.33], - ["S2", "2020-09-01 00:02:10", 761.10], - ], - }, - }, - ), - ( - "simple_ts_no_series", - 4, - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-09-01 00:20:42", 762.33], - ["2020-09-01 00:19:12", 362.1], - ["2020-09-01 00:02:10", 361.1], - ["2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_date_idx", - 3, - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-04", 25.57], - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-04", 20.65], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ( - "ordinal_double_index", - 1, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [["S1", 24.357, 362.1], ["S2", 10.0, 762.33]], - }, - }, - ), - ( - "ordinal_int_index", - 3, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 243, 362.1], - ["S1", 127, 361.1], - ["S1", 20, 351.32], - ["S2", 100, 762.33], - ["S2", 10, 761.10], - ["S2", 1, 751.92], - ], - }, - }, - ), - ( - "parsed_ts_idx", - 3, - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-09-01 00:19:12.043", 362.1], - ["S1", "2020-09-01 00:02:10.032", 361.1], - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S2", "2020-09-01 00:20:42.087", 762.33], - ["S2", "2020-09-01 00:02:10.076", 761.10], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ], - }, - }, - ), - ( - "parsed_date_idx", - 1, - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + 2, + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:19:12", 362.1], + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-09-01 00:20:42", 762.33], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + 4, + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:20:42", 762.33], + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:02:10", 361.1], + ["2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_date_idx", + 3, + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-04", 25.57], + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-04", 20.65], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ( + "ordinal_double_index", + 1, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [["S1", 24.357, 362.1], ["S2", 10.0, 762.33]], + }, + }, + ), + ( + "ordinal_int_index", + 3, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 243, 362.1], + ["S1", 127, 361.1], + ["S1", 20, 351.32], + ["S2", 100, 762.33], + ["S2", 10, 761.10], + ["S2", 1, 751.92], + ], + }, + }, + ), + ( + "parsed_ts_idx", + 3, + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ], + }, + }, + ), + ( + "parsed_date_idx", + 1, + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ] + ) def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -1313,139 +1345,141 @@ def test_latest(self, init_tsdf_id, num_records, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(latest_ts, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - 2, - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-09-01 00:02:10", 361.1], - ["S1", "2020-08-01 00:01:12", 351.32], - ["S2", "2020-09-01 00:02:10", 761.10], - ["S2", "2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-09-01 00:19:12", - 3, - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-09-01 00:19:12", 362.1], - ["2020-09-01 00:02:10", 361.1], - ["2020-08-01 00:01:24", 751.92], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-03", - 2, - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-02", 28.79], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-02", 22.25], - ], - }, - }, - ), - ( - "ordinal_double_index", - 10.0, - 4, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 10.0, 361.1], - ["S1", 1.207, 351.32], - ["S1", 0.13, 349.21], - ["S2", 10.0, 762.33], - ["S2", 1.0, 761.10], - ["S2", 0.1, 751.92], - ["S2", 0.005, 743.01], - ], - }, - }, - ), - ( - "ordinal_int_index", - 1, - 1, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [["S1", 1, 349.21], ["S2", 1, 751.92]], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10.000", - 2, - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-08-01 00:01:12.021", 351.32], - ["S1", "2020-08-01 00:00:10.010", 349.21], - ["S2", "2020-08-01 00:01:24.065", 751.92], - ["S2", "2020-08-01 00:01:10.054", 743.01], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-03", - 3, - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-02", 28.79], - ["LGA", "2020-08-01", 27.58], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-02", 22.25], - ["YYZ", "2020-08-01", 24.16], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + 2, + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S1", "2020-08-01 00:01:12", 351.32], + ["S2", "2020-09-01 00:02:10", 761.10], + ["S2", "2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-09-01 00:19:12", + 3, + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-09-01 00:19:12", 362.1], + ["2020-09-01 00:02:10", 361.1], + ["2020-08-01 00:01:24", 751.92], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-03", + 2, + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-02", 28.79], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-02", 22.25], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + 4, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S1", 1.207, 351.32], + ["S1", 0.13, 349.21], + ["S2", 10.0, 762.33], + ["S2", 1.0, 761.10], + ["S2", 0.1, 751.92], + ["S2", 0.005, 743.01], + ], + }, + }, + ), + ( + "ordinal_int_index", + 1, + 1, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [["S1", 1, 349.21], ["S2", 1, 751.92]], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10.000", + 2, + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-08-01 00:01:12.021", 351.32], + ["S1", "2020-08-01 00:00:10.010", 349.21], + ["S2", "2020-08-01 00:01:24.065", 751.92], + ["S2", "2020-08-01 00:01:10.054", 743.01], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + 3, + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-01", 27.58], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-01", 24.16], + ], + }, + }, + ), + ] + ) def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() @@ -1455,138 +1489,140 @@ def test_priorTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): expected_tsdf = TestDataFrameBuilder(self.spark, expected_tsdf_dict).as_tsdf() self.assertDataFrameEquality(prior_tsdf, expected_tsdf) - @parameterized.expand([ - ( - "simple_ts_idx", - "2020-09-01 00:02:10", - 1, - { - "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["S1", "2020-09-01 00:02:10", 361.1], - ["S2", "2020-09-01 00:02:10", 761.10], - ], - }, - }, - ), - ( - "simple_ts_no_series", - "2020-08-01 00:01:24", - 3, - { - "tsdf": {"ts_col": "event_ts", "series_ids": []}, - "df": { - "schema": "event_ts string, trade_pr float", - "ts_convert": ["event_ts"], - "data": [ - ["2020-08-01 00:01:24", 751.92], - ["2020-09-01 00:02:10", 361.1], - ["2020-09-01 00:19:12", 362.1], - ], - }, - }, - ), - ( - "simple_date_idx", - "2020-08-02", - 5, - { - "tsdf": {"ts_col": "date", "series_ids": ["station"]}, - "df": { - "schema": "station string, date string, temp float", - "date_convert": ["date"], - "data": [ - ["LGA", "2020-08-02", 28.79], - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-02", 22.25], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ( - "ordinal_double_index", - 10.0, - 2, - { - "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, event_ts_dbl double, trade_pr float", - "data": [ - ["S1", 10.0, 361.1], - ["S1", 24.357, 362.1], - ["S2", 10.0, 762.33], - ], - }, - }, - ), - ( - "ordinal_int_index", - 10, - 2, - { - "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, - "df": { - "schema": "symbol string, order int, trade_pr float", - "data": [ - ["S1", 20, 351.32], - ["S1", 127, 361.1], - ["S2", 10, 761.10], - ["S2", 100, 762.33], - ], - }, - }, - ), - ( - "parsed_ts_idx", - "2020-09-01 00:02:10", - 3, - { - "tsdf": { - "ts_col": "event_ts", - "series_ids": ["symbol"], - "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "symbol string, event_ts string, trade_pr float", - "data": [ - ["S1", "2020-09-01 00:02:10.032", 361.1], - ["S1", "2020-09-01 00:19:12.043", 362.1], - ["S2", "2020-09-01 00:02:10.076", 761.10], - ["S2", "2020-09-01 00:20:42.087", 762.33], - ], - }, - }, - ), - ( - "parsed_date_idx", - "2020-08-03", - 2, - { - "tsdf": { - "ts_col": "date", - "series_ids": ["station"], - "ts_fmt": "yyyy-MM-dd", - }, - "tsdf_constructor": "fromStringTimestamp", - "df": { - "schema": "station string, date string, temp float", - "data": [ - ["LGA", "2020-08-03", 28.53], - ["LGA", "2020-08-04", 25.57], - ["YYZ", "2020-08-03", 20.62], - ["YYZ", "2020-08-04", 20.65], - ], - }, - }, - ), - ]) + @parameterized.expand( + [ + ( + "simple_ts_idx", + "2020-09-01 00:02:10", + 1, + { + "tsdf": {"ts_col": "event_ts", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["S1", "2020-09-01 00:02:10", 361.1], + ["S2", "2020-09-01 00:02:10", 761.10], + ], + }, + }, + ), + ( + "simple_ts_no_series", + "2020-08-01 00:01:24", + 3, + { + "tsdf": {"ts_col": "event_ts", "series_ids": []}, + "df": { + "schema": "event_ts string, trade_pr float", + "ts_convert": ["event_ts"], + "data": [ + ["2020-08-01 00:01:24", 751.92], + ["2020-09-01 00:02:10", 361.1], + ["2020-09-01 00:19:12", 362.1], + ], + }, + }, + ), + ( + "simple_date_idx", + "2020-08-02", + 5, + { + "tsdf": {"ts_col": "date", "series_ids": ["station"]}, + "df": { + "schema": "station string, date string, temp float", + "date_convert": ["date"], + "data": [ + ["LGA", "2020-08-02", 28.79], + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-02", 22.25], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ( + "ordinal_double_index", + 10.0, + 2, + { + "tsdf": {"ts_col": "event_ts_dbl", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, event_ts_dbl double, trade_pr float", + "data": [ + ["S1", 10.0, 361.1], + ["S1", 24.357, 362.1], + ["S2", 10.0, 762.33], + ], + }, + }, + ), + ( + "ordinal_int_index", + 10, + 2, + { + "tsdf": {"ts_col": "order", "series_ids": ["symbol"]}, + "df": { + "schema": "symbol string, order int, trade_pr float", + "data": [ + ["S1", 20, 351.32], + ["S1", 127, 361.1], + ["S2", 10, 761.10], + ["S2", 100, 762.33], + ], + }, + }, + ), + ( + "parsed_ts_idx", + "2020-09-01 00:02:10", + 3, + { + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["symbol"], + "ts_fmt": "yyyy-MM-dd HH:mm:ss.SSS", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "symbol string, event_ts string, trade_pr float", + "data": [ + ["S1", "2020-09-01 00:02:10.032", 361.1], + ["S1", "2020-09-01 00:19:12.043", 362.1], + ["S2", "2020-09-01 00:02:10.076", 761.10], + ["S2", "2020-09-01 00:20:42.087", 762.33], + ], + }, + }, + ), + ( + "parsed_date_idx", + "2020-08-03", + 2, + { + "tsdf": { + "ts_col": "date", + "series_ids": ["station"], + "ts_fmt": "yyyy-MM-dd", + }, + "tsdf_constructor": "fromStringTimestamp", + "df": { + "schema": "station string, date string, temp float", + "data": [ + ["LGA", "2020-08-03", 28.53], + ["LGA", "2020-08-04", 25.57], + ["YYZ", "2020-08-03", 20.62], + ["YYZ", "2020-08-04", 20.65], + ], + }, + }, + ), + ] + ) def test_subsequentTo(self, init_tsdf_id, ts, n, expected_tsdf_dict): # load TSDF tsdf = self.get_test_function_df_builder(init_tsdf_id).as_tsdf() diff --git a/python/tests/unit_test_data/resample_tests.json b/python/tests/unit_test_data/resample_tests.json index eb7c2d11..2c3e2702 100644 --- a/python/tests/unit_test_data/resample_tests.json +++ b/python/tests/unit_test_data/resample_tests.json @@ -64,6 +64,23 @@ "symbol" ] } + }, + "interpol_data": { + "df": { + "schema": "partition string, event_ts string, value_a double, value_b double", + "data": [ + ["A", "2020-01-01 00:00:00", 1.0, 10.0], + ["A", "2020-01-01 00:30:00", null, null], + ["A", "2020-01-01 01:00:00", 2.0, 20.0], + ["A", "2020-01-01 01:30:00", null, null], + ["A", "2020-01-01 02:00:00", 3.0, 30.0] + ], + "ts_convert": ["event_ts"] + }, + "tsdf": { + "ts_col": "event_ts", + "series_ids": ["partition"] + } } }, "ResampleUnitTests": { diff --git a/python/tests/utils_tests.py b/python/tests/utils_tests.py index 5212ee95..033d8769 100644 --- a/python/tests/utils_tests.py +++ b/python/tests/utils_tests.py @@ -30,10 +30,7 @@ def test_calculate_time_horizon(self): tsdf = self.get_test_function_df_builder("init").as_tsdf() with warnings.catch_warnings(record=True) as w: - calculate_time_horizon( - tsdf, - "30 seconds" - ) + calculate_time_horizon(tsdf, "30 seconds") warning_message = """ Resample Metrics Warning: Earliest Timestamp: 2020-01-01 00:00:10 @@ -192,8 +189,11 @@ def test_time_range_with_custom_column_name(self): num_intervals = 3 result = time_range( - self.spark, start, step_size=step, num_intervals=num_intervals, - ts_colname="custom_ts" + self.spark, + start, + step_size=step, + num_intervals=num_intervals, + ts_colname="custom_ts", ) self.assertEqual(result.count(), 3) @@ -212,8 +212,11 @@ def test_time_range_with_interval_ends(self): num_intervals = 3 result = time_range( - self.spark, start, step_size=step, num_intervals=num_intervals, - include_interval_ends=True + self.spark, + start, + step_size=step, + num_intervals=num_intervals, + include_interval_ends=True, ) self.assertEqual(result.count(), 3)