1616from flaml import AutoML
1717from sklearn import base
1818
19+ from palma .base .splitting_strategy import ValidationStrategy
20+
1921try :
2022 from autosklearn .classification import AutoSklearnClassifier
2123 from autosklearn .regression import AutoSklearnRegressor
@@ -29,7 +31,8 @@ def __init__(self, engine_parameters: dict) -> None:
2931 self .__engine_parameters = engine_parameters
3032
3133 @abstractmethod
32- def optimize (self , X : pd .DataFrame , y : pd .Series , splitter = None
34+ def optimize (self , X : pd .DataFrame , y : pd .Series ,
35+ splitter : "ValidationStrategy" = None
3336 ) -> None :
3437 ...
3538
@@ -87,13 +90,13 @@ def optimize(self, X: pd.DataFrame, y: pd.Series, splitter=None) -> None:
8790 @property
8891 def optimizer (self ) -> Union [
8992 'AutoSklearnClassifier' ,
90- 'AutoSklearnRegressor' ]:
93+ 'AutoSklearnRegressor' ]:
9194 return self .__optimizer
9295
9396 @property
9497 def estimator_ (self ) -> Union [
9598 'AutoSklearnClassifier' ,
96- 'AutoSklearnRegressor' ]:
99+ 'AutoSklearnRegressor' ]:
97100 return self .__optimizer .get_models_with_weights ()
98101
99102 @property
@@ -111,14 +114,20 @@ def __init__(self, problem: str, engine_parameters: dict) -> None:
111114 )
112115 engine_parameters ["task" ] = problem
113116
114- def optimize (self , X : pd .DataFrame , y : pd .DataFrame , splitter = None
117+ def optimize (self , X : pd .DataFrame , y : pd .DataFrame ,
118+ splitter : ValidationStrategy = None
115119 ) -> None :
120+ split_type = None if splitter is None else splitter .splitter
121+ groups = None if splitter is None else splitter .groups
122+ groups = groups if groups is None else groups [splitter .train_index ]
123+
116124 self .allowing_splitter (splitter )
117125 self .__optimizer = AutoML ()
118126 self .__optimizer .fit (
119127 X_train = pd .DataFrame (X .values , index = range (len (X ))),
120128 y_train = pd .Series (y .values , index = range (len (X ))),
121- split_type = splitter , mlflow_logging = False ,
129+ split_type = split_type , groups = groups ,
130+ mlflow_logging = False ,
122131 ** self .engine_parameters
123132 )
124133
0 commit comments