@@ -119,6 +119,7 @@ def __init__(
119119 parametrization : torch .nn .Module | None = None ,
120120 cap_border_dict : CapBorderDict = None ,
121121 cap_outside_of_unitcube = False ,
122+ geometric_dim = 3 ,
122123 ):
123124 self ._deformation_spline = deformation_spline
124125
@@ -130,6 +131,7 @@ def __init__(
130131
131132 self ._cap_border_dict = cap_border_dict
132133 self .cap_outside_of_unitcube = cap_outside_of_unitcube
134+ self .geometric_dim = geometric_dim
133135
134136 @property
135137 def deformation_spline (self ):
@@ -186,7 +188,9 @@ def __call__(self, queries: torch.Tensor) -> torch.Tensor:
186188 # cap everything outside of the unit cube
187189 # k and d are y = k*(x-dx) + dy
188190 if self .cap_outside_of_unitcube :
189- sdf_values = _cap_outside_of_unitcube (queries , sdf_values )
191+ sdf_values = _cap_outside_of_unitcube (
192+ queries , sdf_values , max_dim = self .geometric_dim
193+ )
190194 return sdf_values
191195
192196 def _validate_input (self , queries : torch .Tensor ):
@@ -216,6 +220,46 @@ def plot_slice(self, *args, **kwargs):
216220 def __add__ (self , other ):
217221 return SummedSDF (self , other )
218222
223+ def to2D (self , axes : list [int ], offset = 0.0 ):
224+ """
225+ Converts SDF to 2D
226+
227+ :param axis: list of axes that will be used for the 2D
228+ """
229+ sdf2D = SDF2D (self , axes , offset = offset )
230+ sdf2D .deformation_spline = self .deformation_spline
231+ sdf2D .parametrization = self .parametrization
232+ return sdf2D
233+
234+
235+ class SDF2D (SDFBase ):
236+ def __init__ (self , obj : SDFBase , axes : list [int ], offset = 0.0 ):
237+ super ().__init__ ()
238+ self .obj = obj
239+ assert (
240+ len (axes ) == 2
241+ ), "List of axes must be of size 2 and needs to correspond to the 2D plane"
242+ self .axes = axes
243+ self .offset = offset
244+
245+ def _compute (self , queries ):
246+ queries_3D = (
247+ torch .zeros (
248+ (queries .shape [0 ], 3 ), dtype = queries .dtype , device = queries .device
249+ )
250+ + self .offset
251+ )
252+ queries_3D [:, self .axes [0 ]] = queries [:, 0 ]
253+ queries_3D [:, self .axes [1 ]] = queries [:, 1 ]
254+ result = self .obj ._compute (queries_3D )
255+ return result
256+
257+ def _get_domain_bounds (self ):
258+ return self .obj ._get_domain_bounds ()
259+
260+ def _set_param (self , parameter ):
261+ return self .obj ._set_param (parameter )
262+
219263
220264class SummedSDF (SDFBase ):
221265 def __init__ (self , obj1 : SDFBase , obj2 : SDFBase ):
@@ -267,7 +311,26 @@ def _compute(self, queries: torch.tensor) -> torch.tensor:
267311 return output .reshape (- 1 , 1 )
268312
269313
270- def union (D , k = 0 ):
314+ def union_torch (D , k = 0 ):
315+ """
316+ D: np.array of shape (num_points, num_geometries)
317+ k: smoothness parameter
318+ """
319+ if k == 0 :
320+ return torch .min (D , axis = 1 )[0 ]
321+ # Start with the first column as d1
322+ d1 = D [:, 0 ].copy ()
323+
324+ # Loop over remaining columns
325+ for i in range (1 , D .shape [1 ]):
326+ d2 = D [:, i ]
327+ h = torch .clip (0.5 + 0.5 * (d2 - d1 ) / k , 0 , 1 )
328+ d1 = d2 + (d1 - d2 ) * h - k * h * (1 - h )
329+
330+ return d1
331+
332+
333+ def union_numpy (D , k = 0 ):
271334 """
272335 D: np.array of shape (num_points, num_geometries)
273336 k: smoothness parameter
@@ -428,7 +491,9 @@ def _compute(self, queries: torch.Tensor | np.ndarray):
428491 sdf = point_segment_distance (lines [:, 0 ], lines [:, 1 ], queries_np ) - self .t / 2
429492 if is_tensor :
430493 sdf = torch .tensor (sdf , dtype = orig_dtype , device = orig_device )
431- return union (sdf , k = self .smoothness )
494+ return union_torch (sdf , k = self .smoothness )
495+ else :
496+ return union_numpy (sdf , k = self .smoothness )
432497
433498
434499class SDFfromDeepSDF (SDFBase ):
@@ -507,13 +572,12 @@ def _compute(self, queries: torch.Tensor) -> torch.Tensor:
507572 return sdf_values .to (orig_device ).reshape (- 1 , 1 )
508573
509574
510- def _cap_outside_of_unitcube (samples , sdf_values ):
511- dy = 0
512- for dim , k , dx in zip (
513- [0 , 0 , 1 , 1 , 2 , 2 ], [1 , - 1 , 1 , - 1 , 1 , - 1 ], [0 , 1 , 0 , 1 , 0 , 1 ]
514- ):
515- x = samples [:, dim ]
516- border_sdf = k * (x - dx ) + dy
575+ def _cap_outside_of_unitcube (samples , sdf_values , max_dim = 3 ):
576+
577+ for dim in range (max_dim ):
578+ border_sdf = samples [:, dim ]
579+ sdf_values = torch .maximum (sdf_values , - border_sdf .reshape (- 1 , 1 ))
580+ border_sdf = 1 - samples [:, dim ]
517581 sdf_values = torch .maximum (sdf_values , - border_sdf .reshape (- 1 , 1 ))
518582 return sdf_values
519583
0 commit comments