diff --git a/README.md b/README.md index 72ff3e6..2bbf2ed 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ K8s StatefulSet (N replicas) - Static partition assignment via pod ordinal — no consumer groups - If a pod dies, its partitions stop being consumed until K8s restarts it +## Record Filtering + +Millpond can optionally filter records by a field value, keeping only records where a specified column matches a given string. Set `FILTER_FIELD` and `FILTER_VALUE` (both required together). + +Filtered records are tracked via the `millpond_records_skipped_total{reason="filter"}` metric. + ## Performance The hot path is all C/C++: librdkafka → orjson → PyArrow → DuckDB (zero-copy Arrow scan). Python is glue. @@ -135,6 +141,8 @@ All configuration via environment variables: | `FETCH_MAX_WAIT_MS` | no | `500` | Max broker wait when `fetch.min.bytes` not yet satisfied | | `STATS_INTERVAL_MS` | no | `5000` | librdkafka internal stats emission interval (0 to disable) | | `LOG_LEVEL` | no | `INFO` | Python log level (DEBUG, INFO, WARNING, ERROR) | +| `FILTER_FIELD` | no | | Column name to filter on. Must be set with `FILTER_VALUE`. | +| `FILTER_VALUE` | no | | Value to match in `FILTER_FIELD`. Only records where `FILTER_FIELD == FILTER_VALUE` (string comparison) are written. All others are discarded after parsing. | ## Releases diff --git a/millpond/config.py b/millpond/config.py index c09a6a1..656ede3 100644 --- a/millpond/config.py +++ b/millpond/config.py @@ -47,6 +47,10 @@ class Config: consume_batch_size: int stats_interval_ms: int + # Record filter (optional) — only keep records where filter_field == filter_value + filter_field: str | None + filter_value: str | None + # Extra librdkafka config (from KAFKA_CONSUMER_* env vars) kafka_config_overrides: tuple[tuple[str, str], ...] @@ -123,9 +127,14 @@ def load() -> Config: fetch_max_wait_ms=int(os.environ.get("FETCH_MAX_WAIT_MS", "500")), consume_batch_size=int(os.environ.get("CONSUME_BATCH_SIZE", "1000")), stats_interval_ms=int(os.environ.get("STATS_INTERVAL_MS", "5000")), + filter_field=os.environ.get("FILTER_FIELD", "").strip() or None, + filter_value=os.environ.get("FILTER_VALUE", "").strip() or None, kafka_config_overrides=kafka_overrides, ) + if bool(cfg.filter_field) != bool(cfg.filter_value): + raise RuntimeError("FILTER_FIELD and FILTER_VALUE must both be set or both be unset") + log.info( "Config: topic=%s table=%s ordinal=%d/%d group_id=%s", topic, @@ -134,4 +143,6 @@ def load() -> Config: replica_count, cfg.group_id, ) + if cfg.filter_field: + log.info("Filter: %s=%s", cfg.filter_field, cfg.filter_value) return cfg diff --git a/millpond/main.py b/millpond/main.py index 8a633fb..1dc2934 100644 --- a/millpond/main.py +++ b/millpond/main.py @@ -4,6 +4,7 @@ import time import pyarrow as pa +import pyarrow.compute as pc from confluent_kafka import TopicPartition from millpond import arrow_converter, config, consumer, ducklake, logging_config, metrics, schema, server @@ -27,6 +28,21 @@ def _convert_batch(values: list[bytes]) -> pa.Table | None: return table +def _apply_filter(table: pa.Table, cfg: config.Config) -> pa.Table: + """Filter table by field/value if configured. Returns the (possibly smaller) table.""" + if cfg.filter_field is None: + return table + if cfg.filter_field not in table.column_names: + log.warning("Filter field %r not in table schema, keeping all records", cfg.filter_field) + return table + column = pc.cast(table[cfg.filter_field], pa.string()) + filtered = table.filter(pc.equal(column, cfg.filter_value)) + filtered_out = len(table) - len(filtered) + if filtered_out > 0: + metrics.records_skipped_total.labels(reason="filter").inc(filtered_out) + return filtered + + def _write_with_retry(db, table_name, consolidated, schema_mgr, partition_by=None): """Write to DuckLake with exponential backoff on transient failures.""" for attempt in range(_WRITE_MAX_RETRIES): @@ -53,13 +69,8 @@ def _write_with_retry(db, table_name, consolidated, schema_mgr, partition_by=Non time.sleep(delay) -def _flush(db, cfg, kafka, consolidated, pending_bytes, pending_records, offsets, elapsed, schema_mgr, trigger="time"): - """Write to DuckLake, commit offsets, update metrics.""" - t0 = time.monotonic() - _write_with_retry(db, cfg.ducklake_table, consolidated, schema_mgr, cfg.partition_by) - write_duration = time.monotonic() - t0 - - # Commit offsets synchronously — at-least-once requires knowing commit succeeded +def _commit_offsets(kafka, offsets): + """Commit Kafka offsets with retries. Returns the TopicPartition list committed.""" tp_offsets = [ TopicPartition(topic, partition, offset + 1) # +1: committed offset is next-to-fetch for (topic, partition), offset in offsets.items() @@ -67,7 +78,7 @@ def _flush(db, cfg, kafka, consolidated, pending_bytes, pending_records, offsets for attempt in range(_COMMIT_MAX_RETRIES): try: kafka.commit(offsets=tp_offsets, asynchronous=False) - break + return tp_offsets except Exception: metrics.errors_total.labels(type="offset_commit").inc() if attempt == _COMMIT_MAX_RETRIES - 1: @@ -87,6 +98,15 @@ def _flush(db, cfg, kafka, consolidated, pending_bytes, pending_records, offsets ) time.sleep(delay) + +def _flush(db, cfg, kafka, consolidated, pending_bytes, pending_records, offsets, elapsed, schema_mgr, trigger="time"): + """Write to DuckLake, commit offsets, update metrics.""" + t0 = time.monotonic() + _write_with_retry(db, cfg.ducklake_table, consolidated, schema_mgr, cfg.partition_by) + write_duration = time.monotonic() - t0 + + tp_offsets = _commit_offsets(kafka, offsets) + log.info( "Flush: %d records, %d bytes, %d columns, write=%.2fs, elapsed=%.1fs", len(consolidated), @@ -178,9 +198,11 @@ def on_signal(signum, _frame): table = _convert_batch(values) if table is not None: skipped = len(values) - len(table) - pending.append(table) - pending_bytes += table.nbytes - pending_records += len(table) + table = _apply_filter(table, cfg) + if len(table) > 0: + pending.append(table) + pending_bytes += table.nbytes + pending_records += len(table) metrics.pending_bytes.set(pending_bytes) else: skipped = len(values) @@ -226,6 +248,13 @@ def on_signal(signum, _frame): metrics.pending_bytes.set(0) last_flush = time.monotonic() + elif time_triggered and offsets: + # No pending records but offsets advanced (e.g. all records filtered out). + # Commit offsets so restarts don't replay already-processed data. + _commit_offsets(kafka, offsets) + offsets.clear() + last_flush = time.monotonic() + except Exception: log.exception("Fatal error in main loop") raise diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 5659c8b..3fcdaa9 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -129,3 +129,25 @@ def test_kafka_consumer_overrides(self, monkeypatch): def test_kafka_consumer_overrides_default_empty(self): cfg = load() assert cfg.kafka_config_overrides == () + + def test_filter_default_none(self): + cfg = load() + assert cfg.filter_field is None + assert cfg.filter_value is None + + def test_filter_both_set(self, monkeypatch): + monkeypatch.setenv("FILTER_FIELD", "team_id") + monkeypatch.setenv("FILTER_VALUE", "42") + cfg = load() + assert cfg.filter_field == "team_id" + assert cfg.filter_value == "42" + + def test_filter_field_without_value_rejected(self, monkeypatch): + monkeypatch.setenv("FILTER_FIELD", "team_id") + with pytest.raises(RuntimeError, match="FILTER_FIELD and FILTER_VALUE must both be set"): + load() + + def test_filter_value_without_field_rejected(self, monkeypatch): + monkeypatch.setenv("FILTER_VALUE", "42") + with pytest.raises(RuntimeError, match="FILTER_FIELD and FILTER_VALUE must both be set"): + load() diff --git a/tests/unit/test_consumer.py b/tests/unit/test_consumer.py index ee475e9..ba7ee00 100644 --- a/tests/unit/test_consumer.py +++ b/tests/unit/test_consumer.py @@ -138,6 +138,8 @@ def _make_cfg(**overrides) -> Config: flush_size=100, flush_interval_ms=1000, partition_by=None, + filter_field=None, + filter_value=None, fetch_min_bytes=1, fetch_max_wait_ms=500, consume_batch_size=1000, diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index f2682b7..5f82847 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -4,7 +4,8 @@ import pytest from confluent_kafka import KafkaException -from millpond.main import _convert_batch, _flush, _write_with_retry +from millpond.main import _apply_filter, _convert_batch, _flush, _write_with_retry +from millpond.config import Config class TestWriteWithRetry: @@ -144,6 +145,65 @@ def test_successful_flush_records_write_metrics(self, mock_dl, mock_metrics, moc mock_metrics.batches_flushed_total.labels.return_value.inc.assert_called_once() +def _make_filter_cfg(filter_field=None, filter_value=None): + return MagicMock(spec=Config, filter_field=filter_field, filter_value=filter_value) + + +class TestApplyFilter: + def test_no_filter_returns_table_unchanged(self): + table = pa.table({"team_id": [42, 99], "event": ["click", "view"]}) + cfg = _make_filter_cfg() + result = _apply_filter(table, cfg) + assert len(result) == 2 + + def test_filters_by_string_field(self): + table = pa.table({"region": ["us", "eu", "us"], "event": ["a", "b", "c"]}) + cfg = _make_filter_cfg(filter_field="region", filter_value="us") + result = _apply_filter(table, cfg) + assert len(result) == 2 + assert result.column("event").to_pylist() == ["a", "c"] + + def test_filters_by_numeric_field(self): + table = pa.table({"team_id": [42, 99, 42], "event": ["a", "b", "c"]}) + cfg = _make_filter_cfg(filter_field="team_id", filter_value="42") + result = _apply_filter(table, cfg) + assert len(result) == 2 + assert result.column("event").to_pylist() == ["a", "c"] + + def test_filter_removes_all(self): + table = pa.table({"team_id": [99, 100], "event": ["a", "b"]}) + cfg = _make_filter_cfg(filter_field="team_id", filter_value="42") + result = _apply_filter(table, cfg) + assert len(result) == 0 + + def test_filter_keeps_all(self): + table = pa.table({"team_id": [42, 42], "event": ["a", "b"]}) + cfg = _make_filter_cfg(filter_field="team_id", filter_value="42") + result = _apply_filter(table, cfg) + assert len(result) == 2 + + def test_missing_field_keeps_all(self): + table = pa.table({"event": ["a", "b"]}) + cfg = _make_filter_cfg(filter_field="team_id", filter_value="42") + result = _apply_filter(table, cfg) + assert len(result) == 2 + + @patch("millpond.main.metrics") + def test_filter_increments_skip_metric(self, mock_metrics): + table = pa.table({"team_id": [42, 99, 100], "event": ["a", "b", "c"]}) + cfg = _make_filter_cfg(filter_field="team_id", filter_value="42") + _apply_filter(table, cfg) + mock_metrics.records_skipped_total.labels.assert_called_with(reason="filter") + mock_metrics.records_skipped_total.labels.return_value.inc.assert_called_once_with(2) + + @patch("millpond.main.metrics") + def test_no_metric_when_nothing_filtered(self, mock_metrics): + table = pa.table({"team_id": [42], "event": ["a"]}) + cfg = _make_filter_cfg(filter_field="team_id", filter_value="42") + _apply_filter(table, cfg) + mock_metrics.records_skipped_total.labels.assert_not_called() + + class TestArrowConversionTiming: """Arrow conversion time should be tracked via a histogram metric."""