Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tempo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from tempo.resample_result import ResampledTSDF # noqa: F401
from tempo.tsdf import TSDF # noqa: F401
from tempo.utils import display # noqa: F401
14 changes: 9 additions & 5 deletions python/tempo/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Expand All @@ -10,6 +11,9 @@
Union,
)

if TYPE_CHECKING:
from tempo.resample_result import ResampledTSDF

import pyspark.sql.functions as sfn
from pyspark.sql import DataFrame

Expand Down Expand Up @@ -419,7 +423,7 @@ def resample(
prefix: Optional[str] = None,
fill: Optional[bool] = None,
perform_checks: bool = True,
) -> t_tsdf.TSDF:
) -> ResampledTSDF:
"""
function to upsample based on frequency and aggregate function similar to pandas

Expand All @@ -446,13 +450,13 @@ def resample(

enriched_df: DataFrame = aggregate(tsdf, freq, func, metricCols, prefix, fill)

# Import TSDF here to avoid circular import
# Import TSDF and ResampledTSDF here to avoid circular import
from tempo.resample_result import ResampledTSDF
from tempo.tsdf import TSDF

return TSDF(
plain_tsdf = TSDF(
enriched_df,
ts_col=tsdf.ts_col,
series_ids=tsdf.series_ids,
resample_freq=freq,
resample_func=func,
)
return ResampledTSDF(plain_tsdf, resample_freq=freq, resample_func=func)
144 changes: 144 additions & 0 deletions python/tempo/resample_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, List, Optional, Union

from pyspark.sql import DataFrame

from tempo.tsschema import TSSchema

if TYPE_CHECKING:
from tempo.tsdf import TSDF


class ResampledTSDF:
"""
Restricted wrapper around a TSDF that has been resampled.

Like Spark's GroupedData, this object only exposes operations that are
valid after a resample: interpolate() and as_tsdf(). Arbitrary DataFrame
transformations (filter, withColumn, etc.) are intentionally blocked to
prevent silent invalidation of resample context.
"""

def __init__(
self,
tsdf: TSDF,
resample_freq: str,
resample_func: Union[Callable, str],
) -> None:
self._tsdf = tsdf
self._resample_freq = resample_freq
self._resample_func = resample_func

# ------------------------------------------------------------------
# Read-only properties
# ------------------------------------------------------------------

@property
def df(self) -> DataFrame:
"""The underlying Spark DataFrame."""
return self._tsdf.df

@property
def ts_col(self) -> str:
return self._tsdf.ts_col

@property
def series_ids(self) -> list:
return self._tsdf.series_ids

@property
def ts_schema(self) -> TSSchema:
return self._tsdf.ts_schema

@property
def columns(self) -> list:
return self._tsdf.columns

@property
def resample_freq(self) -> str:
return self._resample_freq

@property
def resample_func(self) -> Union[Callable, str]:
return self._resample_func

# ------------------------------------------------------------------
# Valid post-resample operations
# ------------------------------------------------------------------

def interpolate(
self,
method: str,
target_cols: Optional[List[str]] = None,
show_interpolated: bool = False,
) -> TSDF:
"""
Interpolate missing values produced by the resample step.

:param method: interpolation method — "linear", "zero", "null", "bfill", or "ffill"
:param target_cols: columns to interpolate (default: all non-key columns)
:param show_interpolated: if True, add a column indicating interpolated rows
:return: a plain TSDF with interpolated data
"""
import logging

import pandas as pd

from tempo.interpol import backward_fill, forward_fill, zero_fill
from tempo.interpol import interpolate as interpol_func

logger = logging.getLogger(__name__)

tsdf = self._tsdf

# Resolve target columns
if target_cols is None:
prohibited_cols: List[str] = tsdf.series_ids + [tsdf.ts_col]
target_cols = [col for col in tsdf.df.columns if col not in prohibited_cols]

# Map method name to interpolation function
fn: Union[str, Callable[[pd.Series], pd.Series]]
if method == "linear":
fn = "linear"
elif method == "null":
return tsdf
elif method == "zero":
fn = zero_fill
elif method == "bfill":
fn = backward_fill
elif method == "ffill":
fn = forward_fill
else:
fn = method

interpolated_tsdf = interpol_func(
tsdf=tsdf,
cols=target_cols,
fn=fn,
leading_margin=2,
lagging_margin=2,
)

if show_interpolated:
logger.warning(
"show_interpolated=True is not yet implemented in the refactored version"
)

return interpolated_tsdf

def as_tsdf(self) -> TSDF:
"""Return the underlying TSDF without resample metadata (explicit escape hatch)."""
return self._tsdf

def show(self, n: int = 20, truncate: bool = True) -> None:
"""Delegates to the internal TSDF's show()."""
self._tsdf.df.show(n, truncate)

def __repr__(self) -> str:
return (
f"ResampledTSDF(freq={self._resample_freq!r}, "
f"func={self._resample_func!r}, "
f"ts_col={self.ts_col!r}, "
f"series_ids={self.series_ids!r})"
)
8 changes: 4 additions & 4 deletions python/tempo/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,16 @@ def calc_bars(
# - min/max compute column-wise (may mix values from different rows)
resample_open = tsdf.resample(
freq=freq, func="floor", metricCols=metric_cols, prefix="open", fill=fill
)
).as_tsdf()
resample_low = tsdf.resample(
freq=freq, func="min", metricCols=metric_cols, prefix="low", fill=fill
)
).as_tsdf()
resample_high = tsdf.resample(
freq=freq, func="max", metricCols=metric_cols, prefix="high", fill=fill
)
).as_tsdf()
resample_close = tsdf.resample(
freq=freq, func="ceil", metricCols=metric_cols, prefix="close", fill=fill
)
).as_tsdf()

join_cols = resample_open.series_ids + [resample_open.ts_col]
bars = (
Expand Down
81 changes: 18 additions & 63 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_time_format,
sub_seconds_precision_digits,
)
from tempo.resample_result import ResampledTSDF
from tempo.typing import ColumnOrName, PandasGroupedMapFunction, PandasMapIterFunction

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,8 +125,6 @@ def __init__(
ts_schema: Optional[TSSchema] = None,
ts_col: Optional[str] = None,
series_ids: Optional[Collection[str]] = None,
resample_freq: Optional[str] = None,
resample_func: Optional[Union[Callable, str]] = None,
) -> None:
self.df = df
# construct schema if we don't already have one
Expand All @@ -137,10 +136,6 @@ def __init__(
# validate that this schema works for this DataFrame
self.ts_schema.validate(df.schema)

# Optional resample metadata (used when this TSDF is created from resample())
self.resample_freq = resample_freq
self.resample_func = resample_func

def __repr__(self) -> str:
return f"{self.__class__.__name__}(df={self.df}, ts_schema={self.ts_schema})"

Expand All @@ -160,8 +155,6 @@ def __withTransformedDF(self, new_df: DataFrame) -> TSDF:
return TSDF(
new_df,
ts_schema=copy.deepcopy(self.ts_schema),
resample_freq=self.resample_freq,
resample_func=self.resample_func,
)

def __withStandardizedColOrder(self) -> TSDF:
Expand Down Expand Up @@ -1466,7 +1459,7 @@ def resample(
prefix: Optional[str] = None,
fill: Optional[bool] = None,
perform_checks: bool = True,
) -> TSDF:
) -> ResampledTSDF:
"""
function to upsample based on frequency and aggregate function similar to pandas
:param freq: frequency for upsample - valid inputs are "hr", "min", "sec" corresponding to hour, minute, or second
Expand All @@ -1475,7 +1468,7 @@ def resample(
:param prefix - supply a prefix for the newly sampled columns
:param fill - Boolean - set to True if the desired output should contain filled in gaps (with 0s currently)
:param perform_checks: calculate time horizon and warnings if True (default is True)
:return: TSDF object with sample data using aggregate function
:return: ResampledTSDF object with sample data using aggregate function
"""
t_resample_utils.validateFuncExists(func)

Expand All @@ -1486,12 +1479,11 @@ def resample(
enriched_df: DataFrame = t_resample.aggregate(
self, freq, func, metricCols, prefix, fill
)
return TSDF(
plain_tsdf = TSDF(
enriched_df,
ts_schema=copy.deepcopy(self.ts_schema),
resample_freq=freq,
resample_func=func,
)
return ResampledTSDF(plain_tsdf, resample_freq=freq, resample_func=func)

def interpolate(
self,
Expand All @@ -1518,70 +1510,33 @@ def interpolate(
:param perform_checks: calculate time horizon and warnings if True (default is True)
:return: new TSDF object containing interpolated data
"""

# Set defaults for target columns, timestamp column and partition columns when not provided
if freq is None:
raise ValueError("freq must be provided")
if func is None:
raise ValueError("func must be provided")

# Resolve target columns using the same defaults as before
if ts_col is None:
ts_col = self.ts_col
if partition_cols is None:
partition_cols = self.series_ids
if target_cols is None:
prohibited_cols: List[str] = partition_cols + [ts_col]
# Don't filter by data type - allow all columns for PR-421 compatibility
target_cols = [col for col in self.df.columns if col not in prohibited_cols]

# First resample the data
# Don't fill with zeros - let interpolation handle the nulls
resampled_tsdf = self.resample(
# Delegate through resample().interpolate()
return self.resample(
freq=freq,
func=func,
metricCols=target_cols,
fill=False, # Don't fill - interpolation will handle nulls
fill=False,
perform_checks=perform_checks,
).interpolate(
method=method,
target_cols=target_cols,
show_interpolated=show_interpolated,
)

# Import interpolation function and pre-defined fill functions
from tempo.interpol import backward_fill, forward_fill, zero_fill
from tempo.interpol import interpolate as interpol_func

# Map method names to interpolation functions (no lambdas)
fn: Union[str, Callable[[pd.Series], pd.Series]]
if method == "linear":
fn = "linear" # String method for pandas interpolation
elif method == "null":
# For null method, we don't fill - just return the resampled data
return resampled_tsdf
elif method == "zero":
fn = zero_fill
elif method == "bfill":
fn = backward_fill
elif method == "ffill":
fn = forward_fill
else:
# Assume it's a valid pandas interpolation method string
fn = method

# Apply interpolation to the resampled data
interpolated_tsdf = interpol_func(
tsdf=resampled_tsdf,
cols=target_cols,
fn=fn,
leading_margin=2,
lagging_margin=2,
)

if show_interpolated:
# Add a column indicating which rows were interpolated
# This would require tracking which rows had nulls before interpolation
logger.warning(
"show_interpolated=True is not yet implemented in the refactored version"
)

return interpolated_tsdf

def calc_bars(
tsdf,
freq: str,
Expand All @@ -1590,16 +1545,16 @@ def calc_bars(
) -> TSDF:
resample_open = tsdf.resample(
freq=freq, func="floor", metricCols=metricCols, prefix="open", fill=fill
)
).as_tsdf()
resample_low = tsdf.resample(
freq=freq, func="min", metricCols=metricCols, prefix="low", fill=fill
)
).as_tsdf()
resample_high = tsdf.resample(
freq=freq, func="max", metricCols=metricCols, prefix="high", fill=fill
)
).as_tsdf()
resample_close = tsdf.resample(
freq=freq, func="ceil", metricCols=metricCols, prefix="close", fill=fill
)
).as_tsdf()

join_cols = resample_open.series_ids + [resample_open.ts_col]
bars = (
Expand Down
Loading