Skip to content

Commit 6fec7d5

Browse files
[DEV] correct logging and various bug
1 parent ac69a1e commit 6fec7d5

File tree

5 files changed

+94
-161
lines changed

5 files changed

+94
-161
lines changed

examples/classification.ipynb

Lines changed: 41 additions & 129 deletions
Large diffs are not rendered by default.

palma/components/performance.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,12 @@ def compute_metrics(self, metric: dict):
110110
from palma import logger
111111
for name, fun in metric.items():
112112
self._compute_metric(name, fun)
113-
logger.logger.log_metrics(
114-
{k: str(v) for k, v in self.get_test_metrics().to_dict().items()},
115-
path="metrics")
113+
114+
for m_name, metric_fold in self.get_test_metrics().to_dict().items():
115+
for k, v in metric_fold.items():
116+
if isinstance(v, float) or isinstance(v, int):
117+
logger.logger.log_metrics(
118+
{f"{m_name}_fold{k}": v}, path="metrics")
116119

117120
def _compute_metric(self, name: str, fun: typing.Callable):
118121
"""

tests/conftest.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,29 +72,6 @@ def learning_data(classification_project, classification_data):
7272
return classification_project, learn, X, y
7373

7474

75-
@pytest.fixture(scope='module')
76-
def get_scoring_analyser(learning_data):
77-
project, model, X, y = learning_data
78-
perf = performance.ScoringAnalysis(on="indexes_train_test")
79-
perf._add(project, model)
80-
81-
perf.compute_metrics(metric={
82-
metrics.roc_auc_score.__name__: metrics.roc_auc_score,
83-
metrics.roc_curve.__name__: metrics.roc_curve
84-
})
85-
return perf
86-
87-
88-
@pytest.fixture(scope='module')
89-
def get_shap_analyser(learning_data):
90-
project, model, X, y = learning_data
91-
perf = performance.ShapAnalysis(on="indexes_val", n_shap=100,
92-
compute_interaction=True)
93-
perf(project, model)
94-
95-
return perf
96-
97-
9875
@pytest.fixture(scope='module')
9976
def learning_data_regression(regression_data):
10077
from palma import set_logger
@@ -129,9 +106,6 @@ def get_regression_analyser(learning_data_regression):
129106
return perf
130107

131108

132-
133-
134-
135109
@pytest.fixture(scope='module')
136110
def build_classification_project(unbuilt_classification_project,
137111
classification_data):

tests/test_component/test_logger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,3 @@ def test_artifact_logging():
146146
logger.logger.log_metrics({'a': 1}, "metric")
147147
logger.logger.log_artifact(fig, "figure")
148148

149-

tests/test_component/test_performance.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,59 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111
import matplotlib
12+
import numpy as np
1213
import pytest
13-
from sklearn import metrics
14+
import pandas as pd
15+
from sklearn import metrics, model_selection
1416

15-
from palma.components import performance
17+
from palma.components import performance, FileSystemLogger, MLFlowLogger
18+
from sklearn.ensemble import RandomForestClassifier
19+
import tempfile
20+
from palma import ModelEvaluation, Project
21+
from palma import set_logger
1622

1723
matplotlib.use("agg")
1824

1925

26+
@pytest.fixture(scope='module')
27+
def get_scoring_analyser(classification_data):
28+
set_logger(FileSystemLogger(tempfile.gettempdir() + "/logger"))
29+
30+
X, y = classification_data
31+
X = pd.DataFrame(X)
32+
y = pd.Series(y)
33+
project = Project(problem="classification",
34+
project_name=str(np.random.uniform()))
35+
36+
project.start(
37+
X, y,
38+
splitter=model_selection.ShuffleSplit(n_splits=4, random_state=42))
39+
estimator = RandomForestClassifier()
40+
41+
learn = ModelEvaluation(estimator)
42+
learn.fit(project)
43+
44+
perf = performance.ScoringAnalysis(on="indexes_val")
45+
perf(project, learn)
46+
47+
perf.compute_metrics(metric={
48+
metrics.roc_auc_score.__name__: metrics.roc_auc_score,
49+
metrics.roc_curve.__name__: metrics.roc_curve
50+
})
51+
return perf
52+
53+
54+
@pytest.fixture(scope='module')
55+
def get_shap_analyser(learning_data):
56+
project, model, X, y = learning_data
57+
perf = performance.ShapAnalysis(on="indexes_val", n_shap=100,
58+
compute_interaction=True)
59+
60+
perf(project, model)
61+
62+
return perf
63+
64+
2065
def test_classification_perf(get_scoring_analyser):
2166
performance.plot.figure(figsize=(6, 6), dpi=200)
2267
get_scoring_analyser.plot_roc_curve(

0 commit comments

Comments
 (0)