diff --git a/physicsnemo/models/figconvnet/figconvunet.py b/physicsnemo/models/figconvnet/figconvunet.py index a027aced26..78c87f5699 100644 --- a/physicsnemo/models/figconvnet/figconvunet.py +++ b/physicsnemo/models/figconvnet/figconvunet.py @@ -135,7 +135,7 @@ def __init__( aabb_min: Tuple[float, float, float] = (0.0, 0.0, 0.0), voxel_size: Optional[float] = None, resolution_memory_format_pairs: List[ - Tuple[GridFeaturesMemoryFormat, Tuple[int, int, int]] + Tuple[GridFeaturesMemoryFormat | str, Tuple[int, int, int]] ] = [ (GridFeaturesMemoryFormat.b_xc_y_z, (2, 128, 128)), (GridFeaturesMemoryFormat.b_yc_x_z, (128, 2, 128)), @@ -163,7 +163,10 @@ def __init__( self.point_feature_to_grids = nn.ModuleList() self.aabb_length = torch.tensor(aabb_max) - torch.tensor(aabb_min) self.min_voxel_edge_length = torch.tensor([np.inf, np.inf, np.inf]) + for mem_fmt, res in resolution_memory_format_pairs: + if isinstance(mem_fmt, str): + mem_fmt = GridFeaturesMemoryFormat[mem_fmt] compressed_axis = memory_format_to_axis_index[mem_fmt] compressed_spatial_dims.append(res[compressed_axis]) to_grid = nn.Sequential(