From adb78f75f597cf08020aa95b3a19d7deb01b3d6d Mon Sep 17 00:00:00 2001 From: Martin Fleischmann Date: Tue, 17 Jun 2025 19:13:36 +0200 Subject: [PATCH 1/2] ENH: allow custom Graph object --- gwlearn/base.py | 23 +++++++++++++++++-- gwlearn/tests/test_base.py | 47 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/gwlearn/base.py b/gwlearn/base.py index 6e3ab40..ac7e44b 100644 --- a/gwlearn/base.py +++ b/gwlearn/base.py @@ -94,6 +94,7 @@ def __init__( ] | Callable = "bisquare", include_focal: bool = False, + graph: graph.Graph = None, n_jobs: int = -1, fit_global_model: bool = True, measure_performance: bool = True, @@ -108,6 +109,7 @@ def __init__( self.bandwidth = bandwidth self.kernel = kernel self.include_focal = include_focal + self.graph = graph self.fixed = fixed self._model_kwargs = kwargs self.n_jobs = n_jobs @@ -380,6 +382,10 @@ class BaseClassifier(_BaseModel, ClassifierMixin): futher spatial analysis of the model performance (and generalises to models that do not support OOB scoring). However, it leaves out the most representative sample. By default False + graph : Graph, optional + Custom libpysal.graph.Graph object encoding the spatial interaction between + observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`, + and `include_focal` keywords are ignored. n_jobs : int, optional The number of jobs to run in parallel. ``-1`` means using all processors by default ``-1`` @@ -471,6 +477,7 @@ def __init__( ] | Callable = "bisquare", include_focal: bool = False, + graph: graph.Graph = None, n_jobs: int = -1, fit_global_model: bool = True, measure_performance: bool = True, @@ -490,6 +497,7 @@ def __init__( fixed=fixed, kernel=kernel, include_focal=include_focal, + graph=graph, n_jobs=n_jobs, fit_global_model=fit_global_model, measure_performance=measure_performance, @@ -548,7 +556,11 @@ def _is_binary(series: pd.Series) -> bool: if self.verbose: print(f"{(time() - self._start):.2f}s: Building weights") - weights = self._build_weights(geometry) + + if self.graph is not None: + weights = self.graph + else: + weights = self._build_weights(geometry) if self.verbose: print(f"{(time() - self._start):.2f}s: Weights ready") self._setup_model_storage() @@ -877,6 +889,10 @@ class BaseRegressor(_BaseModel, RegressorMixin): futher spatial analysis of the model performance (and generalises to models that do not support OOB scoring). However, it leaves out the most representative sample. By default False + graph : Graph, optional + Custom libpysal.graph.Graph object encoding the spatial interaction between + observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`, + and `include_focal` keywords are ignored. n_jobs : int, optional The number of jobs to run in parallel. ``-1`` means using all processors by default ``-1`` @@ -925,7 +941,10 @@ def fit( """ self._validate_geometry(geometry) - weights = self._build_weights(geometry) + if self.graph is not None: + weights = self.graph + else: + weights = self._build_weights(geometry) self._setup_model_storage() # fit the models diff --git a/gwlearn/tests/test_base.py b/gwlearn/tests/test_base.py index 3d30bbc..0c59523 100644 --- a/gwlearn/tests/test_base.py +++ b/gwlearn/tests/test_base.py @@ -8,6 +8,7 @@ import pandas as pd import pytest from geodatasets import get_path +from libpysal.graph import Graph from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.linear_model import LinearRegression, LogisticRegression @@ -1440,3 +1441,49 @@ def test_regressor_fit_focal_inclusion(sample_regression_data): # RF should 'remember' focal point when included assert (y - no_focal.pred_).mean() > (y - focal.pred_).mean() + + +def test_custom_graph_baseregressor(sample_regression_data): + """Test BaseRegressor with a custom graph object.""" + X, y, geometry = sample_regression_data + + # Create a fixed distance weights graph + g = Graph.build_distance_band(geometry, threshold=150000, binary=False) + + # Create regressor with custom graph + reg = BaseRegressor( + LinearRegression, + bandwidth=100, # This should be ignored when custom graph is provided + fixed=False, # This should be ignored when custom graph is provided + graph=g, + ) + + # Fit the model + reg.fit(X, y, geometry) + + # Check that the model was fit successfully + assert hasattr(reg, "pred_") + assert hasattr(reg, "local_r2_") + + +def test_custom_graph_baseclassifier(sample_data): + """Test BaseClassifier with a custom graph object.""" + X, y, geometry = sample_data + + # Create a fixed distance weights graph + g = Graph.build_distance_band(geometry, threshold=150000, binary=False) + # Create classifier with custom graph + clf = BaseClassifier( + LogisticRegression, + bandwidth=100, # This should be ignored when custom graph is provided + fixed=True, # This should be ignored when custom graph is provided + graph=g, + max_iter=500, + ) + + # Fit the model + clf.fit(X, y, geometry) + + # Check that the model was fit successfully + assert hasattr(clf, "proba_") + assert hasattr(clf, "score_") From aa08d4355e2858829c35d9718aa5e9958a970493 Mon Sep 17 00:00:00 2001 From: Martin Fleischmann Date: Wed, 18 Jun 2025 09:12:30 +0200 Subject: [PATCH 2/2] expose in models --- gwlearn/ensemble.py | 13 +++++++++++++ gwlearn/linear_model.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/gwlearn/ensemble.py b/gwlearn/ensemble.py index 12cf787..2e731ad 100644 --- a/gwlearn/ensemble.py +++ b/gwlearn/ensemble.py @@ -6,6 +6,7 @@ import geopandas as gpd import numpy as np import pandas as pd +from libpysal import graph from sklearn import metrics from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier @@ -31,6 +32,10 @@ class GWRandomForestClassifier(BaseClassifier): analysis of the model performance (and generalises to models that do not support OOB scoring). However, it leaves out the most representative sample. By default False + graph : Graph, optional + Custom libpysal.graph.Graph object encoding the spatial interaction between + observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`, + and `include_focal` keywords are ignored. n_jobs : int, optional The number of jobs to run in parallel. ``-1`` means using all processors by default ``-1`` @@ -153,6 +158,7 @@ def __init__( ] | Callable = "bisquare", include_focal: bool = False, + graph: graph.Graph = None, n_jobs: int = -1, fit_global_model: bool = True, measure_performance: bool = True, @@ -172,6 +178,7 @@ def __init__( fixed=fixed, kernel=kernel, include_focal=include_focal, + graph=graph, n_jobs=n_jobs, fit_global_model=fit_global_model, measure_performance=measure_performance, @@ -303,6 +310,10 @@ class GWGradientBoostingClassifier(BaseClassifier): futher spatial analysis of the model performance (and generalises to models that do not support OOB scoring). However, it leaves out the most representative sample. By default False + graph : Graph, optional + Custom libpysal.graph.Graph object encoding the spatial interaction between + observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`, + and `include_focal` keywords are ignored. n_jobs : int, optional The number of jobs to run in parallel. ``-1`` means using all processors by default ``-1`` @@ -393,6 +404,7 @@ def __init__( ] | Callable = "bisquare", include_focal: bool = False, + graph: graph.Graph = None, n_jobs: int = -1, fit_global_model: bool = True, measure_performance: bool = True, @@ -408,6 +420,7 @@ def __init__( fixed=fixed, kernel=kernel, include_focal=include_focal, + graph=graph, n_jobs=n_jobs, fit_global_model=fit_global_model, measure_performance=measure_performance, diff --git a/gwlearn/linear_model.py b/gwlearn/linear_model.py index 240240f..ecaca1b 100644 --- a/gwlearn/linear_model.py +++ b/gwlearn/linear_model.py @@ -5,6 +5,7 @@ import geopandas as gpd import numpy as np import pandas as pd +from libpysal import graph from sklearn import metrics from sklearn.linear_model import LinearRegression, LogisticRegression @@ -30,6 +31,10 @@ class GWLogisticRegression(BaseClassifier): futher spatial analysis of the model performance (and generalises to models that do not support OOB scoring). However, it leaves out the most representative sample. By default True + graph : Graph, optional + Custom libpysal.graph.Graph object encoding the spatial interaction between + observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`, + and `include_focal` keywords are ignored. n_jobs : int, optional The number of jobs to run in parallel. ``-1`` means using all processors by default ``-1`` @@ -158,6 +163,7 @@ def __init__( ] | Callable = "bisquare", include_focal: bool = True, + graph: graph.Graph = None, n_jobs: int = -1, fit_global_model: bool = True, measure_performance: bool = True, @@ -174,6 +180,7 @@ def __init__( fixed=fixed, kernel=kernel, include_focal=include_focal, + graph=graph, n_jobs=n_jobs, fit_global_model=fit_global_model, measure_performance=measure_performance, @@ -306,6 +313,10 @@ class GWLinearRegression(BaseRegressor): futher spatial analysis of the model performance (and generalises to models that do not support OOB scoring). However, it leaves out the most representative sample. By default True + graph : Graph, optional + Custom libpysal.graph.Graph object encoding the spatial interaction between + observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`, + and `include_focal` keywords are ignored. n_jobs : int, optional The number of jobs to run in parallel. ``-1`` means using all processors by default ``-1`` @@ -353,6 +364,7 @@ def __init__( ] | Callable = "bisquare", include_focal: bool = True, + graph: graph.Graph = None, n_jobs: int = -1, fit_global_model: bool = True, measure_performance: bool = True, @@ -367,6 +379,7 @@ def __init__( fixed=fixed, kernel=kernel, include_focal=include_focal, + graph=graph, n_jobs=n_jobs, fit_global_model=fit_global_model, measure_performance=measure_performance,