-
Notifications
You must be signed in to change notification settings - Fork 47
Allow custom classifiers via ipw model parameter #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
neuralsorcerer
commented
Nov 26, 2025
- Related to [FEATURE] balance 0.13.0 - missing steps #175 and [FEATURE] Support more models from sklearn (other than logistic regression) #139
|
@talgalili has imported this pull request. If you are a Meta employee, you can view this in D87926504. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enhances the ipw() function to accept custom sklearn classifiers directly via the model parameter, replacing the previous sklearn_model parameter. Users can now pass any sklearn classifier implementing fit and predict_proba (e.g., RandomForestClassifier) or use the default logistic regression by specifying model="sklearn".
Key Changes:
- The
modelparameter now accepts sklearn classifiers in addition to string identifiers - The
sklearn_modelparameter is deprecated in favor ofmodel - Added comprehensive test coverage for the new parameter behavior
- Created a new tutorial notebook demonstrating both default and custom classifier usage
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| balance/weighting_methods/ipw.py | Updated ipw() signature and logic to accept classifiers via model parameter, deprecated sklearn_model, improved error messages and documentation |
| tests/test_ipw.py | Added tests for new model parameter, conflicting arguments validation, and error handling |
| tutorials/balance_quickstart_ipw.ipynb | New tutorial demonstrating IPW with default logistic regression and custom RandomForestClassifier |
| CHANGELOG.md | Updated documentation to reflect the API change from sklearn_model to model parameter |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review for a chance to win a $100 gift card. Take the survey.
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
talgalili
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. Please see my comments
| Defaults to None. | ||
| Examples: | ||
| >>> from sklearn.ensemble import RandomForestClassifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
best to add an example that uses the simulated data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in 2a8d6d6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Notice that in the tutorials, you get the processed output of the code command. In the examples, they are not executed on the website - so it's worth adding here also the output.
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review for a chance to win a $100 gift card. Take the survey.
| logger.warning( | ||
| "penalty_factor is ignored when using a custom sklearn_model." | ||
| ) | ||
| logger.warning("penalty_factor is ignored when using a custom model.") |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning message should clarify that penalty_factor is only supported for the default logistic regression. Consider: 'penalty_factor is only supported with the default logistic regression model and will be ignored when using a custom classifier.'
| logger.warning("penalty_factor is ignored when using a custom model.") | |
| logger.warning("penalty_factor is only supported with the default logistic regression model and will be ignored when using a custom classifier.") |
| model_name = model | ||
| else: | ||
| raise TypeError( | ||
| "model must be 'sklearn', an sklearn classifier implementing predict_proba, or None" |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message mentions None as a valid option, but None is effectively treated as 'sklearn'. Consider clarifying: 'model must be "sklearn" (string), an sklearn classifier implementing predict_proba, or None (defaults to logistic regression)'
| "model must be 'sklearn', an sklearn classifier implementing predict_proba, or None" | |
| "model must be 'sklearn' (string), an sklearn classifier implementing predict_proba, or None (defaults to logistic regression)" |
| if not hasattr(custom_model, "predict_proba"): | ||
| raise ValueError( | ||
| "The provided sklearn_model must implement predict_proba for propensity estimation." | ||
| "The provided custom model must implement predict_proba for propensity estimation." |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error message could be more actionable. Consider: 'The provided custom model must implement the predict_proba method for propensity estimation. Ensure your classifier inherits from sklearn.base.ClassifierMixin and defines predict_proba.'
| "The provided custom model must implement predict_proba for propensity estimation." | |
| "The provided custom model must implement the predict_proba method for propensity estimation. " | |
| "Ensure your classifier inherits from sklearn.base.ClassifierMixin and defines predict_proba." |
| def test_ipw_supports_custom_model_parameter(self) -> None: | ||
| """The ``model`` parameter accepts sklearn classifiers directly.""" |
Copilot
AI
Nov 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test duplicates the coverage already provided by test_ipw_supports_custom_sklearn_model. Consider removing this test or expanding it to verify distinct behavior not covered by the existing test, such as testing with a different classifier or edge case.
talgalili
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great update. I made a bunch of comments - please review.
| Defaults to None. | ||
| Examples: | ||
| >>> from sklearn.ensemble import RandomForestClassifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
Notice that in the tutorials, you get the processed output of the code command. In the examples, they are not executed on the website - so it's worth adding here also the output.
balance/weighting_methods/ipw.py
Outdated
| using_default_logistic = sklearn_model is None | ||
| using_default_logistic = custom_model is None | ||
|
|
||
| if using_default_logistic: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to keep this as a variable. Just use
custom_model is None
And add as a comment # using_default_logistic
| return self.args.weight_trimming_mean_ratio | ||
|
|
||
| def logistic_regression_kwargs(self) -> Dict[str, Any] | None: | ||
| raw_kwargs = self.args.ipw_logistic_regression_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh no, please don't remove this. It's a great addition!
Just change it so that it works and uses the ipw_logistic_regression_kwargs as input to train a logisticregression model inside this function.
| @@ -0,0 +1,495 @@ | |||
| { | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
tutorials/balance_quickstart_ipw.ipynb
is identical to
tutorials/balance_quickstart.ipynb, just with a few more examples - I suggest you just add them to
tutorials/balance_quickstart.ipynb
| - Added `logistic_regression_kwargs` parameter to `ipw()` for customizing | ||
| sklearn LogisticRegression settings | ||
| ([#138](https://github.com/facebookresearch/balance/pull/138)). | ||
| - CLI now supports `--ipw_logistic_regression_kwargs` for passing custom |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written above - I think keeping this in the CLI is a good idea.
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review for a chance to win a $100 gift card. Take the survey.
talgalili
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copilot suggestions are good - please address them.
There are some linter issues from meta, but these are not things for you to fix (they deal with internal files). So I'll deal with that before landing.
Since there are no github workflows yet, I'll also update if there are any test failures (but I'll know only a bit later once it finishes running internally).
p.s.: thanks for all the commits, very cool work (and I have more ideas for stuff to do moving forward - but let's close 0.13.0 first)
|
@neuralsorcerer has updated the pull request. You must reimport the pull request before landing. |
talgalili
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: this works seems good, THANKS @neuralsorcerer !
I'll pull this and make some changes (some internal linter issues, and also I'll update the examples to work).
Since this is a 'heavy' PR, it might take my other metamates friends a few days to review/accept/land.
|
Thank you for all the help @talgalili :) |
|
FYI: |
|
Happy weekend bro :) |
|
@talgalili merged this pull request in 4e22220. |