Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion physicsnemo/models/figconvnet/figconvunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
aabb_min: Tuple[float, float, float] = (0.0, 0.0, 0.0),
voxel_size: Optional[float] = None,
resolution_memory_format_pairs: List[
Tuple[GridFeaturesMemoryFormat, Tuple[int, int, int]]
Tuple[GridFeaturesMemoryFormat | str, Tuple[int, int, int]]
] = [
(GridFeaturesMemoryFormat.b_xc_y_z, (2, 128, 128)),
(GridFeaturesMemoryFormat.b_yc_x_z, (128, 2, 128)),
Expand Down Expand Up @@ -163,7 +163,10 @@ def __init__(
self.point_feature_to_grids = nn.ModuleList()
self.aabb_length = torch.tensor(aabb_max) - torch.tensor(aabb_min)
self.min_voxel_edge_length = torch.tensor([np.inf, np.inf, np.inf])

for mem_fmt, res in resolution_memory_format_pairs:
if isinstance(mem_fmt, str):
mem_fmt = GridFeaturesMemoryFormat[mem_fmt]
compressed_axis = memory_format_to_axis_index[mem_fmt]
compressed_spatial_dims.append(res[compressed_axis])
to_grid = nn.Sequential(
Expand Down