Skip to content
24 changes: 12 additions & 12 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
implementation (`_run_ipf_numpy`) for iterative proportional fitting,
resulting in significant performance improvements and eliminating external
dependency ([#135](https://github.com/facebookresearch/balance/pull/135)).
- **IPW method enhancements**
- Added `logistic_regression_kwargs` parameter to `ipw()` for customizing
sklearn LogisticRegression settings
([#138](https://github.com/facebookresearch/balance/pull/138)).
- CLI now supports `--ipw_logistic_regression_kwargs` for passing custom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written above - I think keeping this in the CLI is a good idea.

LogisticRegression parameters via JSON
([#138](https://github.com/facebookresearch/balance/pull/138)).
- **Propensity modeling flexibility**
- `ipw()` now accepts any sklearn classifier via the new `sklearn_model`
argument, enabling the use of models like random forests while preserving
all existing trimming and diagnostic workflows. Dense-only estimators and
models without linear coefficients are fully supported, and propensity
probabilities are stabilized to avoid numerical issues.
- `ipw()` now accepts any sklearn classifier via the `model` argument and
deprecates the old `sklearn_model` alias, enabling the use of models like
random forests while preserving all existing trimming and diagnostic
workflows. Dense-only estimators and models without linear coefficients are
fully supported, and propensity probabilities are stabilized to avoid
numerical issues.
- Implemented logistic regression customization by passing a configured
:class:`~sklearn.linear_model.LogisticRegression` instance through the
`model` argument; the CLI now accepts `--ipw_logistic_regression_kwargs`
JSON to build that estimator directly for command-line workflows.
- **Covariate diagnostics**
- Added KL divergence calculations for covariate comparisons (numeric and
one-hot categorical), exposed via `BalanceDF.kld()` alongside linked-sample
Expand All @@ -40,6 +38,8 @@
- Added project badges to README for build status, Python version support, and
release tracking
([#145](https://github.com/facebookresearch/balance/pull/145)).
- Added IPW quickstart tutorial showcasing default logistic regression and
custom sklearn classifier usage in (`balance_quickstart.ipynb`).

## Code Quality & Refactoring

Expand Down
61 changes: 31 additions & 30 deletions balance/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from balance import __version__ # @manual
from balance.sample_class import Sample as balance_sample_cls # @manual
from sklearn.base import ClassifierMixin
from sklearn.linear_model import LogisticRegression

logger: logging.Logger = logging.getLogger(__package__)

Expand All @@ -40,7 +42,6 @@ def __init__(self, args: Namespace) -> None:
self._lambda_max: float | None = None
self._num_lambdas: int | None = None
self._weight_trimming_mean_ratio: float = 20.0
self._logistic_regression_kwargs: Dict[str, Any] | None = None
self._sample_cls: Type[balance_sample_cls] = balance_sample_cls
self._sample_package_name: str = __package__
self._sample_package_version: str = __version__
Expand Down Expand Up @@ -156,6 +157,12 @@ def logistic_regression_kwargs(self) -> Dict[str, Any] | None:
)
return parsed

def logistic_regression_model(self) -> ClassifierMixin | None:
kwargs = self.logistic_regression_kwargs()
if kwargs is None:
return None
return LogisticRegression(**kwargs)

def split_sample(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
in_sample = df[self.sample_column()] == 1
sample_df = df[in_sample]
Expand All @@ -175,7 +182,6 @@ def process_batch(
lambda_max: float | None = 10,
num_lambdas: int | None = 250,
weight_trimming_mean_ratio: float = 20,
logistic_regression_kwargs: Dict[str, Any] | None = None,
sample_cls: Type[balance_sample_cls] = balance_sample_cls,
sample_package_name: str = __package__,
) -> Dict[str, pd.DataFrame]:
Expand Down Expand Up @@ -226,21 +232,26 @@ def process_batch(
logger.info("%s target object: %s" % (sample_package_name, str(target)))

try:
adjusted = sample.set_target(
target
).adjust(
method=self.method(), # pyre-ignore[6] it gets str, but the function will verify internally if it's the str it should be.
transformations=transformations,
formula=formula,
penalty_factor=penalty_factor,
one_hot_encoding=one_hot_encoding,
max_de=max_de,
lambda_min=lambda_min,
lambda_max=lambda_max,
num_lambdas=num_lambdas,
weight_trimming_mean_ratio=weight_trimming_mean_ratio,
logistic_regression_kwargs=logistic_regression_kwargs,
)
method = self.method()
model = self.logistic_regression_model() if method == "ipw" else None

adjusted_kwargs: Dict[str, Any] = {
"method": method, # pyre-ignore[6] it gets str, but the function will verify internally if it's the str it should be.
"transformations": transformations,
"formula": formula,
"penalty_factor": penalty_factor,
"one_hot_encoding": one_hot_encoding,
"max_de": max_de,
"lambda_min": lambda_min,
"lambda_max": lambda_max,
"num_lambdas": num_lambdas,
"weight_trimming_mean_ratio": weight_trimming_mean_ratio,
}

if model is not None:
adjusted_kwargs["model"] = model

adjusted = sample.set_target(target).adjust(**adjusted_kwargs)
logger.info("Succeeded with adjusting sample to target")
logger.info("%s adjusted object: %s" % (sample_package_name, str(adjusted)))

Expand Down Expand Up @@ -383,7 +394,6 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
)
max_de = self.max_de()
weight_trimming_mean_ratio = self.weight_trimming_mean_ratio()
logistic_regression_kwargs = self.logistic_regression_kwargs()
sample_cls, sample_package_name, sample_package_version = (
balance_sample_cls,
__package__,
Expand All @@ -401,7 +411,6 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
self._lambda_max,
self._num_lambdas,
self._weight_trimming_mean_ratio,
self._logistic_regression_kwargs,
self._sample_cls,
self._sample_package_name,
self._sample_package_version,
Expand All @@ -415,7 +424,6 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
lambda_max,
num_lambdas,
weight_trimming_mean_ratio,
logistic_regression_kwargs,
sample_cls,
sample_package_name,
sample_package_version,
Expand All @@ -433,7 +441,6 @@ def main(self) -> None:
lambda_max,
num_lambdas,
weight_trimming_mean_ratio,
logistic_regression_kwargs,
sample_cls,
sample_package_name,
sample_package_version,
Expand All @@ -447,7 +454,6 @@ def main(self) -> None:
self._lambda_max,
self._num_lambdas,
self._weight_trimming_mean_ratio,
self._logistic_regression_kwargs,
self._sample_cls,
self._sample_package_name,
self._sample_package_version,
Expand All @@ -469,7 +475,6 @@ def main(self) -> None:
"lambda_max",
"num_lambdas",
"weight_trimming_mean_ratio",
"logistic_regression_kwargs",
"sample_cls",
"sample_package_name",
"sample_package_version",
Expand All @@ -484,7 +489,6 @@ def main(self) -> None:
lambda_max,
num_lambdas,
weight_trimming_mean_ratio,
logistic_regression_kwargs,
sample_cls,
sample_package_name,
sample_package_version,
Expand Down Expand Up @@ -513,7 +517,6 @@ def main(self) -> None:
lambda_max,
num_lambdas,
weight_trimming_mean_ratio,
logistic_regression_kwargs,
sample_cls,
sample_package_name,
)
Expand All @@ -539,7 +542,6 @@ def main(self) -> None:
lambda_max,
num_lambdas,
weight_trimming_mean_ratio,
logistic_regression_kwargs,
sample_cls,
sample_package_name,
)
Expand Down Expand Up @@ -727,12 +729,11 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
)
parser.add_argument(
"--ipw_logistic_regression_kwargs",
type=str,
required=False,
default=None,
help=(
"JSON object string with additional keyword arguments passed to "
"sklearn.linear_model.LogisticRegression when method is ipw."
"A valid JSON object string of keyword arguments forwarded to sklearn.linear_model.LogisticRegression "
"when using the ipw method. For example: '{\"solver\": \"liblinear\", \"max_iter\": 500}'. "
"Ignored for other methods."
),
)
parser.add_argument(
Expand Down
Loading