From 00ad0a8c1acbc722e49c9b7a9938b6ed735170d4 Mon Sep 17 00:00:00 2001 From: Andrew Moreland Date: Sat, 5 Aug 2023 15:44:29 -0700 Subject: [PATCH 1/4] Working on underwriting example --- underwriting/underwriting.py | 185 +++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 underwriting/underwriting.py diff --git a/underwriting/underwriting.py b/underwriting/underwriting.py new file mode 100644 index 0000000..dd4d735 --- /dev/null +++ b/underwriting/underwriting.py @@ -0,0 +1,185 @@ +import logging +from datetime import datetime, timedelta + +import duckdb +import polars as pl +import pytest +from chalk import DataFrame, Now, has_many, offline +from chalk.features import features +from chalk.sql._internal.integrations.duckdb import DuckDbSourceImpl +from polars.testing import assert_frame_equal + +# _db = DuckDbSourceImpl(database=":memory:") +# _engine = _db.get_engine() +_db = DuckDbSourceImpl(database="foo5.duck") + +_logger = logging.getLogger(__name__) + + +@features +class RApplication: + id: int + + most_recent_revenue: float + + records: "DataFrame[RQuickbooksRecord]" = has_many(lambda: RApplication.id == RQuickbooksRecord.app_id) + + +@features +class RQuickbooksRecord: + id: int + + app_id: int + + created_at: datetime + start_date: datetime + end_date: datetime + + revenue: float + current_assets: float + current_liabilities: float + + +now = datetime.now() + + +records_db = [ + RQuickbooksRecord( + id=j * 100 + i, + app_id=j, + created_at=now - timedelta(days=i * 30), + start_date=now - timedelta(days=i * 30), + end_date=now - timedelta(days=i * 30), + revenue=i, + current_assets=i, + current_liabilities=-i, + ) + for i in range(12) + for j in range(100) +] + + +@offline(tags="scalar-records") +def get_records_scalar(id: RApplication.id) -> DataFrame[RQuickbooksRecord]: + return _db.query_string( + "select * from records where app_id = :app_id", + args={ + "app_id": id, + }, + fields={ + "idi": RQuickbooksRecord.id, + "app_id": RQuickbooksRecord.app_id, + "created_at": RQuickbooksRecord.created_at, + "start_date": RQuickbooksRecord.start_date, + "end_date": RQuickbooksRecord.end_date, + "revenue": RQuickbooksRecord.revenue, + "current_assets": RQuickbooksRecord.current_assets, + "current_liabilities": RQuickbooksRecord.current_liabilities, + }, + ).all() + + +@offline(tags="batch-records") +def get_records() -> DataFrame[RQuickbooksRecord]: + return _db.query_string( + "select * from records", + fields={ + "id": RQuickbooksRecord.id, + "app_id": RQuickbooksRecord.app_id, + "created_at": RQuickbooksRecord.created_at, + "start_date": RQuickbooksRecord.start_date, + "end_date": RQuickbooksRecord.end_date, + "revenue": RQuickbooksRecord.revenue, + "current_assets": RQuickbooksRecord.current_assets, + "current_liabilities": RQuickbooksRecord.current_liabilities, + }, + ).all() + + +@offline(tags="scalar-most-recent-revenue") +def most_recent_revenue( + records: RApplication.records[RQuickbooksRecord.revenue, RQuickbooksRecord.created_at], now: Now +) -> RApplication.most_recent_revenue: + df: pl.LazyFrame = records.to_polars() + + x = ( + df.sort(by=str(RQuickbooksRecord.created_at), descending=True) + .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) + .head(1) + .select(pl.col(str(RQuickbooksRecord.revenue))) + .collect() + ) + + return x + + +@offline(tags="batch-most-recent-revenue") +def batch_most_recent_revenue( + records: DataFrame[RApplication.id, RApplication.records, Now] +) -> DataFrame[RApplication.id, RApplication.most_recent_revenue]: + df: pl.LazyFrame = records.to_polars() + + x = ( + df.sort(by=str(RQuickbooksRecord.created_at), descending=True) + .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) + .head(1) + .select(pl.col(str(RQuickbooksRecord.revenue))) + .collect() + ) + + return x + + +@pytest.fixture() +def db_fixture(): + try: + _db.get_engine().execute("select count(*) from records") + except: + pa_table = DataFrame(records_db).to_pyarrow(prefixed=False) + _db.get_engine().execute("CREATE TABLE IF NOT EXISTS records AS SELECT * FROM pa_table") + + +def test_simple(local_chalk_client, db_fixture): + ds = local_chalk_client.offline_query( + input={RQuickbooksRecord.id: [1]}, + output=[RQuickbooksRecord.revenue], + recompute_features=True, + tags=["batch-records", "scalar-most-recent-revenue"], + ) + + assert_frame_equal( + ds.get_data_as_polars().collect(), + pl.DataFrame({str(RQuickbooksRecord.id): [1], str(RQuickbooksRecord.revenue): [1.0]}), + check_column_order=False, + ) + + +def test_aggregation_time_oblivious(local_chalk_client, db_fixture): + ds = local_chalk_client.offline_query( + input={RApplication.id: [1]}, + output=[RApplication.most_recent_revenue], + recompute_features=True, + tags=["batch-records", "scalar-most-recent-revenue"], + ) + + assert_frame_equal( + ds.get_data_as_polars().collect(), + pl.DataFrame({str(RApplication.id): [1], str(RApplication.most_recent_revenue): [0.0]}), + check_column_order=False, + ) + + +def test_aggregation_time_aware(local_chalk_client, db_fixture): + ds = local_chalk_client.offline_query( + input={RApplication.id: [1, 1]}, + input_times=[now, now - timedelta(days=60)], + output=[RApplication.most_recent_revenue], + recompute_features=True, + tags=["batch-records", "scalar-most-recent-revenue"], + ) + + assert_frame_equal( + ds.get_data_as_polars().collect(), + pl.DataFrame({str(RApplication.id): [1, 1], str(RApplication.most_recent_revenue): [0.0, 2.0]}), + check_column_order=False, + ) From 05069a40270390d18ea5882bf0527b7131a0eaef Mon Sep 17 00:00:00 2001 From: Andrew Moreland Date: Sat, 5 Aug 2023 15:45:34 -0700 Subject: [PATCH 2/4] Fix typo --- underwriting/underwriting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/underwriting/underwriting.py b/underwriting/underwriting.py index dd4d735..ade229a 100644 --- a/underwriting/underwriting.py +++ b/underwriting/underwriting.py @@ -67,7 +67,7 @@ def get_records_scalar(id: RApplication.id) -> DataFrame[RQuickbooksRecord]: "app_id": id, }, fields={ - "idi": RQuickbooksRecord.id, + "id": RQuickbooksRecord.id, "app_id": RQuickbooksRecord.app_id, "created_at": RQuickbooksRecord.created_at, "start_date": RQuickbooksRecord.start_date, From 79cd7e320c201bc4f8eba624456538f17a12fb6a Mon Sep 17 00:00:00 2001 From: Andrew Moreland Date: Tue, 8 Aug 2023 00:19:23 -0700 Subject: [PATCH 3/4] Parameterize tests, and add more sample points --- underwriting/underwriting.py | 74 ++++++++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/underwriting/underwriting.py b/underwriting/underwriting.py index ade229a..888f714 100644 --- a/underwriting/underwriting.py +++ b/underwriting/underwriting.py @@ -9,9 +9,7 @@ from chalk.sql._internal.integrations.duckdb import DuckDbSourceImpl from polars.testing import assert_frame_equal -# _db = DuckDbSourceImpl(database=":memory:") -# _engine = _db.get_engine() -_db = DuckDbSourceImpl(database="foo5.duck") +_db = DuckDbSourceImpl(database="foo7.duck") _logger = logging.getLogger(__name__) @@ -54,8 +52,8 @@ class RQuickbooksRecord: current_assets=i, current_liabilities=-i, ) - for i in range(12) - for j in range(100) + for i in range(36) + for j in range(100000) ] @@ -98,16 +96,16 @@ def get_records() -> DataFrame[RQuickbooksRecord]: @offline(tags="scalar-most-recent-revenue") def most_recent_revenue( - records: RApplication.records[RQuickbooksRecord.revenue, RQuickbooksRecord.created_at], now: Now + records: RApplication.records[RQuickbooksRecord.revenue, RQuickbooksRecord.created_at], now: Now ) -> RApplication.most_recent_revenue: df: pl.LazyFrame = records.to_polars() x = ( df.sort(by=str(RQuickbooksRecord.created_at), descending=True) - .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) - .head(1) - .select(pl.col(str(RQuickbooksRecord.revenue))) - .collect() + .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) + .head(1) + .select(pl.col(str(RQuickbooksRecord.revenue))) + .collect() ) return x @@ -115,16 +113,16 @@ def most_recent_revenue( @offline(tags="batch-most-recent-revenue") def batch_most_recent_revenue( - records: DataFrame[RApplication.id, RApplication.records, Now] + records: DataFrame[RApplication.id, RApplication.records, Now] ) -> DataFrame[RApplication.id, RApplication.most_recent_revenue]: df: pl.LazyFrame = records.to_polars() x = ( df.sort(by=str(RQuickbooksRecord.created_at), descending=True) - .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) - .head(1) - .select(pl.col(str(RQuickbooksRecord.revenue))) - .collect() + .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) + .head(1) + .select(pl.col(str(RQuickbooksRecord.revenue))) + .collect() ) return x @@ -183,3 +181,49 @@ def test_aggregation_time_aware(local_chalk_client, db_fixture): pl.DataFrame({str(RApplication.id): [1, 1], str(RApplication.most_recent_revenue): [0.0, 2.0]}), check_column_order=False, ) + + +def test_aggregation_time_aware_batch(local_chalk_client, db_fixture): + ds = local_chalk_client.offline_query( + input={RApplication.id: [1, 1]}, + input_times=[now, now - timedelta(days=60)], + output=[RApplication.most_recent_revenue], + recompute_features=True, + tags=["batch-records", "batch-most-recent-revenue"], + ) + + assert_frame_equal( + ds.get_data_as_polars().collect(), + pl.DataFrame({str(RApplication.id): [1, 1], str(RApplication.most_recent_revenue): [0.0, 2.0]}), + check_column_order=False, + ) + + +records_tags = ["batch-records", "scalar-records"] +revenue_tags = ["scalar-most-recent-revenue", "batch-most-recent-revenue"] + +all_ids = range(10000) + + +params = [ + ([record_tag, revenue_tag], size) + for record_tag in records_tags + for revenue_tag in revenue_tags + for size in [1, 10, 100, 1000, 10000] +] + + +@pytest.mark.parametrize("tags,size", params) +def test_benchmark(tags, size, local_chalk_client, db_fixture): + + ids = all_ids[:size] + + ds = local_chalk_client.offline_query( + input={RApplication.id: ids}, + input_times=[now] * size, + output=[RApplication.most_recent_revenue], + recompute_features=True, + tags=tags, + ) + + print(ds.get_data_as_polars().collect()) From 354e396cd48f452b2b788571db03c52101769000 Mon Sep 17 00:00:00 2001 From: Andrew Moreland Date: Tue, 8 Aug 2023 19:51:55 -0700 Subject: [PATCH 4/4] Update with fixes for errors in resolver definitions --- underwriting/underwriting.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/underwriting/underwriting.py b/underwriting/underwriting.py index 888f714..2d84379 100644 --- a/underwriting/underwriting.py +++ b/underwriting/underwriting.py @@ -5,11 +5,13 @@ import polars as pl import pytest from chalk import DataFrame, Now, has_many, offline -from chalk.features import features +from chalk.features import Features, features from chalk.sql._internal.integrations.duckdb import DuckDbSourceImpl from polars.testing import assert_frame_equal -_db = DuckDbSourceImpl(database="foo7.duck") +# _db = DuckDbSourceImpl(database=":memory:") +# _engine = _db.get_engine() +_db = DuckDbSourceImpl(database="foo8.duck") _logger = logging.getLogger(__name__) @@ -53,7 +55,7 @@ class RQuickbooksRecord: current_liabilities=-i, ) for i in range(36) - for j in range(100000) + for j in range(100) ] @@ -95,11 +97,14 @@ def get_records() -> DataFrame[RQuickbooksRecord]: @offline(tags="scalar-most-recent-revenue") -def most_recent_revenue( +def scalar_most_recent_revenue( records: RApplication.records[RQuickbooksRecord.revenue, RQuickbooksRecord.created_at], now: Now -) -> RApplication.most_recent_revenue: +) -> Features[RApplication.most_recent_revenue]: df: pl.LazyFrame = records.to_polars() + if len(records) == 0: + return 0 + x = ( df.sort(by=str(RQuickbooksRecord.created_at), descending=True) .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) @@ -117,15 +122,18 @@ def batch_most_recent_revenue( ) -> DataFrame[RApplication.id, RApplication.most_recent_revenue]: df: pl.LazyFrame = records.to_polars() - x = ( - df.sort(by=str(RQuickbooksRecord.created_at), descending=True) - .filter(pl.col(str(RQuickbooksRecord.created_at)) <= now) - .head(1) - .select(pl.col(str(RQuickbooksRecord.revenue))) - .collect() + exploded = df.explode(pl.col(str(RApplication.records))).unnest(str(RApplication.records)) + + time_consistent = exploded.filter(pl.col(str(RQuickbooksRecord.created_at)) <= pl.col(str(Now))) + + most_recent_rev_output = ( + time_consistent.sort(by=str(RQuickbooksRecord.created_at), descending=True) + .groupby(str(RApplication.id), str(Now), maintain_order=True) + .agg(pl.first(str(RQuickbooksRecord.revenue)).alias(str(RApplication.most_recent_revenue))) + .select(str(RApplication.id), str(RApplication.most_recent_revenue)) ) - return x + return most_recent_rev_output @pytest.fixture()