Skip to content

Commit 7b947d6

Browse files
committed
added possibility to generate training data from obj files
1 parent 0b95ffb commit 7b947d6

File tree

6 files changed

+692
-77
lines changed

6 files changed

+692
-77
lines changed

DeepSDFStruct/SDF.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,7 @@ def __init__(
311311
if scale:
312312
# scales from [0,1] to [-1,1]
313313
# https://www.brainvoyager.com/bv/doc/UsersGuide/CoordsAndTransforms/SpatialTransformationMatrices.html
314-
rescale = 2.0
315-
tform = [-1.0 for i in range(3)]
316-
matrix = np.eye(4)
317-
matrix[:3, :3] *= rescale
318-
mesh.apply_transform(matrix)
319-
matrix = np.eye(4)
320-
matrix[:3, 3] = tform
321-
mesh.apply_transform(matrix)
314+
mesh = normalize_mesh_to_unit_cube(mesh)
322315
self.mesh = mesh
323316
self.dtype = dtype
324317
self.flip_sign = flip_sign
@@ -366,6 +359,41 @@ def _compute(self, queries: torch.Tensor | np.ndarray):
366359
return result
367360

368361

362+
def normalize_mesh_to_unit_cube(mesh: trimesh.Trimesh):
363+
"""
364+
Transform mesh coordinates uniformly to [-1, 1] in all axes.
365+
Keeps aspect ratio of original mesh.
366+
"""
367+
logger.debug(f"Scaling mesh from {mesh.bounds.flatten()}")
368+
# --- Compute bounding box ---
369+
bbox_min = mesh.bounds[0] # [x_min, y_min, z_min]
370+
bbox_max = mesh.bounds[1] # [x_max, y_max, z_max]
371+
372+
# Center of the mesh
373+
center = (bbox_max + bbox_min) / 2.0
374+
375+
# Largest extent
376+
scale = (
377+
np.max(bbox_max - bbox_min) / 2.0
378+
) # divide by 2 because [-1,1] spans 2 units
379+
380+
# --- Build transformation matrix ---
381+
matrix = np.eye(4)
382+
383+
# Translate to origin
384+
matrix[:3, 3] = -center
385+
386+
# Apply translation
387+
mesh.apply_transform(matrix)
388+
389+
# --- Apply uniform scaling ---
390+
scale_matrix = np.eye(4)
391+
scale_matrix[:3, :3] *= 1.0 / scale
392+
mesh.apply_transform(scale_matrix)
393+
logger.debug(f"to {mesh.bounds.flatten()}")
394+
return mesh
395+
396+
369397
class SDFfromLineMesh(SDFBase):
370398
line_mesh: gustaf.Edges
371399

@@ -438,20 +466,20 @@ def _compute(self, queries: torch.Tensor) -> torch.Tensor:
438466
queries = queries.to(self.model.device) * 2 - 1
439467
n_queries = queries.shape[0]
440468

441-
if self.latvec is None:
442-
latvec = self.parametrization(queries).to(self.model.device)
443-
else:
444-
latvec = self.latvec.to(self.model.device)
445-
446469
sdf_values = torch.zeros(n_queries, device=self.model.device)
447470

448471
head = 0
449472
while head < n_queries:
450473
end = min(head + self.max_batch, n_queries)
451474
query_batch = queries[head:end]
452475

476+
if self.latvec is None:
477+
latvec = self.parametrization(query_batch).to(self.model.device)
478+
else:
479+
latvec = self.latvec.to(self.model.device)[head:end]
480+
453481
sdf_values[head:end] = (
454-
self.model._decode_sdf(latvec[head:end], query_batch).squeeze(1)
482+
self.model._decode_sdf(latvec, query_batch).squeeze(1)
455483
# .detach()
456484
)
457485

DeepSDFStruct/sampling.py

Lines changed: 122 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import vtk
23
import numpy as np
34
import json
45
import pathlib
@@ -74,7 +75,6 @@ def process_single_geometry(args):
7475
outdir,
7576
dataset_name,
7677
unify_multipatches,
77-
compute_mechanical_properties,
7878
n_faces,
7979
n_samples,
8080
sampling_strategy,
@@ -101,25 +101,24 @@ def process_single_geometry(args):
101101
sdf, show=show, n_samples=n_samples, sampling_strategy=sampling_strategy
102102
)
103103

104-
if compute_mechanical_properties:
105-
mesh_file_name = f"{instance_id}.mesh"
106-
raise NotImplementedError("Compute homogenized material not available yet.")
107-
mesh_file_path = folder_name / "homogenization" / instance_id / mesh_file_name
108-
# E = computeHomogenizedMaterialProperties(
109-
# sdf, mesh_file_path=mesh_file_path, mirror=True
110-
# )
111-
np.savez(fname, neg=neg.stacked, pos=pos.stacked, E=E)
112-
else:
113-
np.savez(fname, neg=neg.stacked, pos=pos.stacked)
104+
np.savez(fname, neg=neg.stacked, pos=pos.stacked)
114105

115106

116107
class SDFSampler:
117-
def __init__(self, outdir, splitdir, dataset_name, unify_multipatches=True) -> None:
108+
def __init__(
109+
self,
110+
outdir,
111+
splitdir,
112+
dataset_name,
113+
unify_multipatches=True,
114+
stds=[0.05, 0.025],
115+
) -> None:
118116
self.outdir = outdir
119117
self.splitdir = splitdir
120118
self.dataset_name = dataset_name
121119
self.unify_multipatches = unify_multipatches
122120
self.geometries = {}
121+
self.stds = stds
123122

124123
def add_class(self, geom_list: list, class_name: str) -> None:
125124
instances = {}
@@ -147,8 +146,8 @@ def process_geometries(
147146
n_faces=100,
148147
n_samples: int = 1e5,
149148
unify_multipatches=True,
150-
compute_mechanical_properties=True,
151-
show=False,
149+
add_surface_samples=True,
150+
also_save_vtk=False,
152151
):
153152
for class_name, instance_list in self.geometries.items():
154153
logger.info(f"processing geometry list {class_name}")
@@ -162,68 +161,48 @@ def process_geometries(
162161
fname = folder_name / file_name
163162
if not os.path.exists(folder_name):
164163
os.makedirs(folder_name)
165-
if os.path.isfile(fname) and show == False:
164+
if os.path.isfile(fname):
166165
logger.warning(f"File {fname} already exists")
167166
continue
168167
sdf = self.get_sdf_from_geometry(
169168
geometry, n_faces, self.unify_multipatches
170169
)
171-
pos, neg = self.sample_sdf(
170+
sampled_sdf = random_sample_sdf(
172171
sdf,
173-
show=show,
174-
n_samples=n_samples,
175-
sampling_strategy=sampling_strategy,
172+
bounds=(-1, 1),
173+
n_samples=int(n_samples),
174+
type=sampling_strategy,
176175
)
177-
if compute_mechanical_properties:
178-
mesh_file_name = f"{instance_id}.mesh"
179-
mesh_file_path = (
180-
folder_name / "homogenization" / instance_id / mesh_file_name
181-
)
182-
raise NotImplementedError(
183-
"Compute homogenized material not available yet."
176+
if add_surface_samples:
177+
if not isinstance(geometry, trimesh.Trimesh):
178+
logger.warning(
179+
"Add surface samples was specified, but geometry"
180+
f"is not given as a trimesh.Trimesh but as {type(geometry)}"
181+
)
182+
surf_samples = sample_mesh_surface(
183+
sdf,
184+
sdf.mesh,
185+
int(n_samples // 2),
186+
self.stds,
187+
device="cpu",
188+
dtype=torch.float32,
184189
)
185-
E = computeHomogenizedMaterialProperties(
186-
sdf, mesh_file_path=mesh_file_path, mirror=True
187-
)
188-
np.savez(fname, neg=neg.stacked, pos=pos.stacked, E=E)
189-
else:
190-
np.savez(fname, neg=neg.stacked, pos=pos.stacked)
190+
sampled_sdf += surf_samples
191+
pos, neg = sampled_sdf.split_pos_neg()
191192

192-
def sample_sdf(
193-
self,
194-
sdf,
195-
show=False,
196-
n_samples: int = 1e5,
197-
sampling_strategy="uniform",
198-
box_size=None,
199-
stds=[0.0025, 0.00025],
200-
):
201-
202-
sampled_sdf = random_sample_sdf(
203-
sdf, bounds=(-1, 1), n_samples=int(n_samples), type=sampling_strategy
204-
)
205-
206-
pos, neg = sampled_sdf.split_pos_neg()
207-
208-
if show:
209-
vp_pos = pos.create_gus_plottable()
210-
vp_neg = neg.create_gus_plottable()
211-
vp_pos.show_options["cmap"] = "coolwarm"
212-
vp_neg.show_options["cmap"] = "coolwarm"
213-
vp_pos.show_options["vmin"] = -0.1
214-
vp_pos.show_options["vmax"] = 0.1
215-
vp_neg.show_options["vmin"] = -0.1
216-
vp_neg.show_options["vmax"] = 0.1
217-
gus.show(vp_neg, vp_pos)
218-
return pos, neg
193+
np.savez(fname, neg=neg.stacked, pos=pos.stacked)
194+
if also_save_vtk:
195+
save_points_to_vtp(
196+
fname.with_suffix(".vtp"), neg=neg.stacked, pos=pos.stacked
197+
)
219198

220199
def get_sdf_from_geometry(
221200
self,
222201
geometry,
223202
n_faces: int,
224203
unify_multipatches: bool = True,
225204
threshold: float = 1e-5,
226-
) -> SDFBase:
205+
) -> SDFfromMesh:
227206
if isinstance(geometry, splinepy.Multipatch):
228207
if unify_multipatches:
229208
patch_meshs = []
@@ -241,6 +220,8 @@ def get_sdf_from_geometry(
241220
sdf_geom = SDFfromMesh(
242221
geometry.extract.faces(n_faces), threshold=threshold
243222
)
223+
elif isinstance(geometry, trimesh.Trimesh):
224+
sdf_geom = SDFfromMesh(geometry, threshold=threshold)
244225

245226
else:
246227
raise NotImplementedError(
@@ -249,6 +230,44 @@ def get_sdf_from_geometry(
249230

250231
return sdf_geom
251232

233+
def get_meshs_from_folder(self, foldername, mesh_type) -> list:
234+
"""
235+
Reads all mesh files of a given type (extension) from a folder using meshio.
236+
237+
Parameters
238+
----------
239+
foldername : str
240+
Path to the folder containing the mesh files.
241+
mesh_type : str
242+
Mesh file extension (e.g., 'vtk', 'obj', 'stl', 'msh', 'xdmf').
243+
244+
Returns
245+
-------
246+
list[trimesh.Trimesh]
247+
A list of trimesh.Trimesh objects loaded from the folder.
248+
"""
249+
meshes = []
250+
251+
# Normalize extension (remove dot if present)
252+
mesh_type = mesh_type.lstrip(".")
253+
254+
# Iterate through all files in the folder
255+
for filename in os.listdir(foldername):
256+
if filename.lower().endswith("." + mesh_type.lower()):
257+
filepath = os.path.join(foldername, filename)
258+
try:
259+
faces = gus.io.meshio.load(filepath)
260+
trim = trimesh.Trimesh(faces.vertices, faces.elements)
261+
meshes.append(trim)
262+
logger.info(f"Loaded mesh: {filename}")
263+
except ValueError as e:
264+
logger.warning(f"Could not read {filename}: {e}")
265+
266+
if not meshes:
267+
print(f"No .{mesh_type} meshes found in {foldername}.")
268+
269+
return meshes
270+
252271
def write_json(self, json_fname):
253272
json_content = defaultdict(lambda: defaultdict(list))
254273
for class_name, instance_list in self.geometries.items():
@@ -364,3 +383,46 @@ def sample_mesh_surface(
364383
distances = sdf(queries)
365384

366385
return SampledSDF(samples=queries, distances=distances)
386+
387+
388+
def save_points_to_vtp(filename, neg, pos):
389+
"""
390+
Save pos/neg SDF sample points as a VTU point cloud using vtkPolyData.
391+
Each point has an SDF scalar value.
392+
"""
393+
# Combine points
394+
all_points = np.vstack((pos, neg))
395+
coords = all_points[:, :3]
396+
sdf_vals = all_points[:, 3]
397+
398+
# --- Create vtkPoints ---
399+
vtk_points = vtk.vtkPoints()
400+
for pt in coords:
401+
vtk_points.InsertNextPoint(pt)
402+
403+
# --- Create PolyData ---
404+
polydata = vtk.vtkPolyData()
405+
polydata.SetPoints(vtk_points)
406+
407+
# Add vertex cells (required for points in PolyData)
408+
verts = vtk.vtkCellArray()
409+
for i in range(len(coords)):
410+
verts.InsertNextCell(1)
411+
verts.InsertCellPoint(i)
412+
polydata.SetVerts(verts)
413+
414+
# --- Add SDF scalar values ---
415+
vtk_array = vtk.vtkDoubleArray()
416+
vtk_array.SetName("SDF")
417+
vtk_array.SetNumberOfValues(len(sdf_vals))
418+
for i, val in enumerate(sdf_vals):
419+
vtk_array.SetValue(i, val)
420+
polydata.GetPointData().SetScalars(vtk_array)
421+
422+
# --- Write to VTU ---
423+
writer = vtk.vtkXMLPolyDataWriter()
424+
writer.SetFileName(filename)
425+
writer.SetInputData(polydata)
426+
writer.Write()
427+
428+
logger.debug(f"Saved {len(coords)} points with SDF to '{filename}'")

0 commit comments

Comments
 (0)