Skip to content

ENH: allow custom Graph object #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions gwlearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
]
| Callable = "bisquare",
include_focal: bool = False,
graph: graph.Graph = None,
n_jobs: int = -1,
fit_global_model: bool = True,
measure_performance: bool = True,
Expand All @@ -108,6 +109,7 @@ def __init__(
self.bandwidth = bandwidth
self.kernel = kernel
self.include_focal = include_focal
self.graph = graph
self.fixed = fixed
self._model_kwargs = kwargs
self.n_jobs = n_jobs
Expand Down Expand Up @@ -380,6 +382,10 @@ class BaseClassifier(_BaseModel, ClassifierMixin):
futher spatial analysis of the model performance (and generalises to models
that do not support OOB scoring). However, it leaves out the most representative
sample. By default False
graph : Graph, optional
Custom libpysal.graph.Graph object encoding the spatial interaction between
observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`,
and `include_focal` keywords are ignored.
n_jobs : int, optional
The number of jobs to run in parallel. ``-1`` means using all processors
by default ``-1``
Expand Down Expand Up @@ -471,6 +477,7 @@ def __init__(
]
| Callable = "bisquare",
include_focal: bool = False,
graph: graph.Graph = None,
n_jobs: int = -1,
fit_global_model: bool = True,
measure_performance: bool = True,
Expand All @@ -490,6 +497,7 @@ def __init__(
fixed=fixed,
kernel=kernel,
include_focal=include_focal,
graph=graph,
n_jobs=n_jobs,
fit_global_model=fit_global_model,
measure_performance=measure_performance,
Expand Down Expand Up @@ -548,7 +556,11 @@ def _is_binary(series: pd.Series) -> bool:

if self.verbose:
print(f"{(time() - self._start):.2f}s: Building weights")
weights = self._build_weights(geometry)

if self.graph is not None:
weights = self.graph
else:
weights = self._build_weights(geometry)
if self.verbose:
print(f"{(time() - self._start):.2f}s: Weights ready")
self._setup_model_storage()
Expand Down Expand Up @@ -877,6 +889,10 @@ class BaseRegressor(_BaseModel, RegressorMixin):
futher spatial analysis of the model performance (and generalises to models
that do not support OOB scoring). However, it leaves out the most representative
sample. By default False
graph : Graph, optional
Custom libpysal.graph.Graph object encoding the spatial interaction between
observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`,
and `include_focal` keywords are ignored.
n_jobs : int, optional
The number of jobs to run in parallel. ``-1`` means using all processors
by default ``-1``
Expand Down Expand Up @@ -925,7 +941,10 @@ def fit(
"""
self._validate_geometry(geometry)

weights = self._build_weights(geometry)
if self.graph is not None:
weights = self.graph
else:
weights = self._build_weights(geometry)
self._setup_model_storage()

# fit the models
Expand Down
13 changes: 13 additions & 0 deletions gwlearn/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
from libpysal import graph
from sklearn import metrics
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier

Expand All @@ -31,6 +32,10 @@ class GWRandomForestClassifier(BaseClassifier):
analysis of the model performance (and generalises to models that do not support
OOB scoring). However, it leaves out the most representative sample. By default
False
graph : Graph, optional
Custom libpysal.graph.Graph object encoding the spatial interaction between
observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`,
and `include_focal` keywords are ignored.
n_jobs : int, optional
The number of jobs to run in parallel. ``-1`` means using all processors by
default ``-1``
Expand Down Expand Up @@ -153,6 +158,7 @@ def __init__(
]
| Callable = "bisquare",
include_focal: bool = False,
graph: graph.Graph = None,
n_jobs: int = -1,
fit_global_model: bool = True,
measure_performance: bool = True,
Expand All @@ -172,6 +178,7 @@ def __init__(
fixed=fixed,
kernel=kernel,
include_focal=include_focal,
graph=graph,
n_jobs=n_jobs,
fit_global_model=fit_global_model,
measure_performance=measure_performance,
Expand Down Expand Up @@ -303,6 +310,10 @@ class GWGradientBoostingClassifier(BaseClassifier):
futher spatial analysis of the model performance (and generalises to models
that do not support OOB scoring). However, it leaves out the most representative
sample. By default False
graph : Graph, optional
Custom libpysal.graph.Graph object encoding the spatial interaction between
observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`,
and `include_focal` keywords are ignored.
n_jobs : int, optional
The number of jobs to run in parallel. ``-1`` means using all processors
by default ``-1``
Expand Down Expand Up @@ -393,6 +404,7 @@ def __init__(
]
| Callable = "bisquare",
include_focal: bool = False,
graph: graph.Graph = None,
n_jobs: int = -1,
fit_global_model: bool = True,
measure_performance: bool = True,
Expand All @@ -408,6 +420,7 @@ def __init__(
fixed=fixed,
kernel=kernel,
include_focal=include_focal,
graph=graph,
n_jobs=n_jobs,
fit_global_model=fit_global_model,
measure_performance=measure_performance,
Expand Down
13 changes: 13 additions & 0 deletions gwlearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
from libpysal import graph
from sklearn import metrics
from sklearn.linear_model import LinearRegression, LogisticRegression

Expand All @@ -30,6 +31,10 @@ class GWLogisticRegression(BaseClassifier):
futher spatial analysis of the model performance (and generalises to models
that do not support OOB scoring). However, it leaves out the most representative
sample. By default True
graph : Graph, optional
Custom libpysal.graph.Graph object encoding the spatial interaction between
observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`,
and `include_focal` keywords are ignored.
n_jobs : int, optional
The number of jobs to run in parallel. ``-1`` means using all processors
by default ``-1``
Expand Down Expand Up @@ -158,6 +163,7 @@ def __init__(
]
| Callable = "bisquare",
include_focal: bool = True,
graph: graph.Graph = None,
n_jobs: int = -1,
fit_global_model: bool = True,
measure_performance: bool = True,
Expand All @@ -174,6 +180,7 @@ def __init__(
fixed=fixed,
kernel=kernel,
include_focal=include_focal,
graph=graph,
n_jobs=n_jobs,
fit_global_model=fit_global_model,
measure_performance=measure_performance,
Expand Down Expand Up @@ -306,6 +313,10 @@ class GWLinearRegression(BaseRegressor):
futher spatial analysis of the model performance (and generalises to models
that do not support OOB scoring). However, it leaves out the most representative
sample. By default True
graph : Graph, optional
Custom libpysal.graph.Graph object encoding the spatial interaction between
observations. If given, it is used directly and `bandwidth`, `fixed`, `kernel`,
and `include_focal` keywords are ignored.
n_jobs : int, optional
The number of jobs to run in parallel. ``-1`` means using all processors
by default ``-1``
Expand Down Expand Up @@ -353,6 +364,7 @@ def __init__(
]
| Callable = "bisquare",
include_focal: bool = True,
graph: graph.Graph = None,
n_jobs: int = -1,
fit_global_model: bool = True,
measure_performance: bool = True,
Expand All @@ -367,6 +379,7 @@ def __init__(
fixed=fixed,
kernel=kernel,
include_focal=include_focal,
graph=graph,
n_jobs=n_jobs,
fit_global_model=fit_global_model,
measure_performance=measure_performance,
Expand Down
47 changes: 47 additions & 0 deletions gwlearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import pytest
from geodatasets import get_path
from libpysal.graph import Graph
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression

Expand Down Expand Up @@ -1440,3 +1441,49 @@ def test_regressor_fit_focal_inclusion(sample_regression_data):

# RF should 'remember' focal point when included
assert (y - no_focal.pred_).mean() > (y - focal.pred_).mean()


def test_custom_graph_baseregressor(sample_regression_data):
"""Test BaseRegressor with a custom graph object."""
X, y, geometry = sample_regression_data

# Create a fixed distance weights graph
g = Graph.build_distance_band(geometry, threshold=150000, binary=False)

# Create regressor with custom graph
reg = BaseRegressor(
LinearRegression,
bandwidth=100, # This should be ignored when custom graph is provided
fixed=False, # This should be ignored when custom graph is provided
graph=g,
)

# Fit the model
reg.fit(X, y, geometry)

# Check that the model was fit successfully
assert hasattr(reg, "pred_")
assert hasattr(reg, "local_r2_")


def test_custom_graph_baseclassifier(sample_data):
"""Test BaseClassifier with a custom graph object."""
X, y, geometry = sample_data

# Create a fixed distance weights graph
g = Graph.build_distance_band(geometry, threshold=150000, binary=False)
# Create classifier with custom graph
clf = BaseClassifier(
LogisticRegression,
bandwidth=100, # This should be ignored when custom graph is provided
fixed=True, # This should be ignored when custom graph is provided
graph=g,
max_iter=500,
)

# Fit the model
clf.fit(X, y, geometry)

# Check that the model was fit successfully
assert hasattr(clf, "proba_")
assert hasattr(clf, "score_")