Skip to content

Commit 3314796

Browse files
committed
combine single fragment per chunk with inserted skipped nodes
1 parent cf97ff1 commit 3314796

File tree

1 file changed

+126
-112
lines changed

1 file changed

+126
-112
lines changed

pychunkedgraph/meshing/manifest/multiscale.py

Lines changed: 126 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# pylint: disable=invalid-name, missing-docstring, line-too-long, no-member
22

3-
import json
4-
import pickle
53
import time
64
import functools
7-
from collections import defaultdict, deque
5+
from collections import defaultdict, deque, namedtuple
86

97
import numpy as np
108
from cloudvolume import CloudVolume
@@ -20,6 +18,19 @@
2018
OCTREE_NODE_SIZE = 5
2119

2220

21+
HierarchyInfo = namedtuple(
22+
"HierarchyInfo",
23+
[
24+
"children_map",
25+
"children_chunks_map",
26+
"chunk_nodes_map",
27+
"node_chunk_id_map",
28+
"coords_map",
29+
"layers_map",
30+
],
31+
)
32+
33+
2334
def _morton_sort(cg: ChunkedGraph, children: np.ndarray):
2435
"""
2536
Sort children by their morton code.
@@ -52,7 +63,71 @@ def less_msb(x: int, y: int) -> bool:
5263
return np.array(children, dtype=NODE_ID)
5364

5465

55-
def _get_hierarchy(cg: ChunkedGraph, node_id: np.uint64) -> dict:
66+
def _get_node_coords_and_layers_map(
67+
cg: ChunkedGraph, children_map: dict
68+
) -> tuple[dict, dict]:
69+
node_ids = np.fromiter(children_map.keys(), dtype=NODE_ID)
70+
coords_map = {}
71+
node_layers = cg.get_chunk_layers(node_ids)
72+
for layer in set(node_layers):
73+
layer_mask = node_layers == layer
74+
coords = cg.get_chunk_coordinates_multiple(node_ids[layer_mask])
75+
_node_coords = dict(zip(node_ids[layer_mask], coords))
76+
coords_map.update(_node_coords)
77+
78+
chunk_id_coords_map = {}
79+
chunk_ids = cg.get_chunk_ids_from_node_ids(node_ids)
80+
node_chunk_id_map = dict(zip(node_ids, chunk_ids))
81+
for k, v in coords_map.items():
82+
chunk_id_coords_map[node_chunk_id_map[k]] = v
83+
coords_map.update(chunk_id_coords_map)
84+
return coords_map, dict(zip(node_ids, node_layers))
85+
86+
87+
def _get_hierarchy(cg: ChunkedGraph, node_id: np.uint64) -> HierarchyInfo:
88+
def _insert_skipped_nodes(cg: ChunkedGraph):
89+
new_children_map = {}
90+
for node, children in children_map.items():
91+
nl = layers_map[node]
92+
if len(children) > 1 or nl == 2:
93+
new_children_map[node] = children
94+
else:
95+
assert (
96+
len(children) == 1
97+
), f"Skipped hierarchy must have exactly 1 child: {node} - {children}."
98+
cl = layers_map[children[0]]
99+
height = nl - cl
100+
if height == 1:
101+
new_children_map[node] = children
102+
continue
103+
104+
cx, cy, cz = coords_map[children[0]]
105+
skipped_hierarchy = [node]
106+
count = 1
107+
height -= 1
108+
while height:
109+
x, y, z = cx >> height, cy >> height, cz >> height
110+
skipped_layer = nl - count
111+
skipped_chunk = cg.get_chunk_id(layer=skipped_layer, x=x, y=y, z=z)
112+
limit = cg.get_segment_id_limit(skipped_chunk)
113+
skipped_child = skipped_chunk + (limit - np.uint64(1))
114+
while skipped_child in new_children_map:
115+
skipped_child = skipped_child - np.uint64(1)
116+
117+
skipped_hierarchy.append(skipped_child)
118+
coords_map[skipped_child] = np.array((x, y, z), dtype=int)
119+
layers_map[skipped_child] = skipped_layer
120+
node_chunk_id_map[skipped_child] = skipped_chunk
121+
count += 1
122+
height -= 1
123+
skipped_hierarchy.append(children[0])
124+
125+
for i in range(len(skipped_hierarchy) - 1):
126+
node = skipped_hierarchy[i]
127+
child = skipped_hierarchy[i + 1]
128+
new_children_map[node] = np.array([child], dtype=NODE_ID)
129+
return new_children_map
130+
56131
node_chunk_id_map = {node_id: cg.get_chunk_id(node_id)}
57132
children_map = {}
58133
children_chunks_map = {}
@@ -81,79 +156,24 @@ def _get_hierarchy(cg: ChunkedGraph, node_id: np.uint64) -> dict:
81156
for l2id in _ids[node_layers == 2]:
82157
children_map[l2id] = empty_1d.copy()
83158

84-
for k, v in children_map.items():
85-
chunk_ids = np.array([node_chunk_id_map[i] for i in v], dtype=NODE_ID)
159+
coords_map, layers_map = _get_node_coords_and_layers_map(cg, children_map)
160+
new_children_map = _insert_skipped_nodes(cg)
161+
coords_map, layers_map = _get_node_coords_and_layers_map(cg, new_children_map)
162+
163+
for node, children in new_children_map.items():
164+
chunk_ids = np.array([node_chunk_id_map[i] for i in children], dtype=NODE_ID)
86165
uchunk_ids = np.unique(chunk_ids)
87-
children_chunks_map[k] = uchunk_ids
166+
children_chunks_map[node] = uchunk_ids
88167
for c in uchunk_ids:
89-
chunk_nodes_map[c] = v[chunk_ids == c]
90-
return children_map, children_chunks_map, chunk_nodes_map, node_chunk_id_map
91-
92-
93-
def _get_node_coords_and_layers_map(
94-
cg: ChunkedGraph, children_map: dict
95-
) -> tuple[dict, dict]:
96-
node_ids = np.fromiter(children_map.keys(), dtype=NODE_ID)
97-
coords_map = {}
98-
node_layers = cg.get_chunk_layers(node_ids)
99-
for layer in set(node_layers):
100-
layer_mask = node_layers == layer
101-
coords = cg.get_chunk_coordinates_multiple(node_ids[layer_mask])
102-
_node_coords = dict(zip(node_ids[layer_mask], coords))
103-
coords_map.update(_node_coords)
104-
105-
chunk_id_coords_map = {}
106-
chunk_ids = cg.get_chunk_ids_from_node_ids(node_ids)
107-
node_chunk_id_map = dict(zip(node_ids, chunk_ids))
108-
for k, v in coords_map.items():
109-
chunk_id_coords_map[node_chunk_id_map[k]] = v
110-
coords_map.update(chunk_id_coords_map)
111-
return coords_map, dict(zip(node_ids, node_layers))
112-
113-
114-
def _insert_skipped_nodes(
115-
cg: ChunkedGraph, children_map: dict, coords_map: dict, layers_map: dict
116-
):
117-
new_children_map = {}
118-
for node, children in children_map.items():
119-
nl = layers_map[node]
120-
if len(children) > 1 or nl == 2:
121-
new_children_map[node] = children
122-
else:
123-
assert (
124-
len(children) == 1
125-
), f"Skipped hierarchy must have exactly 1 child: {node} - {children}."
126-
cl = layers_map[children[0]]
127-
height = nl - cl
128-
if height == 1:
129-
new_children_map[node] = children
130-
continue
131-
132-
cx, cy, cz = coords_map[children[0]]
133-
skipped_hierarchy = [node]
134-
count = 1
135-
height -= 1
136-
while height:
137-
x, y, z = cx >> height, cy >> height, cz >> height
138-
skipped_layer = nl - count
139-
skipped_child = cg.get_chunk_id(layer=skipped_layer, x=x, y=y, z=z)
140-
limit = cg.get_segment_id_limit(skipped_child)
141-
skipped_child += limit - np.uint64(1)
142-
while skipped_child in new_children_map:
143-
skipped_child = skipped_child - np.uint64(1)
144-
145-
skipped_hierarchy.append(skipped_child)
146-
coords_map[skipped_child] = np.array((x, y, z), dtype=int)
147-
layers_map[skipped_child] = skipped_layer
148-
count += 1
149-
height -= 1
150-
skipped_hierarchy.append(children[0])
151-
152-
for i in range(len(skipped_hierarchy) - 1):
153-
node = skipped_hierarchy[i]
154-
child = skipped_hierarchy[i + 1]
155-
new_children_map[node] = np.array([child], dtype=NODE_ID)
156-
return new_children_map, coords_map, layers_map
168+
chunk_nodes_map[c] = children[chunk_ids == c]
169+
return HierarchyInfo(
170+
new_children_map,
171+
children_chunks_map,
172+
chunk_nodes_map,
173+
node_chunk_id_map,
174+
coords_map,
175+
layers_map,
176+
)
157177

158178

159179
def _validate_octree(octree: np.ndarray, octree_node_ids: np.ndarray):
@@ -199,13 +219,7 @@ def _explore_node(node: int):
199219

200220

201221
def build_octree(
202-
cg: ChunkedGraph,
203-
node_id: np.uint64,
204-
children_map: dict,
205-
children_chunks_map: dict,
206-
chunk_nodes_map: dict,
207-
node_chunk_id_map: dict,
208-
mesh_fragments: dict,
222+
cg: ChunkedGraph, node_id: np.uint64, hinfo: HierarchyInfo, mesh_fragments: dict
209223
):
210224
"""
211225
From neuroglancer multiscale specification:
@@ -219,26 +233,27 @@ def build_octree(
219233
requested/rendered.
220234
"""
221235
node_q = deque()
222-
node_q.append(node_chunk_id_map[node_id])
223-
coords_map, _ = _get_node_coords_and_layers_map(cg, children_map)
236+
node_q.append(hinfo.node_chunk_id_map[node_id])
224237

225-
all_chunks = np.concatenate(list(children_chunks_map.values()))
238+
all_chunks = np.concatenate(list(hinfo.children_chunks_map.values()))
226239
all_chunks = np.unique(all_chunks)
227240

228241
ROW_TOTAL = all_chunks.size + 1
229242
row_counter = all_chunks.size + 1
230243
octree_size = OCTREE_NODE_SIZE * ROW_TOTAL
231244
octree = np.zeros(octree_size, dtype=np.uint32)
232245

233-
octree_node_ids = ROW_TOTAL * [0]
246+
octree_chunks = ROW_TOTAL * [0]
234247
octree_fragments = defaultdict(list)
235248
rows_used = 1
249+
virtual_chunk_hierarchy = {}
236250

237251
while len(node_q) > 0:
238252
frags = []
239253
row_counter -= 1
240254
current_chunk = node_q.popleft()
241-
chunk_nodes = chunk_nodes_map[current_chunk]
255+
chunk_nodes = hinfo.chunk_nodes_map[current_chunk]
256+
octree_chunks[row_counter] = current_chunk
242257

243258
for k in chunk_nodes:
244259
if k in mesh_fragments:
@@ -247,17 +262,15 @@ def build_octree(
247262

248263
children_chunks = set()
249264
for k in chunk_nodes:
250-
children_chunks.update(children_chunks_map[k])
265+
children_chunks.update(hinfo.children_chunks_map[k])
251266

252267
children_chunks = np.array(list(children_chunks), dtype=NODE_ID)
253268
children_chunks = _morton_sort(cg, children_chunks)
254269
for child_chunk in children_chunks:
255270
node_q.append(child_chunk)
256271

257-
octree_node_ids[row_counter] = current_chunk
258-
259272
offset = OCTREE_NODE_SIZE * row_counter
260-
x, y, z = coords_map[current_chunk]
273+
x, y, z = hinfo.coords_map[current_chunk]
261274
octree[offset + 0] = x
262275
octree[offset + 1] = y
263276
octree[offset + 2] = z
@@ -269,26 +282,36 @@ def build_octree(
269282
octree[offset + 3] = start
270283
octree[offset + 4] = end_empty
271284

272-
if children_chunks.size == 1:
273-
octree[offset + 3] |= 1 << 31
285+
if len(octree_fragments[int(current_chunk)]) == 0:
286+
virtual_chunk_hierarchy[current_chunk] = children_chunks[0]
274287
if children_chunks.size == 0:
275288
octree[offset + 4] |= 1 << 31
276289

277-
octree[5 * (ROW_TOTAL - 1) + 3] |= 1 << 31
278290
# _validate_octree(octree, octree_node_ids)
279291
fragments = []
280-
for node in octree_node_ids:
281-
fragments.append(octree_fragments[int(node)])
282-
return octree, octree_node_ids, fragments
292+
for chunk in octree_chunks:
293+
if chunk in virtual_chunk_hierarchy:
294+
frags = []
295+
while True:
296+
child = virtual_chunk_hierarchy[chunk]
297+
if child not in virtual_chunk_hierarchy:
298+
break
299+
chunk = child
300+
fragments.append(octree_fragments[int(child)])
301+
else:
302+
fragments.append(octree_fragments[int(chunk)])
303+
304+
for k in hinfo.children_chunks_map[node_id]:
305+
fragments[-1].extend(octree_fragments[int(k)])
306+
octree[5 * (ROW_TOTAL - 1) + 3] &= ~(1 << 31)
307+
return octree, octree_chunks, fragments
283308

284309

285310
def get_manifest(cg: ChunkedGraph, node_id: np.uint64) -> dict:
286311
start = time.time()
287-
children_map, children_chunks_map, chunk_nodes_map, node_chunk_id_map = (
288-
_get_hierarchy(cg, node_id)
289-
)
312+
hierarchy_info = _get_hierarchy(cg, node_id)
290313

291-
node_ids = np.fromiter(children_map.keys(), dtype=NODE_ID)
314+
node_ids = np.fromiter(hierarchy_info.children_map.keys(), dtype=NODE_ID)
292315
manifest_cache = ManifestCache(cg.graph_id, initial=True)
293316

294317
cv = CloudVolume(
@@ -304,24 +327,15 @@ def get_manifest(cg: ChunkedGraph, node_id: np.uint64) -> dict:
304327
manifest_cache.set_fragments(_fragments_d)
305328
fragments_d.update(_fragments_d)
306329

307-
octree, node_ids, fragments = build_octree(
308-
cg,
309-
node_id,
310-
children_map,
311-
children_chunks_map,
312-
chunk_nodes_map,
313-
node_chunk_id_map,
314-
fragments_d,
315-
)
316-
330+
octree, node_ids, fragments = build_octree(cg, node_id, hierarchy_info, fragments_d)
317331
max_layer = min(cg.get_chunk_layer(node_id) + 1, cg.meta.layer_count)
318332
chunk_shape = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=np.dtype("<f4"))
319333
chunk_shape *= cg.meta.resolution
320334
clip_bounds = cg.meta.voxel_bounds.T * cg.meta.resolution
321335
response = {
322336
"chunkShape": chunk_shape,
323337
"chunkGridSpatialOrigin": np.array([0, 0, 0], dtype=np.dtype("<f4")),
324-
"lodScales": np.arange(2, max_layer, dtype=np.dtype("<f4")) * 1,
338+
"lodScales": np.arange(2, max_layer + 1, dtype=np.dtype("<f4")) * 1,
325339
"fragments": fragments,
326340
"octree": octree,
327341
"clipLowerBound": np.array(clip_bounds[0], dtype=np.dtype("<f4")),

0 commit comments

Comments
 (0)