Skip to content

Commit 2975ce8

Browse files
Add support for str version of memory format param in FIGConvNet (#1200)
1 parent d1b6b7b commit 2975ce8

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

physicsnemo/models/figconvnet/figconvunet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
aabb_min: Tuple[float, float, float] = (0.0, 0.0, 0.0),
136136
voxel_size: Optional[float] = None,
137137
resolution_memory_format_pairs: List[
138-
Tuple[GridFeaturesMemoryFormat, Tuple[int, int, int]]
138+
Tuple[GridFeaturesMemoryFormat | str, Tuple[int, int, int]]
139139
] = [
140140
(GridFeaturesMemoryFormat.b_xc_y_z, (2, 128, 128)),
141141
(GridFeaturesMemoryFormat.b_yc_x_z, (128, 2, 128)),
@@ -163,7 +163,10 @@ def __init__(
163163
self.point_feature_to_grids = nn.ModuleList()
164164
self.aabb_length = torch.tensor(aabb_max) - torch.tensor(aabb_min)
165165
self.min_voxel_edge_length = torch.tensor([np.inf, np.inf, np.inf])
166+
166167
for mem_fmt, res in resolution_memory_format_pairs:
168+
if isinstance(mem_fmt, str):
169+
mem_fmt = GridFeaturesMemoryFormat[mem_fmt]
167170
compressed_axis = memory_format_to_axis_index[mem_fmt]
168171
compressed_spatial_dims.append(res[compressed_axis])
169172
to_grid = nn.Sequential(

0 commit comments

Comments
 (0)