From cd6ed53c94ddf614a1b24ab4d8064c3a2d90121a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:39:40 +0200 Subject: [PATCH 01/39] Add `CollidableShapeType` enum --- src/jaxsim/utils/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index d0b881ceb..b0287d409 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -3,3 +3,17 @@ from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing from .wrappers import HashedNumpyArray, HashlessObject + +from typing import ClassVar + + +# TODO (flferretti): Definetely not the best place for this +class CollidableShapeType: + """ + Enum representing the types of collidable shapes. + """ + + Sphere: ClassVar[int] = 0 + Box: ClassVar[int] = 1 + Cylinder: ClassVar[int] = 2 + Unsupported: ClassVar[int] = -1 From d8bae760c2330da7822545b940f6cdca52ab8f14 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:40:46 +0200 Subject: [PATCH 02/39] Transition to `CollidableShape` and define collidables classes --- src/jaxsim/parsers/descriptions/collision.py | 169 +++++-------------- 1 file changed, 42 insertions(+), 127 deletions(-) diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 719c92d2b..369285fb4 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -1,178 +1,93 @@ from __future__ import annotations -import abc import dataclasses - -import jax.numpy as jnp -import numpy as np -import numpy.typing as npt +from abc import ABC import jaxsim.typing as jtp -from jaxsim import logging - -from .link import LinkDescription @dataclasses.dataclass -class CollidablePoint: +class CollisionShape(ABC): """ - Represents a collidable point associated with a parent link. + Base class for collision shapes. - Attributes: - parent_link: The parent link to which the collidable point is attached. - position: The position of the collidable point relative to the parent link. - enabled: A flag indicating whether the collidable point is enabled for collision detection. + This class serves as a base for specific collision shapes like BoxCollision and SphereCollision. + It is not intended to be instantiated directly. """ - parent_link: LinkDescription - position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) - enabled: bool = True - - def change_link( - self, new_link: LinkDescription, new_H_old: npt.NDArray - ) -> CollidablePoint: - """ - Move the collidable point to a new parent link. - - Args: - new_link (LinkDescription): The new parent link to which the collidable point is moved. - new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame. - - Returns: - CollidablePoint: A new collidable point associated with the new parent link. - """ - - msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}" - logging.debug(msg=msg) - - return CollidablePoint( - parent_link=new_link, - position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3], - enabled=self.enabled, - ) + center: jtp.VectorLike + size: jtp.VectorLike + parent_link: str def __hash__(self) -> int: - return hash( ( + hash(tuple(self.center.tolist())), + hash(tuple(self.size.tolist())), hash(self.parent_link), - hash(tuple(self.position.tolist())), - hash(self.enabled), ) ) - def __eq__(self, other: CollidablePoint) -> bool: + def __eq__(self, other: CollisionShape) -> bool: - if not isinstance(other, CollidablePoint): + if not isinstance(other, CollisionShape): return False return hash(self) == hash(other) - def __str__(self) -> str: - return ( - f"{self.__class__.__name__}(" - + f"parent_link={self.parent_link.name}" - + f", position={self.position}" - + f", enabled={self.enabled}" - + ")" - ) - - -@dataclasses.dataclass -class CollisionShape(abc.ABC): - """ - Abstract base class for representing collision shapes. - - Attributes: - collidable_points: A list of collidable points associated with the collision shape. - """ - - collidable_points: tuple[CollidablePoint] - - def __str__(self): - return ( - f"{self.__class__.__name__}(" - + "collidable_points=[\n " - + ",\n ".join(str(cp) for cp in self.collidable_points) - + "\n])" - ) - @dataclasses.dataclass class BoxCollision(CollisionShape): """ Represents a box-shaped collision shape. - - Attributes: - center: The center of the box in the local frame of the collision shape. """ - center: jtp.VectorLike + @property + def x(self) -> float: + return self.size[0] - def __hash__(self) -> int: - return hash( - ( - hash(super()), - hash(tuple(self.center.tolist())), - ) - ) + @property + def y(self) -> float: + return self.size[1] - def __eq__(self, other: BoxCollision) -> bool: + @property + def z(self) -> float: + return self.size[2] - if not isinstance(other, BoxCollision): - return False + @x.setter + def x(self, value: float) -> None: + self.size[0] = value - return hash(self) == hash(other) + @y.setter + def y(self, value: float) -> None: + self.size[1] = value + + @z.setter + def z(self, value: float) -> None: + self.size[2] = value @dataclasses.dataclass class SphereCollision(CollisionShape): """ Represents a spherical collision shape. - - Attributes: - center: The center of the sphere in the local frame of the collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(super()), - hash(tuple(self.center.tolist())), - ) - ) - - def __eq__(self, other: BoxCollision) -> bool: - - if not isinstance(other, BoxCollision): - return False - - return hash(self) == hash(other) + @property + def radius(self) -> float: + return self.size[0] @dataclasses.dataclass -class MeshCollision(CollisionShape): +class CylinderCollision(CollisionShape): """ - Represents a mesh-shaped collision shape. - - Attributes: - center: The center of the mesh in the local frame of the collision shape. + Represents a cylindrical collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(tuple(self.center.tolist())), - hash(self.collidable_points), - ) - ) - - def __eq__(self, other: MeshCollision) -> bool: - if not isinstance(other, MeshCollision): - return False + @property + def radius(self) -> float: + return self.size[0] - return hash(self) == hash(other) + @property + def height(self) -> float: + return self.size[1] From 6fe426a1504f938e4450990690daa765b926ba4c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:41:32 +0200 Subject: [PATCH 03/39] Update collidable shape builders --- src/jaxsim/parsers/descriptions/__init__.py | 4 +- src/jaxsim/parsers/descriptions/model.py | 128 +++++------------- src/jaxsim/parsers/rod/parser.py | 25 ++-- src/jaxsim/parsers/rod/utils.py | 137 +++----------------- 4 files changed, 60 insertions(+), 234 deletions(-) diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index ff3bf631d..0be19de64 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -1,9 +1,7 @@ from .collision import ( BoxCollision, - CollidablePoint, - CollisionShape, - MeshCollision, SphereCollision, + CylinderCollision, ) from .joint import JointDescription, JointGenericAxis, JointType from .link import LinkDescription diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 714df5791..08779c5bd 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -8,7 +8,6 @@ from jaxsim.logging import jaxsim_warn from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose -from .collision import CollidablePoint, CollisionShape from .joint import JointDescription from .link import LinkDescription @@ -28,7 +27,7 @@ class ModelDescription(KinematicGraph): fixed_base: bool = True - collision_shapes: tuple[CollisionShape, ...] = dataclasses.field( + collision_shapes: tuple = dataclasses.field( default_factory=list, repr=False ) @@ -38,7 +37,7 @@ def build_model_from( links: list[LinkDescription], joints: list[JointDescription], frames: list[LinkDescription] | None = None, - collisions: tuple[CollisionShape, ...] = (), + collisions: tuple = (), fixed_base: bool = False, base_link_name: str | None = None, considered_joints: Sequence[str] | None = None, @@ -81,61 +80,53 @@ def build_model_from( fk = KinematicGraphTransforms(graph=kinematic_graph) # Container of the final model's collision shapes. - final_collisions: list[CollisionShape] = [] + final_collisions: list = [] # Move and express the collision shapes of removed links to the resulting # lumped link that replace the combination of the removed link and its parent. for collision_shape in collisions: - # Get all the collidable points of the shape - coll_points = tuple(collision_shape.collidable_points) - - # Assume they have an unique parent link - if not len(set({cp.parent_link.name for cp in coll_points})) == 1: - msg = "Collision shape not currently supported (multiple parent links)" - raise RuntimeError(msg) - # Get the parent link of the collision shape. # Note that this link could have been lumped and we need to find the # link in which it was lumped into. - parent_link_of_shape = collision_shape.collidable_points[0].parent_link + parent_link_of_shape = collision_shape.parent_link # If it is part of the (reduced) graph, add it as it is... - if parent_link_of_shape.name in kinematic_graph.link_names(): + if parent_link_of_shape in kinematic_graph.link_names(): final_collisions.append(collision_shape) continue # ... otherwise look for the frame - if parent_link_of_shape.name not in kinematic_graph.frame_names(): + if parent_link_of_shape not in kinematic_graph.frame_names(): msg = "Parent frame '{}' of collision shape not found, ignoring shape" - logging.info(msg.format(parent_link_of_shape.name)) + logging.info(msg.format(parent_link_of_shape)) continue # Create a new collision shape - new_collision_shape = CollisionShape(collidable_points=()) - final_collisions.append(new_collision_shape) - - # If the frame was found, update the collidable points' pose and add them - # to the new collision shape. - for cp in collision_shape.collidable_points: - # Find the link that is part of the (reduced) model in which the - # collision shape's parent was lumped into - real_parent_link_name = kinematic_graph.frames_dict[ - parent_link_of_shape.name - ].parent_name - - # Change the link associated to the collidable point, updating their - # relative pose - moved_cp = cp.change_link( - new_link=kinematic_graph.links_dict[real_parent_link_name], - new_H_old=fk.relative_transform( - relative_to=real_parent_link_name, - name=cp.parent_link.name, - ), - ) - - # Store the updated collision. - new_collision_shape.collidable_points += (moved_cp,) + # new_collision_shape = CollisionShape(collidable_points=()) + # final_collisions.append(new_collision_shape) + + # # If the frame was found, update the collidable points' pose and add them + # # to the new collision shape. + # for cp in collision_shape.collidable_points: + # # Find the link that is part of the (reduced) model in which the + # # collision shape's parent was lumped into + # real_parent_link_name = kinematic_graph.frames_dict[ + # parent_link_of_shape.name + # ].parent_name + + # # Change the link associated to the collidable point, updating their + # # relative pose + # moved_cp = cp.change_link( + # new_link=kinematic_graph.links_dict[real_parent_link_name], + # new_H_old=fk.relative_transform( + # relative_to=real_parent_link_name, + # name=cp.parent_link.name, + # ), + # ) + + # # Store the updated collision. + # new_collision_shape.collidable_points += (moved_cp,) # Build the model model = ModelDescription( @@ -194,63 +185,6 @@ def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: return reduced_model_description - def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None: - """ - Enable or disable collision shapes associated with a link. - - Args: - link_name: The name of the link. - enabled: Enable or disable collision shapes associated with the link. - """ - - if link_name not in self.link_names(): - raise ValueError(link_name) - - for point in self.collision_shape_of_link( - link_name=link_name - ).collidable_points: - point.enabled = enabled - - def collision_shape_of_link(self, link_name: str) -> CollisionShape: - """ - Get the collision shape associated with a specific link. - - Args: - link_name: The name of the link. - - Returns: - The collision shape associated with the link. - """ - - if link_name not in self.link_names(): - raise ValueError(link_name) - - return CollisionShape( - collidable_points=[ - point - for shape in self.collision_shapes - for point in shape.collidable_points - if point.parent_link.name == link_name - ] - ) - - def all_enabled_collidable_points(self) -> list[CollidablePoint]: - """ - Get all enabled collidable points in the model. - - Returns: - The list of all enabled collidable points. - - """ - - # Get iterator of all collidable points - all_collidable_points = itertools.chain.from_iterable( - [shape.collidable_points for shape in self.collision_shapes] - ) - - # Return enabled collidable points - return [cp for cp in all_collidable_points if cp.enabled] - def __eq__(self, other: ModelDescription) -> bool: if not isinstance(other, ModelDescription): diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index fcb0ad178..a4bfd45ed 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -27,7 +27,7 @@ class SDFData(NamedTuple): link_descriptions: list[descriptions.LinkDescription] joint_descriptions: list[descriptions.JointDescription] frame_descriptions: list[descriptions.LinkDescription] - collision_shapes: list[descriptions.CollisionShape] + collision_shapes: list sdf_model: rod.Model | None = None model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose() @@ -308,7 +308,7 @@ def extract_model_data( # ================ # Initialize the collision shapes - collisions: list[descriptions.CollisionShape] = [] + collisions = [] # Parse the collisions for link in sdf_model.links(): @@ -331,22 +331,13 @@ def extract_model_data( collisions.append(sphere_collision) continue - if collision.geometry.mesh is not None: - if int(os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")): - logging.warning("Mesh collision support is still experimental.") - mesh_collision = utils.create_mesh_collision( - collision=collision, - link_description=links_dict[link.name], - method=utils.meshes.extract_points_vertices, - ) - - collisions.append(mesh_collision) - - else: - logging.warning( - f"Skipping collision shape 'mesh' in link '{link.name}' because mesh collisions are disabled." - ) + if collision.geometry.cylinder is not None: + cylinder_collision = utils.create_cylinder_collision( + collision=collision, + link_description=links_dict[link.name], + ) + collisions.append(cylinder_collision) continue # Check any remaining non-None geometry types. diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index a295b7fab..85a3d0f20 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -117,41 +117,14 @@ def create_box_collision( center = np.array([x / 2, y / 2, z / 2]) - # Define the bottom corners. - bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]]) - - # Conditionally add the top corners based on the environment variable. - top_corners = ( - np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]]) - if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() - in { - "false", - "0", - } - else [] - ) - - # Combine and shift by the center - box_corners = np.vstack([bottom_corners, *top_corners]) - center - H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1] - box_corners_wrt_link = ( - H @ np.hstack([box_corners, np.vstack([1.0] * box_corners.shape[0])]).T - )[0:3, :] - - collidable_points = [ - descriptions.CollidablePoint( - parent_link=link_description, - position=np.array(corner), - enabled=True, - ) - for corner in box_corners_wrt_link.T - ] return descriptions.BoxCollision( - collidable_points=collidable_points, center=center_wrt_link + size=np.array([x, y, z]), + center=center_wrt_link, + parent_link=link_description.name, ) @@ -169,112 +142,42 @@ def create_sphere_collision( The sphere collision description. """ - # From https://stackoverflow.com/a/26127012 - def fibonacci_sphere(samples: int) -> npt.NDArray: - # Get the golden ratio in radians. - phi = np.pi * (3.0 - np.sqrt(5.0)) - - # Generate the points. - points = [ - np.array( - [ - np.cos(phi * i) - * np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2), - y, - np.sin(phi * i) * np.sqrt(1 - y**2), - ] - ) - for i in range(samples) - ] - - # Filter to keep only the bottom half if required. - if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() in { - "true", - "1", - }: - # Keep only the points with z <= 0. - points = [point for point in points if point[2] <= 0] - - return np.vstack(points) - r = collision.geometry.sphere.radius - sphere_points = r * fibonacci_sphere( - samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50")) - ) - H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] - sphere_points_wrt_link = ( - H @ np.hstack([sphere_points, np.vstack([1.0] * sphere_points.shape[0])]).T - )[0:3, :] - - collidable_points = [ - descriptions.CollidablePoint( - parent_link=link_description, - position=np.array(point), - enabled=True, - ) - for point in sphere_points_wrt_link.T - ] - return descriptions.SphereCollision( - collidable_points=collidable_points, center=center_wrt_link + size=np.array([r] * 3), + center=center_wrt_link, + parent_link=link_description.name, ) -def create_mesh_collision( - collision: rod.Collision, - link_description: descriptions.LinkDescription, - method: MeshMappingMethod = None, -) -> descriptions.MeshCollision: +def create_cylinder_collision( + collision: rod.Collision, link_description: descriptions.LinkDescription +) -> descriptions.CylinderCollision: """ - Create a mesh collision from an SDF collision element. + Create a cylinder collision from an SDF collision element. Args: collision: The SDF collision element. link_description: The link description. - method: The method to use for mesh wrapping. Returns: - The mesh collision description. + The cylinder collision description. """ - file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri)) - file_type = file.suffix.replace(".", "") - mesh = trimesh.load_mesh(file, file_type=file_type) + r = collision.geometry.cylinder.radius + l = collision.geometry.cylinder.length - if mesh.is_empty: - raise RuntimeError(f"Failed to process '{file}' with trimesh") + H = collision.pose.transform() if collision.pose is not None else np.eye(4) - mesh.apply_scale(collision.geometry.mesh.scale) - logging.info( - msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'" - ) + center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] - if method is None: - method = meshes.VertexExtraction() - logging.debug("Using default Vertex Extraction method for mesh wrapping") - else: - logging.debug(f"Using method {method} for mesh wrapping") - - points = method(mesh=mesh) - logging.debug(f"Extracted {len(points)} points from mesh") - - W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4) - - # Extract translation from transformation matrix - W_p_L = W_H_L[:3, 3] - mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L - collidable_points = [ - descriptions.CollidablePoint( - parent_link=link_description, - position=point, - enabled=True, - ) - for point in mesh_points_wrt_link - ] - - return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L) + return descriptions.CylinderCollision( + size=np.array([r, l, 0]), + center=center_wrt_link, + parent_link=link_description.name, + ) From 740b8cf514b449c8d502128bec0a84222f628d3e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:46:28 +0200 Subject: [PATCH 04/39] Transition to collidable shapes naming --- src/jaxsim/api/contact.py | 78 +++++++++++------------ src/jaxsim/api/data.py | 2 +- src/jaxsim/rbda/__init__.py | 2 +- src/jaxsim/rbda/collidable_points.py | 65 ------------------- src/jaxsim/rbda/collidable_shapes.py | 58 +++++++++++++++++ src/jaxsim/rbda/contacts/__init__.py | 2 +- src/jaxsim/rbda/contacts/common.py | 8 +-- src/jaxsim/rbda/contacts/relaxed_rigid.py | 6 +- tests/test_api_contact.py | 59 ++++++++--------- tests/test_simulations.py | 12 +--- 10 files changed, 131 insertions(+), 161 deletions(-) delete mode 100644 src/jaxsim/rbda/collidable_points.py create mode 100644 src/jaxsim/rbda/collidable_shapes.py diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index b56c75fab..95ebee444 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -17,7 +17,7 @@ @jax.jit @js.common.named_scope -def collidable_point_kinematics( +def collidable_shape_kinematics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -36,7 +36,7 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_shapes.collidable_shapes_pos_vel( model=model, link_transforms=data._link_transforms, link_velocities=data._link_velocities, @@ -47,7 +47,7 @@ def collidable_point_kinematics( @jax.jit @js.common.named_scope -def collidable_point_positions( +def collidable_shape_positions( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ @@ -61,14 +61,14 @@ def collidable_point_positions( The position of the collidable points in the world frame. """ - W_p_Ci, _ = collidable_point_kinematics(model=model, data=data) + W_p_Ci, _ = collidable_shape_kinematics(model=model, data=data) return W_p_Ci @jax.jit @js.common.named_scope -def collidable_point_velocities( +def collidable_shape_velocities( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ @@ -82,7 +82,7 @@ def collidable_point_velocities( The 3D velocity of the collidable points. """ - _, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data) + _, W_ṗ_Ci = collidable_shape_kinematics(model=model, data=data) return W_ṗ_Ci @@ -112,15 +112,15 @@ def in_contact( raise ValueError("One or more link names are not part of the model") # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes ) - parent_link_idx_of_enabled_collidable_points = jnp.array( + parent_link_idx_of_enabled_collidable_shapes = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + )[indices_of_enabled_collidable_shapes] - W_p_Ci = collidable_point_positions(model=model, data=data) + W_p_Ci = collidable_shape_positions(model=model, data=data) terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))( W_p_Ci[:, 0], W_p_Ci[:, 1] @@ -136,7 +136,7 @@ def in_contact( links_in_contact = jax.vmap( lambda link_index: jnp.where( - parent_link_idx_of_enabled_collidable_points == link_index, + parent_link_idx_of_enabled_collidable_shapes == link_index, below_terrain, jnp.zeros_like(below_terrain, dtype=bool), ).any() @@ -162,7 +162,7 @@ def estimate_good_contact_parameters( *, standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + number_of_active_collidable_shapes_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, ) -> jaxsim.rbda.contacts.ContactParamsTypes: @@ -173,7 +173,7 @@ def estimate_good_contact_parameters( model: The model to consider. standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. - number_of_active_collidable_points_steady_state: + number_of_active_collidable_shapes_steady_state: The number of active collidable points in steady state. damping_ratio: The damping ratio. max_penetration: The maximum penetration allowed. @@ -194,19 +194,19 @@ def estimate_good_contact_parameters( zero_data = js.data.JaxSimModelData.build(model=model) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] if model.floating_base(): - W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] + W_pz_C = collidable_shape_positions(model=model, data=zero_data)[:, -1] W_pz_CoM = W_pz_CoM - W_pz_C.min() # Consider as default a 1% of the model center of mass height. max_penetration = 0.01 * W_pz_CoM - nc = number_of_active_collidable_points_steady_state + nc = number_of_active_collidable_shapes_steady_state return model.contact_model._parameters_class().build_default_from_jaxsim_model( model=model, standard_gravity=standard_gravity, static_friction_coefficient=static_friction_coefficient, max_penetration=max_penetration, - number_of_active_collidable_points_steady_state=nc, + number_of_active_collidable_shapes_steady_state=nc, damping_ratio=damping_ratio, ) @@ -232,19 +232,19 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt """ # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes ) - parent_link_idx_of_enabled_collidable_points = jnp.array( + parent_link_idx_of_enabled_collidable_shapes = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + )[indices_of_enabled_collidable_shapes] # Get the transforms of the parent link of all collidable points. - W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points] + W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_shapes] L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes ] # Build the link-to-point transform from the displacement between the link frame L @@ -288,13 +288,13 @@ def jacobian( ) # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes ) - parent_link_idx_of_enabled_collidable_points = jnp.array( + parent_link_idx_of_enabled_collidable_shapes = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + )[indices_of_enabled_collidable_shapes] # Compute the Jacobians of all links. W_J_WL = js.model.generalized_free_floating_jacobian( @@ -304,7 +304,7 @@ def jacobian( # Compute the contact Jacobian. # In inertial-fixed output representation, the Jacobian of the parent link is also # the Jacobian of the frame C implicitly associated with the collidable point. - W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points] + W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_shapes] # Adjust the output representation. match output_vel_repr: @@ -377,17 +377,17 @@ def jacobian_derivative( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes ) # Get the index of the parent link and the position of the collidable point. - parent_link_idx_of_enabled_collidable_points = jnp.array( + parent_link_idx_of_enabled_collidable_shapes = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + )[indices_of_enabled_collidable_shapes] L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes ] # Get the transforms of all the parent links. @@ -505,7 +505,7 @@ def compute_O_J̇_WC_I( return O_J̇_WC_I O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))( - L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li + L_p_Ci, parent_link_idx_of_enabled_collidable_shapes, W_H_Li ) return O_J̇_WC @@ -575,8 +575,8 @@ def link_forces_from_contact_forces( contact_parameters = model.kin_dyn_parameters.contact_parameters # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - contact_parameters.indices_of_enabled_collidable_points + indices_of_enabled_collidable_shapes = ( + contact_parameters.indices_of_enabled_collidable_shapes ) # Convert the contact forces to a JAX array. @@ -585,13 +585,13 @@ def link_forces_from_contact_forces( # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( + parent_link_index_of_collidable_shapes = jnp.array( contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + )[indices_of_enabled_collidable_shapes] # Create the mask that associate each collidable point to their parent link. # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( + mask = parent_link_index_of_collidable_shapes[:, jnp.newaxis] == jnp.arange( model.number_of_links() ) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 113620f89..cbae3d145 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -176,7 +176,7 @@ def build( if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): contact_state["tangential_deformation"] = contact_state.get( "tangential_deformation", - jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), + jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.center), ) model_data = JaxSimModelData( diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 5e0af2a66..177aff6ee 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -1,6 +1,6 @@ from . import actuation, contacts from .aba import aba -from .collidable_points import collidable_points_pos_vel +from .collidable_shapes import collidable_shapes_pos_vel from .crba import crba from .forward_kinematics import forward_kinematics_model from .jacobian import ( diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py deleted file mode 100644 index 179126bb6..000000000 --- a/src/jaxsim/rbda/collidable_points.py +++ /dev/null @@ -1,65 +0,0 @@ -import jax -import jax.numpy as jnp - -import jaxsim.api as js -import jaxsim.typing as jtp -from jaxsim.math import Skew - - -def collidable_points_pos_vel( - model: js.model.JaxSimModel, - *, - link_transforms: jtp.Matrix, - link_velocities: jtp.Matrix, -) -> tuple[jtp.Matrix, jtp.Matrix]: - """ - - Compute the position and linear velocity of the enabled collidable points in the world frame. - - Args: - model: The model to consider. - link_transforms: The transforms from the world frame to each link. - link_velocities: The linear and angular velocities of each link. - - Returns: - A tuple containing the position and linear velocity of the enabled collidable points. - """ - - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points - ] - - if len(indices_of_enabled_collidable_points) == 0: - return jnp.array(0).astype(float), jnp.empty(0).astype(float) - - def process_point_kinematics( - Li_p_C: jtp.Vector, parent_body: jtp.Int - ) -> tuple[jtp.Vector, jtp.Vector]: - - # Compute the position of the collidable point. - W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3] - - # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}. - CW_vl_WCi = ( - jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) - @ link_velocities[parent_body].squeeze() - ) - - return W_p_Ci, CW_vl_WCi - - # Process all the collidable points in parallel. - W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( - L_p_Ci, - parent_link_idx_of_enabled_collidable_points, - ) - - return W_p_Ci, CW_vl_WC diff --git a/src/jaxsim/rbda/collidable_shapes.py b/src/jaxsim/rbda/collidable_shapes.py new file mode 100644 index 000000000..1cb0b2a66 --- /dev/null +++ b/src/jaxsim/rbda/collidable_shapes.py @@ -0,0 +1,58 @@ +import jax +import jax.numpy as jnp + +import jaxsim.api as js +import jaxsim.typing as jtp + + +def collidable_shapes_pos_vel( + model: js.model.JaxSimModel, + *, + link_transforms: jtp.Matrix, + link_velocities: jtp.Matrix, +) -> tuple[jtp.Matrix, jtp.Matrix]: + """ + + Compute the position and linear velocity of the enabled collidable shapes in the world frame. + + Args: + model: The model to consider. + link_transforms: The transforms from the world frame to each link. + link_velocities: The linear and angular velocities of each link. + + Returns: + A tuple containing the position and linear velocity of the enabled collidable shapes. + """ + + # Get the indices of the enabled collidable shapes. + indices_of_enabled_collidable_shapes = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes + ) + + parent_link_idx_of_enabled_collidable_shapes = jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_shapes] + + L_p_Ci = model.kin_dyn_parameters.contact_parameters.shape[ + indices_of_enabled_collidable_shapes + ] + + if len(indices_of_enabled_collidable_shapes) == 0: + return jnp.array(0).astype(float), jnp.empty(0).astype(float) + + def process_shape_kinematics( + Li_p_C: jtp.Vector, parent_body: jtp.Int + ) -> tuple[jtp.Vector, jtp.Vector]: + + # Compute the position of the collidable shape. + W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3] + + return W_p_Ci + + # Process all the collidable shapes in parallel. + W_p_Ci = jax.vmap(process_shape_kinematics)( + L_p_Ci, + parent_link_idx_of_enabled_collidable_shapes, + ) + + return W_p_Ci, link_velocities[:, :3] diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 32f05e229..6500afbb8 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,5 +1,5 @@ from . import relaxed_rigid, rigid, soft -from .common import ContactModel, ContactsParams +from .common import ContactModel, ContactsParams, CollidableShapeType from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams from .rigid import RigidContacts, RigidContactsParams from .soft import SoftContacts, SoftContactsParams diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index cc772f033..0f770bea3 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -10,7 +10,7 @@ import jaxsim.terrain import jaxsim.typing as jtp from jaxsim.math import STANDARD_GRAVITY -from jaxsim.utils import JaxsimDataclass +from jaxsim.utils import JaxsimDataclass, CollidableShapeType try: from typing import Self @@ -94,7 +94,6 @@ def build_default_from_jaxsim_model( standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, - number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, @@ -110,8 +109,6 @@ def build_default_from_jaxsim_model( standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. max_penetration: The maximum penetration depth. - number_of_active_collidable_points_steady_state: - The number of active collidable points in steady state. damping_ratio: The damping ratio. p: The first parameter of the contact model. q: The second parameter of the contact model. @@ -137,7 +134,6 @@ def build_default_from_jaxsim_model( ξ = damping_ratio δ_max = max_penetration μc = static_friction_coefficient - nc = number_of_active_collidable_points_steady_state # Compute the total mass of the model. m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() @@ -147,7 +143,7 @@ def build_default_from_jaxsim_model( # the damping term of the Hunt/Crossley model. if stiffness is None: # Compute the average support force on each collidable point. - f_average = m * standard_gravity / nc + f_average = m * standard_gravity stiffness = f_average / jnp.power(δ_max, 1 + p) stiffness = jnp.clip(stiffness, 0, MAX_STIFFNESS) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 0b08082ce..903473b04 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -327,14 +327,12 @@ def compute_contact_forces( # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( - model=model, data=data - ) + position, velocity = data._link_transforms[:3, 3], data._link_velocities[:3] # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, model.terrain + position, velocity, model.terrain, data.contact_parameters ) # Compute the position in the constraint frame. diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index e4fe0bbdf..cae9e5070 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -23,28 +23,21 @@ def test_contact_kinematics( velocity_representation=velocity_representation, ) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] + parent_link_idx_of_collidable_shapes = # ===== # Tests # ===== - # Compute the pose of the implicit contact frame associated to the collidable points + # Compute the pose of the implicit contact frame associated to the collidable shapes # and the transforms of all links. W_H_C = js.contact.transforms(model=model, data=data) W_H_L = data._link_transforms # Check that the orientation of the implicit contact frame matches with the - # orientation of the link to which the contact point is attached. + # orientation of the link to which the contact shape is attached. for contact_idx, index_of_parent_link in enumerate( - parent_link_idx_of_enabled_collidable_points + parent_link_idx_of_collidable_shapes ): assert_allclose( W_H_C[contact_idx, 0:3, 0:3], W_H_L[index_of_parent_link][0:3, 0:3] @@ -52,16 +45,16 @@ def test_contact_kinematics( # Check that the origin of the implicit contact frame is located over the # collidable point. - W_p_C = js.contact.collidable_point_positions(model=model, data=data) + W_p_C = js.contact.collidable_shape_positions(model=model, data=data) assert_allclose(W_p_C, W_H_C[:, 0:3, 3]) - # Compute the velocity of the collidable point. + # Compute the velocity of the collidable shape. # This quantity always matches with the linear component of the mixed 6D velocity - # of the implicit frame associated to the collidable point. - W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + # of the implicit frame associated to the collidable shape. + W_ṗ_C = js.contact.collidable_shape_velocities(model=model, data=data) - # Compute the velocity of the collidable point using the contact Jacobian. + # Compute the velocity of the collidable shape using the contact Jacobian. ν = data.generalized_velocity CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] @@ -70,7 +63,7 @@ def test_contact_kinematics( assert_allclose(W_ṗ_C, CW_vl_WC) -def test_collidable_point_jacobians( +def test_collidable_shape_jacobians( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, @@ -87,16 +80,16 @@ def test_collidable_point_jacobians( # Tests # ===== - # Compute the velocity of the collidable points with a RBDA. + # Compute the velocity of the collidable shapes with a RBDA. # This function always returns the linear part of the mixed velocity of the - # implicit frame C corresponding to the collidable point. - W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) + # implicit frame C corresponding to the collidable shape. + W_ṗ_C = js.contact.collidable_shape_velocities(model=model, data=data) # Compute the generalized velocity and the free-floating Jacobian of the frame C. ν = data.generalized_velocity CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) - # Compute the velocity of the collidable points using the Jacobians. + # Compute the velocity of the collidable shapes using the Jacobians. v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) assert_allclose(W_ṗ_C, v_WC_from_jax[:, 0:3]) @@ -117,21 +110,21 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + # Get the indices of the enabled collidable shapes. + indices_of_enabled_collidable_shapes = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes ) - # Extract the parent link names and the poses of the contact points. + # Extract the parent link names and the poses of the contact shapes. parent_link_names = js.link.idxs_to_names( model=model, link_indices=jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points], + )[indices_of_enabled_collidable_shapes], ) - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_points + L_p_Ci = model.kin_dyn_parameters.contact_parameters.shape[ + indices_of_enabled_collidable_shapes ] # ===== @@ -141,13 +134,13 @@ def test_contact_jacobian_derivative( # Load the model in ROD. rod_model = rod.Sdf.load(sdf=model.built_from).model - # Add dummy frames on the contact points. + # Add dummy frames on the contact shapes. for idx, link_name, L_p_C in zip( - indices_of_enabled_collidable_points, parent_link_names, L_p_Ci, strict=True + indices_of_enabled_collidable_shapes, parent_link_names, L_p_Ci, strict=True ): rod_model.add_frame( frame=rod.Frame( - name=f"contact_point_{idx}", + name=f"contact_shape_{idx}", attached_to=link_name, pose=rod.Pose( relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C) @@ -175,11 +168,11 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - # Extract the indexes of the frames attached to the contact points. + # Extract the indexes of the frames attached to the contact shapes. frame_idxs = js.frame.names_to_idxs( model=model_with_frames, frame_names=( - f"contact_point_{idx}" for idx in indices_of_enabled_collidable_points + f"contact_shape_{idx}" for idx in indices_of_enabled_collidable_shapes ), ) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 8b30a4f27..262e9ab51 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -197,7 +197,7 @@ def test_simulation_with_soft_contacts( model = jaxsim_model_box - # Define the maximum penetration of each collidable point at steady state. + # Define the maximum penetration at steady state. max_penetration = 0.001 with model.editable(validate=False) as model: @@ -205,21 +205,11 @@ def test_simulation_with_soft_contacts( model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() model.contact_params = js.contact.estimate_good_contact_parameters( model=model, - number_of_active_collidable_points_steady_state=4, static_friction_coefficient=1.0, damping_ratio=1.0, max_penetration=max_penetration, ) - # Enable a subset of the collidable points. - enabled_collidable_points_mask = np.zeros( - len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool - ) - enabled_collidable_points_mask[[0, 1, 2, 3]] = True - model.kin_dyn_parameters.contact_parameters.enabled = tuple( - enabled_collidable_points_mask.tolist() - ) - assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Check jaxsim_model_box@conftest.py. From 3c78f861f6671b9fe4872481b79410444ed73545 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:48:03 +0200 Subject: [PATCH 05/39] Add SDF for spheres, boxes and cylinders --- src/jaxsim/rbda/contacts/detection.py | 206 ++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 src/jaxsim/rbda/contacts/detection.py diff --git a/src/jaxsim/rbda/contacts/detection.py b/src/jaxsim/rbda/contacts/detection.py new file mode 100644 index 000000000..22ddca84e --- /dev/null +++ b/src/jaxsim/rbda/contacts/detection.py @@ -0,0 +1,206 @@ +import jaxsim +import jaxsim.typing as jtp +import jax.numpy as jnp +import jax + + +def _contact_frame(normal, position): + """Create a contact frame with z-axis aligned with the contact normal.""" + n = normal / jaxsim.math.safe_norm(normal) + + t1_initial = jnp.array([1.0, 0.0, 0.0]) + + t1 = t1_initial - jnp.dot(t1_initial, n) * n + t1 = t1 / jaxsim.math.safe_norm(t1) + t2 = jnp.cross(n, t1) + + R = jnp.stack([t1, t2, n], axis=1) + + return jaxsim.math.Transform.from_rotation_and_translation( + rotation=R, + translation=position, + ) + + +def sphere_plane(terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix): + """ + Detect contacts between a sphere and a plane terrain. + + Args: + terrain: The terrain object. + size: The size of the sphere. + W_H_L: The collision shape transform in world coordinates. + + Returns: + A tuple containing the distance from the sphere to the plane and the pose transform + of the contact frame. + """ + center = W_H_L[0:3, 3] + normal = terrain.normal(x=center[0], y=center[1]) + distance = jnp.dot(center - terrain._height, normal) - size[0] + position = normal * (size[0] + 0.5 * distance) - center + W_H_C = jaxsim.math.Transform.from_rotation_and_translation( + rotation=jaxsim.math.Rotation.from_axis_angle(normal), + translation=-position, + ) + return distance, W_H_C + + +# TODO (flferretti): Keep only the SDF version? +def box_plane_sdf(terrain, size, W_H_L): + """ + Return distances and contact frames of the 3 deepest corners of a box on terrain using SDF. + Fully vectorized, works for any box orientation. + """ + half_size = size.squeeze() / 2 + + R = W_H_L[:3, :3] + t = W_H_L[:3, 3] + + # Generate all 8 corners using meshgrid + sx = jnp.array([-half_size[0], half_size[0]]) + sy = jnp.array([-half_size[1], half_size[1]]) + sz = jnp.array([-half_size[2], half_size[2]]) + xs, ys, zs = jnp.meshgrid(sx, sy, sz, indexing="ij") + corners_local = jnp.stack( + [xs.ravel(), ys.ravel(), zs.ravel()], axis=1 + ) # shape (8,3) + + box_z_world = R[:, 2] + flip_sign = jnp.sign(box_z_world) + R_corrected = R.at[:, 2].set(R[:, 2] * flip_sign) # flip z-axis if needed + + # Transform to world frame + corners_world = t + (R_corrected @ corners_local.T).T # shape (8,3) + + # Vectorized terrain height and normal using vmap + terrain_height_vmap = jax.vmap(lambda p: terrain.height(p[0], p[1])) + terrain_normal_vmap = jax.vmap(lambda p: terrain.normal(p[0], p[1])) + + terrain_heights = terrain_height_vmap(corners_world) + terrain_points = jnp.stack( + [corners_world[:, 0], corners_world[:, 1], terrain_heights], axis=1 + ) + + normals = terrain_normal_vmap(corners_world) + + # Distances along terrain normal + distances = jnp.einsum("ij,ij->i", corners_world - terrain_points, normals) + + # Pick 3 closest points using top_k + _, topk_idx = jax.lax.top_k(-distances, 3) + contact_points = corners_world[topk_idx] + contact_normals = normals[topk_idx] + + # Compute contact frames using vmap + W_H_C = jax.vmap(lambda p, n: _contact_frame(n, p))(contact_points, contact_normals) + + # Distances along terrain normal for the selected points + distances_top3 = distances[topk_idx] + + return distances_top3, W_H_C + + +def box_plane( + terrain: jaxsim.terrain.Terrain, + size: jtp.Vector, + W_H_L: jtp.Matrix, +): + """ + Detect contacts between a box and a plane terrain. + Finds the actual contact point on the box surface (vertex, edge, or face). + + Args: + terrain: The terrain object with _height(x, y) method and normal(x, y) method. + size: A 3D vector [width, height, depth] representing the box dimensions from center. + W_H_L: The collision shape transform in world coordinates. + + Returns: + A tuple containing the distance from the box to the plane and the pose transform + of the contact frame. + """ + half_size = size.squeeze() / 2 + center = W_H_L[:3, 3] + R = W_H_L[:3, :3] + + # Transform terrain normal at box center into world coordinates + normal = terrain.normal(center[0], center[1]) + + # Find the box vertex furthest in the opposite direction of terrain normal + local_normal = R.T @ normal + support_local = -half_size * jnp.sign(local_normal) + + # Vertex in world coordinates + support_world = center + R @ support_local + + # Terrain point and distance + terrain_z = terrain.height(support_world[0], support_world[1]) + terrain_point = jnp.array([support_world[0], support_world[1], terrain_z]) + distance = jnp.dot(support_world - terrain_point, normal) + + # Contact frame + contact_point = support_world - distance * normal + W_H_C = _contact_frame(normal, contact_point) + + return distance, W_H_C + + +def cylinder_plane( + terrain: jaxsim.terrain.Terrain, + size: jtp.Vector, + W_H_L: jtp.Matrix, +): + """ + Detect contacts between a cylinder and a plane terrain. + Finds the actual contact point on the cylinder surface (vertex, edge, or face). + + Args: + terrain: The terrain object with _height(x, y) method and normal(x, y) method. + size: A 3D vector [width, height, depth] representing the cylinder dimensions from center. + W_H_L: The collision shape transform in world coordinates. + + Returns: + A tuple containing the distance from the cylinder to the plane, the contact point position + and the contact frame. + """ + radius = size[0] + half_length = size[1] / 2.0 + + center = W_H_L[0:3, 3] + axis = W_H_L[0:3, 2] / jnp.linalg.norm(W_H_L[0:3, 2]) + + x, y = center[0], center[1] + n = terrain.normal(x, y) + h = terrain.height(x, y) + p0 = jnp.array([x, y, h]) + + d0 = jnp.dot(n, center - p0) + proj = jnp.dot(n, axis) + side_term = radius * jnp.sqrt(jnp.maximum(0.0, 1.0 - proj**2)) + cap_term = half_length * jnp.abs(proj) + distance = d0 - cap_term - side_term + + # contact point + use_side = jnp.abs(proj) < 1.0 - 1e-6 + radial = n - proj * axis + radial /= jnp.linalg.norm(radial) + 1e-12 + side_pt = center + half_length * jnp.sign(proj) * axis + radius * radial + cap_pt = center + half_length * jnp.sign(proj) * axis + support = jnp.where(use_side, side_pt, cap_pt) + contact_point = support - n * distance + + # --- contact frame --- + z_axis = n / (jnp.linalg.norm(n) + 1e-12) + cand = jnp.where( + jnp.abs(jnp.dot(axis, z_axis)) < 0.9, axis, jnp.array([1.0, 0.0, 0.0]) + ) + x_axis = cand - jnp.dot(cand, z_axis) * z_axis + x_axis = x_axis / (jnp.linalg.norm(x_axis) + 1e-12) + y_axis = jnp.cross(z_axis, x_axis) + R = jnp.stack([x_axis, y_axis, z_axis], axis=1) + + W_H_C = jnp.vstack( + [jnp.hstack([R, contact_point[:, None]]), jnp.array([0.0, 0.0, 0.0, 1.0])], + ) + + return distance, W_H_C From a5ad4d8c603b62ae1c9709c90b3e58a7d25ca1f0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:49:08 +0200 Subject: [PATCH 06/39] Update computation of penetration API --- src/jaxsim/rbda/contacts/common.py | 32 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 0f770bea3..1f9452f56 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -17,16 +17,25 @@ except ImportError: from typing_extensions import Self +from .detection import sphere_plane, box_plane, cylinder_plane MAX_STIFFNESS = 1e6 MAX_DAMPING = 1e4 +# Define a mapping from collidable shape types to distance functions. +_COLLISION_MAP = { + CollidableShapeType.Sphere: sphere_plane, + CollidableShapeType.Box: box_plane, + CollidableShapeType.Cylinder: cylinder_plane, +} + @functools.partial(jax.jit, static_argnames=("terrain",)) def compute_penetration_data( p: jtp.VectorLike, v: jtp.VectorLike, terrain: jaxsim.terrain.Terrain, + contact_parameters: js.kin_dyn_parameters.ContactParameters | None = None, ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]: """ Compute the penetration data (depth, rate, and terrain normal) of a collidable point. @@ -37,6 +46,7 @@ def compute_penetration_data( The linear velocity of the point (linear component of the mixed 6D velocity of the implicit frame `C = (W_p_C, [W])` associated to the point). terrain: The considered terrain. + contact_parameters: The parameters of the collidable shapes. Returns: A tuple containing the penetration depth, the penetration velocity, @@ -44,23 +54,15 @@ def compute_penetration_data( """ # Pre-process the position and the linear velocity of the collidable point. - W_ṗ_C = jnp.array(v).squeeze() - px, py, pz = jnp.array(p).squeeze() - - # Compute the terrain normal and the contact depth. - n̂ = terrain.normal(x=px, y=py).squeeze() - h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz]) - - # Compute the penetration depth normal to the terrain. - δ = jnp.maximum(0.0, jnp.dot(h, n̂)) - - # Compute the penetration normal velocity. - δ_dot = -jnp.dot(W_ṗ_C, n̂) + distance_fn = _COLLISION_MAP[contact_parameters.shape_type] - # Enforce the penetration rate to be zero when the penetration depth is zero. - δ_dot = jnp.where(δ > 0, δ_dot, 0.0) + δ, W_H_C = distance_fn( + terrain=terrain, + size=contact_parameters.shape_size, + center=contact_parameters.center, + ) - return δ, δ_dot, n̂ + return δ, W_H_C class ContactsParams(JaxsimDataclass): From 28d35689b02068baa3d645512b9980f16f9bbd66 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:49:23 +0200 Subject: [PATCH 07/39] Update Hunt-Crossley model to work with collidable shapes --- src/jaxsim/rbda/contacts/soft.py | 61 +++++++++++++++----------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 4a2f37f53..5fc1b82a0 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -193,12 +193,12 @@ def update_velocity_after_impact( return data @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) + @jax.jit def hunt_crossley_contact_model( - position: jtp.VectorLike, velocity: jtp.VectorLike, tangential_deformation: jtp.VectorLike, - terrain: Terrain, + distance: jtp.VectorLike, + normal: jtp.VectorLike, K: jtp.FloatLike, D: jtp.FloatLike, mu: jtp.FloatLike, @@ -209,10 +209,10 @@ def hunt_crossley_contact_model( Compute the contact force using the Hunt/Crossley model. Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - tangential_deformation: The material deformation of the collidable point. - terrain: The terrain model. + velocity: The velocity of the collidable shape. + tangential_deformation: The material deformation of the collidable shape. + distance: The distance from the collidable shape to the terrain. + normal: The terrain normal at the contact point. K: The stiffness parameter. D: The damping parameter of the soft contacts model. mu: The static friction coefficient. @@ -229,15 +229,20 @@ def hunt_crossley_contact_model( """ # Convert the input vectors to arrays. - W_p_C = jnp.array(position, dtype=float).squeeze() W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() m = jnp.array(tangential_deformation, dtype=float).squeeze() # Use symbol for the static friction. μ = mu - # Compute the penetration depth, its rate, and the considered terrain normal. - δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) + # Ensure non-negative distance. + δ = jnp.maximum(0.0, -distance) + + # Extract the contact velocity. + δ̇ = -W_ṗ_C[2] + + # Compute the normal. + n̂ = normal / jaxsim.math.safe_norm(normal) # There are few operations like computing the norm of a vector with zero length # or computing the square root of zero that are problematic in an AD context. @@ -351,9 +356,9 @@ def compute_contact_force( Compute the contact force. Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. - tangential_deformation: The material deformation of the collidable point. + position: The position of the collidable shape. + velocity: The velocity of the collidable shape. + tangential_deformation: The material deformation of the collidable shape. parameters: The parameters of the soft contacts model. terrain: The terrain model. @@ -405,40 +410,32 @@ def compute_contact_forces( second element a dictionary with derivative of the material deformation. """ - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - # Compute the position and linear velocities (mixed representation) of - # all the collidable points belonging to the robot and extract the ones - # for the enabled collidable points. - W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) + # all the collidable shapes belonging to the robot and extract the ones + # for the enabled collidable shapes. + W_p_C, W_ṗ_C = js.contact.collidable_shape_kinematics(model=model, data=data) - # Extract the material deformation corresponding to the collidable points. + # Extract the material deformation corresponding to the collidable shapes. m = ( data.contact_state["tangential_deformation"] if "tangential_deformation" in data.contact_state else jnp.zeros_like(W_p_C) ) - m_enabled = m[indices_of_enabled_collidable_points] - - # Initialize the tangential deformation rate array for every collidable point. + # Initialize the tangential deformation rate array for every collidable shape. ṁ = jnp.zeros_like(m) - # Compute the contact forces only for the enabled collidable points. + # Compute the contact forces only for the enabled collidable shapes. # Since we treat them as independent, we can vmap the computation. - W_f, ṁ_enabled = jax.vmap( + W_f, ṁ = jax.vmap( lambda p, v, m: SoftContacts.compute_contact_force( - position=p, - velocity=v, + center_position=p, + center_velocity=v, tangential_deformation=m, parameters=model.contact_params, terrain=model.terrain, + size=model.kin_dyn_parameters.contact_parameters.shape_size, ) - )(W_p_C, W_ṗ_C, m_enabled) - - ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) + )(W_p_C, W_ṗ_C, m) return W_f, {"m_dot": ṁ} From 93c16cbbff9da81d7ad45098e680f6dc659bfe4d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 12 Sep 2025 16:51:00 +0200 Subject: [PATCH 08/39] Update `ContactParameters` class and add collision shape map --- src/jaxsim/api/kin_dyn_parameters.py | 71 ++++++++++++++-------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index e111354c2..5f856782d 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -14,8 +14,15 @@ import jaxsim.typing as jtp from jaxsim.math import Inertia, JointModel, supported_joint_motion from jaxsim.math.adjoint import Adjoint -from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription -from jaxsim.utils import HashedNumpyArray, JaxsimDataclass +from jaxsim.parsers.descriptions import ( + JointDescription, + JointType, + ModelDescription, + SphereCollision, + BoxCollision, + CylinderCollision, +) +from jaxsim.utils import HashedNumpyArray, JaxsimDataclass, CollidableShapeType @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) @@ -762,6 +769,13 @@ def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix: return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float) +_COLLISION_SHAPE_MAP = { + SphereCollision: CollidableShapeType.Sphere, + BoxCollision: CollidableShapeType.Box, + CylinderCollision: CollidableShapeType.Cylinder, +} + + @jax_dataclasses.pytree_dataclass class ContactParameters(JaxsimDataclass): """ @@ -783,18 +797,9 @@ class ContactParameters(JaxsimDataclass): to be created with vmap. This is because the `body` attribute must be `Static`. """ - body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple) - - point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([])) - - enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple) - - @property - def indices_of_enabled_collidable_points(self) -> npt.NDArray: - """ - Return the indices of the enabled collidable points. - """ - return np.where(np.array(self.enabled))[0] + center: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + shape_size: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + shape_type: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) @staticmethod def build_from(model_description: ModelDescription) -> ContactParameters: @@ -811,33 +816,27 @@ def build_from(model_description: ModelDescription) -> ContactParameters: if len(model_description.collision_shapes) == 0: return ContactParameters() - # Get all the links so that we can take their updated index. - links_dict = {link.name: link for link in model_description} - - # Get all the enabled collidable points of the model. - collidable_points = model_description.all_enabled_collidable_points() - - # Extract the positions L_p_C of the collidable points w.r.t. the link frames - # they are rigidly attached to. - points = jnp.vstack([cp.position for cp in collidable_points]) - - # Extract the indices of the links to which the collidable points are rigidly - # attached to. - link_index_of_points = tuple( - links_dict[cp.parent_link.name].index for cp in collidable_points + # Assume the link_parameters and the collision_shapes are in the same order. + centers = jnp.array( + [shape.center for shape in model_description.collision_shapes] ) - # Build the ContactParameters object. - cp = ContactParameters( - point=points, - body=link_index_of_points, - enabled=tuple(True for _ in link_index_of_points), + shape_size = jnp.array( + [shape.size.squeeze() for shape in model_description.collision_shapes] ) - assert cp.point.shape[1] == 3, cp.point.shape[1] - assert cp.point.shape[0] == len(cp.body), cp.point.shape[0] + shape_type = [ + _COLLISION_SHAPE_MAP[type(shape)] + for shape in model_description.collision_shapes + ] + shape_type = jnp.array(shape_type, dtype=int) - return cp + # Build the ContactParameters object. + return ContactParameters( + center=centers, + shape_type=shape_type, + shape_size=shape_size, + ) @jax_dataclasses.pytree_dataclass From 592cfc8a0ff886d8707c361de75cdfb26486b92b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 15 Sep 2025 11:43:29 +0200 Subject: [PATCH 09/39] Fix contact force orientation --- src/jaxsim/rbda/contacts/detection.py | 118 ++++++++++++++------------ src/jaxsim/rbda/contacts/soft.py | 4 +- 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/src/jaxsim/rbda/contacts/detection.py b/src/jaxsim/rbda/contacts/detection.py index 22ddca84e..0444060a5 100644 --- a/src/jaxsim/rbda/contacts/detection.py +++ b/src/jaxsim/rbda/contacts/detection.py @@ -35,14 +35,22 @@ def sphere_plane(terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.M A tuple containing the distance from the sphere to the plane and the pose transform of the contact frame. """ + # Extract sphere center and radius. center = W_H_L[0:3, 3] - normal = terrain.normal(x=center[0], y=center[1]) - distance = jnp.dot(center - terrain._height, normal) - size[0] - position = normal * (size[0] + 0.5 * distance) - center - W_H_C = jaxsim.math.Transform.from_rotation_and_translation( - rotation=jaxsim.math.Rotation.from_axis_angle(normal), - translation=-position, - ) + radius = size[0] + + # Extract terrain properties at sphere center. + x, y = center[0], center[1] + + normal = terrain.normal(x=x, y=y) + height = terrain.height(x=x, y=y) + + distance = jnp.dot(center - height, normal) - radius + + position = center - radius * normal + + W_H_C = _contact_frame(normal, position) + return distance, W_H_C @@ -55,7 +63,7 @@ def box_plane_sdf(terrain, size, W_H_L): half_size = size.squeeze() / 2 R = W_H_L[:3, :3] - t = W_H_L[:3, 3] + center = W_H_L[:3, 3] # Generate all 8 corners using meshgrid sx = jnp.array([-half_size[0], half_size[0]]) @@ -71,7 +79,7 @@ def box_plane_sdf(terrain, size, W_H_L): R_corrected = R.at[:, 2].set(R[:, 2] * flip_sign) # flip z-axis if needed # Transform to world frame - corners_world = t + (R_corrected @ corners_local.T).T # shape (8,3) + corners_world = center + corners_local @ R_corrected.T # Vectorized terrain height and normal using vmap terrain_height_vmap = jax.vmap(lambda p: terrain.height(p[0], p[1])) @@ -151,56 +159,62 @@ def cylinder_plane( W_H_L: jtp.Matrix, ): """ - Detect contacts between a cylinder and a plane terrain. - Finds the actual contact point on the cylinder surface (vertex, edge, or face). - - Args: - terrain: The terrain object with _height(x, y) method and normal(x, y) method. - size: A 3D vector [width, height, depth] representing the cylinder dimensions from center. - W_H_L: The collision shape transform in world coordinates. - - Returns: - A tuple containing the distance from the cylinder to the plane, the contact point position - and the contact frame. + Compact cylinder-plane contact detection returning distance and SE(3) contact frame. """ - radius = size[0] - half_length = size[1] / 2.0 + size = size.squeeze() + radius, half_height = size[0], size[1] / 2.0 + # Cylinder center and axis center = W_H_L[0:3, 3] axis = W_H_L[0:3, 2] / jnp.linalg.norm(W_H_L[0:3, 2]) - x, y = center[0], center[1] - n = terrain.normal(x, y) - h = terrain.height(x, y) - p0 = jnp.array([x, y, h]) - - d0 = jnp.dot(n, center - p0) - proj = jnp.dot(n, axis) - side_term = radius * jnp.sqrt(jnp.maximum(0.0, 1.0 - proj**2)) - cap_term = half_length * jnp.abs(proj) - distance = d0 - cap_term - side_term - - # contact point - use_side = jnp.abs(proj) < 1.0 - 1e-6 - radial = n - proj * axis - radial /= jnp.linalg.norm(radial) + 1e-12 - side_pt = center + half_length * jnp.sign(proj) * axis + radius * radial - cap_pt = center + half_length * jnp.sign(proj) * axis - support = jnp.where(use_side, side_pt, cap_pt) - contact_point = support - n * distance - - # --- contact frame --- - z_axis = n / (jnp.linalg.norm(n) + 1e-12) - cand = jnp.where( - jnp.abs(jnp.dot(axis, z_axis)) < 0.9, axis, jnp.array([1.0, 0.0, 0.0]) + # Terrain properties at center + n = terrain.normal(center[0], center[1]) + + # Axis projection and perpendicular direction + axis_dot_n = jnp.dot(axis, n) + perp = jnp.cross(axis, n) + perp_norm = jnp.linalg.norm(perp) + perp = jnp.where(perp_norm > 1e-6, perp / perp_norm, W_H_L[0:3, 0]) + + # Three potential contact points + cap_offset = axis * half_height * jnp.sign(axis_dot_n) + edge_offset = perp * radius + + contacts = jnp.array( + [ + center + cap_offset, + center + edge_offset + axis * half_height, + center + edge_offset - axis * half_height, + ] ) - x_axis = cand - jnp.dot(cand, z_axis) * z_axis - x_axis = x_axis / (jnp.linalg.norm(x_axis) + 1e-12) - y_axis = jnp.cross(z_axis, x_axis) - R = jnp.stack([x_axis, y_axis, z_axis], axis=1) - W_H_C = jnp.vstack( - [jnp.hstack([R, contact_point[:, None]]), jnp.array([0.0, 0.0, 0.0, 1.0])], + # Vectorized terrain height computation + terrain_heights = jax.vmap(terrain.height)(contacts[:, 0], contacts[:, 1]) + terrain_points = jnp.column_stack([contacts[:, :2], terrain_heights]) + + # Compute distances + distances = jnp.sum((contacts - terrain_points) * n, axis=1) + + # Select best contact based on axis alignment and minimum distance + abs_dot = jnp.abs(axis_dot_n) + weights = jnp.where( + abs_dot > 0.8, # Face contact - use cap + jnp.array([1.0, 0.0, 0.0]), + jnp.where( + abs_dot < 0.3, # Edge contact - use minimum edge distance + jnp.array([0.0, 1.0, 1.0]), + jnp.array([1.0, 0.0, 0.0]), # Corner contact - use cap + ), ) + # Find minimum valid distance + valid_distances = jnp.where(weights > 0, distances, jnp.inf) + min_idx = jnp.argmin(valid_distances) + + distance = distances[min_idx] + contact_pos = contacts[min_idx] + + W_H_C = _contact_frame(n, contact_pos) + return distance, W_H_C diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 5fc1b82a0..0a370e7e6 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -239,10 +239,10 @@ def hunt_crossley_contact_model( δ = jnp.maximum(0.0, -distance) # Extract the contact velocity. - δ̇ = -W_ṗ_C[2] + δ̇ = -W_ṗ_C.dot(normal) # Compute the normal. - n̂ = normal / jaxsim.math.safe_norm(normal) + n̂ = normal # There are few operations like computing the norm of a vector with zero length # or computing the square root of zero that are problematic in an AD context. From c447ba6c59afee4155c971e203f0cc01b7a93959 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 16 Sep 2025 17:36:20 +0200 Subject: [PATCH 10/39] Update arguments of Hunt-Crossley model --- src/jaxsim/rbda/contacts/soft.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 0a370e7e6..9b492dd65 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -195,10 +195,11 @@ def update_velocity_after_impact( @staticmethod @jax.jit def hunt_crossley_contact_model( + penetration: jtp.VectorLike, + penetration_rate: jtp.VectorLike, velocity: jtp.VectorLike, - tangential_deformation: jtp.VectorLike, - distance: jtp.VectorLike, normal: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, K: jtp.FloatLike, D: jtp.FloatLike, mu: jtp.FloatLike, @@ -209,10 +210,11 @@ def hunt_crossley_contact_model( Compute the contact force using the Hunt/Crossley model. Args: - velocity: The velocity of the collidable shape. - tangential_deformation: The material deformation of the collidable shape. - distance: The distance from the collidable shape to the terrain. + penetration: The penetration of the collision point. + penetration_rate: The penetration rate of the collision point. + velocity: The velocity of the contact point. normal: The terrain normal at the contact point. + tangential_deformation: The material deformation of the collidable shape. K: The stiffness parameter. D: The damping parameter of the soft contacts model. mu: The static friction coefficient. @@ -228,6 +230,10 @@ def hunt_crossley_contact_model( material deformation. """ + δ = penetration + δ̇ = penetration_rate + n̂ = normal + # Convert the input vectors to arrays. W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() m = jnp.array(tangential_deformation, dtype=float).squeeze() @@ -235,15 +241,6 @@ def hunt_crossley_contact_model( # Use symbol for the static friction. μ = mu - # Ensure non-negative distance. - δ = jnp.maximum(0.0, -distance) - - # Extract the contact velocity. - δ̇ = -W_ṗ_C.dot(normal) - - # Compute the normal. - n̂ = normal - # There are few operations like computing the norm of a vector with zero length # or computing the square root of zero that are problematic in an AD context. # To avoid these issues, we introduce a small tolerance ε to their arguments From 3edf36c80cbcf36ab1e702bcad4b683d82e83c38 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 16 Sep 2025 17:37:28 +0200 Subject: [PATCH 11/39] Remove `collidable_shapes` module --- src/jaxsim/rbda/__init__.py | 1 - src/jaxsim/rbda/collidable_shapes.py | 58 ---------------------------- 2 files changed, 59 deletions(-) delete mode 100644 src/jaxsim/rbda/collidable_shapes.py diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 177aff6ee..022260e1b 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -1,6 +1,5 @@ from . import actuation, contacts from .aba import aba -from .collidable_shapes import collidable_shapes_pos_vel from .crba import crba from .forward_kinematics import forward_kinematics_model from .jacobian import ( diff --git a/src/jaxsim/rbda/collidable_shapes.py b/src/jaxsim/rbda/collidable_shapes.py deleted file mode 100644 index 1cb0b2a66..000000000 --- a/src/jaxsim/rbda/collidable_shapes.py +++ /dev/null @@ -1,58 +0,0 @@ -import jax -import jax.numpy as jnp - -import jaxsim.api as js -import jaxsim.typing as jtp - - -def collidable_shapes_pos_vel( - model: js.model.JaxSimModel, - *, - link_transforms: jtp.Matrix, - link_velocities: jtp.Matrix, -) -> tuple[jtp.Matrix, jtp.Matrix]: - """ - - Compute the position and linear velocity of the enabled collidable shapes in the world frame. - - Args: - model: The model to consider. - link_transforms: The transforms from the world frame to each link. - link_velocities: The linear and angular velocities of each link. - - Returns: - A tuple containing the position and linear velocity of the enabled collidable shapes. - """ - - # Get the indices of the enabled collidable shapes. - indices_of_enabled_collidable_shapes = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes - ) - - parent_link_idx_of_enabled_collidable_shapes = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes] - - L_p_Ci = model.kin_dyn_parameters.contact_parameters.shape[ - indices_of_enabled_collidable_shapes - ] - - if len(indices_of_enabled_collidable_shapes) == 0: - return jnp.array(0).astype(float), jnp.empty(0).astype(float) - - def process_shape_kinematics( - Li_p_C: jtp.Vector, parent_body: jtp.Int - ) -> tuple[jtp.Vector, jtp.Vector]: - - # Compute the position of the collidable shape. - W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3] - - return W_p_Ci - - # Process all the collidable shapes in parallel. - W_p_Ci = jax.vmap(process_shape_kinematics)( - L_p_Ci, - parent_link_idx_of_enabled_collidable_shapes, - ) - - return W_p_Ci, link_velocities[:, :3] From d650ca5b848d271d7d9c76600ebc7fd6bafffa3a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 16 Sep 2025 17:38:39 +0200 Subject: [PATCH 12/39] Update link force summation --- src/jaxsim/api/contact.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 95ebee444..5bdfcfb19 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -537,7 +537,7 @@ def link_contact_forces( """ # Compute the contact forces for each collidable point with the active contact model. - W_f_C, aux_dict = model.contact_model.compute_contact_forces( + W_f_L, aux_dict = model.contact_model.compute_contact_forces( model=model, data=data, **( @@ -549,7 +549,7 @@ def link_contact_forces( # Compute the 6D forces applied to the links equivalent to the forces applied # to the frames associated to the collidable points. - W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) + # W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) return W_f_L, aux_dict From 7edec4712f6a244b474032103779a17d686aa246 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 16 Sep 2025 17:45:10 +0200 Subject: [PATCH 13/39] Fix SDF functions for boxes and cylinders --- src/jaxsim/api/kin_dyn_parameters.py | 6 +- src/jaxsim/parsers/descriptions/__init__.py | 2 +- src/jaxsim/rbda/contacts/__init__.py | 2 +- src/jaxsim/rbda/contacts/detection.py | 196 ++++++++++---------- src/jaxsim/utils/__init__.py | 6 +- 5 files changed, 102 insertions(+), 110 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 5f856782d..4bdec21ea 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -15,14 +15,14 @@ from jaxsim.math import Inertia, JointModel, supported_joint_motion from jaxsim.math.adjoint import Adjoint from jaxsim.parsers.descriptions import ( + BoxCollision, + CylinderCollision, JointDescription, JointType, ModelDescription, SphereCollision, - BoxCollision, - CylinderCollision, ) -from jaxsim.utils import HashedNumpyArray, JaxsimDataclass, CollidableShapeType +from jaxsim.utils import CollidableShapeType, HashedNumpyArray, JaxsimDataclass @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index 0be19de64..6a08ae6f3 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -1,7 +1,7 @@ from .collision import ( BoxCollision, - SphereCollision, CylinderCollision, + SphereCollision, ) from .joint import JointDescription, JointGenericAxis, JointType from .link import LinkDescription diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 6500afbb8..af6e1d00e 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,5 +1,5 @@ from . import relaxed_rigid, rigid, soft -from .common import ContactModel, ContactsParams, CollidableShapeType +from .common import CollidableShapeType, ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams from .rigid import RigidContacts, RigidContactsParams from .soft import SoftContacts, SoftContactsParams diff --git a/src/jaxsim/rbda/contacts/detection.py b/src/jaxsim/rbda/contacts/detection.py index 0444060a5..ae6f938c6 100644 --- a/src/jaxsim/rbda/contacts/detection.py +++ b/src/jaxsim/rbda/contacts/detection.py @@ -1,10 +1,11 @@ +import jax +import jax.numpy as jnp + import jaxsim import jaxsim.typing as jtp -import jax.numpy as jnp -import jax -def _contact_frame(normal, position): +def _contact_frame(normal: jtp.Vector, position: jtp.Vector) -> jtp.Matrix: """Create a contact frame with z-axis aligned with the contact normal.""" n = normal / jaxsim.math.safe_norm(normal) @@ -22,7 +23,9 @@ def _contact_frame(normal, position): ) -def sphere_plane(terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix): +def sphere_plane( + terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix +) -> tuple[jtp.Float, jtp.Matrix]: """ Detect contacts between a sphere and a plane terrain. @@ -51,11 +54,16 @@ def sphere_plane(terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.M W_H_C = _contact_frame(normal, position) + # Pad distance and transform to match expected output shapes. + distance = jnp.pad(jnp.array([distance]), (0, 2), mode="empty") + W_H_C = jnp.pad(W_H_C[jnp.newaxis, ...], ((0, 2), (0, 0), (0, 0)), mode="empty") + return distance, W_H_C -# TODO (flferretti): Keep only the SDF version? -def box_plane_sdf(terrain, size, W_H_L): +def box_plane( + terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix +) -> tuple[jtp.Vector, jtp.Matrix]: """ Return distances and contact frames of the 3 deepest corners of a box on terrain using SDF. Fully vectorized, works for any box orientation. @@ -74,9 +82,9 @@ def box_plane_sdf(terrain, size, W_H_L): [xs.ravel(), ys.ravel(), zs.ravel()], axis=1 ) # shape (8,3) - box_z_world = R[:, 2] - flip_sign = jnp.sign(box_z_world) - R_corrected = R.at[:, 2].set(R[:, 2] * flip_sign) # flip z-axis if needed + # Project box z-axis on terrain normal and ensure direction away from plane + sign = jnp.sign(R[:, 2] + 1e-12) + R_corrected = R.at[:, 2].set(R[:, 2] * sign) # Transform to world frame corners_world = center + corners_local @ R_corrected.T @@ -109,112 +117,96 @@ def box_plane_sdf(terrain, size, W_H_L): return distances_top3, W_H_C -def box_plane( - terrain: jaxsim.terrain.Terrain, - size: jtp.Vector, - W_H_L: jtp.Matrix, -): +def cylinder_plane( + terrain: jaxsim.terrain.Terrain, size: jtp.Vector, W_H_L: jtp.Matrix +) -> tuple[jtp.Vector, jtp.Matrix]: """ - Detect contacts between a box and a plane terrain. - Finds the actual contact point on the box surface (vertex, edge, or face). + Return distances and contact frames of the 3 deepest points of a cylinder on terrain. Args: - terrain: The terrain object with _height(x, y) method and normal(x, y) method. - size: A 3D vector [width, height, depth] representing the box dimensions from center. + terrain: The terrain object. + size: The size of the cylinder (radius, height). W_H_L: The collision shape transform in world coordinates. Returns: - A tuple containing the distance from the box to the plane and the pose transform - of the contact frame. + A tuple containing the distances from the cylinder to the plane and the pose transforms + of the contact frames. """ - half_size = size.squeeze() / 2 - center = W_H_L[:3, 3] - R = W_H_L[:3, :3] - - # Transform terrain normal at box center into world coordinates - normal = terrain.normal(center[0], center[1]) - # Find the box vertex furthest in the opposite direction of terrain normal - local_normal = R.T @ normal - support_local = -half_size * jnp.sign(local_normal) - - # Vertex in world coordinates - support_world = center + R @ support_local - - # Terrain point and distance - terrain_z = terrain.height(support_world[0], support_world[1]) - terrain_point = jnp.array([support_world[0], support_world[1], terrain_z]) - distance = jnp.dot(support_world - terrain_point, normal) - - # Contact frame - contact_point = support_world - distance * normal - W_H_C = _contact_frame(normal, contact_point) - - return distance, W_H_C - - -def cylinder_plane( - terrain: jaxsim.terrain.Terrain, - size: jtp.Vector, - W_H_L: jtp.Matrix, -): - """ - Compact cylinder-plane contact detection returning distance and SE(3) contact frame. - """ size = size.squeeze() - radius, half_height = size[0], size[1] / 2.0 + r, half_h = size[0], size[1] * 0.5 - # Cylinder center and axis - center = W_H_L[0:3, 3] - axis = W_H_L[0:3, 2] / jnp.linalg.norm(W_H_L[0:3, 2]) - - # Terrain properties at center - n = terrain.normal(center[0], center[1]) - - # Axis projection and perpendicular direction - axis_dot_n = jnp.dot(axis, n) - perp = jnp.cross(axis, n) - perp_norm = jnp.linalg.norm(perp) - perp = jnp.where(perp_norm > 1e-6, perp / perp_norm, W_H_L[0:3, 0]) - - # Three potential contact points - cap_offset = axis * half_height * jnp.sign(axis_dot_n) - edge_offset = perp * radius - - contacts = jnp.array( - [ - center + cap_offset, - center + edge_offset + axis * half_height, - center + edge_offset - axis * half_height, - ] + # Cylinder pose + position = W_H_L[:3, 3] + R = W_H_L[:3, :3] + axis = R[:, 2] + + # Terrain data at cylinder XY + h = terrain.height(position[0], position[1]) + n = terrain.normal(position[0], position[1]) + plane_position = jnp.array([position[0], position[1], h]) + + # Project axis on normal and ensure direction away from plane + prjaxis = jnp.dot(n, axis) + sign = -jnp.sign(prjaxis + 1e-12) + axis, prjaxis = axis * sign, prjaxis * sign + + # Distance from cylinder centre to plane along normal + dist0 = jnp.dot(position - plane_position, n) + + # Remove component along normal from axis + vec = axis * prjaxis - n + len_vec = jnp.linalg.norm(vec) + vec = jnp.where( + len_vec < 1e-12, + R[:, 0] * r, # disk parallel to plane + vec / len_vec * r, # general case ) - # Vectorized terrain height computation - terrain_heights = jax.vmap(terrain.height)(contacts[:, 0], contacts[:, 1]) - terrain_points = jnp.column_stack([contacts[:, :2], terrain_heights]) - - # Compute distances - distances = jnp.sum((contacts - terrain_points) * n, axis=1) - - # Select best contact based on axis alignment and minimum distance - abs_dot = jnp.abs(axis_dot_n) - weights = jnp.where( - abs_dot > 0.8, # Face contact - use cap - jnp.array([1.0, 0.0, 0.0]), - jnp.where( - abs_dot < 0.3, # Edge contact - use minimum edge distance - jnp.array([0.0, 1.0, 1.0]), - jnp.array([1.0, 0.0, 0.0]), # Corner contact - use cap - ), + # Project vec along normal + prjvec = jnp.dot(vec, n) + + # Scale axis by half length + ax_scaled = axis * half_h + prjaxis_h = prjaxis * half_h + + # Sideways vector for 3-point support + prjvec1 = -0.5 * prjvec + vec1 = jnp.cross(vec, ax_scaled) + vec1 = vec1 / (jnp.linalg.norm(vec1) + 1e-12) * r * jnp.sqrt(3.0) * 0.5 + + # Distances of three candidate contacts: + # d1 = top + # d2 = side + top + # d3 = side - top + d1 = dist0 + prjaxis_h + prjvec + d2 = dist0 + prjaxis_h + prjvec1 + dist = jnp.array([d1, d2, d2]) + + # World position of candidates + position_c = ( + position + + ax_scaled + + jnp.array( + [ + vec - n * d1 * 0.5, + vec1 + vec * -0.5 - n * d2 * 0.5, + -vec1 + vec * -0.5 - n * d2 * 0.5, + ] + ) ) - # Find minimum valid distance - valid_distances = jnp.where(weights > 0, distances, jnp.inf) - min_idx = jnp.argmin(valid_distances) - - distance = distances[min_idx] - contact_pos = contacts[min_idx] + # Handle case in which the cylinder lies on the disks + condition = jnp.abs(prjaxis) < 1e-3 + d3 = dist0 - prjaxis_h + prjvec + dist = jnp.where(condition, dist.at[1].set(d3), dist) + position_c = jnp.where( + condition, + position_c.at[1].set(position + vec - ax_scaled - n * d3 * 0.5), + position_c, + ) - W_H_C = _contact_frame(n, contact_pos) + # Build contact frames on the three candidate points + W_H_C = jax.vmap(lambda p: _contact_frame(n, p))(position_c) - return distance, W_H_C + return dist, W_H_C diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index b0287d409..647c64747 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -1,13 +1,13 @@ +from typing import ClassVar + from jax_dataclasses._copy_and_mutate import _Mutability as Mutability from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing from .wrappers import HashedNumpyArray, HashlessObject -from typing import ClassVar - -# TODO (flferretti): Definetely not the best place for this +# TODO (flferretti): Definitely not the best place for this class CollidableShapeType: """ Enum representing the types of collidable shapes. From da43c63b6d2ee1a5ffe085130b7ae2d567293cb2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 12:19:55 +0200 Subject: [PATCH 14/39] Update `compute_penetration_data` return variables --- src/jaxsim/rbda/contacts/common.py | 70 ++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 1f9452f56..b9d49273a 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -1,23 +1,21 @@ from __future__ import annotations import abc -import functools import jax import jax.numpy as jnp import jaxsim.api as js -import jaxsim.terrain import jaxsim.typing as jtp -from jaxsim.math import STANDARD_GRAVITY -from jaxsim.utils import JaxsimDataclass, CollidableShapeType +from jaxsim.math import STANDARD_GRAVITY, Skew +from jaxsim.utils import CollidableShapeType, JaxsimDataclass try: from typing import Self except ImportError: from typing_extensions import Self -from .detection import sphere_plane, box_plane, cylinder_plane +from .detection import box_plane, cylinder_plane, sphere_plane MAX_STIFFNESS = 1e6 MAX_DAMPING = 1e4 @@ -30,39 +28,63 @@ } -@functools.partial(jax.jit, static_argnames=("terrain",)) +@jax.jit def compute_penetration_data( - p: jtp.VectorLike, - v: jtp.VectorLike, - terrain: jaxsim.terrain.Terrain, - contact_parameters: js.kin_dyn_parameters.ContactParameters | None = None, + model: js.model.JaxSimModel, + *, + shape_type: CollidableShapeType, + shape_size: jtp.Vector, + link_transforms: jtp.Matrix, + link_velocities: jtp.Matrix, ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]: """ Compute the penetration data (depth, rate, and terrain normal) of a collidable point. Args: - p: The position of the collidable point. - v: - The linear velocity of the point (linear component of the mixed 6D velocity - of the implicit frame `C = (W_p_C, [W])` associated to the point). - terrain: The considered terrain. - contact_parameters: The parameters of the collidable shapes. + model: The model to consider. + shape_type: The type of the collidable shape. + shape_size: The size parameters of the collidable shape. + link_transforms: The transforms from the world frame to each link. + link_velocities: The linear and angular velocities of each link. Returns: A tuple containing the penetration depth, the penetration velocity, - and the considered terrain normal. + the terrain normal, the contact point position, and the contact point velocity + expressed in mixed representation. """ - # Pre-process the position and the linear velocity of the collidable point. - distance_fn = _COLLISION_MAP[contact_parameters.shape_type] + W_H_L, W_ṗ_L = link_transforms, link_velocities - δ, W_H_C = distance_fn( - terrain=terrain, - size=contact_parameters.shape_size, - center=contact_parameters.center, + # Pre-process the position and the linear velocity of the collidable point. + # Note that we consider 3 candidate contact points also for spherical shapes, + # in which the output is padded with zeros. + # This is to allow parallel evaluation of the collision types. + δ, W_H_C = jax.lax.switch( + shape_type, + (sphere_plane, box_plane, cylinder_plane), + model.terrain, + shape_size, + W_H_L, ) - return δ, W_H_C + W_p_C = W_H_C[:, :3, 3] + n̂ = W_H_C[:, :3, 2] + + def process_shape_kinematics(W_p_Ci: jtp.Vector) -> jtp.Vector: + + # Compute the velocity of the contact points. + CW_ṗ_Ci = jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) @ W_ṗ_L + + return CW_ṗ_Ci + + CW_ṗ_C = jax.vmap(process_shape_kinematics)(W_p_C) + + δ = jnp.maximum(0.0, -δ) + + δ̇ = -jax.vmap(jnp.dot)(CW_ṗ_C, n̂) + δ̇ = jnp.where(δ > 0, δ̇, 0.0) + + return δ, δ̇, n̂, W_p_C, CW_ṗ_C class ContactsParams(JaxsimDataclass): From 2d822a2e8d021145577be38e9df8cf65dd468dc5 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 12:23:40 +0200 Subject: [PATCH 15/39] Update logic for enabling contact forces computation --- src/jaxsim/api/ode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index c4ea9a56d..25f14dd20 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -54,7 +54,7 @@ def system_acceleration( W_f_L_terrain = jnp.zeros_like(f_L) contact_state_derivative = {} - if len(model.kin_dyn_parameters.contact_parameters.body) > 0: + if len(model.kin_dyn_parameters.contact_parameters.center) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. From 6fefe15ad1e48458c85bf139be3c82520ecf73b2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 12:27:30 +0200 Subject: [PATCH 16/39] Polish SDF implementations --- src/jaxsim/rbda/contacts/detection.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/rbda/contacts/detection.py b/src/jaxsim/rbda/contacts/detection.py index ae6f938c6..0c6b2180b 100644 --- a/src/jaxsim/rbda/contacts/detection.py +++ b/src/jaxsim/rbda/contacts/detection.py @@ -55,6 +55,7 @@ def sphere_plane( W_H_C = _contact_frame(normal, position) # Pad distance and transform to match expected output shapes. + # and allow parallel evaluation of the collision types. distance = jnp.pad(jnp.array([distance]), (0, 2), mode="empty") W_H_C = jnp.pad(W_H_C[jnp.newaxis, ...], ((0, 2), (0, 0), (0, 0)), mode="empty") @@ -83,7 +84,7 @@ def box_plane( ) # shape (8,3) # Project box z-axis on terrain normal and ensure direction away from plane - sign = jnp.sign(R[:, 2] + 1e-12) + sign = jnp.sign(R[:, 2]) R_corrected = R.at[:, 2].set(R[:, 2] * sign) # Transform to world frame @@ -176,9 +177,6 @@ def cylinder_plane( vec1 = vec1 / (jnp.linalg.norm(vec1) + 1e-12) * r * jnp.sqrt(3.0) * 0.5 # Distances of three candidate contacts: - # d1 = top - # d2 = side + top - # d3 = side - top d1 = dist0 + prjaxis_h + prjvec d2 = dist0 + prjaxis_h + prjvec1 dist = jnp.array([d1, d2, d2]) @@ -190,8 +188,8 @@ def cylinder_plane( + jnp.array( [ vec - n * d1 * 0.5, - vec1 + vec * -0.5 - n * d2 * 0.5, - -vec1 + vec * -0.5 - n * d2 * 0.5, + vec1 + vec * 0.5 + n * d2 * 0.5, + -vec1 + vec * 0.5 + n * d2 * 0.5, ] ) ) From 444d9157a53fc18f33034e4363604b607e0c14ba Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 12:28:29 +0200 Subject: [PATCH 17/39] Vectorize Hunt-Crossley model over shape types --- src/jaxsim/rbda/contacts/soft.py | 102 ++++++++++++++++--------------- 1 file changed, 54 insertions(+), 48 deletions(-) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 9b492dd65..db182f486 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import functools import jax import jax.numpy as jnp @@ -11,7 +10,6 @@ import jaxsim.math import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.terrain import Terrain from . import common @@ -230,15 +228,12 @@ def hunt_crossley_contact_model( material deformation. """ + # Use symbols for input parameters. + W_ṗ_C = velocity + m = tangential_deformation δ = penetration δ̇ = penetration_rate n̂ = normal - - # Convert the input vectors to arrays. - W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() - m = jnp.array(tangential_deformation, dtype=float).squeeze() - - # Use symbol for the static friction. μ = mu # There are few operations like computing the norm of a vector with zero length @@ -341,53 +336,63 @@ def hunt_crossley_contact_model( return CW_fl, ṁ @staticmethod - @functools.partial(jax.jit, static_argnames=("terrain",)) + @jax.jit def compute_contact_force( - position: jtp.VectorLike, - velocity: jtp.VectorLike, - tangential_deformation: jtp.VectorLike, + penetration: jtp.Float, + penetration_rate: jtp.Float, + position: jtp.Vector, + velocity: jtp.Vector, + normal: jtp.Vector, + tangential_deformation: jtp.Vector, parameters: SoftContactsParams, - terrain: Terrain, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the contact force. Args: - position: The position of the collidable shape. - velocity: The velocity of the collidable shape. + penetration: The penetration of the collision point. + penetration_rate: The penetration rate of the collision point. + position: The position of the contact point. + velocity: The velocity of the contact point. + normal: The terrain normal at the contact point. tangential_deformation: The material deformation of the collidable shape. parameters: The parameters of the soft contacts model. - terrain: The terrain model. Returns: A tuple containing the computed contact force and the derivative of the material deformation. """ - CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( - position=position, - velocity=velocity, - tangential_deformation=tangential_deformation, - terrain=terrain, - K=parameters.K, - D=parameters.D, - mu=parameters.mu, - p=parameters.p, - q=parameters.q, + CW_fl, ṁ = jax.vmap( + SoftContacts.hunt_crossley_contact_model, + in_axes=(0, 0, 0, 0, None, None, None, None, None, None), + )( + penetration, + penetration_rate, + velocity, + normal, + tangential_deformation, + parameters.K, + parameters.D, + parameters.mu, + parameters.p, + parameters.q, ) # Pack a mixed 6D force. - CW_f = jnp.hstack([CW_fl, jnp.zeros(3)]) + CW_f = jax.vmap(lambda f: jnp.hstack([f, jnp.zeros(3)]))(f=CW_fl) # Compute the 6D force transform from the mixed to the inertial-fixed frame. - W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation( - translation=jnp.array(position), inverse=True - ).T + W_Xf_CW = jax.vmap( + lambda W_p_C: jaxsim.math.Adjoint.from_quaternion_and_translation( + translation=jnp.array(W_p_C), inverse=True + ).T + )(W_p_C=position) # Compute the 6D force in the inertial-fixed frame. - W_f = W_Xf_CW @ CW_f + W_f = jnp.einsum("...ij,...j->...i", W_Xf_CW, CW_f) - return W_f, ṁ + return jnp.sum(W_f, axis=0), jnp.mean(ṁ, axis=0) @staticmethod @jax.jit @@ -410,29 +415,30 @@ def compute_contact_forces( # Compute the position and linear velocities (mixed representation) of # all the collidable shapes belonging to the robot and extract the ones # for the enabled collidable shapes. - W_p_C, W_ṗ_C = js.contact.collidable_shape_kinematics(model=model, data=data) + δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( + common.compute_penetration_data, in_axes=(None,) + )( + model, + shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, + shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, + link_transforms=data._link_transforms, + link_velocities=data._link_velocities, + ) # Extract the material deformation corresponding to the collidable shapes. - m = ( - data.contact_state["tangential_deformation"] - if "tangential_deformation" in data.contact_state - else jnp.zeros_like(W_p_C) - ) + m = data.contact_state["tangential_deformation"] # Initialize the tangential deformation rate array for every collidable shape. ṁ = jnp.zeros_like(m) - # Compute the contact forces only for the enabled collidable shapes. + # Compute the contact forces for all the collidable shapes. # Since we treat them as independent, we can vmap the computation. + # We exploit two levels of vmap to vectorize over both the shapes and the points. + # The outer vmap vectorizes over the shapes, while the inner vmap vectorizes + # over the maximum points (3) belonging to each shape. W_f, ṁ = jax.vmap( - lambda p, v, m: SoftContacts.compute_contact_force( - center_position=p, - center_velocity=v, - tangential_deformation=m, - parameters=model.contact_params, - terrain=model.terrain, - size=model.kin_dyn_parameters.contact_parameters.shape_size, - ) - )(W_p_C, W_ṗ_C, m) + SoftContacts.compute_contact_force, + in_axes=(0, 0, 0, 0, 0, 0, None), # vectorize over shapes + )(δ, δ̇, W_p_C, CW_ṗ_C, n̂, m, model.contact_params) return W_f, {"m_dot": ṁ} From 26b3c8864bea165d78a163e7cf808261c2a42bdd Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 12:54:12 +0200 Subject: [PATCH 18/39] Match `CollidableShapeType` enum with `LinkParametrizableShape` --- src/jaxsim/api/kin_dyn_parameters.py | 2 +- src/jaxsim/rbda/contacts/common.py | 2 +- src/jaxsim/utils/__init__.py | 8 +++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 4bdec21ea..41059bfd4 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -826,7 +826,7 @@ def build_from(model_description: ModelDescription) -> ContactParameters: ) shape_type = [ - _COLLISION_SHAPE_MAP[type(shape)] + _COLLISION_SHAPE_MAP.get(type(shape), CollidableShapeType.Unsupported) for shape in model_description.collision_shapes ] shape_type = jnp.array(shape_type, dtype=int) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index b9d49273a..9ba880274 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -61,7 +61,7 @@ def compute_penetration_data( # This is to allow parallel evaluation of the collision types. δ, W_H_C = jax.lax.switch( shape_type, - (sphere_plane, box_plane, cylinder_plane), + (box_plane, cylinder_plane, sphere_plane), model.terrain, shape_size, W_H_L, diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index 647c64747..67c7bcb98 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -1,3 +1,4 @@ +import dataclasses from typing import ClassVar from jax_dataclasses._copy_and_mutate import _Mutability as Mutability @@ -8,12 +9,13 @@ # TODO (flferretti): Definitely not the best place for this +@dataclasses.dataclass(frozen=True) class CollidableShapeType: """ Enum representing the types of collidable shapes. """ - Sphere: ClassVar[int] = 0 - Box: ClassVar[int] = 1 - Cylinder: ClassVar[int] = 2 Unsupported: ClassVar[int] = -1 + Box: ClassVar[int] = 0 + Cylinder: ClassVar[int] = 1 + Sphere: ClassVar[int] = 2 From cec6d50cd2da6aa95f176f028c9e0a6cc19038fe Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 17:30:51 +0200 Subject: [PATCH 19/39] Update `contact.transforms` for collidable shapes --- src/jaxsim/api/contact.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 5bdfcfb19..82de74549 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -10,7 +10,7 @@ import jaxsim.typing as jtp from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform -from jaxsim.rbda.contacts import SoftContacts +from jaxsim.rbda.contacts import SoftContacts, detection from .common import VelRepr @@ -225,34 +225,33 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt The stacked SE(3) matrices of all enabled collidable points. Note: + The output shape is (nL, 3, 4, 4), where nL is the number of links. + Three candidate contact points are considered for each collidable shape. Each collidable point is implicitly associated with a frame :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the collidable point and :math:`[L]` is the orientation frame of the link it is rigidly attached to. """ - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_shapes = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes - ) - - parent_link_idx_of_enabled_collidable_shapes = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes] - # Get the transforms of the parent link of all collidable points. - W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_shapes] - - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_shapes - ] + W_H_L = data._link_transforms + + def _process_single_shape(shape_type, shape_size, W_H_Li): + _, W_H_C = jax.lax.switch( + shape_type, + (detection.box_plane, detection.cylinder_plane, detection.sphere_plane), + model.terrain, + shape_size, + W_H_Li, + ) - # Build the link-to-point transform from the displacement between the link frame L - # and the implicit contact frame C. - L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci) + return W_H_C - # Compose the work-to-link and link-to-point transforms. - return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) + return jax.vmap(_process_single_shape)( + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + W_H_L, + ) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) From 54cbb67133a59cac4ef221771688a6ae11afa025 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 17:31:21 +0200 Subject: [PATCH 20/39] Update `contact.jacobians` for collidable shapes --- src/jaxsim/api/contact.py | 65 ++++++++++++--------------------------- 1 file changed, 20 insertions(+), 45 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 82de74549..492638924 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -282,69 +282,44 @@ def jacobian( rigidly attached to. """ - output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation - ) - - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_shapes = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes - ) - - parent_link_idx_of_enabled_collidable_shapes = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes] + output_vel_repr = output_vel_repr or data.velocity_representation - # Compute the Jacobians of all links. + # Compute link-level Jacobians (n_links, 6, 6+n) W_J_WL = js.model.generalized_free_floating_jacobian( model=model, data=data, output_vel_repr=VelRepr.Inertial ) - # Compute the contact Jacobian. - # In inertial-fixed output representation, the Jacobian of the parent link is also - # the Jacobian of the frame C implicitly associated with the collidable point. - W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_shapes] + # Compute contact transforms (n_links, n_contacts, 4, 4) + W_H_C = transforms(model=model, data=data) - # Adjust the output representation. + # Flatten link × contact axes for single-batch processing (n_links*n_contacts, 6, 6+n) + W_J_WC_flat = jnp.repeat(W_J_WL, 3, axis=0) + + # Flatten contact transforms (n_links*n_contacts, 4, 4) + W_H_C_flat = W_H_C.reshape(-1, 4, 4) + + # Transform Jacobian based on velocity representation match output_vel_repr: case VelRepr.Inertial: - O_J_WC = W_J_WC + return W_J_WC_flat case VelRepr.Body: - W_H_C = transforms(model=model, data=data) - - def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - C_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_C, inverse=True - ) - C_J_WC = C_X_W @ W_J_WC - return C_J_WC - - O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) + def transform_jacobian(H_C, J_WC): + return jaxsim.math.Adjoint.from_transform(H_C, inverse=True) @ J_WC case VelRepr.Mixed: - W_H_C = transforms(model=model, data=data) - - def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - - CW_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_CW, inverse=True - ) - - CW_J_WC = CW_X_W @ W_J_WC - return CW_J_WC - - O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC) + def transform_jacobian(H_C, J_WC): + H_CW = H_C.at[0:3, 0:3].set(jnp.eye(3)) + return jaxsim.math.Adjoint.from_transform(H_CW, inverse=True) @ J_WC case _: - raise ValueError(output_vel_repr) + raise ValueError(f"Unsupported velocity representation: {output_vel_repr}") - return O_J_WC + # Single vmap over all contact points + return jax.vmap(transform_jacobian)(W_H_C_flat, W_J_WC_flat) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) From bad1f693848400c8a5c399e6542f35bb244f1206 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 17:31:39 +0200 Subject: [PATCH 21/39] Update `contact.jacobian_derivatives` for collidable shapes --- src/jaxsim/api/contact.py | 122 ++++++++++++-------------------------- 1 file changed, 39 insertions(+), 83 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 492638924..eef4be8a8 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -347,39 +347,24 @@ def jacobian_derivative( velocity representation. """ - output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation - ) - - indices_of_enabled_collidable_shapes = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes - ) - - # Get the index of the parent link and the position of the collidable point. - parent_link_idx_of_enabled_collidable_shapes = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes] - - L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ - indices_of_enabled_collidable_shapes - ] - - # Get the transforms of all the parent links. - W_H_Li = data._link_transforms + output_vel_repr = output_vel_repr or data.velocity_representation # Get the link velocities. - W_v_WLi = data._link_velocities + W_v_WL = data._link_velocities + + # Compute the contact transforms (n_links, n_contacts, 4, 4) + W_H_C = transforms(model=model, data=data) # ===================================================== # Compute quantities to adjust the input representation # ===================================================== - def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix: + def compute_T(X: jtp.Matrix) -> jtp.Matrix: In = jnp.eye(model.dofs()) T = jax.scipy.linalg.block_diag(X, In) return T - def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: + def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix: On = jnp.zeros(shape=(model.dofs(), model.dofs())) Ṫ = jax.scipy.linalg.block_diag(Ẋ, On) return Ṫ @@ -388,38 +373,23 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # time derivative. match data.velocity_representation: case VelRepr.Inertial: - W_H_W = jnp.eye(4) - W_X_W = Adjoint.from_transform(transform=W_H_W) - W_Ẋ_W = jnp.zeros((6, 6)) - - T = compute_T(model=model, X=W_X_W) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) - + W_X = Adjoint.from_transform(jnp.eye(4)) + W_Ẋ = jnp.zeros((6, 6)) case VelRepr.Body: - W_H_B = data._base_transform - W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity - B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB - - T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) - + W_X = Adjoint.from_transform(data.base_transform) + W_Ẋ = W_X @ Cross.vx(data.base_velocity) case VelRepr.Mixed: - W_H_B = data._base_transform - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_vx_W_BW = Cross.vx(BW_v_W_BW) - W_Ẋ_BW = W_X_BW @ BW_vx_W_BW - - T = compute_T(model=model, X=W_X_BW) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) - + H_BW = data.base_transform.at[0:3, 0:3].set(jnp.eye(3)) + X_BW = Adjoint.from_transform(H_BW) + v_BW = data.base_velocity.at[3:6].set(0) + W_X = X_BW + W_Ẋ = X_BW @ Cross.vx(v_BW) case _: raise ValueError(data.velocity_representation) + T = compute_T(W_X) + Ṫ = compute_Ṫ(W_Ẋ) + # ===================================================== # Compute quantities to adjust the output representation # ===================================================== @@ -436,51 +406,37 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: data=data, ) - def compute_O_J̇_WC_I( - L_p_C: jtp.Vector, - parent_link_idx: jtp.Int, - W_H_L: jtp.Matrix, - ) -> jtp.Matrix: - + def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: match output_vel_repr: case VelRepr.Inertial: - O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841 - transform=jnp.eye(4) - ) - O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841 - + O_X_W = jnp.eye(6) + O_Ẋ_W = jnp.zeros((6, 6)) case VelRepr.Body: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C - O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) - W_v_WC = W_v_WLi[parent_link_idx] - W_vx_WC = Cross.vx(W_v_WC) - O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841 - + O_X_W = Adjoint.from_transform(W_H_C, inverse=True) + O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WL) case VelRepr.Mixed: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - CW_H_W = Transform.inverse(W_H_CW) - O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W) - CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx] - W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) - W_vx_W_CW = Cross.vx(W_v_W_CW) - O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841 - + O_X_W = Adjoint.from_transform(Transform.inverse(W_H_CW)) + v_CW = O_X_W @ W_v_WL + O_Ẋ_W = -O_X_W @ Cross.vx(v_CW.at[:3].set(v_CW[:3])) case _: raise ValueError(output_vel_repr) - O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs())) - O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T - O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T - O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ + O_J̇_WC_I = O_Ẋ_W @ W_J_WL_W @ T + O_J̇_WC_I += O_X_W @ W_J̇_WL_W @ T + O_J̇_WC_I += O_X_W @ W_J_WL_W @ Ṫ return O_J̇_WC_I - O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))( - L_p_Ci, parent_link_idx_of_enabled_collidable_shapes, W_H_Li - ) + O_J̇_per_link = jax.vmap( + lambda H_C_link, v_WL_link, J_WL_link, J̇_WL_link: jax.vmap( + compute_O_J̇_WC_I, + in_axes=(0, None, None, None), # Map over contacts for H_C only + )(H_C_link, v_WL_link, J_WL_link, J̇_WL_link), + in_axes=(0, 0, 0, 0), # Map over links + )(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) + + O_J̇_WC = O_J̇_per_link.reshape(-1, 6, 6 + model.dofs()) return O_J̇_WC From 43524fe1fe94cc456ab2fdd8f063119d1ba1665a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 17:32:16 +0200 Subject: [PATCH 22/39] Update `RelaxedRigid` contact model for collidable shapes --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 122 +++++++++++----------- 1 file changed, 60 insertions(+), 62 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 903473b04..1f4ca4e80 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import functools from collections.abc import Callable from typing import Any @@ -14,7 +15,7 @@ import jaxsim.typing as jtp from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr -from . import common, soft +from . import common, detection, soft try: from typing import Self @@ -325,14 +326,16 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) - # Compute the position and linear velocities (mixed representation) of - # all collidable points belonging to the robot. - position, velocity = data._link_transforms[:3, 3], data._link_velocities[:3] - # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. - δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, model.terrain, data.contact_parameters + δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( + common.compute_penetration_data, in_axes=(None,) + )( + model, + shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, + shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, + link_transforms=data._link_transforms, + link_velocities=data._link_velocities, ) # Compute the position in the constraint frame. @@ -342,13 +345,16 @@ def compute_contact_forces( a_ref, r, *_ = self._regularizers( model=model, position_constraint=position_constraint, - velocity_constraint=velocity, + velocity_constraint=CW_ṗ_C, parameters=model.contact_params, ) # Compute the transforms of the implicit frames corresponding to the # collidable points. - W_H_C = js.contact.transforms(model=model, data=data) + # The final shape will be (n_links, 3 (max_contact_points), 4, 4). + W_H_C = jax.vmap( + lambda n, p: jax.vmap(detection._contact_frame)(n, p), + )(n̂, W_p_C) with ( data.switch_velocity_representation(VelRepr.Mixed), @@ -370,15 +376,17 @@ def compute_contact_forces( # Compute the linear part of the Jacobian of the collidable points Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( - js.contact.jacobian(model=model, data=data)[:, :3, :], δ + js.contact.jacobian(model=model, data=data)[:, :3], + jnp.concatenate(δ), ) ) # Compute the linear part of the Jacobian derivative of the collidable points J̇l_WC = jnp.vstack( jax.vmap(lambda J̇, δ: J̇ * (δ > 0))( - js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ - ), + js.contact.jacobian_derivative(model=model, data=data)[:, :3], + jnp.concatenate(δ), + ) ) # Compute the Delassus matrix for contacts (mixed representation). @@ -466,20 +474,20 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: # ====================================== # Initialize the optimized forces with a linear Hunt/Crossley model. + hunt_crossley_closure = functools.partial( + soft.SoftContacts.hunt_crossley_contact_model, + K=1e6, + D=2e3, + p=0.5, + q=0.5, + mu=0.0, + tangential_deformation=jnp.zeros(3), + ) + init_params = jax.vmap( - lambda p, v: soft.SoftContacts.hunt_crossley_contact_model( - position=p, - velocity=v, - terrain=model.terrain, - K=1e6, - D=2e3, - p=0.5, - q=0.5, - # No tangential initial forces. - mu=0.0, - tangential_deformation=jnp.zeros(3), - )[0] - )(position, velocity).flatten() + jax.vmap(hunt_crossley_closure, in_axes=(0, 0, 0, 0)), # map over contacts + in_axes=(0, 0, 0, 0), # map over links + )(δ, δ̇, CW_ṗ_C, n̂)[0].flatten() # Get the solver options. solver_options = self.solver_options @@ -507,21 +515,26 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: ) # Reshape the optimized solution to be a matrix of 3D contact forces. - CW_fl_C = solution.reshape(-1, 3) - - # Convert the contact forces from mixed to inertial-fixed representation. - W_f_C = jax.vmap( - lambda CW_fl_C, W_H_C: ( - ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=jnp.zeros(6).at[0:3].set(CW_fl_C), - transform=W_H_C, - other_representation=VelRepr.Mixed, - is_force=True, - ) - ), - )(CW_fl_C, W_H_C) + CW_fl_per_link = solution.reshape(-1, 3, 3) + + # Transform each contact force to inertial frame + def to_inertial(force, H_C): + return ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(force), + transform=H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) - return W_f_C, {} + # Compute the contact forces in inertial representation for + # each link and contact point. + # Nested vmap: inner over contacts, outer over links + W_f_C = jax.vmap(lambda f_link, H_link: jax.vmap(to_inertial)(f_link, H_link))( + CW_fl_per_link, W_H_C + ) + + # Sum over contacts for each link + return W_f_C.sum(axis=1), {} @staticmethod def _regularizers( @@ -561,17 +574,8 @@ def _regularizers( ) ) - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - parent_link_idx_of_enabled_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - # Compute the 6D inertia matrices of all links. - M_L = js.model.link_spatial_inertia_matrices(model=model) + M_L = js.model.link_spatial_inertia_matrices(model=model)[:, :3, :3] def imp_aref( pos: jtp.Vector, @@ -619,7 +623,7 @@ def imp_aref( def compute_row( *, - link_idx: jtp.Int, + M_Li: jtp.Matrix, pos: jtp.Vector, vel: jtp.Vector, ) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]: @@ -628,11 +632,7 @@ def compute_row( ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel) # Compute the regularization term. - R = ( - (2 * μ**2 * (1 - ξ) / (ξ + 1e-12)) - * (1 + μ**2) - @ jnp.linalg.inv(M_L[link_idx, :3, :3]) - ) + R = (2 * μ**2 * (1 - ξ) / (ξ + 1e-12)) * (1 + μ**2) @ jnp.linalg.inv(M_Li) # Return the computed values, setting them to zero in case of no contact. is_active = (pos.dot(pos) > 0).astype(float) @@ -641,13 +641,11 @@ def compute_row( ) a_ref, R, K, D = jax.tree.map( - f=jnp.concatenate, - tree=( - *jax.vmap(compute_row)( - link_idx=parent_link_idx_of_enabled_collidable_points, - pos=position_constraint, - vel=velocity_constraint, - ), + jnp.ravel, + jax.vmap(compute_row)( + M_Li=M_L, + pos=position_constraint, + vel=velocity_constraint, ), ) From 5fd741d3480ca4ab421559e0e073388280bb8e35 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 17 Sep 2025 17:41:20 +0200 Subject: [PATCH 23/39] Remove mesh collision support --- docs/guide/configuration.rst | 4 -- environment.yml | 1 - pixi.lock | 4 +- pyproject.toml | 4 +- src/jaxsim/parsers/rod/meshes.py | 104 ------------------------------- src/jaxsim/parsers/rod/utils.py | 12 ---- tests/test_meshes.py | 103 ------------------------------ 7 files changed, 3 insertions(+), 229 deletions(-) delete mode 100644 src/jaxsim/parsers/rod/meshes.py delete mode 100644 tests/test_meshes.py diff --git a/docs/guide/configuration.rst b/docs/guide/configuration.rst index 4061b30ff..993160f73 100644 --- a/docs/guide/configuration.rst +++ b/docs/guide/configuration.rst @@ -13,10 +13,6 @@ Environment variables starting with ``JAXSIM_COLLISION_`` are used to configure *Default:* ``50``. -- ``JAXSIM_COLLISION_MESH_ENABLED``: Enables or disables mesh-based collision detection. - - *Default:* ``False``. - - ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere. *Default:* ``False``. diff --git a/environment.yml b/environment.yml index 6cec4bb33..446c77766 100644 --- a/environment.yml +++ b/environment.yml @@ -15,7 +15,6 @@ dependencies: - pptree - qpax - rod >= 0.3.3 - - trimesh - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg diff --git a/pixi.lock b/pixi.lock index 7f3c7f097..0ea6d36c1 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4d9f48d5ba0b3f575a7b54563023b9a20cb3e483691a2db4c92689b2fc1ba862 -size 513426 +oid sha256:8785a62069c16d3c4f74af043c2bf75ffd743b7630bc24ce283c4b5c7454f2a3 +size 513965 diff --git a/pyproject.toml b/pyproject.toml index 55f5123b4..de19003b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ dependencies = [ "qpax", "rod >= 0.4.1", "typing_extensions ; python_version < '3.12'", - "trimesh", ] [project.optional-dependencies] @@ -224,8 +223,7 @@ jax-dataclasses = "*" pptree = "*" optax = "*" qpax = "*" -rod = ">=0.4.1" -trimesh = "*" +rod = "*" typing_extensions = "*" # # Optional dependencies. diff --git a/src/jaxsim/parsers/rod/meshes.py b/src/jaxsim/parsers/rod/meshes.py deleted file mode 100644 index 3679597e8..000000000 --- a/src/jaxsim/parsers/rod/meshes.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -import trimesh - -VALID_AXIS = {"x": 0, "y": 1, "z": 2} - - -def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray: - """ - Extract the vertices of a mesh as points. - """ - return mesh.vertices - - -def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray: - """ - Extract N random points from the surface of a mesh. - - Args: - mesh: The mesh from which to extract points. - n: The number of points to extract. - - Returns: - The extracted points (N x 3 array). - """ - - return mesh.sample(n) - - -def extract_points_uniform_surface_sampling( - mesh: trimesh.Trimesh, n: int -) -> np.ndarray: - """ - Extract N uniformly sampled points from the surface of a mesh. - - Args: - mesh: The mesh from which to extract points. - n: The number of points to extract. - - Returns: - The extracted points (N x 3 array). - """ - - return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0] - - -def extract_points_select_points_over_axis( - mesh: trimesh.Trimesh, axis: str, direction: str, n: int -) -> np.ndarray: - """ - Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis. - - Args: - mesh: The mesh from which to extract points. - axis: The axis along which to extract points. - direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower". - n: The number of points to extract. - - Returns: - The extracted points (N x 3 array). - """ - - dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]} - arr = mesh.vertices - - # Sort rows lexicographically first, then columnar. - arr.sort(axis=0) - sorted_arr = arr[dirs[direction]] - return sorted_arr - - -def extract_points_aap( - mesh: trimesh.Trimesh, - axis: str, - upper: float | None = None, - lower: float | None = None, -) -> np.ndarray: - """ - Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. - - Args: - mesh: The mesh from which to extract points. - axis: The axis along which to extract points. - upper: The upper bound of the range. - lower: The lower bound of the range. - - Returns: - The extracted points (N x 3 array). - - Raises: - AssertionError: If the lower bound is greater than the upper bound. - """ - - # Check bounds. - upper = upper if upper is not None else np.inf - lower = lower if lower is not None else -np.inf - assert lower < upper, "Invalid bounds for axis-aligned plane" - - # Logic. - points = mesh.vertices[ - (mesh.vertices[:, VALID_AXIS[axis]] >= lower) - & (mesh.vertices[:, VALID_AXIS[axis]] <= upper) - ] - - return points diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 85a3d0f20..851385cb7 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,21 +1,9 @@ -import os -import pathlib -from collections.abc import Callable -from typing import TypeVar - import numpy as np -import numpy.typing as npt import rod -import trimesh -from rod.utils.resolve_uris import resolve_local_uri import jaxsim.typing as jtp -from jaxsim import logging from jaxsim.math import Adjoint, Inertia from jaxsim.parsers import descriptions -from jaxsim.parsers.rod import meshes - -MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray]) def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: diff --git a/tests/test_meshes.py b/tests/test_meshes.py deleted file mode 100644 index d9bd66dcc..000000000 --- a/tests/test_meshes.py +++ /dev/null @@ -1,103 +0,0 @@ -import trimesh - -from jaxsim.parsers.rod import meshes - - -def test_mesh_wrapping_vertex_extraction(): - """ - Test the vertex extraction method on different meshes. - - 1. A simple box. - 2. A sphere. - """ - - # Test 1: A simple box. - # First, create a box with origin at (0,0,0) and extents (3,3,3), - # i.e. points span from -1.5 to 1.5 on the axis. - mesh = trimesh.creation.box( - extents=[3.0, 3.0, 3.0], - ) - points = meshes.extract_points_vertices(mesh=mesh) - assert len(points) == len(mesh.vertices) - - # Test 2: A sphere. - # The sphere is centered at the origin and has a radius of 1.0. - mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) - points = meshes.extract_points_vertices(mesh=mesh) - assert len(points) == len(mesh.vertices) - - -def test_mesh_wrapping_aap(): - """ - Test the AAP wrapping method on different meshes. - - 1. A simple box - 1.1: Remove all points above x=0.0 - 1.2: Remove all points below y=0.0 - 2. A sphere - """ - - # Test 1.1: Remove all points above x=0.0. - # The expected result is that the number of points is halved. - # First, create a box with origin at (0,0,0) and extents (3,3,3), - # i.e. points span from -1.5 to 1.5 on the axis. - mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) - points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0) - assert len(points) == len(mesh.vertices) // 2 - assert all(points[:, 0] > 0.0) - - # Test 1.2: Remove all points below y=0.0. - # The expected result is that the number of points is halved. - points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0) - assert len(points) == len(mesh.vertices) // 2 - assert all(points[:, 1] < 0.0) - - # Test 2: A sphere. - # The sphere is centered at the origin and has a radius of 1.0. - # Points are expected to be halved. - mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) - - # Remove all points above y=0.0. - points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0) - assert all(points[:, 1] >= 0.0) - assert len(points) < len(mesh.vertices) - - -def test_mesh_wrapping_points_over_axis(): - """ - Test the points over axis method on different meshes. - - 1. A simple box - 1.1: Select 10 points from the lower end of the x-axis - 1.2: Select 10 points from the higher end of the y-axis - 2. A sphere - """ - - # Test 1.1: Remove 10 points from the lower end of the x-axis. - # First, create a box with origin at (0,0,0) and extents (3,3,3), - # i.e. points span from -1.5 to 1.5 on the axis. - mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) - points = meshes.extract_points_select_points_over_axis( - mesh=mesh, axis="x", direction="lower", n=4 - ) - assert len(points) == 4 - assert all(points[:, 0] < 0.0) - - # Test 1.2: Select 10 points from the higher end of the y-axis. - points = meshes.extract_points_select_points_over_axis( - mesh=mesh, axis="y", direction="higher", n=4 - ) - assert len(points) == 4 - assert all(points[:, 1] > 0.0) - - # Test 2: A sphere. - # The sphere is centered at the origin and has a radius of 1.0. - mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) - sphere_n_vertices = len(mesh.vertices) - - # Select 10 points from the higher end of the z-axis. - points = meshes.extract_points_select_points_over_axis( - mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2 - ) - assert len(points) == sphere_n_vertices // 2 - assert all(points[:, 2] >= 0.0) From b1e054a678a2d9dae51485d26c675f4dea5c289e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 18 Sep 2025 19:22:44 +0200 Subject: [PATCH 24/39] Fix contact API test --- tests/test_api_contact.py | 87 +++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 49 deletions(-) diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index cae9e5070..d341fe21b 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -1,5 +1,10 @@ import jax import jax.numpy as jnp +<<<<<<< HEAD +======= +import numpy as np +import pytest +>>>>>>> 91f80b4 (Fix contact API test) import rod import jaxsim.api as js @@ -23,8 +28,6 @@ def test_contact_kinematics( velocity_representation=velocity_representation, ) - parent_link_idx_of_collidable_shapes = - # ===== # Tests # ===== @@ -32,27 +35,16 @@ def test_contact_kinematics( # Compute the pose of the implicit contact frame associated to the collidable shapes # and the transforms of all links. W_H_C = js.contact.transforms(model=model, data=data) - W_H_L = data._link_transforms - - # Check that the orientation of the implicit contact frame matches with the - # orientation of the link to which the contact shape is attached. - for contact_idx, index_of_parent_link in enumerate( - parent_link_idx_of_collidable_shapes - ): - assert_allclose( - W_H_C[contact_idx, 0:3, 0:3], W_H_L[index_of_parent_link][0:3, 0:3] - ) - + # Check that the origin of the implicit contact frame is located over the - # collidable point. - W_p_C = js.contact.collidable_shape_positions(model=model, data=data) - - assert_allclose(W_p_C, W_H_C[:, 0:3, 3]) + # collidable shape. + W_p_C = js.contact.contact_point_positions(model=model, data=data) + assert_allclose(W_p_C, W_H_C[:, :, 0:3, 3]) # Compute the velocity of the collidable shape. # This quantity always matches with the linear component of the mixed 6D velocity # of the implicit frame associated to the collidable shape. - W_ṗ_C = js.contact.collidable_shape_velocities(model=model, data=data) + W_ṗ_C = js.contact.contact_point_velocities(model=model, data=data) # Compute the velocity of the collidable shape using the contact Jacobian. ν = data.generalized_velocity @@ -60,10 +52,10 @@ def test_contact_kinematics( CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] # Compare the two velocities. - assert_allclose(W_ṗ_C, CW_vl_WC) + assert_allclose(jnp.contatenate(W_ṗ_C), CW_vl_WC) -def test_collidable_shape_jacobians( +def test_contact_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, @@ -83,7 +75,7 @@ def test_collidable_shape_jacobians( # Compute the velocity of the collidable shapes with a RBDA. # This function always returns the linear part of the mixed velocity of the # implicit frame C corresponding to the collidable shape. - W_ṗ_C = js.contact.collidable_shape_velocities(model=model, data=data) + W_ṗ_C = js.contact.contact_point_velocities(model=model, data=data) # Compute the generalized velocity and the free-floating Jacobian of the frame C. ν = data.generalized_velocity @@ -92,7 +84,7 @@ def test_collidable_shape_jacobians( # Compute the velocity of the collidable shapes using the Jacobians. v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) - assert_allclose(W_ṗ_C, v_WC_from_jax[:, 0:3]) + assert_allclose(jnp.concatenate(W_ṗ_C), v_WC_from_jax[:, 0:3]) def test_contact_jacobian_derivative( @@ -110,22 +102,16 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - # Get the indices of the enabled collidable shapes. - indices_of_enabled_collidable_shapes = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes - ) + W_H_L = data._link_transforms + W_p_C = js.contact.contact_point_positions(model=model, data=data) - # Extract the parent link names and the poses of the contact shapes. - parent_link_names = js.link.idxs_to_names( - model=model, - link_indices=jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes], + # Vectorize over the 3 points for one link + transform_points = jax.vmap( + lambda H, p: H @ jnp.hstack([p, 1.0]), in_axes=(None, 0) ) - L_p_Ci = model.kin_dyn_parameters.contact_parameters.shape[ - indices_of_enabled_collidable_shapes - ] + # Vectorize over the links + L_p_Ci = jax.vmap(transform_points, in_axes=(0, 0))(W_H_L, W_p_C)[..., :3] # ===== # Tests @@ -135,18 +121,22 @@ def test_contact_jacobian_derivative( rod_model = rod.Sdf.load(sdf=model.built_from).model # Add dummy frames on the contact shapes. - for idx, link_name, L_p_C in zip( - indices_of_enabled_collidable_shapes, parent_link_names, L_p_Ci, strict=True + + for idx, link_name, points in zip( + np.arange(model.number_of_links()), model.link_names(), L_p_Ci, strict=True ): - rod_model.add_frame( - frame=rod.Frame( - name=f"contact_shape_{idx}", - attached_to=link_name, - pose=rod.Pose( - relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C) + # points: shape (3, 3) for this link + for j, p in enumerate(points): + rod_model.add_frame( + frame=rod.Frame( + name=f"contact_shape_{idx}_{j}", + attached_to=link_name, + pose=rod.Pose( + relative_to=link_name, + pose=jnp.zeros((6,)).at[0:3].set(p), + ), ), - ), - ) + ) # Rebuild the JaxSim model. model_with_frames = js.model.JaxSimModel.build_from_model_description( @@ -172,13 +162,12 @@ def test_contact_jacobian_derivative( frame_idxs = js.frame.names_to_idxs( model=model_with_frames, frame_names=( - f"contact_shape_{idx}" for idx in indices_of_enabled_collidable_shapes + f"contact_shape_{idx}_{j}" + for idx in np.arange(model.number_of_links()) + for j in range(3) ), ) - # Check that the number of frames is correct. - assert len(frame_idxs) == len(parent_link_names) - # Compute the contact Jacobian derivative. J̇_WC = js.contact.jacobian_derivative( model=model_with_frames, data=data_with_frames From d0380955f3d383169b300f3ae6b4c4e640d13d41 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 18 Sep 2025 19:24:29 +0200 Subject: [PATCH 25/39] Fix collisions parsing contact API --- src/jaxsim/api/contact.py | 52 +++++++++++++++------------- src/jaxsim/api/kin_dyn_parameters.py | 27 +++++++-------- src/jaxsim/parsers/rod/parser.py | 18 ++++++++-- 3 files changed, 56 insertions(+), 41 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index eef4be8a8..354f7b997 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -17,7 +17,7 @@ @jax.jit @js.common.named_scope -def collidable_shape_kinematics( +def contact_point_kinematics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Matrix, jtp.Matrix]: """ @@ -36,8 +36,12 @@ def collidable_shape_kinematics( the linear component of the mixed 6D frame velocity. """ - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_shapes.collidable_shapes_pos_vel( - model=model, + _, _, _, W_p_Ci, W_ṗ_Ci = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, in_axes=(None,) + )( + model, + shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, + shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, link_transforms=data._link_transforms, link_velocities=data._link_velocities, ) @@ -47,7 +51,7 @@ def collidable_shape_kinematics( @jax.jit @js.common.named_scope -def collidable_shape_positions( +def contact_point_positions( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ @@ -61,14 +65,14 @@ def collidable_shape_positions( The position of the collidable points in the world frame. """ - W_p_Ci, _ = collidable_shape_kinematics(model=model, data=data) + W_p_Ci, _ = contact_point_kinematics(model=model, data=data) return W_p_Ci @jax.jit @js.common.named_scope -def collidable_shape_velocities( +def contact_point_velocities( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ @@ -82,7 +86,7 @@ def collidable_shape_velocities( The 3D velocity of the collidable points. """ - _, W_ṗ_Ci = collidable_shape_kinematics(model=model, data=data) + _, W_ṗ_Ci = contact_point_kinematics(model=model, data=data) return W_ṗ_Ci @@ -112,15 +116,15 @@ def in_contact( raise ValueError("One or more link names are not part of the model") # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_shapes = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_shapes + indices_of_enabled_contact_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_contact_points ) - parent_link_idx_of_enabled_collidable_shapes = jnp.array( + parent_link_idx_of_enabled_contact_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes] + )[indices_of_enabled_contact_points] - W_p_Ci = collidable_shape_positions(model=model, data=data) + W_p_Ci = contact_point_positions(model=model, data=data) terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))( W_p_Ci[:, 0], W_p_Ci[:, 1] @@ -136,7 +140,7 @@ def in_contact( links_in_contact = jax.vmap( lambda link_index: jnp.where( - parent_link_idx_of_enabled_collidable_shapes == link_index, + parent_link_idx_of_enabled_contact_points == link_index, below_terrain, jnp.zeros_like(below_terrain, dtype=bool), ).any() @@ -162,7 +166,7 @@ def estimate_good_contact_parameters( *, standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, - number_of_active_collidable_shapes_steady_state: jtp.IntLike = 1, + number_of_active_contact_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, ) -> jaxsim.rbda.contacts.ContactParamsTypes: @@ -173,7 +177,7 @@ def estimate_good_contact_parameters( model: The model to consider. standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. - number_of_active_collidable_shapes_steady_state: + number_of_active_contact_points_steady_state: The number of active collidable points in steady state. damping_ratio: The damping ratio. max_penetration: The maximum penetration allowed. @@ -194,19 +198,19 @@ def estimate_good_contact_parameters( zero_data = js.data.JaxSimModelData.build(model=model) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] if model.floating_base(): - W_pz_C = collidable_shape_positions(model=model, data=zero_data)[:, -1] + W_pz_C = contact_point_positions(model=model, data=zero_data)[:, -1] W_pz_CoM = W_pz_CoM - W_pz_C.min() # Consider as default a 1% of the model center of mass height. max_penetration = 0.01 * W_pz_CoM - nc = number_of_active_collidable_shapes_steady_state + nc = number_of_active_contact_points_steady_state return model.contact_model._parameters_class().build_default_from_jaxsim_model( model=model, standard_gravity=standard_gravity, static_friction_coefficient=static_friction_coefficient, max_penetration=max_penetration, - number_of_active_collidable_shapes_steady_state=nc, + number_of_active_contact_points_steady_state=nc, damping_ratio=damping_ratio, ) @@ -505,8 +509,8 @@ def link_forces_from_contact_forces( contact_parameters = model.kin_dyn_parameters.contact_parameters # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_shapes = ( - contact_parameters.indices_of_enabled_collidable_shapes + indices_of_enabled_contact_points = ( + contact_parameters.indices_of_enabled_contact_points ) # Convert the contact forces to a JAX array. @@ -515,13 +519,13 @@ def link_forces_from_contact_forces( # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly # attached to the same link. - parent_link_index_of_collidable_shapes = jnp.array( - contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_shapes] + parent_link_index_of_contact_points = jnp.array(contact_parameters.body, dtype=int)[ + indices_of_enabled_contact_points + ] # Create the mask that associate each collidable point to their parent link. # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_shapes[:, jnp.newaxis] == jnp.arange( + mask = parent_link_index_of_contact_points[:, jnp.newaxis] == jnp.arange( model.number_of_links() ) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 41059bfd4..fb0630b5a 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -816,26 +816,25 @@ def build_from(model_description: ModelDescription) -> ContactParameters: if len(model_description.collision_shapes) == 0: return ContactParameters() + shape_types, shape_sizes, centers = [], [], [] + # Assume the link_parameters and the collision_shapes are in the same order. - centers = jnp.array( - [shape.center for shape in model_description.collision_shapes] - ) + for collision in model_description.collision_shapes: + shape_types.append( + _COLLISION_SHAPE_MAP.get( + type(collision), CollidableShapeType.Unsupported + ) + ) - shape_size = jnp.array( - [shape.size.squeeze() for shape in model_description.collision_shapes] - ) + shape_sizes.append(collision.size.squeeze()) - shape_type = [ - _COLLISION_SHAPE_MAP.get(type(shape), CollidableShapeType.Unsupported) - for shape in model_description.collision_shapes - ] - shape_type = jnp.array(shape_type, dtype=int) + centers.append(collision.center) # Build the ContactParameters object. return ContactParameters( - center=centers, - shape_type=shape_type, - shape_size=shape_size, + center=jnp.array(centers, dtype=float), + shape_type=jnp.array(shape_types, dtype=int), + shape_size=jnp.array(shape_sizes, dtype=float), ) diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index a4bfd45ed..87e710461 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -312,6 +312,8 @@ def extract_model_data( # Parse the collisions for link in sdf_model.links(): + # If a link has multiple collision shapes, we consider only the first + # supported one with priority box > sphere > cylinder. for collision in link.collisions(): if collision.geometry.box is not None: box_collision = utils.create_box_collision( @@ -320,7 +322,7 @@ def extract_model_data( ) collisions.append(box_collision) - continue + break if collision.geometry.sphere is not None: sphere_collision = utils.create_sphere_collision( @@ -329,7 +331,7 @@ def extract_model_data( ) collisions.append(sphere_collision) - continue + break if collision.geometry.cylinder is not None: cylinder_collision = utils.create_cylinder_collision( @@ -338,7 +340,7 @@ def extract_model_data( ) collisions.append(cylinder_collision) - continue + break # Check any remaining non-None geometry types. for attr_name in collision.geometry.__dict__: @@ -347,6 +349,16 @@ def extract_model_data( f"Skipping collision shape '{attr_name}' in link '{link.name}' as not supported." ) + else: + # Fill with unsupported collision shape + collisions.append( + descriptions.collision.CollisionShape( + center=jnp.array([0.0, 0.0, 0.0]), + size=jnp.array([0.0, 0.0, 0.0]), + parent_link=link.name, + ) + ) + return SDFData( model_name=sdf_model.name, link_descriptions=links, From f0de993fd17a1fd67e7f8888d9e2e247e99b088c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 19 Sep 2025 11:30:27 +0200 Subject: [PATCH 26/39] Fix contact Jacobian derivative --- src/jaxsim/api/contact.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 354f7b997..a9ede1364 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -383,11 +383,11 @@ def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix: W_X = Adjoint.from_transform(data.base_transform) W_Ẋ = W_X @ Cross.vx(data.base_velocity) case VelRepr.Mixed: - H_BW = data.base_transform.at[0:3, 0:3].set(jnp.eye(3)) - X_BW = Adjoint.from_transform(H_BW) - v_BW = data.base_velocity.at[3:6].set(0) - W_X = X_BW - W_Ẋ = X_BW @ Cross.vx(v_BW) + W_H_BW = data.base_transform.at[0:3, 0:3].set(jnp.eye(3)) + W_X_BW = Adjoint.from_transform(W_H_BW) + BW_v_W_BW = data.base_velocity.at[3:6].set(0) + W_X = W_X_BW + W_Ẋ = W_X_BW @ Cross.vx(BW_v_W_BW) case _: raise ValueError(data.velocity_representation) @@ -400,14 +400,11 @@ def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix: with data.switch_velocity_representation(VelRepr.Inertial): # Compute the Jacobian of the parent link in inertial representation. - W_J_WL_W = js.model.generalized_free_floating_jacobian( - model=model, - data=data, - ) + W_J_WL_W = js.model.generalized_free_floating_jacobian(model=model, data=data) + # Compute the Jacobian derivative of the parent link in inertial representation. W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( - model=model, - data=data, + model=model, data=data ) def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: @@ -435,7 +432,7 @@ def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: O_J̇_per_link = jax.vmap( lambda H_C_link, v_WL_link, J_WL_link, J̇_WL_link: jax.vmap( compute_O_J̇_WC_I, - in_axes=(0, None, None, None), # Map over contacts for H_C only + in_axes=(0, None, None, None), # Map over contacts for W_H_C only )(H_C_link, v_WL_link, J_WL_link, J̇_WL_link), in_axes=(0, 0, 0, 0), # Map over links )(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) From c0840d02c5a758820284da8e4b106b585f435d9e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 19 Sep 2025 19:55:27 +0200 Subject: [PATCH 27/39] Consider collision shape offset wrt link center --- src/jaxsim/api/contact.py | 1 + src/jaxsim/parsers/rod/utils.py | 6 ++---- src/jaxsim/rbda/contacts/common.py | 5 +++++ src/jaxsim/rbda/contacts/relaxed_rigid.py | 1 + src/jaxsim/rbda/contacts/soft.py | 1 + 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index a9ede1364..278936784 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -40,6 +40,7 @@ def contact_point_kinematics( jaxsim.rbda.contacts.common.compute_penetration_data, in_axes=(None,) )( model, + shape_offset=model.kin_dyn_parameters.contact_parameters.center, shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, link_transforms=data._link_transforms, diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 851385cb7..e3b004ee0 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -103,15 +103,13 @@ def create_box_collision( x, y, z = collision.geometry.box.size - center = np.array([x / 2, y / 2, z / 2]) - H = collision.pose.transform() if collision.pose is not None else np.eye(4) - center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1] + center = H[:3, 3] return descriptions.BoxCollision( size=np.array([x, y, z]), - center=center_wrt_link, + center=center, parent_link=link_description.name, ) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 9ba880274..eb82f135c 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -32,6 +32,7 @@ def compute_penetration_data( model: js.model.JaxSimModel, *, + shape_offset: jtp.Vector, shape_type: CollidableShapeType, shape_size: jtp.Vector, link_transforms: jtp.Matrix, @@ -42,6 +43,7 @@ def compute_penetration_data( Args: model: The model to consider. + shape_offset: The offset of the collidable shape with respect to the link frame. shape_type: The type of the collidable shape. shape_size: The size parameters of the collidable shape. link_transforms: The transforms from the world frame to each link. @@ -55,6 +57,9 @@ def compute_penetration_data( W_H_L, W_ṗ_L = link_transforms, link_velocities + # Offset the collision shape origin. + W_H_L = W_H_L.at[:3, 3].set(W_H_L[:3, 3] + shape_offset @ W_H_L[:3, :3].T) + # Pre-process the position and the linear velocity of the collidable point. # Note that we consider 3 candidate contact points also for spherical shapes, # in which the output is padded with zeros. diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 1f4ca4e80..d5f744f4b 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -332,6 +332,7 @@ def compute_contact_forces( common.compute_penetration_data, in_axes=(None,) )( model, + shape_offset=model.kin_dyn_parameters.contact_parameters.center, shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, link_transforms=data._link_transforms, diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index db182f486..669d0a531 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -419,6 +419,7 @@ def compute_contact_forces( common.compute_penetration_data, in_axes=(None,) )( model, + shape_offset=model.kin_dyn_parameters.contact_parameters.center, shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, link_transforms=data._link_transforms, From 76f72a7e516dca274a038001d18ca2619d0f663a Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 22 Oct 2025 15:27:50 +0200 Subject: [PATCH 28/39] Add `.DS_Store` to `.gitignore` --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 05520e402..1448c0216 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ src/jaxsim/_version.py # data .mp4 .png + +# macOS +.DS_Store From c32f2d610d5b543c5ee0f11df10600b3ab9675e4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Oct 2025 15:25:43 +0200 Subject: [PATCH 29/39] Save collision shape to link transform --- src/jaxsim/api/kin_dyn_parameters.py | 52 +++++++++++++++----- src/jaxsim/parsers/descriptions/collision.py | 52 ++++++-------------- src/jaxsim/parsers/rod/parser.py | 2 +- src/jaxsim/parsers/rod/utils.py | 12 ++--- src/jaxsim/rbda/contacts/common.py | 13 ++--- tests/test_api_contact.py | 38 ++++++++------ 6 files changed, 88 insertions(+), 81 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index fb0630b5a..d7dfe907d 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -783,24 +783,41 @@ class ContactParameters(JaxsimDataclass): Attributes: body: - A tuple of integers representing, for each collidable point, the index of - the body (link) to which it is rigidly attached to. - point: - The translations between the link frame and the collidable point, expressed - in the coordinates of the parent link frame. - enabled: - A tuple of booleans representing, for each collidable point, whether it is - enabled or not in contact models. + A tuple of integers representing, for each collision shape, the index of + the link to which it is rigidly attached to. + transform: + The 4x4 homogeneous transformation matrices representing the pose of each + collision shape with respect to the parent link frame. + shape_size: + The size parameters of each collidable shape. + shape_type: + The type of each collidable shape (sphere, box, cylinder, etc.). Note: Contrarily to LinkParameters and JointParameters, this class is not meant to be created with vmap. This is because the `body` attribute must be `Static`. """ - center: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple) + + transform: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([])) shape_size: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) shape_type: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + @property + def center(self) -> jtp.Array: + """Extract translation vectors from transformation matrices.""" + if self.transform.size == 0: + return jnp.array([]) + return self.transform[:, :3, 3] + + @property + def orientation(self) -> jtp.Array: + """Extract rotation matrices from transformation matrices.""" + if self.transform.size == 0: + return jnp.array([]) + return self.transform[:, :3, :3] + @staticmethod def build_from(model_description: ModelDescription) -> ContactParameters: """ @@ -816,7 +833,12 @@ def build_from(model_description: ModelDescription) -> ContactParameters: if len(model_description.collision_shapes) == 0: return ContactParameters() - shape_types, shape_sizes, centers = [], [], [] + shape_types, shape_sizes, transforms, parent_link_indices = ( + [], + [], + [], + [], + ) # Assume the link_parameters and the collision_shapes are in the same order. for collision in model_description.collision_shapes: @@ -828,11 +850,17 @@ def build_from(model_description: ModelDescription) -> ContactParameters: shape_sizes.append(collision.size.squeeze()) - centers.append(collision.center) + transforms.append(collision.transform) + + # Get the parent link index for this collision shape. + parent_link_indices.append( + model_description.links_dict[collision.parent_link].index + ) # Build the ContactParameters object. return ContactParameters( - center=jnp.array(centers, dtype=float), + body=tuple(parent_link_indices), + transform=jnp.array(transforms, dtype=float), shape_type=jnp.array(shape_types, dtype=int), shape_size=jnp.array(shape_sizes, dtype=float), ) diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 369285fb4..cf63f71e7 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -3,6 +3,8 @@ import dataclasses from abc import ABC +import numpy as np + import jaxsim.typing as jtp @@ -15,16 +17,16 @@ class CollisionShape(ABC): It is not intended to be instantiated directly. """ - center: jtp.VectorLike size: jtp.VectorLike parent_link: str + transform: jtp.MatrixLike = dataclasses.field(default_factory=lambda: np.eye(4)) def __hash__(self) -> int: return hash( ( - hash(tuple(self.center.tolist())), hash(tuple(self.size.tolist())), hash(self.parent_link), + hash(tuple(self.transform.flatten().tolist())), ) ) @@ -35,6 +37,16 @@ def __eq__(self, other: CollisionShape) -> bool: return hash(self) == hash(other) + @property + def center(self) -> jtp.Vector: + """Extract the translation from the transformation matrix.""" + return self.transform[:3, 3] + + @property + def orientation(self) -> jtp.Matrix: + """Extract the rotation matrix from the transformation matrix.""" + return self.transform[:3, :3] + @dataclasses.dataclass class BoxCollision(CollisionShape): @@ -42,30 +54,6 @@ class BoxCollision(CollisionShape): Represents a box-shaped collision shape. """ - @property - def x(self) -> float: - return self.size[0] - - @property - def y(self) -> float: - return self.size[1] - - @property - def z(self) -> float: - return self.size[2] - - @x.setter - def x(self, value: float) -> None: - self.size[0] = value - - @y.setter - def y(self, value: float) -> None: - self.size[1] = value - - @z.setter - def z(self, value: float) -> None: - self.size[2] = value - @dataclasses.dataclass class SphereCollision(CollisionShape): @@ -73,21 +61,9 @@ class SphereCollision(CollisionShape): Represents a spherical collision shape. """ - @property - def radius(self) -> float: - return self.size[0] - @dataclasses.dataclass class CylinderCollision(CollisionShape): """ Represents a cylindrical collision shape. """ - - @property - def radius(self) -> float: - return self.size[0] - - @property - def height(self) -> float: - return self.size[1] diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 87e710461..719dd78d5 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -353,7 +353,7 @@ def extract_model_data( # Fill with unsupported collision shape collisions.append( descriptions.collision.CollisionShape( - center=jnp.array([0.0, 0.0, 0.0]), + transform=jnp.eye(4), size=jnp.array([0.0, 0.0, 0.0]), parent_link=link.name, ) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index e3b004ee0..2a2eb0d63 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -105,11 +105,9 @@ def create_box_collision( H = collision.pose.transform() if collision.pose is not None else np.eye(4) - center = H[:3, 3] - return descriptions.BoxCollision( size=np.array([x, y, z]), - center=center, + transform=H, parent_link=link_description.name, ) @@ -132,11 +130,9 @@ def create_sphere_collision( H = collision.pose.transform() if collision.pose is not None else np.eye(4) - center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] - return descriptions.SphereCollision( size=np.array([r] * 3), - center=center_wrt_link, + transform=H, parent_link=link_description.name, ) @@ -160,10 +156,8 @@ def create_cylinder_collision( H = collision.pose.transform() if collision.pose is not None else np.eye(4) - center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] - return descriptions.CylinderCollision( size=np.array([r, l, 0]), - center=center_wrt_link, + transform=H, parent_link=link_description.name, ) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index eb82f135c..daae5b2ba 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -32,7 +32,7 @@ def compute_penetration_data( model: js.model.JaxSimModel, *, - shape_offset: jtp.Vector, + shape_transform: jtp.Matrix, shape_type: CollidableShapeType, shape_size: jtp.Vector, link_transforms: jtp.Matrix, @@ -43,7 +43,7 @@ def compute_penetration_data( Args: model: The model to consider. - shape_offset: The offset of the collidable shape with respect to the link frame. + shape_transform: The 4x4 transform of the collidable shape with respect to the link frame. shape_type: The type of the collidable shape. shape_size: The size parameters of the collidable shape. link_transforms: The transforms from the world frame to each link. @@ -55,10 +55,11 @@ def compute_penetration_data( expressed in mixed representation. """ - W_H_L, W_ṗ_L = link_transforms, link_velocities + W_H_L, W_ṗ_L = link_transforms, link_velocities - # Offset the collision shape origin. - W_H_L = W_H_L.at[:3, 3].set(W_H_L[:3, 3] + shape_offset @ W_H_L[:3, :3].T) + # Apply the collision shape transform. + # This computes W_H_S where S is the collision shape frame. + W_H_S = W_H_L @ shape_transform # Pre-process the position and the linear velocity of the collidable point. # Note that we consider 3 candidate contact points also for spherical shapes, @@ -69,7 +70,7 @@ def compute_penetration_data( (box_plane, cylinder_plane, sphere_plane), model.terrain, shape_size, - W_H_L, + W_H_S, ) W_p_C = W_H_C[:, :3, 3] diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index d341fe21b..0fcc1997f 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -102,16 +102,24 @@ def test_contact_jacobian_derivative( velocity_representation=velocity_representation, ) - W_H_L = data._link_transforms + body_indices = np.array(model.kin_dyn_parameters.contact_parameters.body) + + # Get link transforms for each collision shape + W_H_L = data._link_transforms[body_indices] + + # Get contact point positions (shape: num_collision_shapes, 3, 3) W_p_C = js.contact.contact_point_positions(model=model, data=data) - # Vectorize over the 3 points for one link - transform_points = jax.vmap( - lambda H, p: H @ jnp.hstack([p, 1.0]), in_axes=(None, 0) - ) + # Transform contact points from world to link frame + # For each collision shape, transform its 3 contact points + def transform_to_link_frame(W_H_L_i, W_p_Ci): + """Transform 3 contact points from world to link frame.""" - # Vectorize over the links - L_p_Ci = jax.vmap(transform_points, in_axes=(0, 0))(W_H_L, W_p_C)[..., :3] + L_H_W = jnp.linalg.inv(W_H_L_i) + return jax.vmap(lambda p: (L_H_W @ jnp.hstack([p, 1.0]))[:3])(W_p_Ci) + + # Apply to all collision shapes: shape (num_collision_shapes, 3, 3) + L_p_Ci = jax.vmap(transform_to_link_frame)(W_H_L, W_p_C) # ===== # Tests @@ -120,16 +128,15 @@ def test_contact_jacobian_derivative( # Load the model in ROD. rod_model = rod.Sdf.load(sdf=model.built_from).model - # Add dummy frames on the contact shapes. - - for idx, link_name, points in zip( - np.arange(model.number_of_links()), model.link_names(), L_p_Ci, strict=True + for shape_idx, (link_idx, points) in enumerate( + zip(body_indices, L_p_Ci, strict=True) ): - # points: shape (3, 3) for this link + link_name = model.link_names()[link_idx] + for j, p in enumerate(points): rod_model.add_frame( frame=rod.Frame( - name=f"contact_shape_{idx}_{j}", + name=f"contact_shape_{shape_idx}_{j}", attached_to=link_name, pose=rod.Pose( relative_to=link_name, @@ -159,11 +166,12 @@ def test_contact_jacobian_derivative( ) # Extract the indexes of the frames attached to the contact shapes. + num_collision_shapes = len(model.kin_dyn_parameters.contact_parameters.body) frame_idxs = js.frame.names_to_idxs( model=model_with_frames, frame_names=( - f"contact_shape_{idx}_{j}" - for idx in np.arange(model.number_of_links()) + f"contact_shape_{shape_idx}_{j}" + for shape_idx in range(num_collision_shapes) for j in range(3) ), ) From 5078049591f3c53f1628d2fae320ace4044bb192 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Oct 2025 15:25:58 +0200 Subject: [PATCH 30/39] Refactor contact models to use shape transforms --- src/jaxsim/api/contact.py | 76 ++++++++++++++++------- src/jaxsim/rbda/contacts/relaxed_rigid.py | 52 +++++++++++----- src/jaxsim/rbda/contacts/soft.py | 38 ++++++++---- 3 files changed, 117 insertions(+), 49 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 278936784..0f359dd9b 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -36,18 +36,28 @@ def contact_point_kinematics( the linear component of the mixed 6D frame velocity. """ - _, _, _, W_p_Ci, W_ṗ_Ci = jax.vmap( - jaxsim.rbda.contacts.common.compute_penetration_data, in_axes=(None,) + _, _, _, W_p_Ci, W_ṗ_Ci = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: jaxsim.rbda.contacts.common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) )( - model, - shape_offset=model.kin_dyn_parameters.contact_parameters.center, - shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, - shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, - link_transforms=data._link_transforms, - link_velocities=data._link_velocities, + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) - return W_p_Ci, W_ṗ_Ci + return W_p_Ci, W_ṗ_Ci @jax.jit @@ -241,13 +251,20 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt # Get the transforms of the parent link of all collidable points. W_H_L = data._link_transforms - def _process_single_shape(shape_type, shape_size, W_H_Li): + # Index transforms by the body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_H_L_indexed = W_H_L[body_indices] + + def _process_single_shape(shape_type, shape_size, shape_transform, W_H_Li): + # Apply the collision shape transform to get W_H_S + W_H_S = W_H_Li @ shape_transform + _, W_H_C = jax.lax.switch( shape_type, (detection.box_plane, detection.cylinder_plane, detection.sphere_plane), model.terrain, shape_size, - W_H_Li, + W_H_S, ) return W_H_C @@ -255,7 +272,8 @@ def _process_single_shape(shape_type, shape_size, W_H_Li): return jax.vmap(_process_single_shape)( model.kin_dyn_parameters.contact_parameters.shape_type, model.kin_dyn_parameters.contact_parameters.shape_size, - W_H_L, + model.kin_dyn_parameters.contact_parameters.transform, + W_H_L_indexed, ) @@ -294,13 +312,17 @@ def jacobian( model=model, data=data, output_vel_repr=VelRepr.Inertial ) - # Compute contact transforms (n_links, n_contacts, 4, 4) + # Compute contact transforms (n_shapes, n_contacts_per_shape, 4, 4) W_H_C = transforms(model=model, data=data) - # Flatten link × contact axes for single-batch processing (n_links*n_contacts, 6, 6+n) - W_J_WC_flat = jnp.repeat(W_J_WL, 3, axis=0) + # Index Jacobians by the body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_J_WL_indexed = W_J_WL[body_indices] # (n_shapes, 6, 6+n) + + # Repeat for each contact point per shape: (n_shapes*n_contacts_per_shape, 6, 6+n) + W_J_WC_flat = jnp.repeat(W_J_WL_indexed, 3, axis=0) - # Flatten contact transforms (n_links*n_contacts, 4, 4) + # Flatten contact transforms (n_shapes*n_contacts_per_shape, 4, 4) W_H_C_flat = W_H_C.reshape(-1, 4, 4) # Transform Jacobian based on velocity representation @@ -357,7 +379,11 @@ def jacobian_derivative( # Get the link velocities. W_v_WL = data._link_velocities - # Compute the contact transforms (n_links, n_contacts, 4, 4) + # Index link velocities by body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_v_WL_indexed = W_v_WL[body_indices] # (n_shapes, 6) + + # Compute the contact transforms (n_shapes, n_contacts, 4, 4) W_H_C = transforms(model=model, data=data) # ===================================================== @@ -408,6 +434,10 @@ def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix: model=model, data=data ) + # Index Jacobians by body (parent link) of each collision shape + W_J_WL_W_indexed = W_J_WL_W[body_indices] # (n_shapes, 6, 6+n) + W_J̇_WL_W_indexed = W_J̇_WL_W[body_indices] # (n_shapes, 6, 6+n) + def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: match output_vel_repr: case VelRepr.Inertial: @@ -430,15 +460,15 @@ def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: return O_J̇_WC_I - O_J̇_per_link = jax.vmap( - lambda H_C_link, v_WL_link, J_WL_link, J̇_WL_link: jax.vmap( + O_J̇_per_shape = jax.vmap( + lambda H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape: jax.vmap( compute_O_J̇_WC_I, in_axes=(0, None, None, None), # Map over contacts for W_H_C only - )(H_C_link, v_WL_link, J_WL_link, J̇_WL_link), - in_axes=(0, 0, 0, 0), # Map over links - )(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) + )(H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape), + in_axes=(0, 0, 0, 0), # Map over shapes + )(W_H_C, W_v_WL_indexed, W_J_WL_W_indexed, W_J̇_WL_W_indexed) - O_J̇_WC = O_J̇_per_link.reshape(-1, 6, 6 + model.dofs()) + O_J̇_WC = O_J̇_per_shape.reshape(-1, 6, 6 + model.dofs()) return O_J̇_WC diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index d5f744f4b..7334cea94 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -328,15 +328,25 @@ def compute_contact_forces( # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. - δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( - common.compute_penetration_data, in_axes=(None,) + δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) )( - model, - shape_offset=model.kin_dyn_parameters.contact_parameters.center, - shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, - shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, - link_transforms=data._link_transforms, - link_velocities=data._link_velocities, + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) # Compute the position in the constraint frame. @@ -346,7 +356,7 @@ def compute_contact_forces( a_ref, r, *_ = self._regularizers( model=model, position_constraint=position_constraint, - velocity_constraint=CW_ṗ_C, + velocity_constraint=CW_ṗ_C, parameters=model.contact_params, ) @@ -529,13 +539,21 @@ def to_inertial(force, H_C): # Compute the contact forces in inertial representation for # each link and contact point. - # Nested vmap: inner over contacts, outer over links - W_f_C = jax.vmap(lambda f_link, H_link: jax.vmap(to_inertial)(f_link, H_link))( - CW_fl_per_link, W_H_C + # Nested vmap: inner over contacts, outer over shapes + W_f_C = jax.vmap( + lambda f_shape, H_shape: jax.vmap(to_inertial)(f_shape, H_shape) + )(CW_fl_per_link, W_H_C) + + # Sum over contacts for each shape: (n_shapes, 6) + W_f_per_shape = W_f_C.sum(axis=1) + + # Accumulate forces by parent link using segment_sum + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_f_per_link = jax.ops.segment_sum( + W_f_per_shape, body_indices, num_segments=model.number_of_links() ) - # Sum over contacts for each link - return W_f_C.sum(axis=1), {} + return W_f_per_link, {} @staticmethod def _regularizers( @@ -576,7 +594,11 @@ def _regularizers( ) # Compute the 6D inertia matrices of all links. - M_L = js.model.link_spatial_inertia_matrices(model=model)[:, :3, :3] + M_L_all = js.model.link_spatial_inertia_matrices(model=model)[:, :3, :3] + + # Index M_L by the body (parent link) of each collision shape + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + M_L = M_L_all[body_indices] def imp_aref( pos: jtp.Vector, diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 669d0a531..ac248912f 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -415,15 +415,25 @@ def compute_contact_forces( # Compute the position and linear velocities (mixed representation) of # all the collidable shapes belonging to the robot and extract the ones # for the enabled collidable shapes. - δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( - common.compute_penetration_data, in_axes=(None,) + δ, δ̇, n̂, W_p_C, CW_ṗ_C = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) )( - model, - shape_offset=model.kin_dyn_parameters.contact_parameters.center, - shape_type=model.kin_dyn_parameters.contact_parameters.shape_type, - shape_size=model.kin_dyn_parameters.contact_parameters.shape_size, - link_transforms=data._link_transforms, - link_velocities=data._link_velocities, + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) # Extract the material deformation corresponding to the collidable shapes. @@ -437,9 +447,15 @@ def compute_contact_forces( # We exploit two levels of vmap to vectorize over both the shapes and the points. # The outer vmap vectorizes over the shapes, while the inner vmap vectorizes # over the maximum points (3) belonging to each shape. - W_f, ṁ = jax.vmap( + W_f_per_shape, ṁ = jax.vmap( SoftContacts.compute_contact_force, in_axes=(0, 0, 0, 0, 0, 0, None), # vectorize over shapes - )(δ, δ̇, W_p_C, CW_ṗ_C, n̂, m, model.contact_params) + )(δ, δ̇, W_p_C, CW_ṗ_C, n̂, m, model.contact_params) + + # Accumulate forces by parent link using segment_sum + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_f = jax.ops.segment_sum( + W_f_per_shape, body_indices, num_segments=model.number_of_links() + ) - return W_f, {"m_dot": ṁ} + return W_f, {"m_dot": ṁ} From 33bd9f9bcb9a3042006e2a75d3b89df396fd5765 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Oct 2025 15:34:45 +0200 Subject: [PATCH 31/39] Handle collision lumping when reducing models --- src/jaxsim/parsers/descriptions/model.py | 57 ++++++++++++------------ 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 08779c5bd..659e738ea 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import itertools from collections.abc import Sequence from jaxsim import logging @@ -27,9 +26,7 @@ class ModelDescription(KinematicGraph): fixed_base: bool = True - collision_shapes: tuple = dataclasses.field( - default_factory=list, repr=False - ) + collision_shapes: tuple = dataclasses.field(default_factory=list, repr=False) @staticmethod def build_model_from( @@ -102,31 +99,33 @@ def build_model_from( logging.info(msg.format(parent_link_of_shape)) continue - # Create a new collision shape - # new_collision_shape = CollisionShape(collidable_points=()) - # final_collisions.append(new_collision_shape) - - # # If the frame was found, update the collidable points' pose and add them - # # to the new collision shape. - # for cp in collision_shape.collidable_points: - # # Find the link that is part of the (reduced) model in which the - # # collision shape's parent was lumped into - # real_parent_link_name = kinematic_graph.frames_dict[ - # parent_link_of_shape.name - # ].parent_name - - # # Change the link associated to the collidable point, updating their - # # relative pose - # moved_cp = cp.change_link( - # new_link=kinematic_graph.links_dict[real_parent_link_name], - # new_H_old=fk.relative_transform( - # relative_to=real_parent_link_name, - # name=cp.parent_link.name, - # ), - # ) - - # # Store the updated collision. - # new_collision_shape.collidable_points += (moved_cp,) + # Find the link that is part of the (reduced) model in which the + # collision shape's parent was lumped into. + real_parent_link_name = kinematic_graph.frames_dict[ + parent_link_of_shape + ].parent_name + + # Get the transform from the real parent link to the removed link + # that still exists as a frame. + parent_H_frame = fk.relative_transform( + relative_to=real_parent_link_name, + name=parent_link_of_shape, + ) + + # Transform the collision shape's pose to the new parent link frame. + # The collision shape was defined w.r.t. the removed link (now a frame). + # Now we need to express it w.r.t. the link that absorbed the removed link. + # Compose the transforms: parent_H_shape = parent_H_frame @ frame_H_shape + parent_H_shape = parent_H_frame @ collision_shape.transform + + # Create a new collision shape with updated pose and parent link + new_collision_shape = dataclasses.replace( + collision_shape, + transform=parent_H_shape, + parent_link=real_parent_link_name, + ) + + final_collisions.append(new_collision_shape) # Build the model model = ModelDescription( From 81927809d8f754abb33355de570abec2143b3bf3 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 Oct 2025 09:37:43 +0100 Subject: [PATCH 32/39] Avoid to save unsupported collision attributes --- src/jaxsim/api/kin_dyn_parameters.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index d7dfe907d..981a3d2c5 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -840,14 +840,17 @@ def build_from(model_description: ModelDescription) -> ContactParameters: [], ) - # Assume the link_parameters and the collision_shapes are in the same order. for collision in model_description.collision_shapes: - shape_types.append( - _COLLISION_SHAPE_MAP.get( - type(collision), CollidableShapeType.Unsupported - ) + shape_type = _COLLISION_SHAPE_MAP.get( + type(collision), CollidableShapeType.Unsupported ) + # Skip unsupported collision shapes + if shape_type == CollidableShapeType.Unsupported: + continue + + shape_types.append(shape_type) + shape_sizes.append(collision.size.squeeze()) transforms.append(collision.transform) From 38698ec9203db016c22b647f5f480af728edeed9 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 Oct 2025 09:40:09 +0100 Subject: [PATCH 33/39] Update cylinder collision configuration - Removed sphere collision points options in docs. - Added env var to disable cylinder collisions. --- docs/guide/configuration.rst | 9 +-------- examples/jaxsim_as_physics_engine_advanced.ipynb | 5 +---- src/jaxsim/parsers/rod/parser.py | 4 +++- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/docs/guide/configuration.rst b/docs/guide/configuration.rst index 993160f73..d1ea88041 100644 --- a/docs/guide/configuration.rst +++ b/docs/guide/configuration.rst @@ -9,17 +9,10 @@ Collision Dynamics Environment variables starting with ``JAXSIM_COLLISION_`` are used to configure collision dynamics. The available variables are: -- ``JAXSIM_COLLISION_SPHERE_POINTS``: Specifies the number of collision points to approximate the sphere. - - *Default:* ``50``. - -- ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere. +- ``JAXSIM_COLLISION_ENABLE_CYLINDER``: Enables collision dynamics for cylindrical geometries. *Default:* ``False``. -.. note:: - The bottom half is defined as the half of the box or sphere with the lowest z-coordinate in the collision link frame. - Testing ~~~~~~~ diff --git a/examples/jaxsim_as_physics_engine_advanced.ipynb b/examples/jaxsim_as_physics_engine_advanced.ipynb index e74dc6e23..0faa12f23 100644 --- a/examples/jaxsim_as_physics_engine_advanced.ipynb +++ b/examples/jaxsim_as_physics_engine_advanced.ipynb @@ -130,10 +130,7 @@ "# JaxSim currently only supports collisions between points attached to bodies\n", "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", "# While this approach is universal as it applies to generic meshes, the number\n", - "# of considered points greatly affects the performance. Spheres, by default,\n", - "# are discretized with 250 points. It's too much for this simple example.\n", - "# This number can be decreased with the following environment variable.\n", - "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" + "# of considered points greatly affects the performance." ] }, { diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 719dd78d5..19aba0e2d 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -333,7 +333,9 @@ def extract_model_data( collisions.append(sphere_collision) break - if collision.geometry.cylinder is not None: + if collision.geometry.cylinder is not None and int( + os.environ.get("JAXSIM_ENABLE_CYLINDER_COLLISION", 0) + ): cylinder_collision = utils.create_cylinder_collision( collision=collision, link_description=links_dict[link.name], From 99971265f314062baf8310ac6b8a74027eb699ed Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 29 Oct 2025 17:10:45 +0100 Subject: [PATCH 34/39] Use JAX NumPy to build contact frame --- src/jaxsim/rbda/contacts/detection.py | 10 +++++++--- src/jaxsim/rbda/contacts/soft.py | 5 +++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/rbda/contacts/detection.py b/src/jaxsim/rbda/contacts/detection.py index 0c6b2180b..41c0dda4b 100644 --- a/src/jaxsim/rbda/contacts/detection.py +++ b/src/jaxsim/rbda/contacts/detection.py @@ -17,9 +17,13 @@ def _contact_frame(normal: jtp.Vector, position: jtp.Vector) -> jtp.Matrix: R = jnp.stack([t1, t2, n], axis=1) - return jaxsim.math.Transform.from_rotation_and_translation( - rotation=R, - translation=position, + return jnp.block( + [ + [R[0, 0], R[0, 1], R[0, 2], position[0]], + [R[1, 0], R[1, 1], R[1, 2], position[1]], + [R[2, 0], R[2, 1], R[2, 2], position[2]], + [0.0, 0.0, 0.0, 1.0], + ] ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index ac248912f..e40fb75d0 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp import jax_dataclasses +import numpy as np import jaxsim.api as js import jaxsim.math @@ -429,10 +430,10 @@ def compute_contact_forces( model.kin_dyn_parameters.contact_parameters.shape_type, model.kin_dyn_parameters.contact_parameters.shape_size, data._link_transforms[ - jnp.array(model.kin_dyn_parameters.contact_parameters.body) + np.array(model.kin_dyn_parameters.contact_parameters.body) ], data._link_velocities[ - jnp.array(model.kin_dyn_parameters.contact_parameters.body) + np.array(model.kin_dyn_parameters.contact_parameters.body) ], ) From 67bf77cbc2e225103ffd7929b482eedbdec9d3f0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 19 Nov 2025 11:48:27 +0100 Subject: [PATCH 35/39] Fix contact API tests --- tests/test_api_contact.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 0fcc1997f..0c0314201 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -1,10 +1,6 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -======= import numpy as np -import pytest ->>>>>>> 91f80b4 (Fix contact API test) import rod import jaxsim.api as js @@ -35,7 +31,7 @@ def test_contact_kinematics( # Compute the pose of the implicit contact frame associated to the collidable shapes # and the transforms of all links. W_H_C = js.contact.transforms(model=model, data=data) - + # Check that the origin of the implicit contact frame is located over the # collidable shape. W_p_C = js.contact.contact_point_positions(model=model, data=data) @@ -52,7 +48,7 @@ def test_contact_kinematics( CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] # Compare the two velocities. - assert_allclose(jnp.contatenate(W_ṗ_C), CW_vl_WC) + assert_allclose(jnp.concatenate(W_ṗ_C), CW_vl_WC) def test_contact_point_jacobians( From d40634ba2feee6e722853f0a78d837d7089e39bb Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 19 Nov 2025 12:30:23 +0100 Subject: [PATCH 36/39] Allow selecting collidable links --- src/jaxsim/api/kin_dyn_parameters.py | 65 +++++++++++++++++----------- src/jaxsim/api/model.py | 11 ++++- src/jaxsim/rbda/contacts/rigid.py | 13 ++---- tests/test_simulations.py | 44 +++++++------------ 4 files changed, 68 insertions(+), 65 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 981a3d2c5..fb1d354cb 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -92,7 +92,9 @@ def support_body_array_bool(self) -> jtp.Matrix: @staticmethod def build( - model_description: ModelDescription, constraints: ConstraintMap | None + model_description: ModelDescription, + constraints: ConstraintMap | None, + indices_of_enabled_collidable_links: set[int] | None = None, ) -> KinDynParameters: """ Construct the kinematic and dynamic parameters of the model. @@ -100,6 +102,8 @@ def build( Args: model_description: The parsed model description to consider. constraints: An object of type ConstraintMap specifying the kinematic constraint of the model. + indices_of_enabled_collidable_links: + The set of link indices for which collision shapes should be considered. If None, all links with collision shapes are considered. Returns: The kinematic and dynamic parameters of the model. @@ -175,7 +179,8 @@ def build( # must be Static for JIT-related reasons, and tree_map would not consider it # as a leaf. contact_parameters = ContactParameters.build_from( - model_description=model_description + model_description=model_description, + indices_of_enabled_collidable_links=indices_of_enabled_collidable_links, ) # ================= @@ -819,53 +824,63 @@ def orientation(self) -> jtp.Array: return self.transform[:, :3, :3] @staticmethod - def build_from(model_description: ModelDescription) -> ContactParameters: + def build_from( + model_description: ModelDescription, + indices_of_enabled_collidable_links: set[int] | None = None, + ) -> ContactParameters: """ Build a ContactParameters object from a model description. Args: model_description: The model description to consider. + indices_of_enabled_collidable_links: + An optional set of link indices for which to include collision shapes. + If None, all collision shapes are included. Returns: The ContactParameters object. """ - if len(model_description.collision_shapes) == 0: + if not (collisions := model_description.collision_shapes): return ContactParameters() - shape_types, shape_sizes, transforms, parent_link_indices = ( - [], - [], - [], - [], + links_dict = model_description.links_dict + shape_map = _COLLISION_SHAPE_MAP + + enabled = ( + indices_of_enabled_collidable_links + if indices_of_enabled_collidable_links is not None + else set(range(len(links_dict))) ) - for collision in model_description.collision_shapes: - shape_type = _COLLISION_SHAPE_MAP.get( - type(collision), CollidableShapeType.Unsupported - ) + shape_types = [] + shape_sizes = [] + transforms = [] + parent_indices = [] + + for collision in collisions: + + shape_type = shape_map.get(type(collision), CollidableShapeType.Unsupported) + + parent_idx = links_dict[collision.parent_link].index # Skip unsupported collision shapes - if shape_type == CollidableShapeType.Unsupported: + if shape_type == CollidableShapeType.Unsupported or ( + parent_idx not in enabled + ): continue shape_types.append(shape_type) - shape_sizes.append(collision.size.squeeze()) - transforms.append(collision.transform) - - # Get the parent link index for this collision shape. - parent_link_indices.append( - model_description.links_dict[collision.parent_link].index - ) + parent_indices.append(parent_idx) # Build the ContactParameters object. return ContactParameters( - body=tuple(parent_link_indices), - transform=jnp.array(transforms, dtype=float), - shape_type=jnp.array(shape_types, dtype=int), - shape_size=jnp.array(shape_sizes, dtype=float), + body=tuple(parent_indices), + transform=jnp.asarray(transforms, dtype=float), + shape_type=jnp.asarray(shape_types, dtype=int), + shape_size=jnp.asarray(shape_sizes, dtype=float), ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 766f922cf..971036ef0 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -141,6 +141,7 @@ def build_from_model_description( considered_joints: Sequence[str] | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None, + indices_of_enabled_collidable_links: set[int] | None = None, ) -> JaxSimModel: """ Build a Model object from a model description. @@ -170,6 +171,8 @@ def build_from_model_description( constraints: An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. Note that constraints can be used only with RelaxedRigidContacts. + indices_of_enabled_collidable_links: + The set of link indices for which collision shapes should be considered. If None, all links with collision shapes are considered. Returns: The built Model object. @@ -202,6 +205,7 @@ def build_from_model_description( integrator=integrator, gravity=-gravity, constraints=constraints, + indices_of_enabled_collidable_links=indices_of_enabled_collidable_links, ) # Store the origin of the model, in case downstream logic needs it. @@ -230,6 +234,7 @@ def build( integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None, + indices_of_enabled_collidable_links: set[int] | None = None, ) -> JaxSimModel: """ Build a Model object from an intermediate model description. @@ -254,6 +259,8 @@ def build( gravity: The gravity constant. constraints: An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. + indices_of_enabled_collidable_links: + The set of link indices for which collision shapes should be considered. If None, all links with collision shapes are considered. Returns: The built Model object. @@ -302,7 +309,9 @@ def build( model = cls( model_name=model_name, kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build( - model_description=model_description, constraints=constraints + model_description=model_description, + constraints=constraints, + indices_of_enabled_collidable_links=indices_of_enabled_collidable_links, ), time_step=time_step, terrain=terrain, diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 568b6ec32..e03692265 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -245,13 +245,6 @@ def compute_contact_forces( A tuple containing as first element the computed contact forces. """ - # Get the indices of the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - n_collidable_points = len(indices_of_enabled_collidable_points) - link_forces = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None @@ -275,7 +268,7 @@ def compute_contact_forces( # Compute the position and linear velocities (mixed representation) of # all enabled collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( + position, velocity = js.contact.contact_point_kinematics( model=model, data=data ) @@ -343,10 +336,10 @@ def compute_contact_forces( G = _compute_ineq_constraint_matrix( inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu ) - h_bounds = jnp.zeros(shape=(n_collidable_points * 6,)) + h_bounds = jnp.zeros(shape=(G.shape[0],)) # Construct the equality constraints. - A = jnp.zeros((0, 3 * n_collidable_points)) + A = jnp.zeros((0, 3 * G.shape[0] // 6)) b = jnp.zeros((0,)) # Solve the following optimization problem with qpax: diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 262e9ab51..18faa5bb0 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -192,10 +192,11 @@ def run_simulation( def test_simulation_with_soft_contacts( - jaxsim_model_box: js.model.JaxSimModel, integrator + jaxsim_model_box: js.model.JaxSimModel, integrator, prng_key: jax.Array, ): model = jaxsim_model_box + _, subkey = jax.random.split(prng_key, num=2) # Define the maximum penetration at steady state. max_penetration = 0.001 @@ -210,15 +211,15 @@ def test_simulation_with_soft_contacts( max_penetration=max_penetration, ) - assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 - # Check jaxsim_model_box@conftest.py. box_height = 0.1 + rnd_x, rnd_y = jax.random.uniform(subkey, shape=(2,), minval=-0.5, maxval=0.5) + # Build the data of the model. data_t0 = js.data.JaxSimModelData.build( model=model, - base_position=jnp.array([0.0, 0.0, box_height * 2]), + base_position=jnp.array([rnd_x, rnd_y, box_height * 2]), velocity_representation=VelRepr.Inertial, ) @@ -233,10 +234,11 @@ def test_simulation_with_soft_contacts( def test_simulation_with_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, integrator + jaxsim_model_box: js.model.JaxSimModel, integrator, prng_key: jax.Array, ): model = jaxsim_model_box + _, subkey = jax.random.split(prng_key, num=2) with model.editable(validate=False) as model: @@ -247,17 +249,6 @@ def test_simulation_with_rigid_contacts( ) model.contact_params = model.contact_model._parameters_class(K=1e5) - # Enable a subset of the collidable points. - enabled_collidable_points_mask = np.zeros( - len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool - ) - enabled_collidable_points_mask[[0, 1, 2, 3]] = True - model.kin_dyn_parameters.contact_parameters.enabled = tuple( - enabled_collidable_points_mask.tolist() - ) - - assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 - # Initialize the maximum penetration of each collidable point at steady state. # This model is rigid, so we expect (almost) no penetration. max_penetration = 0.000 @@ -265,10 +256,12 @@ def test_simulation_with_rigid_contacts( # Check jaxsim_model_box@conftest.py. box_height = 0.1 + rnd_x, rnd_y = jax.random.uniform(subkey, shape=(2,), minval=-0.5, maxval=0.5) + # Build the data of the model. data_t0 = js.data.JaxSimModelData.build( model=model, - base_position=jnp.array([0.0, 0.0, box_height * 2]), + base_position=jnp.array([rnd_x, rnd_y, box_height * 2]), velocity_representation=VelRepr.Inertial, ) @@ -283,10 +276,11 @@ def test_simulation_with_rigid_contacts( def test_simulation_with_relaxed_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, integrator + jaxsim_model_box: js.model.JaxSimModel, integrator, prng_key: jax.Array, ): model = jaxsim_model_box + _, subkey = jax.random.split(prng_key, num=2) with model.editable(validate=False) as model: @@ -295,18 +289,8 @@ def test_simulation_with_relaxed_rigid_contacts( ) model.contact_params = model.contact_model._parameters_class() - # Enable a subset of the collidable points. - enabled_collidable_points_mask = np.zeros( - len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool - ) - enabled_collidable_points_mask[[0, 1, 2, 3]] = True - model.kin_dyn_parameters.contact_parameters.enabled = tuple( - enabled_collidable_points_mask.tolist() - ) model.integrator = integrator - assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 - # Initialize the maximum penetration of each collidable point at steady state. # This model is quasi-rigid, so we expect (almost) no penetration. max_penetration = 0.000 @@ -314,10 +298,12 @@ def test_simulation_with_relaxed_rigid_contacts( # Check jaxsim_model_box@conftest.py. box_height = 0.1 + rnd_x, rnd_y = jax.random.uniform(subkey, shape=(2,), minval=-0.5, maxval=0.5) + # Build the data of the model. data_t0 = js.data.JaxSimModelData.build( model=model, - base_position=jnp.array([0.0, 0.0, box_height * 2]), + base_position=jnp.array([rnd_x, rnd_y, box_height * 2]), velocity_representation=VelRepr.Inertial, ) From 3025fee91aa1ebc6a8a3ebe50014085811969800 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 24 Nov 2025 17:41:48 +0100 Subject: [PATCH 37/39] Fix automatic differentiation tests --- src/jaxsim/api/kin_dyn_parameters.py | 12 +++++++----- src/jaxsim/api/model.py | 6 ++++-- tests/test_automatic_differentiation.py | 17 +++++++++++++---- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index fb1d354cb..263e80e19 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1109,7 +1109,7 @@ def _convert_scaling_to_3d_vector( return scaling_factors.dims[per_link_indices.squeeze()] @staticmethod - def compute_contact_points( + def compute_contact_transforms( original_contact_params: jtp.Vector, link_shapes: jtp.Vector, original_com_positions: jtp.Vector, @@ -1117,7 +1117,7 @@ def compute_contact_points( scaling_factors: ScalingFactors, ) -> jtp.Matrix: """ - Compute the new contact points based on the original contact parameters and + Compute the new contact transforms based on the original contact parameters and the scaling factors. Args: @@ -1128,7 +1128,7 @@ def compute_contact_points( scaling_factors: The scaling factors for the link dimensions. Returns: - The new contact points positions in the parent link frame. + The new contact transforms. """ parent_link_indices = np.array(original_contact_params.body) @@ -1136,7 +1136,7 @@ def compute_contact_points( # Translate the original contact point positions in the origin, so # that we can apply the scaling factors. L_p_Ci = ( - original_contact_params.point - original_com_positions[parent_link_indices] + original_contact_params.center - original_com_positions[parent_link_indices] ) # Extract the shape types of the parent links. @@ -1170,7 +1170,9 @@ def box(parent_idx, L_p_C): L_p_Ci, ) - return new_positions + updated_com_positions[parent_link_indices] + centers = new_positions + updated_com_positions[parent_link_indices] + + return original_contact_params.transform.at[:, :3, 3].set(centers) @staticmethod def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 971036ef0..d1ce39781 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2449,7 +2449,7 @@ def update_hw_parameters( ) # Compute the contact parameters - points = HwLinkMetadata.compute_contact_points( + transforms = HwLinkMetadata.compute_contact_transforms( original_contact_params=kin_dyn_params.contact_parameters, link_shapes=updated_hw_link_metadata.link_shape, original_com_positions=link_parameters.center_of_mass, @@ -2458,7 +2458,9 @@ def update_hw_parameters( ) # Update contact parameters - updated_contact_parameters = kin_dyn_params.contact_parameters.replace(point=points) + updated_contact_parameters = kin_dyn_params.contact_parameters.replace( + transform=transforms + ) # Update joint model transforms (λ_H_pre) def update_λ_H_pre(joint_index): diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 05e949e91..0419e3e42 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -302,9 +302,13 @@ def test_ad_soft_contacts( model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) - p = jax.random.uniform(subkey1, shape=(3,), minval=-1) - v = jax.random.uniform(subkey2, shape=(3,), minval=-1) + p = jax.random.uniform(subkey1, shape=(1, 3), minval=-1) + v = jax.random.uniform(subkey2, shape=(1, 3), minval=-1) m = jax.random.uniform(subkey3, shape=(3,), minval=-1) + n = jax.random.uniform(subkey3, shape=(1, 3), minval=-1) + n = n / jnp.linalg.norm(n) + delta = jax.random.uniform(subkey1, shape=(1, 3), minval=-0.1, maxval=0.1) + delta_dot = jax.random.uniform(subkey2, shape=(1, 3), minval=-1, maxval=1) # Get the soft contacts parameters. parameters = js.contact.estimate_good_contact_parameters(model=model) @@ -315,18 +319,23 @@ def test_ad_soft_contacts( # Get a closure exposing only the parameters to be differentiated. def close_over_inputs_and_parameters( + delta: jtp.VectorLike, + delta_dot: jtp.VectorLike, p: jtp.VectorLike, v: jtp.VectorLike, + n: jtp.VectorLike, m: jtp.VectorLike, params: SoftContactsParams, ) -> tuple[jtp.Vector, jtp.Vector]: W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force( + penetration=delta, + penetration_rate=delta_dot, position=p, velocity=v, + normal=n, tangential_deformation=m, parameters=params, - terrain=model.terrain, ) return W_f_Ci, CW_ṁ @@ -334,7 +343,7 @@ def close_over_inputs_and_parameters( # Check derivatives against finite differences. check_grads( f=close_over_inputs_and_parameters, - args=(p, v, m, parameters), + args=(delta, delta_dot, p, v, n, m, parameters), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, From 3b7e0f4c714fcd12dec19ee8fbcff64046066868 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 24 Nov 2025 18:05:00 +0100 Subject: [PATCH 38/39] Fix mixed representation in contact Jacobian derivative --- src/jaxsim/api/contact.py | 8 ++++---- src/jaxsim/rbda/contacts/rigid.py | 4 +--- tests/test_simulations.py | 12 +++++++++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 0f359dd9b..1f33631d3 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -9,7 +9,7 @@ import jaxsim.exceptions import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.math import Adjoint, Cross, Transform +from jaxsim.math import Adjoint, Cross from jaxsim.rbda.contacts import SoftContacts, detection from .common import VelRepr @@ -448,9 +448,9 @@ def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WL) case VelRepr.Mixed: W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - O_X_W = Adjoint.from_transform(Transform.inverse(W_H_CW)) - v_CW = O_X_W @ W_v_WL - O_Ẋ_W = -O_X_W @ Cross.vx(v_CW.at[:3].set(v_CW[:3])) + O_X_W = Adjoint.from_transform(W_H_CW, inverse=True) + O_v_CW = (O_X_W @ W_v_WL).at[3:6].set(0.0) + O_Ẋ_W = -O_X_W @ Cross.vx(O_v_CW) case _: raise ValueError(output_vel_repr) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index e03692265..0b089c0bb 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -268,9 +268,7 @@ def compute_contact_forces( # Compute the position and linear velocities (mixed representation) of # all enabled collidable points belonging to the robot. - position, velocity = js.contact.contact_point_kinematics( - model=model, data=data - ) + position, velocity = js.contact.contact_point_kinematics(model=model, data=data) # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 18faa5bb0..453c08b1e 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -192,7 +192,9 @@ def run_simulation( def test_simulation_with_soft_contacts( - jaxsim_model_box: js.model.JaxSimModel, integrator, prng_key: jax.Array, + jaxsim_model_box: js.model.JaxSimModel, + integrator, + prng_key: jax.Array, ): model = jaxsim_model_box @@ -234,7 +236,9 @@ def test_simulation_with_soft_contacts( def test_simulation_with_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, integrator, prng_key: jax.Array, + jaxsim_model_box: js.model.JaxSimModel, + integrator, + prng_key: jax.Array, ): model = jaxsim_model_box @@ -276,7 +280,9 @@ def test_simulation_with_rigid_contacts( def test_simulation_with_relaxed_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, integrator, prng_key: jax.Array, + jaxsim_model_box: js.model.JaxSimModel, + integrator, + prng_key: jax.Array, ): model = jaxsim_model_box From a70564e29234b68197b38f079eb44f36daf646d4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 24 Nov 2025 21:09:15 +0100 Subject: [PATCH 39/39] Fix rigid contact model --- src/jaxsim/api/contact.py | 24 +++-- src/jaxsim/rbda/contacts/rigid.py | 110 +++++++++++++-------- tests/test_api_model.py | 4 +- tests/test_api_model_hw_parametrization.py | 8 +- 4 files changed, 92 insertions(+), 54 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 1f33631d3..9e3b3d87b 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -382,6 +382,7 @@ def jacobian_derivative( # Index link velocities by body (parent link) of each collision shape body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) W_v_WL_indexed = W_v_WL[body_indices] # (n_shapes, 6) + W_H_L_indexed = data._link_transforms[body_indices] # (n_shapes, 4, 4) # Compute the contact transforms (n_shapes, n_contacts, 4, 4) W_H_C = transforms(model=model, data=data) @@ -438,13 +439,15 @@ def compute_Ṫ(Ẋ: jtp.Matrix) -> jtp.Matrix: W_J_WL_W_indexed = W_J_WL_W[body_indices] # (n_shapes, 6, 6+n) W_J̇_WL_W_indexed = W_J̇_WL_W[body_indices] # (n_shapes, 6, 6+n) - def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: + def compute_O_J̇_WC_I(W_H_C, W_H_L, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: match output_vel_repr: case VelRepr.Inertial: O_X_W = jnp.eye(6) O_Ẋ_W = jnp.zeros((6, 6)) case VelRepr.Body: - O_X_W = Adjoint.from_transform(W_H_C, inverse=True) + O_X_W = Adjoint.from_transform( + W_H_C.at[0:3, 0:3].set(W_H_L[0:3, 0:3]), inverse=True + ) O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WL) case VelRepr.Mixed: W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) @@ -461,13 +464,18 @@ def compute_O_J̇_WC_I(W_H_C, W_v_WL, W_J_WL_W, W_J̇_WL_W) -> jtp.Matrix: return O_J̇_WC_I O_J̇_per_shape = jax.vmap( - lambda H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape: jax.vmap( + lambda H_C_shape, H_L_shape, v_WL_shape, J_WL_shape, J̇_WL_shape: jax.vmap( compute_O_J̇_WC_I, - in_axes=(0, None, None, None), # Map over contacts for W_H_C only - )(H_C_shape, v_WL_shape, J_WL_shape, J̇_WL_shape), - in_axes=(0, 0, 0, 0), # Map over shapes - )(W_H_C, W_v_WL_indexed, W_J_WL_W_indexed, W_J̇_WL_W_indexed) - + in_axes=(0, None, None, None, None), # map over contacts in H_C + )(H_C_shape, H_L_shape, v_WL_shape, J_WL_shape, J̇_WL_shape), + in_axes=(0, 0, 0, 0, 0), # map over shapes + )( + W_H_C, + W_H_L_indexed, + W_v_WL_indexed, + W_J_WL_W_indexed, + W_J̇_WL_W_indexed, + ) O_J̇_WC = O_J̇_per_shape.reshape(-1, 6, 6 + model.dofs()) return O_J̇_WC diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 0b089c0bb..40b334ded 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -266,15 +266,31 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) - # Compute the position and linear velocities (mixed representation) of - # all enabled collidable points belonging to the robot. - position, velocity = js.contact.contact_point_kinematics(model=model, data=data) - - # Compute the penetration depth and velocity of the collidable points. + # Compute the position and linear velocities (mixed representation) and + # the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. - δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, model.terrain + δ, δ_dot, n̂, *_ = jax.vmap( + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) + )( + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], ) + δ = δ.flatten() + δ_dot = δ_dot.flatten() W_H_C = js.contact.transforms(model=model, data=data) @@ -353,21 +369,34 @@ def compute_contact_forces( ) # Reshape the optimized solution to be a matrix of 3D contact forces. - CW_fl_C = solution.reshape(-1, 3) + CW_fl_C_per_link = solution.reshape(-1, 3, 3) + + # Transform each contact force to inertial frame + def to_inertial(force, H_C): + return ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(force), + transform=H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) - # Convert the contact forces from mixed to inertial-fixed representation. + # Compute the contact forces in inertial representation for + # each link and contact point. + # Nested vmap: inner over contacts, outer over shapes W_f_C = jax.vmap( - lambda CW_fl_C, W_H_C: ( - ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=jnp.zeros(6).at[0:3].set(CW_fl_C), - transform=W_H_C, - other_representation=VelRepr.Mixed, - is_force=True, - ) - ), - )(CW_fl_C, W_H_C) + lambda f_shape, H_shape: jax.vmap(to_inertial)(f_shape, H_shape) + )(CW_fl_C_per_link, W_H_C) + + # Sum over contacts for each shape: (n_shapes, 6) + W_f_per_shape = W_f_C.sum(axis=1) + + # Accumulate forces by parent link using segment_sum + body_indices = jnp.array(model.kin_dyn_parameters.contact_parameters.body) + W_f_per_link = jax.ops.segment_sum( + W_f_per_shape, body_indices, num_segments=model.number_of_links() + ) - return W_f_C, {} + return W_f_per_link, {} @jax.jit @js.common.named_scope @@ -385,25 +414,32 @@ def update_velocity_after_impact( The updated data of the considered model. """ - # Extract the indices corresponding to the enabled collidable points. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) - - W_p_C = js.contact.collidable_point_positions(model, data)[ - indices_of_enabled_collidable_points - ] - # Compute the penetration depth of the collidable points. δ, *_ = jax.vmap( - common.compute_penetration_data, - in_axes=(0, 0, None), - )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + lambda shape_transform, shape_type, shape_size, link_transform, link_velocity: common.compute_penetration_data( + model, + shape_transform=shape_transform, + shape_type=shape_type, + shape_size=shape_size, + link_transforms=link_transform, + link_velocities=link_velocity, + ) + )( + model.kin_dyn_parameters.contact_parameters.transform, + model.kin_dyn_parameters.contact_parameters.shape_type, + model.kin_dyn_parameters.contact_parameters.shape_size, + data._link_transforms[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + data._link_velocities[ + jnp.array(model.kin_dyn_parameters.contact_parameters.body) + ], + ) + + δ = δ.flatten() with data.switch_velocity_representation(VelRepr.Mixed): - J_WC = js.contact.jacobian(model, data)[ - indices_of_enabled_collidable_points - ] + J_WC = js.contact.jacobian(model, data) M = js.model.free_floating_mass_matrix(model, data) BW_ν_pre_impact = data.generalized_velocity @@ -523,8 +559,4 @@ def _compute_baumgarte_stabilization_term( D: jtp.FloatLike, ) -> jtp.Array: - return jnp.where( - inactive_collidable_points[:, jnp.newaxis], - jnp.zeros_like(n), - (K * δ + D * δ_dot)[:, jnp.newaxis] * n, - ) + return (K * δ + D * δ_dot) * n diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 65415fafb..d59cf6126 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -156,8 +156,8 @@ def test_model_creation_and_reduction( # Check that collidable point positions are preserved. assert_allclose( - js.contact.collidable_point_positions(model=model_full, data=data_full), - js.contact.collidable_point_positions(model=model_reduced, data=data_reduced), + js.contact.contact_point_positions(model=model_full, data=data_full), + js.contact.contact_point_positions(model=model_reduced, data=data_reduced), ) # ===================== diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 945715283..cdf0fbb36 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -131,8 +131,8 @@ def test_model_scaling_against_rod( # Compare collidable points positions assert_allclose( - jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.point, - updated_model.kin_dyn_parameters.contact_parameters.point, + jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.transform, + updated_model.kin_dyn_parameters.contact_parameters.transform, atol=1e-6, ) @@ -485,9 +485,7 @@ def test_hw_parameters_collision_scaling( updated_base_height = data.base_position[2] # Assert that the box settles at the expected height - assert jnp.isclose( - updated_base_height, expected_height, atol=1e-3 - ), f"model base height mismatch: expected {expected_height}, got {updated_base_height}" + assert_allclose(updated_base_height, expected_height, atol=1e-3) def test_unsupported_link_cases():