@@ -135,7 +135,7 @@ def __init__(
135135 aabb_min : Tuple [float , float , float ] = (0.0 , 0.0 , 0.0 ),
136136 voxel_size : Optional [float ] = None ,
137137 resolution_memory_format_pairs : List [
138- Tuple [GridFeaturesMemoryFormat , Tuple [int , int , int ]]
138+ Tuple [GridFeaturesMemoryFormat | str , Tuple [int , int , int ]]
139139 ] = [
140140 (GridFeaturesMemoryFormat .b_xc_y_z , (2 , 128 , 128 )),
141141 (GridFeaturesMemoryFormat .b_yc_x_z , (128 , 2 , 128 )),
@@ -163,7 +163,10 @@ def __init__(
163163 self .point_feature_to_grids = nn .ModuleList ()
164164 self .aabb_length = torch .tensor (aabb_max ) - torch .tensor (aabb_min )
165165 self .min_voxel_edge_length = torch .tensor ([np .inf , np .inf , np .inf ])
166+
166167 for mem_fmt , res in resolution_memory_format_pairs :
168+ if isinstance (mem_fmt , str ):
169+ mem_fmt = GridFeaturesMemoryFormat [mem_fmt ]
167170 compressed_axis = memory_format_to_axis_index [mem_fmt ]
168171 compressed_spatial_dims .append (res [compressed_axis ])
169172 to_grid = nn .Sequential (
0 commit comments