Skip to content

Commit 0f221c9

Browse files
Better chunking strategy for Crash Zarr (#36)
- The initial implementation had hardcoded chunking - The new implementation makes this configurable based on chunk_size_mb parameter, this is similar to external aero - Also updated tests and docs
1 parent e937467 commit 0f221c9

File tree

4 files changed

+313
-10
lines changed

4 files changed

+313
-10
lines changed

examples/structural_mechanics/crash/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ physicsnemo-curator-etl \
118118
serialization_format=zarr \
119119
etl.source.input_dir=/data/crash_sims \
120120
serialization_format.sink.output_dir=/data/crash_processed_zarr \
121-
serialization_format.sink.compression_level=5
121+
serialization_format.sink.compression_level=5 \
122+
serialization_format.sink.chunk_size_mb=2.0
122123
```
123124

124125
**Config:** See [`config/serialization_format/zarr.yaml`](./config/serialization_format/zarr.yaml)
@@ -127,6 +128,10 @@ physicsnemo-curator-etl \
127128

128129
- `compression_level`: Compression level (1-9, higher = more compression, default: 3)
129130
- `compression_method`: Compression codec (default: "zstd")
131+
- `chunk_size_mb`: Target chunk size in MB for automatic chunking (default: 1.0)
132+
- Smaller values: Better for random access, more metadata overhead
133+
- Larger values: Better for sequential reads, less metadata overhead
134+
- Warnings are issued for very small (<0.1 MB) or very large (>100 MB) values
130135
- `overwrite_existing`: Whether to overwrite existing output stores (default: true)
131136

132137
Output structure:

examples/structural_mechanics/crash/config/serialization_format/zarr.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ sink:
2121
overwrite_existing: true
2222
compression_level: 3
2323
compression_method: "zstd"
24+
chunk_size_mb: 1.0 # Target chunk size in MB (adjust based on data size)

examples/structural_mechanics/crash/data_sources.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import logging
18+
import warnings
1819
from pathlib import Path
1920
from typing import Any, Dict, List
2021

@@ -292,6 +293,7 @@ def __init__(
292293
overwrite_existing: bool = True,
293294
compression_level: int = 3,
294295
compression_method: str = "zstd",
296+
chunk_size_mb: float = 1.0,
295297
):
296298
"""Initialize the Zarr data source.
297299
@@ -301,18 +303,34 @@ def __init__(
301303
overwrite_existing: Whether to overwrite existing files
302304
compression_level: Compression level (1-9, higher = more compression)
303305
compression_method: Compression method
306+
chunk_size_mb: Target chunk size in MB (default: 1.0)
304307
"""
305308
super().__init__(cfg)
306309
self.output_dir = Path(output_dir)
307310
self.overwrite_existing = overwrite_existing
308311
self.compression_level = compression_level
309312
self.compression_method = compression_method
313+
self.chunk_size_mb = chunk_size_mb
310314

311315
# Set up compressor
312316
self.compressor = Blosc(
313317
cname=compression_method, clevel=compression_level, shuffle=Blosc.SHUFFLE
314318
)
315319

320+
# Warn if chunk size might be problematic
321+
if chunk_size_mb < 0.1:
322+
warnings.warn(
323+
f"Chunk size of {chunk_size_mb}MB is very small. "
324+
"This could lead to poor performance due to overhead.",
325+
UserWarning,
326+
)
327+
elif chunk_size_mb > 100.0:
328+
warnings.warn(
329+
f"Chunk size of {chunk_size_mb}MB is very large. "
330+
"This could lead to memory issues and poor random access performance.",
331+
UserWarning,
332+
)
333+
316334
self.output_dir.mkdir(parents=True, exist_ok=True)
317335

318336
def get_file_list(self) -> List[str]:
@@ -323,6 +341,61 @@ def read_file(self, filename: str) -> Dict[str, Any]:
323341
"""Not implemented - this sink only writes."""
324342
raise NotImplementedError("CrashZarrDataSource only supports writing")
325343

344+
def _calculate_chunks(self, array: np.ndarray) -> tuple:
345+
"""Calculate optimal chunk sizes based on target chunk size in MB.
346+
347+
Args:
348+
array: Array to calculate chunks for
349+
350+
Returns:
351+
Tuple of chunk dimensions
352+
"""
353+
target_chunk_size = int(self.chunk_size_mb * 1024 * 1024) # Convert MB to bytes
354+
item_size = array.itemsize
355+
shape = array.shape
356+
357+
if len(shape) == 1:
358+
# 1D array: chunk along the single dimension
359+
chunk_size = min(shape[0], target_chunk_size // item_size)
360+
return (max(1, chunk_size),)
361+
elif len(shape) == 2:
362+
# 2D array: try to keep rows together
363+
chunk_rows = min(
364+
shape[0], max(1, target_chunk_size // (item_size * shape[1]))
365+
)
366+
return (max(1, chunk_rows), shape[1])
367+
elif len(shape) == 3:
368+
# 3D array (e.g., mesh_pos with shape [T, N, 3]):
369+
# Try to balance between timesteps and nodes
370+
# Keep the last dimension (3 for coordinates) intact
371+
elements_per_slice = shape[1] * shape[2]
372+
chunk_timesteps = max(
373+
1, min(shape[0], target_chunk_size // (item_size * elements_per_slice))
374+
)
375+
376+
# If we can fit multiple timesteps, reduce node chunks
377+
if chunk_timesteps >= shape[0]:
378+
# All timesteps fit, chunk along nodes
379+
chunk_nodes = min(
380+
shape[1],
381+
max(1, target_chunk_size // (item_size * shape[0] * shape[2])),
382+
)
383+
return (shape[0], max(1, chunk_nodes), shape[2])
384+
else:
385+
# Chunk along timesteps, keep reasonable node chunks
386+
remaining_size = target_chunk_size // (
387+
item_size * chunk_timesteps * shape[2]
388+
)
389+
chunk_nodes = min(shape[1], max(1, remaining_size))
390+
return (chunk_timesteps, max(1, chunk_nodes), shape[2])
391+
else:
392+
# For higher-dimensional arrays, use simple heuristic
393+
# Chunk the first dimension, keep others intact
394+
chunk_first = max(
395+
1, min(shape[0], target_chunk_size // (item_size * np.prod(shape[1:])))
396+
)
397+
return (chunk_first,) + shape[1:]
398+
326399
def _get_output_path(self, filename: str) -> Path:
327400
"""Get the output path for the Zarr store.
328401
@@ -366,36 +439,42 @@ def _write_impl_temp_file(
366439
root.attrs["num_edges"] = len(data.edges)
367440
root.attrs["compression"] = self.compression_method
368441
root.attrs["compression_level"] = self.compression_level
442+
root.attrs["chunk_size_mb"] = self.chunk_size_mb
369443

370-
# Calculate optimal chunks for temporal data
444+
# Convert data to appropriate dtypes
371445
num_timesteps, num_nodes, _ = data.filtered_pos_raw.shape
372-
chunk_timesteps = min(10, num_timesteps) # Chunk along time dimension
373-
chunk_nodes = min(1000, num_nodes) # Chunk along node dimension
446+
mesh_pos_data = data.filtered_pos_raw.astype(np.float32)
447+
thickness_data = data.filtered_node_thickness.astype(np.float32)
448+
edges_array = np.array(list(data.edges), dtype=np.int64)
449+
450+
# Calculate optimal chunks for each array
451+
mesh_pos_chunks = self._calculate_chunks(mesh_pos_data)
452+
thickness_chunks = self._calculate_chunks(thickness_data)
453+
edges_chunks = self._calculate_chunks(edges_array)
374454

375455
# Write temporal position data
376456
root.create_dataset(
377457
"mesh_pos",
378-
data=data.filtered_pos_raw.astype(np.float32),
379-
chunks=(chunk_timesteps, chunk_nodes, 3),
458+
data=mesh_pos_data,
459+
chunks=mesh_pos_chunks,
380460
compressor=self.compressor,
381461
dtype=np.float32,
382462
)
383463

384464
# Write node thickness (static per node)
385465
root.create_dataset(
386466
"thickness",
387-
data=data.filtered_node_thickness.astype(np.float32),
388-
chunks=(chunk_nodes,),
467+
data=thickness_data,
468+
chunks=thickness_chunks,
389469
compressor=self.compressor,
390470
dtype=np.float32,
391471
)
392472

393473
# Write edges connectivity
394-
edges_array = np.array(list(data.edges), dtype=np.int64)
395474
root.create_dataset(
396475
"edges",
397476
data=edges_array,
398-
chunks=(min(10000, len(edges_array)), 2),
477+
chunks=edges_chunks,
399478
compressor=self.compressor,
400479
dtype=np.int64,
401480
)

0 commit comments

Comments
 (0)