20
20
from ._kernel import _distance_band , _kernel
21
21
from ._matching import _spatial_matching
22
22
from ._plotting import _explore_graph , _plot
23
+ from ._raster import _generate_da , _raster_contiguity
23
24
from ._set_ops import SetOpsMixin
24
25
from ._spatial_lag import _lag_spatial
25
26
from ._summary import GraphSummary
@@ -938,6 +939,79 @@ def build_fuzzy_contiguity(
938
939
939
940
return cls .from_arrays (heads , tails , weights )
940
941
942
+ @classmethod
943
+ def build_raster_contiguity (
944
+ cls ,
945
+ da ,
946
+ rook = False ,
947
+ z_value = None ,
948
+ coords_labels = None ,
949
+ k = 1 ,
950
+ include_nodata = False ,
951
+ n_jobs = 1 ,
952
+ ):
953
+ """Generate Graph from ``xarray.DataArray`` raster object
954
+
955
+ Create Graph object encoding contiguity of raster cells from
956
+ ``xarray.DataArray`` object. The coordinates are flatten to tuples representing
957
+ the location of each cell within the raster.
958
+
959
+ Parameters
960
+ ----------
961
+ da : xarray.DataArray
962
+ Input 2D or 3D DataArray with shape=(z, y, x)
963
+ rook : bool, optional
964
+ Contiguity method. If True, two cells are considered neighbours if
965
+ they share at least one edge. If False, two geometries are considered
966
+ neighbours if they share at least one vertex. By default True
967
+ z_value : {int, str, float}, optional
968
+ Select the z_value of 3D DataArray with multiple layers. By default None
969
+ coords_labels : dict, optional
970
+ Pass dimension labels for coordinates and layers if they do not
971
+ belong to default dimensions, which are (band/time, y/lat, x/lon)
972
+ e.g. ``coords_labels = {"y_label": "latitude", "x_label": "longitude",
973
+ "z_label": "year"}``
974
+ When None, defaults to empty dictionary.
975
+ k : int, optional
976
+ Order of contiguity, this will select all neighbors up to k-th order.
977
+ Default is 1.
978
+ include_nodata : bool, optional
979
+ If True, missing values will be assumed as non-missing when
980
+ selecting higher_order neighbors, Default is False
981
+ n_jobs : int, optional
982
+ Number of cores to be used in the sparse weight construction. If -1,
983
+ all available cores are used. Default is 1. Requires ``joblib``.
984
+
985
+ Returns
986
+ -------
987
+ Graph
988
+ libpysal.graph.Graph encoding raster contiguity
989
+ """
990
+
991
+ if coords_labels is None :
992
+ coords_labels = {}
993
+ criterion = "rook" if rook else "queen"
994
+
995
+ heads , tails , weights , xarray_index = _raster_contiguity (
996
+ da = da ,
997
+ criterion = criterion ,
998
+ z_value = z_value ,
999
+ coords_labels = coords_labels ,
1000
+ k = k ,
1001
+ include_nodata = include_nodata ,
1002
+ n_jobs = n_jobs ,
1003
+ )
1004
+ heads , tails , weights = _resolve_islands (
1005
+ heads , tails , xarray_index .to_numpy (), weights
1006
+ )
1007
+ contig = cls .from_arrays (heads , tails , weights )
1008
+ contig ._xarray_index_names = xarray_index .names
1009
+
1010
+ if k > 1 and not include_nodata :
1011
+ contig = contig .higher_order (k , lower_order = True )
1012
+
1013
+ return contig
1014
+
941
1015
@classmethod
942
1016
def build_kernel (
943
1017
cls ,
@@ -1511,7 +1585,11 @@ def transform(self, transformation):
1511
1585
standardized_adjacency = pd .Series (
1512
1586
standardized , name = "weight" , index = self ._adjacency .index
1513
1587
)
1514
- return Graph (standardized_adjacency , transformation , is_sorted = True )
1588
+ transformed = Graph (standardized_adjacency , transformation , is_sorted = True )
1589
+
1590
+ if hasattr (self , "_xarray_index_names" ):
1591
+ transformed ._xarray_index_names = self ._xarray_index_names
1592
+ return transformed
1515
1593
1516
1594
@cached_property
1517
1595
def _components (self ):
@@ -1598,7 +1676,7 @@ def pct_nonzero(self):
1598
1676
@cached_property
1599
1677
def nonzero (self ):
1600
1678
"""Number of nonzero weights."""
1601
- return (self ._adjacency . drop ( self . isolates ) > 0 ).sum ()
1679
+ return (self ._adjacency > 0 ).sum ()
1602
1680
1603
1681
@cached_property
1604
1682
def index_pairs (self ):
@@ -1855,7 +1933,7 @@ def higher_order(self, k=2, shortest_path=True, diagonal=False, lower_order=Fals
1855
1933
if not diagonal :
1856
1934
sk = {(i , j ) for i , j in sk if i != j }
1857
1935
1858
- return Graph .from_sparse (
1936
+ higher = Graph .from_sparse (
1859
1937
sparse .coo_array (
1860
1938
(
1861
1939
np .ones (len (sk ), dtype = np .int8 ),
@@ -1865,6 +1943,10 @@ def higher_order(self, k=2, shortest_path=True, diagonal=False, lower_order=Fals
1865
1943
),
1866
1944
ids = self .unique_ids ,
1867
1945
)
1946
+ if hasattr (self , "_xarray_index_names" ):
1947
+ higher ._xarray_index_names = self ._xarray_index_names
1948
+
1949
+ return higher
1868
1950
1869
1951
def lag (self , y , categorical = False , ties = "raise" ):
1870
1952
"""Spatial lag operator
@@ -2261,11 +2343,11 @@ def subgraph(self, ids):
2261
2343
Unlike the implementation in ``networkx``, this creates a copy since
2262
2344
Graphs in ``libpysal`` are immutable.
2263
2345
"""
2264
- masked_adj = self ._adjacency [ids ]
2346
+ masked_adj = self ._adjacency . loc [ids , : ]
2265
2347
filtered_adj = masked_adj [
2266
2348
masked_adj .index .get_level_values ("neighbor" ).isin (ids )
2267
2349
]
2268
- return Graph .from_arrays (
2350
+ sub = Graph .from_arrays (
2269
2351
* _resolve_islands (
2270
2352
filtered_adj .index .get_level_values ("focal" ),
2271
2353
filtered_adj .index .get_level_values ("neighbor" ),
@@ -2274,6 +2356,11 @@ def subgraph(self, ids):
2274
2356
)
2275
2357
)
2276
2358
2359
+ if hasattr (self , "_xarray_index_names" ):
2360
+ sub ._xarray_index_names = self ._xarray_index_names
2361
+
2362
+ return sub
2363
+
2277
2364
def eliminate_zeros (self ):
2278
2365
"""Remove graph edges with zero weight
2279
2366
@@ -2290,7 +2377,12 @@ def eliminate_zeros(self):
2290
2377
zeros = (self ._adjacency == 0 ) != np .isin (
2291
2378
self ._adjacency .index .get_level_values (0 ), self .isolates
2292
2379
)
2293
- return Graph (self ._adjacency [~ zeros ], is_sorted = True )
2380
+
2381
+ eliminated = Graph (self ._adjacency [~ zeros ], is_sorted = True )
2382
+ if hasattr (self , "_xarray_index_names" ):
2383
+ eliminated ._xarray_index_names = self ._xarray_index_names
2384
+
2385
+ return eliminated
2294
2386
2295
2387
def assign_self_weight (self , weight = 1 ):
2296
2388
"""Assign values to edges representing self-weight.
@@ -2364,7 +2456,12 @@ def assign_self_weight(self, weight=1):
2364
2456
.reindex (self .unique_ids , level = 0 )
2365
2457
.reindex (self .unique_ids , level = 1 )
2366
2458
)
2367
- return Graph (adj , is_sorted = True )
2459
+ assigned = Graph (adj , is_sorted = True )
2460
+
2461
+ if hasattr (self , "_xarray_index_names" ):
2462
+ assigned ._xarray_index_names = self ._xarray_index_names
2463
+
2464
+ return assigned
2368
2465
2369
2466
def apply (self , y , func , ** kwargs ):
2370
2467
"""Apply a reduction across the neighbor sets
@@ -2479,6 +2576,23 @@ def describe(
2479
2576
stat_ .loc [self .isolates ] = np .nan
2480
2577
return stat_
2481
2578
2579
+ def generate_da (self , y ):
2580
+ """Creates xarray.DataArray object from passed data aligned with the Graph.
2581
+
2582
+ Parameters
2583
+ ----------
2584
+ y : array_like
2585
+ flat array that shall be reshaped into a DataArray with dimensionality
2586
+ conforming to Graph
2587
+
2588
+ Returns
2589
+ -------
2590
+ xarray.DataArray
2591
+ instance of xarray.DataArray that can be aligned with the DataArray from
2592
+ which Graph was built
2593
+ """
2594
+ return _generate_da (self , y )
2595
+
2482
2596
2483
2597
def _arrange_arrays (heads , tails , weights , ids = None ):
2484
2598
"""
@@ -2530,8 +2644,11 @@ def read_parquet(path, **kwargs):
2530
2644
--------
2531
2645
>>> graph.read_parquet("contiguity.parquet")
2532
2646
"""
2533
- adjacency , transformation = _read_parquet (path , ** kwargs )
2534
- return Graph (adjacency , transformation , is_sorted = True )
2647
+ adjacency , transformation , xarray_index_names = _read_parquet (path , ** kwargs )
2648
+ graph_obj = Graph (adjacency , transformation , is_sorted = True )
2649
+ if xarray_index_names is not None :
2650
+ graph_obj ._xarray_index_names = xarray_index_names
2651
+ return graph_obj
2535
2652
2536
2653
2537
2654
def read_gal (path ):
0 commit comments