From af02ab9d906c9dae3c734eab07ee8ef00b683bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jimmy=20Charit=C3=A9?= Date: Mon, 23 Mar 2026 08:53:16 -0400 Subject: [PATCH] Expose CATE estimation and inference methods on DRPolicyForest/Tree MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DRPolicyForest and DRPolicyTree internally fit a DRLearner but do not expose its CATE estimation and inference API. Users who need both optimal policy assignments and CATE confidence intervals may consider strategies such as a 3-way data split: 1. Data for learning CATEs (ForestDRLearner) 2. Data for learning policies (DRPolicyForest using out-of-sample rewards) 3. Data for out-of-sample evaluation with CIs While statistically valid, this approach can be less data-efficient. EconML's cross-fitting within DRPolicyForest already provides noise separation between CATE estimation and policy learning, which could allow users to consolidate the first two splits. Exposing the underlying CATE inference methods enables this more data-efficient workflow. Changes: - Add delegation methods on _BaseDRPolicyLearner for effect(), effect_interval(), effect_inference(), const_marginal_effect(), const_marginal_effect_interval(), const_marginal_effect_inference(), marginal_effect(), ate(), ate_interval(), ate_inference(), shap_values(), score(), and model_final_ - Pass inference parameter through _BaseDRPolicyLearner.fit() to the underlying DRLearner - Fit per-treatment RegressionForest CATE models alongside the policy model in _PolicyModelFinal to support GenericModelFinalInferenceDiscrete - Override _get_inference_options in _DRLearnerWrapper to enable automatic inference when CATE models are available Co-Authored-By: Claude Opus 4.6 Signed-off-by: Jimmy Charité --- econml/policy/_drlearner.py | 294 +++++++++++++++++++++++++++++++++++- 1 file changed, 288 insertions(+), 6 deletions(-) diff --git a/econml/policy/_drlearner.py b/econml/policy/_drlearner.py index 8f9ef7082..01a6b0947 100644 --- a/econml/policy/_drlearner.py +++ b/econml/policy/_drlearner.py @@ -7,33 +7,50 @@ from ..utilities import filter_none_kwargs, check_input_arrays from ..dr import DRLearner from ..dr._drlearner import _ModelFinal +from ..inference import GenericModelFinalInferenceDiscrete +from ..grf import RegressionForest from ._base import PolicyLearner from . import PolicyTree, PolicyForest class _PolicyModelFinal(_ModelFinal): + def __init__(self, model_final, featurizer, multitask_model_final, cate_model=None): + super().__init__(model_final, featurizer, multitask_model_final) + self._cate_model = cate_model + def fit(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, freq_weight=None, sample_var=None, groups=None): if sample_var is not None: warn('Parameter `sample_var` is ignored by the final estimator') sample_var = None - Y_pred, _, _ = nuisances + Y_pred = nuisances[0] self.d_y = Y_pred.shape[1:-1] # track whether there's a Y dimension (must be a singleton) + self.d_t = Y_pred.shape[-1] - 1 if (X is not None) and (self._featurizer is not None): X = self._featurizer.fit_transform(X) filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, sample_var=sample_var) ys = Y_pred[..., 1:] - Y_pred[..., [0]] # subtract control results from each other arm if self.d_y: # need to squeeze out singleton so that we fit on 2D array ys = ys.squeeze(1) - ys = np.hstack([np.zeros((ys.shape[0], 1)), ys]) - self.model_cate = self._model_final.fit(X, ys, **filtered_kwargs) + ys_with_control = np.hstack([np.zeros((ys.shape[0], 1)), ys]) + self.model_cate = self._model_final.fit(X, ys_with_control, **filtered_kwargs) + + # Also fit per-treatment CATE models for inference support + if self._cate_model is not None: + self.models_cate = [clone(self._cate_model, safe=False).fit(X, ys[..., t], **filtered_kwargs) + for t in range(self.d_t)] + return self def predict(self, X=None): if (X is not None) and (self._featurizer is not None): X = self._featurizer.transform(X) + # Use per-treatment CATE models for prediction when available (supports predict_interval) + if hasattr(self, 'models_cate') and self.models_cate is not None: + preds = np.array([mdl.predict(X).reshape((-1,) + self.d_y) for mdl in self.models_cate]) + return np.moveaxis(preds, 0, -1) # move treatment dim to end pred = self.model_cate.predict_value(X)[:, 1:] if self.d_y: # need to reintroduce singleton Y dimension return pred[:, np.newaxis, :] @@ -45,8 +62,19 @@ def score(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, groups=N class _DRLearnerWrapper(DRLearner): + def __init__(self, *args, cate_model=None, **kwargs): + super().__init__(*args, **kwargs) + self._cate_model = cate_model + def _gen_ortho_learner_model_final(self): - return _PolicyModelFinal(self._gen_model_final(), self._gen_featurizer(), self.multitask_model_final) + return _PolicyModelFinal(self._gen_model_final(), self._gen_featurizer(), + self.multitask_model_final, cate_model=self._cate_model) + + def _get_inference_options(self): + options = super()._get_inference_options() + if self._cate_model is not None: + options.update(auto=GenericModelFinalInferenceDiscrete) + return options class _BaseDRPolicyLearner(PolicyLearner): @@ -54,7 +82,7 @@ class _BaseDRPolicyLearner(PolicyLearner): def _gen_drpolicy_learner(self): pass - def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None): + def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None, inference='auto'): """ Estimate a policy model from data. @@ -74,13 +102,18 @@ def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None): All rows corresponding to the same group will be kept together during splitting. If groups is not None, the `cv` argument passed to this class's initializer must support a 'groups' argument to its split method. + inference: str or :class:`.Inference` instance, optional + Method for performing inference. All estimators support ``'bootstrap'`` + (or an instance of :class:`.BootstrapInference`). The default is ``'auto'``, + which uses the built-in inference method of the underlying DRLearner. Returns ------- self: object instance """ self.drlearner_ = self._gen_drpolicy_learner() - self.drlearner_.fit(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups) + self.drlearner_.fit(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups, + inference=inference) return self def predict_value(self, X): @@ -99,6 +132,245 @@ def predict_value(self, X): """ return self.drlearner_.const_marginal_effect(X) + # ── CATE estimation and inference (delegated to internal DRLearner) ── + + def effect(self, X=None, *, T0=0, T1=1): + """Calculate the heterogeneous treatment effect τ(X) = E[Y(T1) - Y(T0) | X]. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + effect : array_like of shape (n_samples, n_outcomes) + The heterogeneous treatment effect for each sample. + """ + return self.drlearner_.effect(X, T0=T0, T1=T1) + + def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.05): + """Get confidence interval for the heterogeneous treatment effect τ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + alpha : float, default 0.05 + The significance level. The confidence interval is (1 - alpha)%. + + Returns + ------- + lower, upper : tuple of array_like of shape (n_samples, n_outcomes) + Lower and upper bounds of the confidence interval. + """ + return self.drlearner_.effect_interval(X, T0=T0, T1=T1, alpha=alpha) + + def effect_inference(self, X=None, *, T0=0, T1=1): + """Get inference results for the heterogeneous treatment effect τ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + inference_results : :class:`~econml.inference.NormalInferenceResults` + Inference results including point estimates, confidence intervals, and p-values. + """ + return self.drlearner_.effect_inference(X, T0=T0, T1=T1) + + def const_marginal_effect(self, X=None): + """Calculate the constant marginal CATE θ(X) for each non-baseline treatment. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + theta : array_like of shape (n_samples, n_treatments - 1) + The constant marginal effect for each sample and treatment. + """ + return self.drlearner_.const_marginal_effect(X) + + def const_marginal_effect_interval(self, X=None, *, alpha=0.05): + """Get confidence interval for the constant marginal CATE θ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + alpha : float, default 0.05 + The significance level. The confidence interval is (1 - alpha)%. + + Returns + ------- + lower, upper : tuple of array_like + Lower and upper bounds of the confidence interval. + """ + return self.drlearner_.const_marginal_effect_interval(X, alpha=alpha) + + def const_marginal_effect_inference(self, X=None): + """Get inference results for the constant marginal CATE θ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + inference_results : :class:`~econml.inference.NormalInferenceResults` + Inference results including point estimates, confidence intervals, and p-values. + """ + return self.drlearner_.const_marginal_effect_inference(X) + + def marginal_effect(self, T, X=None): + """Calculate the heterogeneous marginal effect ∂τ(T, X). + + Parameters + ---------- + T : array_like of shape (n_samples,) + Treatment values at which to evaluate the marginal effect. + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + marginal_effect : array_like + The marginal effect for each sample. + """ + return self.drlearner_.marginal_effect(T, X) + + def ate(self, X=None, *, T0=0, T1=1): + """Calculate the average treatment effect E_X[τ(X)]. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + ate : scalar or array_like + The average treatment effect. + """ + return self.drlearner_.ate(X, T0=T0, T1=T1) + + def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.05): + """Get confidence interval for the average treatment effect. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + alpha : float, default 0.05 + The significance level. + + Returns + ------- + lower, upper : tuple of scalars or array_like + Lower and upper bounds of the confidence interval. + """ + return self.drlearner_.ate_interval(X, T0=T0, T1=T1, alpha=alpha) + + def ate_inference(self, X=None, *, T0=0, T1=1): + """Get inference results for the average treatment effect. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + inference_results : :class:`~econml.inference.NormalInferenceResults` + Inference results including point estimates, confidence intervals, and p-values. + """ + return self.drlearner_.ate_inference(X, T0=T0, T1=T1) + + def shap_values(self, X, *, feature_names=None, treatment_names=None, + output_names=None, background_samples=100): + """Get SHAP values for the CATE model. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features) + Features for each sample. + feature_names : list of str, optional + The names of the input features. + treatment_names : list of str, optional + The names of the treatments. + output_names : list of str, optional + The names of the outputs. + background_samples : int, default 100 + Number of background samples for SHAP. + + Returns + ------- + shap_values : object + SHAP values for the CATE model. + """ + return self.drlearner_.shap_values(X, feature_names=feature_names, + treatment_names=treatment_names, + output_names=output_names, + background_samples=background_samples) + + def score(self, Y, T, X=None, W=None, *, sample_weight=None): + """Score the fitted CATE model on new data. + + Parameters + ---------- + Y : array_like of shape (n_samples,) + Outcomes for each sample. + T : array_like of shape (n_samples,) + Treatments for each sample. + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + W : array_like of shape (n_samples, n_controls), optional + Controls for each sample. + sample_weight : array_like of shape (n_samples,), optional + Weights for each sample. + + Returns + ------- + score : float + The score of the CATE model. + """ + return self.drlearner_.score(Y, T, X=X, W=W, sample_weight=sample_weight) + + @property + def model_final_(self): + """The fitted final model of the underlying DRLearner.""" + return self.drlearner_.model_final_ + def predict_proba(self, X): """Predict the probability of recommending each treatment. @@ -436,6 +708,10 @@ def _gen_drpolicy_learner(self): honest=self.honest, random_state=self.random_state), multitask_model_final=True, + cate_model=RegressionForest( + min_samples_leaf=self.min_samples_leaf, + honest=self.honest, + random_state=self.random_state), random_state=self.random_state) def plot(self, *, feature_names=None, treatment_names=None, ax=None, title=None, @@ -868,6 +1144,12 @@ def _gen_drpolicy_learner(self): verbose=self.verbose, random_state=self.random_state), multitask_model_final=True, + cate_model=RegressionForest( + n_estimators=max(4, 4 * (self.n_estimators // 4)), + min_samples_leaf=self.min_samples_leaf, + honest=self.honest, + n_jobs=self.n_jobs, + random_state=self.random_state), random_state=self.random_state) def plot(self, tree_id, *, feature_names=None, treatment_names=None,