Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 288 additions & 6 deletions econml/policy/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,50 @@
from ..utilities import filter_none_kwargs, check_input_arrays
from ..dr import DRLearner
from ..dr._drlearner import _ModelFinal
from ..inference import GenericModelFinalInferenceDiscrete
from ..grf import RegressionForest
from ._base import PolicyLearner
from . import PolicyTree, PolicyForest


class _PolicyModelFinal(_ModelFinal):

def __init__(self, model_final, featurizer, multitask_model_final, cate_model=None):
super().__init__(model_final, featurizer, multitask_model_final)
self._cate_model = cate_model

def fit(self, Y, T, X=None, W=None, *, nuisances,
sample_weight=None, freq_weight=None, sample_var=None, groups=None):
if sample_var is not None:
warn('Parameter `sample_var` is ignored by the final estimator')
sample_var = None
Y_pred, _, _ = nuisances
Y_pred = nuisances[0]

self.d_y = Y_pred.shape[1:-1] # track whether there's a Y dimension (must be a singleton)
self.d_t = Y_pred.shape[-1] - 1
if (X is not None) and (self._featurizer is not None):
X = self._featurizer.fit_transform(X)
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, sample_var=sample_var)
ys = Y_pred[..., 1:] - Y_pred[..., [0]] # subtract control results from each other arm
if self.d_y: # need to squeeze out singleton so that we fit on 2D array
ys = ys.squeeze(1)
ys = np.hstack([np.zeros((ys.shape[0], 1)), ys])
self.model_cate = self._model_final.fit(X, ys, **filtered_kwargs)
ys_with_control = np.hstack([np.zeros((ys.shape[0], 1)), ys])
self.model_cate = self._model_final.fit(X, ys_with_control, **filtered_kwargs)

# Also fit per-treatment CATE models for inference support
if self._cate_model is not None:
self.models_cate = [clone(self._cate_model, safe=False).fit(X, ys[..., t], **filtered_kwargs)
for t in range(self.d_t)]

return self

def predict(self, X=None):
if (X is not None) and (self._featurizer is not None):
X = self._featurizer.transform(X)
# Use per-treatment CATE models for prediction when available (supports predict_interval)
if hasattr(self, 'models_cate') and self.models_cate is not None:
preds = np.array([mdl.predict(X).reshape((-1,) + self.d_y) for mdl in self.models_cate])
return np.moveaxis(preds, 0, -1) # move treatment dim to end
pred = self.model_cate.predict_value(X)[:, 1:]
if self.d_y: # need to reintroduce singleton Y dimension
return pred[:, np.newaxis, :]
Expand All @@ -45,16 +62,27 @@ def score(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, groups=N

class _DRLearnerWrapper(DRLearner):

def __init__(self, *args, cate_model=None, **kwargs):
super().__init__(*args, **kwargs)
self._cate_model = cate_model

def _gen_ortho_learner_model_final(self):
return _PolicyModelFinal(self._gen_model_final(), self._gen_featurizer(), self.multitask_model_final)
return _PolicyModelFinal(self._gen_model_final(), self._gen_featurizer(),
self.multitask_model_final, cate_model=self._cate_model)

def _get_inference_options(self):
options = super()._get_inference_options()
if self._cate_model is not None:
options.update(auto=GenericModelFinalInferenceDiscrete)
return options


class _BaseDRPolicyLearner(PolicyLearner):

def _gen_drpolicy_learner(self):
pass

def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None):
def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None, inference='auto'):
"""
Estimate a policy model from data.

Expand All @@ -74,13 +102,18 @@ def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None):
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the `cv` argument passed to this class's initializer
must support a 'groups' argument to its split method.
inference: str or :class:`.Inference` instance, optional
Method for performing inference. All estimators support ``'bootstrap'``
(or an instance of :class:`.BootstrapInference`). The default is ``'auto'``,
which uses the built-in inference method of the underlying DRLearner.

Returns
-------
self: object instance
"""
self.drlearner_ = self._gen_drpolicy_learner()
self.drlearner_.fit(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups)
self.drlearner_.fit(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups,
inference=inference)
return self

def predict_value(self, X):
Expand All @@ -99,6 +132,245 @@ def predict_value(self, X):
"""
return self.drlearner_.const_marginal_effect(X)

# ── CATE estimation and inference (delegated to internal DRLearner) ──

def effect(self, X=None, *, T0=0, T1=1):
"""Calculate the heterogeneous treatment effect τ(X) = E[Y(T1) - Y(T0) | X].

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
T0 : int or array_like, default 0
Baseline treatment.
T1 : int or array_like, default 1
Target treatment.

Returns
-------
effect : array_like of shape (n_samples, n_outcomes)
The heterogeneous treatment effect for each sample.
"""
return self.drlearner_.effect(X, T0=T0, T1=T1)

def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.05):
"""Get confidence interval for the heterogeneous treatment effect τ(X).

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
T0 : int or array_like, default 0
Baseline treatment.
T1 : int or array_like, default 1
Target treatment.
alpha : float, default 0.05
The significance level. The confidence interval is (1 - alpha)%.

Returns
-------
lower, upper : tuple of array_like of shape (n_samples, n_outcomes)
Lower and upper bounds of the confidence interval.
"""
return self.drlearner_.effect_interval(X, T0=T0, T1=T1, alpha=alpha)

def effect_inference(self, X=None, *, T0=0, T1=1):
"""Get inference results for the heterogeneous treatment effect τ(X).

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
T0 : int or array_like, default 0
Baseline treatment.
T1 : int or array_like, default 1
Target treatment.

Returns
-------
inference_results : :class:`~econml.inference.NormalInferenceResults`
Inference results including point estimates, confidence intervals, and p-values.
"""
return self.drlearner_.effect_inference(X, T0=T0, T1=T1)

def const_marginal_effect(self, X=None):
"""Calculate the constant marginal CATE θ(X) for each non-baseline treatment.

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.

Returns
-------
theta : array_like of shape (n_samples, n_treatments - 1)
The constant marginal effect for each sample and treatment.
"""
return self.drlearner_.const_marginal_effect(X)

def const_marginal_effect_interval(self, X=None, *, alpha=0.05):
"""Get confidence interval for the constant marginal CATE θ(X).

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
alpha : float, default 0.05
The significance level. The confidence interval is (1 - alpha)%.

Returns
-------
lower, upper : tuple of array_like
Lower and upper bounds of the confidence interval.
"""
return self.drlearner_.const_marginal_effect_interval(X, alpha=alpha)

def const_marginal_effect_inference(self, X=None):
"""Get inference results for the constant marginal CATE θ(X).

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.

Returns
-------
inference_results : :class:`~econml.inference.NormalInferenceResults`
Inference results including point estimates, confidence intervals, and p-values.
"""
return self.drlearner_.const_marginal_effect_inference(X)

def marginal_effect(self, T, X=None):
"""Calculate the heterogeneous marginal effect ∂τ(T, X).

Parameters
----------
T : array_like of shape (n_samples,)
Treatment values at which to evaluate the marginal effect.
X : array_like of shape (n_samples, n_features), optional
Features for each sample.

Returns
-------
marginal_effect : array_like
The marginal effect for each sample.
"""
return self.drlearner_.marginal_effect(T, X)

def ate(self, X=None, *, T0=0, T1=1):
"""Calculate the average treatment effect E_X[τ(X)].

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
T0 : int or array_like, default 0
Baseline treatment.
T1 : int or array_like, default 1
Target treatment.

Returns
-------
ate : scalar or array_like
The average treatment effect.
"""
return self.drlearner_.ate(X, T0=T0, T1=T1)

def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.05):
"""Get confidence interval for the average treatment effect.

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
T0 : int or array_like, default 0
Baseline treatment.
T1 : int or array_like, default 1
Target treatment.
alpha : float, default 0.05
The significance level.

Returns
-------
lower, upper : tuple of scalars or array_like
Lower and upper bounds of the confidence interval.
"""
return self.drlearner_.ate_interval(X, T0=T0, T1=T1, alpha=alpha)

def ate_inference(self, X=None, *, T0=0, T1=1):
"""Get inference results for the average treatment effect.

Parameters
----------
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
T0 : int or array_like, default 0
Baseline treatment.
T1 : int or array_like, default 1
Target treatment.

Returns
-------
inference_results : :class:`~econml.inference.NormalInferenceResults`
Inference results including point estimates, confidence intervals, and p-values.
"""
return self.drlearner_.ate_inference(X, T0=T0, T1=T1)

def shap_values(self, X, *, feature_names=None, treatment_names=None,
output_names=None, background_samples=100):
"""Get SHAP values for the CATE model.

Parameters
----------
X : array_like of shape (n_samples, n_features)
Features for each sample.
feature_names : list of str, optional
The names of the input features.
treatment_names : list of str, optional
The names of the treatments.
output_names : list of str, optional
The names of the outputs.
background_samples : int, default 100
Number of background samples for SHAP.

Returns
-------
shap_values : object
SHAP values for the CATE model.
"""
return self.drlearner_.shap_values(X, feature_names=feature_names,
treatment_names=treatment_names,
output_names=output_names,
background_samples=background_samples)

def score(self, Y, T, X=None, W=None, *, sample_weight=None):
"""Score the fitted CATE model on new data.

Parameters
----------
Y : array_like of shape (n_samples,)
Outcomes for each sample.
T : array_like of shape (n_samples,)
Treatments for each sample.
X : array_like of shape (n_samples, n_features), optional
Features for each sample.
W : array_like of shape (n_samples, n_controls), optional
Controls for each sample.
sample_weight : array_like of shape (n_samples,), optional
Weights for each sample.

Returns
-------
score : float
The score of the CATE model.
"""
return self.drlearner_.score(Y, T, X=X, W=W, sample_weight=sample_weight)

@property
def model_final_(self):
"""The fitted final model of the underlying DRLearner."""
return self.drlearner_.model_final_

def predict_proba(self, X):
"""Predict the probability of recommending each treatment.

Expand Down Expand Up @@ -436,6 +708,10 @@ def _gen_drpolicy_learner(self):
honest=self.honest,
random_state=self.random_state),
multitask_model_final=True,
cate_model=RegressionForest(
min_samples_leaf=self.min_samples_leaf,
honest=self.honest,
random_state=self.random_state),
random_state=self.random_state)

def plot(self, *, feature_names=None, treatment_names=None, ax=None, title=None,
Expand Down Expand Up @@ -868,6 +1144,12 @@ def _gen_drpolicy_learner(self):
verbose=self.verbose,
random_state=self.random_state),
multitask_model_final=True,
cate_model=RegressionForest(
n_estimators=max(4, 4 * (self.n_estimators // 4)),
min_samples_leaf=self.min_samples_leaf,
honest=self.honest,
n_jobs=self.n_jobs,
random_state=self.random_state),
random_state=self.random_state)

def plot(self, tree_id, *, feature_names=None, treatment_names=None,
Expand Down