Skip to content

Commit 9521743

Browse files
authored
Merge pull request #740 from martinfleis/graph-xarray
ENH: add xarray interface to Graph
2 parents 576a076 + 3049970 commit 9521743

File tree

5 files changed

+433
-50
lines changed

5 files changed

+433
-50
lines changed

libpysal/graph/_raster.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from warnings import warn
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from ..weights.raster import _da2wsp
7+
from ._utils import (
8+
_sparse_to_arrays,
9+
)
10+
11+
12+
def _raster_contiguity(
13+
da,
14+
criterion="queen",
15+
z_value=None,
16+
coords_labels=None,
17+
k=1,
18+
include_nodata=False,
19+
n_jobs=1,
20+
):
21+
"""
22+
Create an input for Graph from xarray.DataArray.
23+
24+
Parameters
25+
----------
26+
da : xarray.DataArray
27+
Input 2D or 3D DataArray with shape=(z, y, x)
28+
criterion : {"rook", "queen"}
29+
Type of contiguity. Default is queen.
30+
z_value : int/string/float
31+
Select the z_value of 3D DataArray with multiple layers.
32+
coords_labels : dictionary
33+
Pass dimension labels for coordinates and layers if they do not
34+
belong to default dimensions, which are (band/time, y/lat, x/lon)
35+
e.g. coords_labels = {"y_label": "latitude",
36+
"x_label": "longitude", "z_label": "year"}
37+
Default is {} empty dictionary.
38+
k : int
39+
Order of contiguity, this will select all neighbors upto kth order.
40+
Default is 1.
41+
include_nodata : boolean
42+
If True, missing values will be assumed as non-missing when
43+
selecting higher_order neighbors, Default is False
44+
n_jobs : int
45+
Number of cores to be used in the sparse weight construction. If -1,
46+
all available cores are used. Default is 1.
47+
48+
Returns
49+
-------
50+
(head, tail, weight, index_names)
51+
52+
53+
"""
54+
try:
55+
import numba # noqa: F401
56+
57+
use_numba = True
58+
include_nodata = False
59+
except (ModuleNotFoundError, ImportError):
60+
warn(
61+
"numba cannot be imported, parallel processing "
62+
"and include_nodata functionality will be disabled. "
63+
"falling back to slower method",
64+
stacklevel=2,
65+
)
66+
use_numba = False
67+
68+
if coords_labels is None:
69+
coords_labels = {}
70+
71+
if use_numba:
72+
(weight, (head, tail)), ser, _ = _da2wsp(
73+
da=da,
74+
criterion=criterion,
75+
z_value=z_value,
76+
coords_labels=coords_labels,
77+
k=k,
78+
include_nodata=include_nodata,
79+
n_jobs=n_jobs,
80+
use_numba=use_numba,
81+
)
82+
order = np.lexsort((tail, head))
83+
head = head[order]
84+
tail = tail[order]
85+
weight = weight[order]
86+
87+
head = ser.index.to_numpy()[head]
88+
tail = ser.index.to_numpy()[tail]
89+
else:
90+
sw, ser = _da2wsp(
91+
da=da,
92+
criterion=criterion,
93+
z_value=z_value,
94+
coords_labels=coords_labels,
95+
k=k,
96+
include_nodata=include_nodata,
97+
n_jobs=n_jobs,
98+
use_numba=use_numba,
99+
)
100+
head, tail, weight = _sparse_to_arrays(sw, ser.index.to_numpy())
101+
102+
return (
103+
head,
104+
tail,
105+
weight,
106+
ser.index,
107+
)
108+
109+
110+
def _generate_da(g, y):
111+
"""Creates xarray.DataArray object from passed data aligned with the Graph.
112+
113+
Parameters
114+
----------
115+
g : Graph
116+
Graph, ideally generated using _raster_contiguity builder to ensure it
117+
contains _xarray_index_names attribute.
118+
y : array_like
119+
flat array that shall be reshaped into a DataArray with dimensionality
120+
conforming to Graph
121+
122+
Returns
123+
-------
124+
xarray.DataArray
125+
instance of xarray.DataArray that can be aligned with the DataArray from which
126+
Graph was built
127+
"""
128+
if hasattr(g, "_xarray_index_names"):
129+
names = g._xarray_index_names
130+
else:
131+
warn(
132+
UserWarning,
133+
"Graph does not store xarray index names."
134+
"The output may not align with the original DataArray.",
135+
stacklevel=3,
136+
)
137+
names = None
138+
return pd.Series(
139+
y,
140+
index=pd.MultiIndex.from_tuples(g.unique_ids, names=names),
141+
).to_xarray()

libpysal/graph/base.py

Lines changed: 126 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ._kernel import _distance_band, _kernel
2121
from ._matching import _spatial_matching
2222
from ._plotting import _explore_graph, _plot
23+
from ._raster import _generate_da, _raster_contiguity
2324
from ._set_ops import SetOpsMixin
2425
from ._spatial_lag import _lag_spatial
2526
from ._summary import GraphSummary
@@ -938,6 +939,79 @@ def build_fuzzy_contiguity(
938939

939940
return cls.from_arrays(heads, tails, weights)
940941

942+
@classmethod
943+
def build_raster_contiguity(
944+
cls,
945+
da,
946+
rook=False,
947+
z_value=None,
948+
coords_labels=None,
949+
k=1,
950+
include_nodata=False,
951+
n_jobs=1,
952+
):
953+
"""Generate Graph from ``xarray.DataArray`` raster object
954+
955+
Create Graph object encoding contiguity of raster cells from
956+
``xarray.DataArray`` object. The coordinates are flatten to tuples representing
957+
the location of each cell within the raster.
958+
959+
Parameters
960+
----------
961+
da : xarray.DataArray
962+
Input 2D or 3D DataArray with shape=(z, y, x)
963+
rook : bool, optional
964+
Contiguity method. If True, two cells are considered neighbours if
965+
they share at least one edge. If False, two geometries are considered
966+
neighbours if they share at least one vertex. By default True
967+
z_value : {int, str, float}, optional
968+
Select the z_value of 3D DataArray with multiple layers. By default None
969+
coords_labels : dict, optional
970+
Pass dimension labels for coordinates and layers if they do not
971+
belong to default dimensions, which are (band/time, y/lat, x/lon)
972+
e.g. ``coords_labels = {"y_label": "latitude", "x_label": "longitude",
973+
"z_label": "year"}``
974+
When None, defaults to empty dictionary.
975+
k : int, optional
976+
Order of contiguity, this will select all neighbors up to k-th order.
977+
Default is 1.
978+
include_nodata : bool, optional
979+
If True, missing values will be assumed as non-missing when
980+
selecting higher_order neighbors, Default is False
981+
n_jobs : int, optional
982+
Number of cores to be used in the sparse weight construction. If -1,
983+
all available cores are used. Default is 1. Requires ``joblib``.
984+
985+
Returns
986+
-------
987+
Graph
988+
libpysal.graph.Graph encoding raster contiguity
989+
"""
990+
991+
if coords_labels is None:
992+
coords_labels = {}
993+
criterion = "rook" if rook else "queen"
994+
995+
heads, tails, weights, xarray_index = _raster_contiguity(
996+
da=da,
997+
criterion=criterion,
998+
z_value=z_value,
999+
coords_labels=coords_labels,
1000+
k=k,
1001+
include_nodata=include_nodata,
1002+
n_jobs=n_jobs,
1003+
)
1004+
heads, tails, weights = _resolve_islands(
1005+
heads, tails, xarray_index.to_numpy(), weights
1006+
)
1007+
contig = cls.from_arrays(heads, tails, weights)
1008+
contig._xarray_index_names = xarray_index.names
1009+
1010+
if k > 1 and not include_nodata:
1011+
contig = contig.higher_order(k, lower_order=True)
1012+
1013+
return contig
1014+
9411015
@classmethod
9421016
def build_kernel(
9431017
cls,
@@ -1511,7 +1585,11 @@ def transform(self, transformation):
15111585
standardized_adjacency = pd.Series(
15121586
standardized, name="weight", index=self._adjacency.index
15131587
)
1514-
return Graph(standardized_adjacency, transformation, is_sorted=True)
1588+
transformed = Graph(standardized_adjacency, transformation, is_sorted=True)
1589+
1590+
if hasattr(self, "_xarray_index_names"):
1591+
transformed._xarray_index_names = self._xarray_index_names
1592+
return transformed
15151593

15161594
@cached_property
15171595
def _components(self):
@@ -1598,7 +1676,7 @@ def pct_nonzero(self):
15981676
@cached_property
15991677
def nonzero(self):
16001678
"""Number of nonzero weights."""
1601-
return (self._adjacency.drop(self.isolates) > 0).sum()
1679+
return (self._adjacency > 0).sum()
16021680

16031681
@cached_property
16041682
def index_pairs(self):
@@ -1855,7 +1933,7 @@ def higher_order(self, k=2, shortest_path=True, diagonal=False, lower_order=Fals
18551933
if not diagonal:
18561934
sk = {(i, j) for i, j in sk if i != j}
18571935

1858-
return Graph.from_sparse(
1936+
higher = Graph.from_sparse(
18591937
sparse.coo_array(
18601938
(
18611939
np.ones(len(sk), dtype=np.int8),
@@ -1865,6 +1943,10 @@ def higher_order(self, k=2, shortest_path=True, diagonal=False, lower_order=Fals
18651943
),
18661944
ids=self.unique_ids,
18671945
)
1946+
if hasattr(self, "_xarray_index_names"):
1947+
higher._xarray_index_names = self._xarray_index_names
1948+
1949+
return higher
18681950

18691951
def lag(self, y, categorical=False, ties="raise"):
18701952
"""Spatial lag operator
@@ -2261,11 +2343,11 @@ def subgraph(self, ids):
22612343
Unlike the implementation in ``networkx``, this creates a copy since
22622344
Graphs in ``libpysal`` are immutable.
22632345
"""
2264-
masked_adj = self._adjacency[ids]
2346+
masked_adj = self._adjacency.loc[ids, :]
22652347
filtered_adj = masked_adj[
22662348
masked_adj.index.get_level_values("neighbor").isin(ids)
22672349
]
2268-
return Graph.from_arrays(
2350+
sub = Graph.from_arrays(
22692351
*_resolve_islands(
22702352
filtered_adj.index.get_level_values("focal"),
22712353
filtered_adj.index.get_level_values("neighbor"),
@@ -2274,6 +2356,11 @@ def subgraph(self, ids):
22742356
)
22752357
)
22762358

2359+
if hasattr(self, "_xarray_index_names"):
2360+
sub._xarray_index_names = self._xarray_index_names
2361+
2362+
return sub
2363+
22772364
def eliminate_zeros(self):
22782365
"""Remove graph edges with zero weight
22792366
@@ -2290,7 +2377,12 @@ def eliminate_zeros(self):
22902377
zeros = (self._adjacency == 0) != np.isin(
22912378
self._adjacency.index.get_level_values(0), self.isolates
22922379
)
2293-
return Graph(self._adjacency[~zeros], is_sorted=True)
2380+
2381+
eliminated = Graph(self._adjacency[~zeros], is_sorted=True)
2382+
if hasattr(self, "_xarray_index_names"):
2383+
eliminated._xarray_index_names = self._xarray_index_names
2384+
2385+
return eliminated
22942386

22952387
def assign_self_weight(self, weight=1):
22962388
"""Assign values to edges representing self-weight.
@@ -2364,7 +2456,12 @@ def assign_self_weight(self, weight=1):
23642456
.reindex(self.unique_ids, level=0)
23652457
.reindex(self.unique_ids, level=1)
23662458
)
2367-
return Graph(adj, is_sorted=True)
2459+
assigned = Graph(adj, is_sorted=True)
2460+
2461+
if hasattr(self, "_xarray_index_names"):
2462+
assigned._xarray_index_names = self._xarray_index_names
2463+
2464+
return assigned
23682465

23692466
def apply(self, y, func, **kwargs):
23702467
"""Apply a reduction across the neighbor sets
@@ -2479,6 +2576,23 @@ def describe(
24792576
stat_.loc[self.isolates] = np.nan
24802577
return stat_
24812578

2579+
def generate_da(self, y):
2580+
"""Creates xarray.DataArray object from passed data aligned with the Graph.
2581+
2582+
Parameters
2583+
----------
2584+
y : array_like
2585+
flat array that shall be reshaped into a DataArray with dimensionality
2586+
conforming to Graph
2587+
2588+
Returns
2589+
-------
2590+
xarray.DataArray
2591+
instance of xarray.DataArray that can be aligned with the DataArray from
2592+
which Graph was built
2593+
"""
2594+
return _generate_da(self, y)
2595+
24822596

24832597
def _arrange_arrays(heads, tails, weights, ids=None):
24842598
"""
@@ -2530,8 +2644,11 @@ def read_parquet(path, **kwargs):
25302644
--------
25312645
>>> graph.read_parquet("contiguity.parquet")
25322646
"""
2533-
adjacency, transformation = _read_parquet(path, **kwargs)
2534-
return Graph(adjacency, transformation, is_sorted=True)
2647+
adjacency, transformation, xarray_index_names = _read_parquet(path, **kwargs)
2648+
graph_obj = Graph(adjacency, transformation, is_sorted=True)
2649+
if xarray_index_names is not None:
2650+
graph_obj._xarray_index_names = xarray_index_names
2651+
return graph_obj
25352652

25362653

25372654
def read_gal(path):

libpysal/graph/io/_parquet.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def _to_parquet(graph_obj, destination, **kwargs):
2828

2929
meta = table.schema.metadata
3030
d = {"transformation": graph_obj.transformation, "version": libpysal.__version__}
31+
if hasattr(graph_obj, "_xarray_index_names"):
32+
d["_xarray_index_names"] = list(graph_obj._xarray_index_names)
3133
meta[b"libpysal"] = json.dumps(d).encode("utf-8")
3234
schema = table.schema.with_metadata(meta)
3335

@@ -47,7 +49,7 @@ def _read_parquet(source, **kwargs):
4749
Returns
4850
-------
4951
tuple
50-
tuple of adjacency table and transformation
52+
tuple of adjacency table, transformation, and xarray_index_names
5153
"""
5254
try:
5355
import pyarrow.parquet as pq
@@ -61,4 +63,10 @@ def _read_parquet(source, **kwargs):
6163
else:
6264
transformation = "O"
6365

64-
return table.to_pandas()["weight"], transformation
66+
if b"_xarray_index_names" in table.schema.metadata:
67+
meta = json.loads(table.schema.metadata[b"_xarray_index_names"])
68+
xarray_index_names = meta["_xarray_index_names"]
69+
else:
70+
xarray_index_names = None
71+
72+
return table.to_pandas()["weight"], transformation, xarray_index_names

0 commit comments

Comments
 (0)