1515# limitations under the License.
1616
1717import logging
18+ import warnings
1819from pathlib import Path
1920from typing import Any , Dict , List
2021
@@ -292,6 +293,7 @@ def __init__(
292293 overwrite_existing : bool = True ,
293294 compression_level : int = 3 ,
294295 compression_method : str = "zstd" ,
296+ chunk_size_mb : float = 1.0 ,
295297 ):
296298 """Initialize the Zarr data source.
297299
@@ -301,18 +303,34 @@ def __init__(
301303 overwrite_existing: Whether to overwrite existing files
302304 compression_level: Compression level (1-9, higher = more compression)
303305 compression_method: Compression method
306+ chunk_size_mb: Target chunk size in MB (default: 1.0)
304307 """
305308 super ().__init__ (cfg )
306309 self .output_dir = Path (output_dir )
307310 self .overwrite_existing = overwrite_existing
308311 self .compression_level = compression_level
309312 self .compression_method = compression_method
313+ self .chunk_size_mb = chunk_size_mb
310314
311315 # Set up compressor
312316 self .compressor = Blosc (
313317 cname = compression_method , clevel = compression_level , shuffle = Blosc .SHUFFLE
314318 )
315319
320+ # Warn if chunk size might be problematic
321+ if chunk_size_mb < 0.1 :
322+ warnings .warn (
323+ f"Chunk size of { chunk_size_mb } MB is very small. "
324+ "This could lead to poor performance due to overhead." ,
325+ UserWarning ,
326+ )
327+ elif chunk_size_mb > 100.0 :
328+ warnings .warn (
329+ f"Chunk size of { chunk_size_mb } MB is very large. "
330+ "This could lead to memory issues and poor random access performance." ,
331+ UserWarning ,
332+ )
333+
316334 self .output_dir .mkdir (parents = True , exist_ok = True )
317335
318336 def get_file_list (self ) -> List [str ]:
@@ -323,6 +341,61 @@ def read_file(self, filename: str) -> Dict[str, Any]:
323341 """Not implemented - this sink only writes."""
324342 raise NotImplementedError ("CrashZarrDataSource only supports writing" )
325343
344+ def _calculate_chunks (self , array : np .ndarray ) -> tuple :
345+ """Calculate optimal chunk sizes based on target chunk size in MB.
346+
347+ Args:
348+ array: Array to calculate chunks for
349+
350+ Returns:
351+ Tuple of chunk dimensions
352+ """
353+ target_chunk_size = int (self .chunk_size_mb * 1024 * 1024 ) # Convert MB to bytes
354+ item_size = array .itemsize
355+ shape = array .shape
356+
357+ if len (shape ) == 1 :
358+ # 1D array: chunk along the single dimension
359+ chunk_size = min (shape [0 ], target_chunk_size // item_size )
360+ return (max (1 , chunk_size ),)
361+ elif len (shape ) == 2 :
362+ # 2D array: try to keep rows together
363+ chunk_rows = min (
364+ shape [0 ], max (1 , target_chunk_size // (item_size * shape [1 ]))
365+ )
366+ return (max (1 , chunk_rows ), shape [1 ])
367+ elif len (shape ) == 3 :
368+ # 3D array (e.g., mesh_pos with shape [T, N, 3]):
369+ # Try to balance between timesteps and nodes
370+ # Keep the last dimension (3 for coordinates) intact
371+ elements_per_slice = shape [1 ] * shape [2 ]
372+ chunk_timesteps = max (
373+ 1 , min (shape [0 ], target_chunk_size // (item_size * elements_per_slice ))
374+ )
375+
376+ # If we can fit multiple timesteps, reduce node chunks
377+ if chunk_timesteps >= shape [0 ]:
378+ # All timesteps fit, chunk along nodes
379+ chunk_nodes = min (
380+ shape [1 ],
381+ max (1 , target_chunk_size // (item_size * shape [0 ] * shape [2 ])),
382+ )
383+ return (shape [0 ], max (1 , chunk_nodes ), shape [2 ])
384+ else :
385+ # Chunk along timesteps, keep reasonable node chunks
386+ remaining_size = target_chunk_size // (
387+ item_size * chunk_timesteps * shape [2 ]
388+ )
389+ chunk_nodes = min (shape [1 ], max (1 , remaining_size ))
390+ return (chunk_timesteps , max (1 , chunk_nodes ), shape [2 ])
391+ else :
392+ # For higher-dimensional arrays, use simple heuristic
393+ # Chunk the first dimension, keep others intact
394+ chunk_first = max (
395+ 1 , min (shape [0 ], target_chunk_size // (item_size * np .prod (shape [1 :])))
396+ )
397+ return (chunk_first ,) + shape [1 :]
398+
326399 def _get_output_path (self , filename : str ) -> Path :
327400 """Get the output path for the Zarr store.
328401
@@ -366,36 +439,42 @@ def _write_impl_temp_file(
366439 root .attrs ["num_edges" ] = len (data .edges )
367440 root .attrs ["compression" ] = self .compression_method
368441 root .attrs ["compression_level" ] = self .compression_level
442+ root .attrs ["chunk_size_mb" ] = self .chunk_size_mb
369443
370- # Calculate optimal chunks for temporal data
444+ # Convert data to appropriate dtypes
371445 num_timesteps , num_nodes , _ = data .filtered_pos_raw .shape
372- chunk_timesteps = min (10 , num_timesteps ) # Chunk along time dimension
373- chunk_nodes = min (1000 , num_nodes ) # Chunk along node dimension
446+ mesh_pos_data = data .filtered_pos_raw .astype (np .float32 )
447+ thickness_data = data .filtered_node_thickness .astype (np .float32 )
448+ edges_array = np .array (list (data .edges ), dtype = np .int64 )
449+
450+ # Calculate optimal chunks for each array
451+ mesh_pos_chunks = self ._calculate_chunks (mesh_pos_data )
452+ thickness_chunks = self ._calculate_chunks (thickness_data )
453+ edges_chunks = self ._calculate_chunks (edges_array )
374454
375455 # Write temporal position data
376456 root .create_dataset (
377457 "mesh_pos" ,
378- data = data . filtered_pos_raw . astype ( np . float32 ) ,
379- chunks = ( chunk_timesteps , chunk_nodes , 3 ) ,
458+ data = mesh_pos_data ,
459+ chunks = mesh_pos_chunks ,
380460 compressor = self .compressor ,
381461 dtype = np .float32 ,
382462 )
383463
384464 # Write node thickness (static per node)
385465 root .create_dataset (
386466 "thickness" ,
387- data = data . filtered_node_thickness . astype ( np . float32 ) ,
388- chunks = ( chunk_nodes ,) ,
467+ data = thickness_data ,
468+ chunks = thickness_chunks ,
389469 compressor = self .compressor ,
390470 dtype = np .float32 ,
391471 )
392472
393473 # Write edges connectivity
394- edges_array = np .array (list (data .edges ), dtype = np .int64 )
395474 root .create_dataset (
396475 "edges" ,
397476 data = edges_array ,
398- chunks = ( min ( 10000 , len ( edges_array )), 2 ) ,
477+ chunks = edges_chunks ,
399478 compressor = self .compressor ,
400479 dtype = np .int64 ,
401480 )
0 commit comments