diff --git a/docs/source/tutorial/grasp_generator.rst b/docs/source/tutorial/grasp_generator.rst new file mode 100644 index 00000000..51e802e8 --- /dev/null +++ b/docs/source/tutorial/grasp_generator.rst @@ -0,0 +1,77 @@ +Generating and Executing Robot Grasps +====================================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial demonstrates how to generate antipodal grasp poses for a target object and execute a full grasp trajectory with a robot arm. It covers scene initialization, robot and object creation, interactive grasp region annotation, grasp pose computation, and trajectory execution in the simulation loop. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``grasp_generator.py`` script in the ``scripts/tutorials/grasp`` directory. + +.. dropdown:: Code for grasp_generator.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +Configuring the simulation +-------------------------- + +Command-line arguments are parsed with ``argparse`` to select the number of parallel environments, the compute device, and optional rendering features such as ray tracing and headless mode. + +.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def parse_arguments(): + :end-at: return parser.parse_args() + +The parsed arguments are passed to ``initialize_simulation``, which builds a :class:`SimulationManagerCfg` and creates the :class:`SimulationManager` instance. When ray tracing is enabled a directional :class:`cfg.LightCfg` is also added to the scene. + +.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: def initialize_simulation(args) -> SimulationManager: + :end-at: return sim + +Annotating and computing grasp poses +------------------------------------- + +Grasp generation is performed by :meth:`objects.RigidObject.get_grasp_pose`, which internally runs an antipodal sampler on the object mesh. A :class:`toolkits.graspkit.pg_grasp.GraspAnnotatorCfg` controls sampler parameters (sample count, gripper jaw limits) and the interactive annotation workflow: + +1. Open the visualization in a browser at the reported port (e.g. ``http://localhost:11801``). +2. Use *Rect Select Region* to highlight the area of the object that should be grasped. +3. Click *Confirm Selection* to finalize the region. + +The function returns a batch of ``(N_envs, 4, 4)`` homogeneous transformation matrices representing candidate grasp frames in the world coordinate system. + +For each grasp pose, gripper approach direction in world coordinate is required to compute the antipodal grasp. In this tutorial, we use a fixed approach direction (straight down in world frame) for simplicity, but it can be customized based on the task or object geometry. + +.. literalinclude:: ../../../scripts/tutorials/grasp/grasp_generator.py + :language: python + :start-at: # get mug grasp pose + :end-at: logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") + + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the script, execute the following command from the project root: + +.. code-block:: bash + + python scripts/tutorials/grasp/grasp_generator.py + +A simulation window will open showing the robot and the mug. A browser-based visualizer will also launch (default port ``11801``) for interactive grasp region annotation. + +You can customize the run with additional arguments: + +.. code-block:: bash + + python scripts/tutorials/grasp/grasp_generator.py --num_envs --device --enable_rt --headless + +After confirming the grasp region in the browser, the script will compute a grasp pose, print the elapsed time, and then wait for you to press **Enter** before executing the full grasp trajectory in the simulation. Press **Enter** again to exit once the motion is complete. diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst index ef6efe79..0a28a8e9 100644 --- a/docs/source/tutorial/index.rst +++ b/docs/source/tutorial/index.rst @@ -15,6 +15,7 @@ Tutorials sensor motion_gen gizmo + grasp_generator basic_env modular_env rl diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 0c8477e2..a9cde20a 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -35,6 +35,11 @@ from embodichain.utils.math import convert_quat from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler from embodichain.utils import logger +from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( + GraspAnnotator, + GraspAnnotatorCfg, +) +import torch.nn.functional as F @dataclass @@ -1122,3 +1127,63 @@ def destroy(self) -> None: arenas = [env] for i, entity in enumerate(self._entities): arenas[i].remove_actor(entity) + + def get_grasp_pose( + self, + cfg: GraspAnnotatorCfg, + approach_direction: torch.Tensor = None, + is_visual: bool = False, + ) -> torch.Tensor: + if approach_direction is None: + approach_direction = torch.tensor( + [0, 0, -1], dtype=torch.float32, device=self.device + ) + approach_direction = F.normalize(approach_direction, dim=-1) + if hasattr(self, "_grasp_annotator") is False: + vertices = torch.tensor( + self._entities[0].get_vertices(), + dtype=torch.float32, + device=self.device, + ) + triangles = torch.tensor( + self._entities[0].get_triangles(), dtype=torch.int32, device=self.device + ) + scale = torch.tensor( + self._entities[0].get_body_scale(), + dtype=torch.float32, + device=self.device, + ) + vertices = vertices * scale + self._grasp_annotator = GraspAnnotator( + vertices=vertices, triangles=triangles, cfg=cfg + ) + + # Annotate antipodal point pairs + if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate: + self._hit_point_pairs = self._grasp_annotator.annotate() + + poses = self.get_local_pose(to_matrix=True) + poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device) + grasp_poses: tuple[torch.Tensor] = [] + open_lengths: tuple[torch.Tensor] = [] + for pose in poses: + grasp_pose, open_length = self._grasp_annotator.get_grasp_poses( + self._hit_point_pairs, pose, approach_direction, is_visual=False + ) + grasp_poses.append(grasp_pose) + open_lengths.append(open_length) + grasp_poses = torch.cat( + [grasp_pose.unsqueeze(0) for grasp_pose in grasp_poses], dim=0 + ) + + if is_visual: + vertices = self._entities[0].get_vertices() + triangles = self._entities[0].get_triangles() + scale = self._entities[0].get_body_scale() + vertices = vertices * scale + self._grasp_annotator.visualize_grasp_pose( + obj_pose=poses[0], + grasp_pose=grasp_poses[0], + open_length=open_lengths[0].item(), + ) + return grasp_poses diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py new file mode 100644 index 00000000..54cac47d --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_annotator.py @@ -0,0 +1,574 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import argparse +import open3d as o3d +import time +from pathlib import Path +from typing import Any, cast +import torch +import numpy as np +import trimesh + +import viser +import viser.transforms as tf +from embodichain.utils import logger +from dataclasses import dataclass +from embodichain.toolkits.graspkit.pg_grasp.antipodal_sampler import ( + AntipodalSampler, + AntipodalSamplerCfg, +) +from .gripper_collision_checker import ( + SimpleGripperCollisionChecker, + SimpleGripperCollisionCfg, +) +import hashlib +import torch.nn.functional as F +import tempfile + + +@dataclass +class GraspAnnotatorCfg: + viser_port: int = 15531 + use_largest_connected_component: bool = False + antipodal_sampler_cfg: AntipodalSamplerCfg = AntipodalSamplerCfg() + force_regenerate: bool = False + max_deviation_angle: float = np.pi / 12 + + +@dataclass +class SelectResult: + vertex_indices: np.ndarray | None = None + face_indices: np.ndarray | None = None + vertices: np.ndarray | None = None + faces: np.ndarray | None = None + + +class GraspAnnotator: + """GraspAnnotator provides functionality to annotate antipodal grasp regions on a given object mesh. It allows users to interactively select regions on the mesh and generates antipodal point pairs for grasping based on the selected region. The annotator also includes a collision checker to filter out infeasible grasp poses and can visualize the generated grasp poses in a 3D viewer. + """ + def __init__( + self, + vertices: torch.Tensor, + triangles: torch.Tensor, + cfg: GraspAnnotatorCfg = GraspAnnotatorCfg(), + ) -> None: + """Initialize the GraspAnnotator with the given mesh vertices, triangles, and configuration. + Args: + vertices (torch.Tensor): A tensor of shape (V, 3) representing the vertex positions of the mesh. + triangles (torch.Tensor): A tensor of shape (F, 3) representing the triangle indices of the mesh. + cfg (GraspAnnotatorCfg, optional): Configuration for the grasp annotator. Defaults to GraspAnnotatorCfg(). + """ + self.device = vertices.device + self.vertices = vertices + self.triangles = triangles + self.mesh = trimesh.Trimesh( + vertices=vertices.to("cpu").numpy(), + faces=triangles.to("cpu").numpy(), + process=False, + force="mesh", + ) + self._collision_checker = SimpleGripperCollisionChecker( + object_mesh_verts=vertices, + object_mesh_faces=triangles, + cfg=SimpleGripperCollisionCfg(), + ) + self.cfg = cfg + self.antipodal_sampler = AntipodalSampler(cfg=cfg.antipodal_sampler_cfg) + + def annotate(self) -> torch.Tensor: + """Annotate antipodal grasp region on the mesh and return sampled antipodal point pairs. + Returns: + torch.Tensor: A tensor of shape (N, 2, 3) representing N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + """ + cache_path = self._get_cache_dir(self.vertices, self.triangles) + if os.path.exists(cache_path) and not self.cfg.force_regenerate: + logger.log_info( + f"Found existing antipodal retult. Loading cached antipodal pairs from {cache_path}" + ) + hit_point_pairs = torch.tensor( + np.load(cache_path), dtype=torch.float32, device=self.device + ) + return hit_point_pairs + else: + logger.log_info( + f"[Viser] *****Annotate grasp region in http://localhost:{self.cfg.viser_port}" + ) + + server = viser.ViserServer(port=self.cfg.viser_port) + server.gui.configure_theme(brand_color=(130, 0, 150)) + server.scene.set_up_direction("+z") + + mesh_handle = server.scene.add_mesh_trimesh(name="/mesh", mesh=self.mesh) + selected_overlay: viser.GlbHandle | None = None + selection: SelectResult = SelectResult() + + hit_point_pairs = None + return_flag = False + + @server.on_client_connect + def _(client: viser.ClientHandle) -> None: + nonlocal mesh_handle + nonlocal selected_overlay + nonlocal selection + + # client.camera.position = np.array([0.0, 0.0, -0.5]) + # client.camera.wxyz = np.array([1.0, 0.0, 0.0, 0.0]) + + select_button = client.gui.add_button( + "Rect Select Region", icon=viser.Icon.PAINT + ) + confirm_button = client.gui.add_button("Confirm Selection") + + @select_button.on_click + def _(_evt: viser.GuiEvent) -> None: + select_button.disabled = True + + @client.scene.on_pointer_event(event_type="rect-select") + def _(event: viser.ScenePointerEvent) -> None: + nonlocal mesh_handle + nonlocal selected_overlay + nonlocal selection + nonlocal hit_point_pairs + client.scene.remove_pointer_callback() + + proj, depth = GraspAnnotator._project_vertices_to_screen( + cast(np.ndarray, self.mesh.vertices), + mesh_handle, + event.client.camera, + ) + + lower = np.minimum( + np.array(event.screen_pos[0]), np.array(event.screen_pos[1]) + ) + upper = np.maximum( + np.array(event.screen_pos[0]), np.array(event.screen_pos[1]) + ) + vertex_mask = ((proj >= lower) & (proj <= upper)).all(axis=1) & ( + depth > 1e-6 + ) + + selection = GraspAnnotator._extract_selection( + self.mesh, vertex_mask, self.cfg.use_largest_connected_component + ) + if selection.vertices is None: + logger.log_warning("[Selection] No vertices selected.") + return + + color_mesh = self.mesh.copy() + used_vertex_indices = selection.vertex_indices + vertex_colors = np.tile( + np.array([[0.85, 0.85, 0.85, 1.0]]), + (self.mesh.vertices.shape[0], 1), + ) + vertex_colors[used_vertex_indices] = np.array( + [0.56, 0.17, 0.92, 1.0] + ) + color_mesh.visual.vertex_colors = vertex_colors # type: ignore + mesh_handle = server.scene.add_mesh_trimesh( + name="/mesh", mesh=color_mesh + ) + + if selected_overlay is not None: + selected_overlay.remove() + selected_mesh = trimesh.Trimesh( + vertices=selection.vertices, + faces=selection.faces, + process=False, + ) + selected_mesh.visual.face_colors = (0.9, 0.2, 0.2, 0.65) # type: ignore + selected_overlay = server.scene.add_mesh_trimesh( + name="/selected", mesh=selected_mesh + ) + logger.log_info( + f"[Selection] Selected {selection.vertex_indices.size} vertices and {selection.face_indices.size} faces." + ) + + hit_point_pairs = self.antipodal_sampler.sample( + torch.tensor(selection.vertices, device=self.device), + torch.tensor(selection.faces, device=self.device), + ) + extended_hit_point_pairs = GraspAnnotator._extend_hit_point_pairs( + hit_point_pairs + ) + server.scene.add_line_segments( + name="/antipodal_pairs", + points=extended_hit_point_pairs.to("cpu").numpy(), + colors=(20, 200, 200), + line_width=1.5, + ) + + @client.scene.on_pointer_callback_removed + def _() -> None: + select_button.disabled = False + + @confirm_button.on_click + def _(_evt: viser.GuiEvent) -> None: + nonlocal return_flag + if selection.vertices is None: + logger.log_warning("[Selection] No vertex selected.") + return + else: + logger.log_info( + f"[Selection] {selection.vertices.shape[0]}vertices selected. Generating antipodal point pairs." + ) + return_flag = True + + while True: + if return_flag: + # save result to cache + if hit_point_pairs is not None: + self._save_cache(cache_path, hit_point_pairs) + break + time.sleep(0.5) + return hit_point_pairs + + def _get_cache_dir(self, vertices: torch.Tensor, triangles: torch.Tensor): + vert_bytes = vertices.to("cpu").numpy().tobytes() + face_bytes = triangles.to("cpu").numpy().tobytes() + md5_hash = hashlib.md5(vert_bytes + face_bytes).hexdigest() + cache_path = os.path.join( + tempfile.gettempdir(), f"antipodal_cache_{md5_hash}.npy" + ) + return cache_path + + def _save_cache(self, cache_path: str, hit_point_pairs: torch.Tensor): + np.save(cache_path, hit_point_pairs.cpu().numpy().astype(np.float32)) + + @staticmethod + def _extend_hit_point_pairs(hit_point_pairs: torch.Tensor): + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + mid_points = (origin_points + hit_points) / 2 + point_diff = hit_points - origin_points + extended_origin = mid_points - 0.8 * point_diff + extended_hit = mid_points + 0.8 * point_diff + extended_point_pairs = torch.cat( + [extended_origin[:, None, :], extended_hit[:, None, :]], dim=1 + ) + return extended_point_pairs + + @staticmethod + def _project_vertices_to_screen( + vertices_mesh: np.ndarray, + mesh_handle: viser.GlbHandle, + camera: Any, + ) -> tuple[np.ndarray, np.ndarray]: + T_world_mesh = tf.SE3.from_rotation_and_translation( + tf.SO3(np.asarray(mesh_handle.wxyz)), + np.asarray(mesh_handle.position), + ) + vertices_world_h = ( + T_world_mesh.as_matrix() + @ np.hstack([vertices_mesh, np.ones((vertices_mesh.shape[0], 1))]).T + ).T + vertices_world = vertices_world_h[:, :3] + + T_camera_world = tf.SE3.from_rotation_and_translation( + tf.SO3(np.asarray(camera.wxyz)), + np.asarray(camera.position), + ).inverse() + vertices_camera_h = ( + T_camera_world.as_matrix() + @ np.hstack([vertices_world, np.ones((vertices_world.shape[0], 1))]).T + ).T + vertices_camera = vertices_camera_h[:, :3] + + fov = float(camera.fov) + aspect = float(camera.aspect) + projected = vertices_camera[:, :2] / np.maximum(vertices_camera[:, 2:3], 1e-8) + projected /= np.tan(fov / 2.0) + projected[:, 0] /= aspect + projected = (1.0 + projected) / 2.0 + return projected, vertices_camera[:, 2] + + def _extract_selection( + mesh: trimesh.Trimesh, + vertex_mask: np.ndarray, + largest_component: bool, + ) -> SelectResult: + def _largest_connected_face_component(face_ids: np.ndarray) -> np.ndarray: + if face_ids.size <= 1: + return face_ids + + face_id_set = set(face_ids.tolist()) + parent: dict[int, int] = { + int(face_id): int(face_id) for face_id in face_ids + } + + def find(x: int) -> int: + root = x + while parent[root] != root: + root = parent[root] + while parent[x] != x: + x_parent = parent[x] + parent[x] = root + x = x_parent + return root + + def union(a: int, b: int) -> None: + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + face_adjacency = cast(np.ndarray, mesh.face_adjacency) + for face_a, face_b in face_adjacency: + if int(face_a) in face_id_set and int(face_b) in face_id_set: + union(int(face_a), int(face_b)) + + groups: dict[int, list[int]] = {} + for face_id in face_ids: + root = find(int(face_id)) + groups.setdefault(root, []).append(int(face_id)) + + largest_group = max(groups.values(), key=len) + return np.array(largest_group, dtype=np.int32) + + faces = cast(np.ndarray, mesh.faces) + face_mask = np.all(vertex_mask[faces], axis=1) + + face_indices = np.flatnonzero(face_mask) + if face_indices.size == 0: + return SelectResult() + if largest_component: + face_indices = _largest_connected_face_component(face_indices) + if face_indices.size == 0: + return SelectResult() + + selected_face_vertices = faces[face_indices] + vertex_indices = np.unique(selected_face_vertices.reshape(-1)) + + old_to_new = np.full(mesh.vertices.shape[0], -1, dtype=np.int32) + old_to_new[vertex_indices] = np.arange(vertex_indices.size, dtype=np.int32) + + sub_vertices = np.asarray(mesh.vertices)[vertex_indices] + sub_faces = np.asarray(old_to_new)[selected_face_vertices] + + return SelectResult( + vertex_indices=vertex_indices, + face_indices=face_indices, + vertices=sub_vertices, + faces=sub_faces, + ) + + @staticmethod + def _apply_transform(points: torch.Tensor, transform: torch.Tensor) -> torch.Tensor: + r = transform[:3, :3] + t = transform[:3, 3] + return points @ r.T + t + + def get_grasp_poses( + self, + hit_point_pairs: torch.Tensor, + object_pose: torch.Tensor, + approach_direction: torch.Tensor, + is_visual: bool = False, + ) -> torch.Tensor: + """Get grasp pose given approach direction + + Args: + hit_point_pairs (torch.Tensor): (N, 2, 3) tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + object_pose (torch.Tensor): (4, 4) homogeneous transformation matrix representing the pose of the object in the world frame. + approach_direction (torch.Tensor): (3,) unit vector representing the desired approach direction of the gripper in the world frame. + + Returns: + torch.Tensor: (4, 4) homogeneous transformation matrix representing the grasp pose in the world frame that aligns the gripper's approach direction with the given approach_direction. Returns None if no valid grasp pose can be found. + """ + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + origin_points_ = self._apply_transform(origin_points, object_pose) + hit_points_ = self._apply_transform(hit_points, object_pose) + centers = (origin_points_ + hit_points_) / 2 + + mesh_vert_transformed = self._apply_transform(self.vertices, object_pose) + mesh_center = mesh_vert_transformed.mean(dim=0) + + # filter perpendicular antipodal point + grasp_x = F.normalize(hit_points_ - origin_points_, dim=-1) + cos_angle = torch.clamp((grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0) + positive_angle = torch.abs(torch.acos(cos_angle)) + valid_mask = ( + positive_angle - torch.pi / 2 + ).abs() <= self.cfg.max_deviation_angle + valid_grasp_x = grasp_x[valid_mask] + valid_centers = centers[valid_mask] + + # compute grasp poses using antipodal point pairs and approach direction + valid_grasp_poses = GraspAnnotator._grasp_pose_from_approach_direction( + valid_grasp_x, approach_direction, valid_centers + ) + valid_open_lengths = torch.norm( + origin_points_[valid_mask] - hit_points_[valid_mask], dim=-1 + ) + # select non-collide grasp poses + is_colliding, max_penetration = self._collision_checker.query( + object_pose, + valid_grasp_poses, + valid_open_lengths, + is_visual=is_visual, + collision_threshold=0.0, + ) + # get best grasp pose + valid_grasp_poses = valid_grasp_poses[~is_colliding] + valid_open_lengths = valid_open_lengths[~is_colliding] + valid_centers = valid_centers[~is_colliding] + valid_grasp_x = F.normalize(valid_grasp_poses[:, :3, 0], dim=-1) + + cos_angle = torch.clamp( + (valid_grasp_x * approach_direction).sum(dim=-1), -1.0, 1.0 + ) + positive_angle = torch.abs(torch.acos(cos_angle)) + angle_cost = torch.abs(positive_angle - 0.5 * torch.pi) / (0.5 * torch.pi) + center_distance = torch.norm(valid_centers - mesh_center, dim=-1) + center_cost = center_distance / center_distance.max() + length_cost = 1 - valid_open_lengths / valid_open_lengths.max() + total_cost = 0.3 * angle_cost + 0.3 * length_cost + 0.4 * center_cost + best_idx = torch.argmin(total_cost) + best_grasp_pose = valid_grasp_poses[best_idx] + best_open_length = valid_open_lengths[best_idx] + return best_grasp_pose, best_open_length + + @staticmethod + def _grasp_pose_from_approach_direction( + grasp_x: torch.Tensor, approach_direction: torch.Tensor, center: torch.Tensor + ): + approach_direction_repeat = approach_direction[None, :].repeat( + grasp_x.shape[0], 1 + ) + grasp_y = torch.cross(approach_direction_repeat, grasp_x, dim=-1) + grasp_y = F.normalize(grasp_y, dim=-1) + grasp_z = torch.cross(grasp_x, grasp_y, dim=-1) + grasp_z = F.normalize(grasp_z, dim=-1) + grasp_poses = ( + torch.eye(4, device=grasp_x.device, dtype=torch.float32) + .unsqueeze(0) + .repeat(grasp_x.shape[0], 1, 1) + ) + grasp_poses[:, :3, 0] = grasp_x + grasp_poses[:, :3, 1] = grasp_y + grasp_poses[:, :3, 2] = grasp_z + grasp_poses[:, :3, 3] = center + return grasp_poses + + def visualize_grasp_pose( + self, + obj_pose: torch.Tensor, + grasp_pose: torch.Tensor, + open_length: float, + ): + mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(self.vertices.to("cpu").numpy()), + triangles=o3d.utility.Vector3iVector(self.triangles.to("cpu").numpy()), + ) + mesh.compute_vertex_normals() + mesh.paint_uniform_color([0.3, 0.6, 0.3]) + mesh.transform(obj_pose.to("cpu").numpy()) + vertices_ = torch.tensor( + np.asarray(mesh.vertices), + device=self.vertices.device, + dtype=self.vertices.dtype, + ) + mesh_scale = (vertices_.max(dim=0)[0] - vertices_.min(dim=0)[0]).max().item() + groud_plane = o3d.geometry.TriangleMesh.create_cylinder( + radius=mesh_scale, height=0.01 * mesh_scale + ) + groud_plane.compute_vertex_normals() + center = vertices_.mean(dim=0) + z_sim = vertices_.min(dim=0)[0][2].item() + groud_plane.translate( + (center[0].item(), center[1].item(), z_sim - 0.005 * mesh_scale) + ) + + draw_thickness = 0.02 * mesh_scale + draw_length = 0.3 * mesh_scale + grasp_finger1 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_finger1.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_finger2 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_finger2.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_finger1.translate((-open_length / 2, 0, -0.25 * draw_length)) + grasp_finger2.translate((open_length / 2, 0, -0.25 * draw_length)) + grasp_root1 = o3d.geometry.TriangleMesh.create_box( + open_length, draw_thickness, draw_thickness + ) + grasp_root1.translate( + (-open_length / 2, -0.5 * draw_thickness, -0.5 * draw_thickness) + ) + grasp_root1.translate((0, 0, -0.75 * draw_length)) + grasp_root2 = o3d.geometry.TriangleMesh.create_box( + draw_thickness, draw_thickness, draw_length + ) + grasp_root2.translate( + (-0.5 * draw_thickness, -0.5 * draw_thickness, -0.5 * draw_length) + ) + grasp_root2.translate((0, 0, -1.25 * draw_length)) + + grasp_visual = grasp_finger1 + grasp_finger2 + grasp_root1 + grasp_root2 + grasp_visual.paint_uniform_color([0.8, 0.2, 0.8]) + grasp_visual.transform(grasp_pose.to("cpu").numpy()) + o3d.visualization.draw_geometries( + [grasp_visual, mesh, groud_plane], + window_name="Grasp Pose Visualization", + mesh_show_back_face=True, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Viser mesh 标注工具:框选并导出对应顶点与三角面" + ) + parser.add_argument( + "--mesh", type=Path, required=True, help="输入 mesh 文件路径,例如 mug.obj" + ) + parser.add_argument("--scale", type=float, default=1.0, help="加载后整体缩放系数") + parser.add_argument("--port", type=int, default=12151, help="viser 服务端口") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("outputs/mesh_annotations"), + help="标注结果导出目录", + ) + parser.add_argument( + "--largest-component", + action="store_true", + help="只保留框选结果中的最大连通块(常用于稳定提取把手等局部)", + ) + args = parser.parse_args() + + mesh = trimesh.load(args.mesh, process=False, force="mesh") + vertices = mesh.vertices * args.scale + triangles = mesh.faces + cfg = GraspAnnotatorCfg( + force_regenerate=True, + ) + tool = GraspAnnotator(cfg=cfg) + hit_point_pairs = tool.annotate( + vertices=torch.from_numpy(vertices).float(), + triangles=torch.from_numpy(triangles).long(), + ) + logger.log_info(f"Sample {hit_point_pairs.shape[0]} antipodal point pairs.") + + +if __name__ == "__main__": + main() diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py new file mode 100644 index 00000000..cebcafde --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal_sampler.py @@ -0,0 +1,253 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import torch.nn.functional as F +import numpy as np +import open3d as o3d +import open3d.core as o3c +from dataclasses import dataclass +from embodichain.utils import logger + + +@dataclass +class AntipodalSamplerCfg: + """ Configuration for AntipodalSampler.""" + n_sample: int = 20000 + """surface point sample number""" + max_angle: float = np.pi / 12 + """maximum angle (in radians) to randomly disturb the ray direction for antipodal point sampling, used to increase the diversity of sampled antipodal points. Note that setting max_angle to 0 will disable the random disturbance and sample antipodal points strictly along the surface normals, which may result in less diverse antipodal points and may not be ideal for all objects or grasping scenarios.""" + max_length: float = 0.1 + """maximum gripper open width, used to filter out antipodal points that are too far apart to be grasped""" + min_length: float = 0.001 + """minimum gripper open width, used to filter out antipodal points that are too close to be grasped""" + + +class AntipodalSampler: + """ AntipodalSampler samples antipodal point pairs on a given mesh. It uses Open3D's raycasting functionality to find points on the mesh that are visible along the negative normal direction from uniformly sampled points on the mesh surface. The sampler can also apply a random disturbance to the ray direction to increase the diversity of sampled antipodal points. The resulting antipodal point pairs can be used for grasp generation and annotation tasks.""" + def __init__( + self, + cfg: AntipodalSamplerCfg = AntipodalSamplerCfg(), + ): + self.mesh: o3d.t.geometry.TriangleMesh | None = None + self.cfg = cfg + + def sample(self, vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: + """Get sample Antipodal point pair + + Args: + vertices: [V, 3] vertex positions of the mesh + faces: [F, 3] triangle indices of the mesh + + Returns: + hit_point_pairs: [N, 2, 3] tensor of N antipodal point pairs. Each pair consists of a hit point and its corresponding surface point. + """ + # update mesh + self.mesh = o3d.t.geometry.TriangleMesh() + self.mesh.vertex.positions = o3c.Tensor( + vertices.to("cpu").numpy(), dtype=o3c.float32 + ) + self.mesh.triangle.indices = o3c.Tensor( + faces.to("cpu").numpy(), dtype=o3c.int32 + ) + self.mesh.compute_vertex_normals() + # sample points and normals + sample_pcd = self.mesh.sample_points_uniformly( + number_of_points=self.cfg.n_sample + ) + sample_points = torch.tensor( + sample_pcd.point.positions.numpy(), + device=vertices.device, + dtype=vertices.dtype, + ) + sample_normals = torch.tensor( + sample_pcd.point.normals.numpy(), + device=vertices.device, + dtype=vertices.dtype, + ) + # generate rays + ray_direc = -sample_normals + ray_origin = ( + sample_points + 1e-3 * ray_direc + ) # Offset ray origin slightly along the normal to avoid self-intersection + disturb_direc = AntipodalSampler._random_rotate_unit_vectors( + ray_direc, max_angle=self.cfg.max_angle + ) + ray_origin = torch.vstack([ray_origin, ray_origin]) + ray_direc = torch.vstack([ray_direc, disturb_direc]) + # casting + return self._get_raycast_result( + ray_origin, + ray_direc, + surface_origin=torch.vstack([sample_points, sample_points]), + ) + + def _get_raycast_result( + self, + ray_origin: torch.Tensor, + ray_direc: torch.Tensor, + surface_origin: torch.Tensor, + ): + if ray_origin.ndim != 2 or ray_origin.shape[-1] != 3: + raise ValueError("ray_origin must have shape [N, 3]") + if ray_direc.ndim != 2 or ray_direc.shape[-1] != 3: + raise ValueError("ray_direc must have shape [N, 3]") + if ray_origin.shape[0] != ray_direc.shape[0]: + raise ValueError( + "ray_origin and ray_direc must have the same number of rays" + ) + if ray_origin.shape[0] != surface_origin.shape[0]: + raise ValueError( + "ray_origin and surface_origin must have the same number of rays" + ) + + scene = o3d.t.geometry.RaycastingScene() + scene.add_triangles(self.mesh) + + rays = torch.cat([ray_origin, ray_direc], dim=-1) + rays_o3d = o3c.Tensor(rays.detach().to("cpu").numpy(), dtype=o3c.float32) + + ans = scene.cast_rays(rays_o3d) + t_hit = torch.from_numpy(ans["t_hit"].numpy()).to( + device=ray_origin.device, dtype=ray_origin.dtype + ) + hit_mask = torch.logical_and( + t_hit > self.cfg.min_length, t_hit < self.cfg.max_length + ) + hit_points = ray_origin[hit_mask] + t_hit[hit_mask, None] * ray_direc[hit_mask] + hit_origins = surface_origin[hit_mask] + hit_point_pairs = torch.cat( + [hit_points[:, None, :], hit_origins[:, None, :]], dim=1 + ) + hit_point_pairs = hit_point_pairs.to(dtype=torch.float32) + return hit_point_pairs + + @staticmethod + def _random_rotate_unit_vectors( + vectors: torch.Tensor, + max_angle: float, + degrees: bool = False, + eps: float = 1e-8, + ) -> torch.Tensor: + """ + Apply random small rotations to a batch of unit vectors [N, 3]. + + Args: + vectors: [N, 3], unit vectors + max_angle: Maximum rotation angle + degrees: If True, `max_angle` is given in degrees + eps: Numerical stability constant + + Returns: + rotated: [N, 3], rotated unit vectors + """ + assert vectors.ndim == 2 and vectors.shape[-1] == 3, "vectors must be [N, 3]" + + v = F.normalize(vectors, dim=-1) + + if degrees: + max_angle = torch.deg2rad( + torch.tensor(max_angle, dtype=v.dtype, device=v.device) + ).item() + + n = v.shape[0] + + # 1) Generate a random direction for each vector + # then project it onto the plane perpendicular to v to get the rotation axis k + rand_dir = torch.randn_like(v) + eps + proj = (rand_dir * v).sum(dim=-1, keepdim=True) * v + k = rand_dir - proj + k = F.normalize(k, dim=-1) + + # 2) Sample rotation angles in the range [eps, max_angle] + theta = ( + torch.rand(n, 1, device=v.device, dtype=v.dtype) * (max_angle - eps) + eps + ) + + # 3) Rodrigues' rotation formula + # R(v) = v*cosθ + (k×v)*sinθ + k*(k·v)*(1-cosθ) + # Since k ⟂ v, the last term is theoretically 0, but keeping the general formula is more robust + cos_t = torch.cos(theta) + sin_t = torch.sin(theta) + + kv = (k * v).sum(dim=-1, keepdim=True) + rotated = v * cos_t + torch.cross(k, v, dim=-1) * sin_t + k * kv * (1.0 - cos_t) + + return F.normalize(rotated, dim=-1) + + def visualize(self, hit_point_pairs: torch.Tensor): + if self.mesh is None: + logger.log_warning("Mesh is not initialized. Cannot visualize.") + return + + if hit_point_pairs.shape[0] == 0: + raise ValueError("No point pairs to visualize") + origin_points = hit_point_pairs[:, 0, :] + hit_points = hit_point_pairs[:, 1, :] + + origin_points_np = origin_points.to("cpu").numpy() + hit_points_np = hit_points.detach().to("cpu").numpy() + + n_pairs = hit_point_pairs.shape[0] + line_indices = np.stack( + [np.arange(n_pairs), np.arange(n_pairs) + n_pairs], axis=1 + ) + + mesh_legacy = self.mesh.to_legacy() + mesh_legacy.compute_vertex_normals() + mesh_legacy.paint_uniform_color([0.8, 0.8, 0.8]) + + origin_pcd = o3d.geometry.PointCloud() + origin_pcd.points = o3d.utility.Vector3dVector(origin_points_np) + origin_pcd.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[0.1, 0.4, 1.0]]), (n_pairs, 1)) + ) + + hit_pcd = o3d.geometry.PointCloud() + hit_pcd.points = o3d.utility.Vector3dVector(hit_points_np) + hit_pcd.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[1.0, 0.2, 0.2]]), (n_pairs, 1)) + ) + + line_set = o3d.geometry.LineSet() + mid_points = (origin_points_np + hit_points_np) / 2 + point_diff = hit_points_np - origin_points_np + draw_origin = mid_points - 0.6 * point_diff + draw_end = mid_points + 0.6 * point_diff + draw_pointpair = np.concatenate([draw_origin, draw_end], axis=0) + line_set.points = o3d.utility.Vector3dVector(draw_pointpair) + line_set.lines = o3d.utility.Vector2iVector(line_indices) + line_set.colors = o3d.utility.Vector3dVector( + np.tile(np.array([[0.2, 0.9, 0.2]]), (n_pairs, 1)) + ) + + o3d.visualization.draw_geometries( + [mesh_legacy, origin_pcd, hit_pcd, line_set], + window_name="Antipodal Point Pairs", + mesh_show_back_face=True, + ) + + +if __name__ == "__main__": + mesh_path = "/media/chenjian/_abc/project/grasp_annotator/dustpan_saa.ply" + mesh = o3d.t.io.read_triangle_mesh(mesh_path) + vertices = torch.from_numpy(mesh.vertex.positions.cpu().numpy()) + faces = torch.from_numpy(mesh.triangle.indices.cpu().numpy()) + + sampler = AntipodalSampler() + hit_point_pairs = sampler.sample(vertices, faces) + sampler.visualize(hit_point_pairs) + print(f"Sampled {hit_point_pairs.shape[0]} antipodal points") diff --git a/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py new file mode 100644 index 00000000..a488108b --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/batch_collision_checker.py @@ -0,0 +1,472 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import trimesh +import numpy as np +import torch +import time +from typing import List, Tuple, Union +from dexsim.kit.meshproc import convex_decomposition_coacd +import hashlib +from dataclasses import dataclass +import os +import pickle +import open3d as o3d +from embodichain.utils import logger +from embodichain.utils.warp import convex_signed_distance_kernel +import warp as wp +from embodichain.utils.device_utils import standardize_device_string + +CONVEX_CACHE_DIR = os.path.join( + os.path.expanduser("~"), ".cache", "embodichain_cache", "convex_decomposition" +) + + +@dataclass +class BatchConvexCollisionCheckerCfg: + """ Configuration for BatchConvexCollisionChecker.""" + + collsion_threshold: float = 0.0 + """ Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision.""" + n_query_mesh_samples: int = 4096 + """ Number of points to sample from the query mesh surface for collision checking. A higher number of samples can provide a more accurate collision check at the cost of increased computation time. The optimal number may depend on the complexity of the mesh and the required precision of collision detection.""" + debug: bool = False + """ Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing.""" + + +class BatchConvexCollisionChecker: + """ BatchConvexCollisionChecker performs efficient collision checking between a batch of query point clouds and a convex decomposition of a mesh. The convex decomposition is represented by plane equations of the convex hulls, which are precomputed and cached for efficiency. The collision checking is done by computing the signed distance from each query point to the convex hulls using the plane equations, and determining if any points are colliding based on a specified collision threshold. This class can be used""" + + def __init__( + self, + base_mesh_verts: torch.Tensor, + base_mesh_faces: torch.Tensor, + max_decomposition_hulls: int = 32, + ): + """ Initialize the BatchConvexCollisionChecker by performing convex decomposition on the input mesh and extracting plane equations for the convex hulls. The plane equations are cached to disk to avoid redundant computation in future runs. + Args: + base_mesh_verts: [N, 3] vertex positions of the input mesh. + base_mesh_faces: [M, 3] triangle indices of the input mesh. + max_decomposition_hulls: maximum number of convex hulls to decompose into. A higher number allows for a more accurate approximation of the original mesh but increases computation time and memory usage. The optimal number may depend on the complexity of the mesh and the required precision of collision checking. + """ + if not os.path.isdir(CONVEX_CACHE_DIR): + os.makedirs(CONVEX_CACHE_DIR, exist_ok=True) + self.device = base_mesh_verts.device + base_mesh_verts_np = base_mesh_verts.cpu().numpy() + base_mesh_faces_np = base_mesh_faces.cpu().numpy() + mesh_hash = hashlib.md5( + (base_mesh_verts_np.tobytes() + base_mesh_faces_np.tobytes()) + ).hexdigest() + + # for visualization + self.mesh = o3d.geometry.TriangleMesh( + vertices=o3d.utility.Vector3dVector(base_mesh_verts_np), + triangles=o3d.utility.Vector3iVector(base_mesh_faces_np), + ) + self.mesh.compute_vertex_normals() + + self.cache_path = os.path.join( + CONVEX_CACHE_DIR, f"{mesh_hash}_{max_decomposition_hulls}.pkl" + ) + + if not os.path.isfile(self.cache_path): + # [n_convex, n_max_faces, 4]: plane equations, normals(3) and offsets(1), padded with zeros if a hull has less than n_max_faces + # [n_convex, ]: number of faces for each convex hull + + # generate convex hulls and extract plane equations, then cache to disk + plane_equations_np = BatchConvexCollisionChecker._compute_plane_equations( + base_mesh_verts_np, base_mesh_faces_np, max_decomposition_hulls + ) + # pack as a single tensor + n_convex = len(plane_equations_np) + n_max_equation = max(len(normals) for normals, _ in plane_equations_np) + plane_equations = torch.zeros( + size=(n_convex, n_max_equation, 4), + dtype=torch.float32, + device=self.device, + ) + plane_equations_counts = torch.zeros( + n_convex, dtype=torch.int32, device=self.device + ) + for i in range(n_convex): + n_equation = plane_equations_np[i][0].shape[0] + # plane normals + plane_equations[i, :n_equation, :3] = torch.tensor( + plane_equations_np[i][0], device=self.device + ) + # plane offsets + plane_equations[i, :n_equation, 3] = torch.tensor( + plane_equations_np[i][1], device=self.device + ) + plane_equations_counts[i] = n_equation + self.plane_equations = { + "plane_equations": plane_equations, + "plane_equation_counts": plane_equations_counts, + } + pickle.dump(self.plane_equations, open(self.cache_path, "wb")) + else: + self.plane_equations = pickle.load(open(self.cache_path, "rb")) + self.plane_equations["plane_equations"] = self.plane_equations[ + "plane_equations" + ].to(self.device) + self.plane_equations["plane_equation_counts"] = self.plane_equations[ + "plane_equation_counts" + ].to(self.device) + + @staticmethod + def batch_point_convex_query( + plane_equations: torch.Tensor, + plane_equation_counts: torch.Tensor, + batch_points: torch.Tensor, + device: torch.device, + collision_threshold: float = -0.003, + ): + plane_equations_wp = wp.from_torch(plane_equations) + plane_equation_counts_wp = wp.from_torch(plane_equation_counts) + batch_points_wp = wp.from_torch(batch_points) + + n_pose = batch_points.shape[0] + n_point = batch_points.shape[1] + n_convex = plane_equations.shape[0] + point_convex_signed_distance_wp = wp.full( + shape=(n_pose, n_point, n_convex), + value=-float("inf"), + dtype=float, + device=standardize_device_string(device), + ) # [n_pose, n_point, n_convex] + wp.launch( + kernel=convex_signed_distance_kernel, + dim=(n_pose, n_point, n_convex), + inputs=(batch_points_wp, plane_equations_wp, plane_equation_counts_wp), + outputs=(point_convex_signed_distance_wp,), + device=standardize_device_string(device), + ) + point_convex_signed_distance = wp.to_torch(point_convex_signed_distance_wp) + # import ipdb; ipdb.set_trace() + point_signed_distance = point_convex_signed_distance.min( + dim=-1 + ).values # [n_pose, n_point] + is_point_collide = point_signed_distance <= collision_threshold + return point_signed_distance, is_point_collide + + def query_batch_points( + self, + batch_points: torch.Tensor, + collision_threshold: float = 0.0, + is_visual: bool = False, + ) -> torch.Tensor: + """ Query collision status for a batch of point clouds. The collision status is determined by checking if the signed distance from any point in the cloud to the convex hulls is less than or equal to the specified collision threshold. + Args: + batch_points: [B, n_point, 3] batch of point clouds to query for collision status. + collision_threshold: Collision threshold in meters. A point is considered colliding if its signed distance to the hull interior is <= this threshold. This allows for a margin of error in collision checking, where a small positive threshold can be used to consider points near the surface as colliding, and a small negative threshold can be used to allow for slight penetration without considering it a collision. + is_visual: Whether to visualize the collision checking results for debugging purposes. If set to True, the code will generate visualizations of the query points colored by their collision status (e.g., red for colliding points and green for non-colliding points) along with the original mesh. This can help in understanding and verifying the collision checking process, especially during development and testing. + Returns: + is_pose_collide: [B, ] boolean tensor indicating whether each point cloud in the""" + n_batch = batch_points.shape[0] + point_signed_distance, is_point_collide = ( + BatchConvexCollisionChecker.batch_point_convex_query( + self.plane_equations["plane_equations"], + self.plane_equations["plane_equation_counts"], + batch_points, + device=self.device, + collision_threshold=collision_threshold, + ) + ) + is_pose_collide = is_point_collide.any(dim=-1) # [B] + pose_surface_distance = point_signed_distance.min(dim=-1).values # [B] + if is_visual: + # visualize result + frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + for i in range(n_batch): + query_points_o3d = o3d.geometry.PointCloud() + query_points_np = batch_points[i].cpu().numpy() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_color = np.zeros_like(query_points_np) + query_points_color[is_point_collide[i].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~is_point_collide[i].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self.mesh, query_points_o3d, frame], mesh_show_back_face=True + ) + return is_pose_collide, pose_surface_distance + + def query( + self, + query_mesh_verts: torch.Tensor, + query_mesh_faces: torch.Tensor, + poses: torch.Tensor, + cfg: BatchConvexCollisionCheckerCfg = BatchConvexCollisionCheckerCfg(), + ) -> Tuple[torch.Tensor, torch.Tensor]: + query_mesh = trimesh.Trimesh( + vertices=query_mesh_verts.to("cpu").numpy(), + faces=query_mesh_faces.to("cpu").numpy(), + ) + n_query = cfg.n_query_mesh_samples + n_batch = poses.shape[0] + query_points_np = query_mesh.sample(n_query).astype(np.float32) + query_points = torch.tensor( + query_points_np, device=poses.device + ) # [n_query, 3] + penetration_result = torch.zeros(size=(n_batch, n_query), device=poses.device) + penetration_result.fill_(-float("inf")) + collision_result = torch.zeros( + size=(n_batch, n_query), dtype=torch.bool, device=poses.device + ) + collision_result.fill_(False) + for normals, offsets in self.plane_equations: + normals_torch = torch.tensor(normals, device=poses.device) + offsets_torch = torch.tensor(offsets, device=poses.device) + penetration, collides = check_collision_single_hull( + normals_torch, + offsets_torch, + transform_points_batch(query_points, poses), + cfg.collsion_threshold, + ) + penetration_result = torch.max(penetration_result, penetration) + collision_result = torch.logical_or(collision_result, collides) + is_colliding = collision_result.any(dim=-1) # [B] + max_penetration = penetration_result.max(dim=-1)[0] # [B] + + if cfg.debug: + # visualize result + for i in range(n_batch): + query_points_o3d = o3d.geometry.PointCloud() + query_points_o3d.points = o3d.utility.Vector3dVector(query_points_np) + query_points_o3d.transform(poses[i].to("cpu").numpy()) + query_points_color = np.zeros_like(query_points_np) + query_points_color[collision_result[i].cpu().numpy()] = [ + 1.0, + 0, + 0, + ] # red for colliding points + query_points_color[~collision_result[i].cpu().numpy()] = [ + 0, + 1.0, + 0, + ] # green for non-colliding points + query_points_o3d.colors = o3d.utility.Vector3dVector(query_points_color) + o3d.visualization.draw_geometries( + [self.mesh, query_points_o3d], mesh_show_back_face=True + ) + return is_colliding, max_penetration + + @staticmethod + def _compute_plane_equations( + vertices: np.ndarray, faces: np.ndarray, max_decomposition_hulls: int + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convex decomposition and extract plane equations given mesh vertices and triangles. + Each convex hull is represented by its outward-facing face normals and offsets. + No padding is applied; each hull can have a different number of faces. + + Args: + vertices: [N, 3] vertex positions of the input mesh. + faces: [M, 3] triangle indices of the input mesh. + max_decomposition_hulls: maximum number of convex hulls to decompose into. + + Returns: + List of (normals_i [Ki, 3], offsets_i [Ki]) tuples, one per convex hull. + Ki is the number of faces of the i-th hull and can differ across hulls. + """ + mesh = o3d.t.geometry.TriangleMesh() + mesh.vertex.positions = o3d.core.Tensor(vertices, dtype=o3d.core.Dtype.Float32) + mesh.triangle.indices = o3d.core.Tensor(faces, dtype=o3d.core.Dtype.Int32) + is_success, out_mesh_list = convex_decomposition_coacd( + mesh, max_convex_hull_num=max_decomposition_hulls + ) + convex_vert_face_list = [] + for out_mesh in out_mesh_list: + verts = out_mesh.vertex.positions.numpy() + faces = out_mesh.triangle.indices.numpy() + convex_vert_face_list.append((verts, faces)) + return extract_plane_equations(convex_vert_face_list) + + +def extract_plane_equations( + convex_meshes: List[Tuple[np.ndarray, np.ndarray]], +) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Extract plane equations from a list of convex hull meshes. + Each convex hull is represented by its outward-facing face normals and offsets. + No padding is applied; each hull can have a different number of faces. + + Args: + convex_meshes: List of convex hull data. + - tuple of (vertices [N,3], faces [M,3]) + + Returns: + List of (normals_i [Ki, 3], offsets_i [Ki]) tuples, one per convex hull. + Ki is the number of faces of the i-th hull and can differ across hulls. + """ + convex_plane_data = [] + + for i, convex_mesh_data in enumerate(convex_meshes): + vertices, faces = convex_mesh_data + hull = trimesh.Trimesh( + vertices=vertices, + faces=faces, + ) + # Outward-facing face normals [Ki, 3] + face_normals = hull.face_normals + # One vertex per face to compute offset [Ki, 3] + face_origins = hull.triangles[:, 0, :] + # Plane equation: n · x + d = 0 => d = -(n · p) + offsets_i = -np.sum(face_normals * face_origins, axis=1) + + convex_plane_data.append( + (face_normals.astype(np.float32), offsets_i.astype(np.float32)) + ) + return convex_plane_data + + +def sample_surface_points(mesh_path: str, num_points: int = 4096) -> np.ndarray: + """ + Sample surface points from a mesh file. + + Args: + mesh_path: Path to the mesh file. + num_points: Number of surface points to sample. + + Returns: + points: [P, 3] numpy array of sampled surface points. + """ + mesh = trimesh.load(mesh_path, force="mesh") + points = mesh.sample(num_points) + return points.astype(np.float32) + + +def check_collision_single_hull( + normals: torch.Tensor, # [K, 3] + offsets: torch.Tensor, # [K] + transformed_points: torch.Tensor, # [B, P, 3] + threshold: float = 0.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Check collision between a batch of transformed point clouds and a single convex hull. + + A point p is inside the convex hull iff: + max_k (n_k · p + d_k) <= 0 + + Penetration depth for a point is defined as: + penetration = -(max_k (n_k · p + d_k)) + Positive penetration means the point is inside the hull. + + Args: + normals: [K, 3] outward face normals of the convex hull. + offsets: [K] plane offsets of the convex hull. + transformed_points: [B, P, 3] point cloud already transformed by batch poses. + threshold: collision threshold. A point is considered colliding if + its signed distance to the hull interior is <= threshold. + + Returns: + penetration: [B, P] penetration depth for each point. + Positive values indicate the point is inside the hull. + collides: [B, P] boolean mask, True if the point collides with this hull. + """ + # signed_dist: [B, P, K] = einsum([B,P,3], [K,3]) + [K] + signed_dist = torch.einsum("bpj, kj -> bpk", transformed_points, normals) + offsets + + # For each point, the maximum signed distance across all planes + # If max <= 0, the point satisfies all half-plane constraints => inside the hull + max_over_planes, _ = signed_dist.max(dim=-1) # [B, P] + + # Penetration depth: negate so that positive = inside + penetration = -max_over_planes # [B, P] + + # A point collides if its penetration exceeds the threshold + collides = penetration > threshold # [B, P] + + return penetration, collides + + +def transform_points_batch( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + +if __name__ == "__main__": + from embodichain.data import get_data_path + + mug_path = get_data_path("CoffeeCup/cup.ply") + mug_path = get_data_path("ScannedBottle/moliwulong_processed.ply") + mug_mesh = trimesh.load(mug_path, force="mesh", process=False) + verts = torch.tensor(mug_mesh.vertices, dtype=torch.float32) + faces = torch.tensor(mug_mesh.faces, dtype=torch.int32) + collision_checker = BatchConvexCollisionChecker( + verts, faces, max_decomposition_hulls=16 + ) + + poses = torch.tensor( + [ + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0.05], + [0, 0, 0, 1], + ], + [ + [1, 0, 0, 0.05], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], + ] + ) + from scipy.spatial.transform import Rotation + + rot = Rotation.from_euler("xyz", [12, 3, 32], degrees=True).as_matrix() + poses[0, :3, :3] = torch.tensor(rot, dtype=torch.float32) + poses[1, :3, :3] = torch.tensor(rot, dtype=torch.float32) + + obj_path = get_data_path("ScannedBottle/yibao_processed.ply") + obj_mesh = trimesh.load(obj_path, force="mesh", process=False) + obj_verts = torch.tensor(obj_mesh.vertices, dtype=torch.float32) + obj_faces = torch.tensor(obj_mesh.faces, dtype=torch.int32) + test_pc = transform_points_batch(obj_verts, poses) + + collision_checker.query_batch_points( + test_pc, collision_threshold=0.003, is_visual=True + ) + collision_checker.query( + obj_verts, + obj_faces, + poses, + cfg=BatchConvexCollisionCheckerCfg( + debug=True, n_query_mesh_samples=32768, collsion_threshold=0.000 + ), + ) diff --git a/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py new file mode 100644 index 00000000..bacd6037 --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/gripper_collision_checker.py @@ -0,0 +1,250 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence +from .batch_collision_checker import BatchConvexCollisionChecker +import torch + + +@dataclass +class SimpleGripperCollisionCfg: + """ Configuration for the SimpleGripperCollisionChecker. This class defines various parameters related to the gripper geometry, point cloud generation, and collision checking process. Users can customize these parameters based on the specific gripper being modeled and the requirements of the application.""" + + max_open_length: float = 0.1 + """ Maximum opening length of the gripper fingers. This should be set according to the specific gripper being modeled, and it defines the maximum distance between the two fingers when fully open.""" + finger_length: float = 0.16 + """ Length of the gripper fingers from the root to the tip. This should be set according to the specific gripper being modeled, and it defines how far the fingers extend from the gripper root frame.""" + y_thickness: float = 0.03 + """ Thickness of the gripper along the Y-axis (the axis perpendicular to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the width of the gripper's main body and fingers in the Y direction.""" + x_thickness: float = 0.01 + """ Thickness of the gripper along the X-axis (the axis parallel to the finger opening direction). This should be set according to the specific gripper being modeled, and it defines the thickness of the fingers and the root in the X direction.""" + root_z_width: float = 0.06 + """ Width of the gripper root along the Z-axis (the axis along the finger length direction). This should be set according to the specific gripper being modeled, and it defines how far the root extends along the Z direction.""" + device = torch.device("cpu") + """ Device on which the gripper point cloud will be generated and processed. This should be set according to the computational resources available and the requirements of the application. For example, if using a GPU for collision checking, this should be set to torch.device('cuda'). """ + rough_dense: float = 0.015 + """ Approximate number of points per unit length for the gripper point cloud. Higher values will yield denser point clouds, which can improve collision checking accuracy but also increase computational cost. This should be set based on the desired balance between accuracy and efficiency for the specific application.""" + max_decomposition_hulls: int = 16 + """ Maximum number of convex hulls to decompose the object mesh into for collision checking. This should be set based on the complexity of the object geometry and the desired accuracy of collision checking. More hulls can provide a tighter approximation of the object shape but will increase computational cost.""" + open_check_margin: float = 0.01 + """ Additional margin added to the gripper open length when checking for collisions. This can help account for uncertainties in the gripper pose or object geometry, and can be set based on the specific requirements of the application.""" + + +class SimpleGripperCollisionChecker: + def __init__( + self, + object_mesh_verts: torch.Tensor, + object_mesh_faces: torch.Tensor, + cfg: SimpleGripperCollisionCfg = SimpleGripperCollisionCfg(), + ): + self._checker = BatchConvexCollisionChecker( + base_mesh_verts=object_mesh_verts, + base_mesh_faces=object_mesh_faces, + max_decomposition_hulls=cfg.max_decomposition_hulls, + ) + self.device = object_mesh_verts.device + self.cfg = cfg + self._init_pc_template() + + def _init_pc_template(self): + self.root_template = box_surface_grid( + size=( + self.cfg.max_open_length, + self.cfg.y_thickness, + self.cfg.root_z_width, + ), + dense=self.cfg.rough_dense, + device=self.device, + ) + self.left_template = box_surface_grid( + size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), + dense=self.cfg.rough_dense, + device=self.device, + ) + self.right_template = box_surface_grid( + size=(self.cfg.x_thickness, self.cfg.y_thickness, self.cfg.finger_length), + dense=self.cfg.rough_dense, + device=self.device, + ) + + def _get_gripper_pc( + self, grasp_poses: torch.Tensor, open_lengths: torch.Tensor + ) -> torch.Tensor: + """ + Args: + grasp_poses: [B, 4, 4] homogeneous transformation matrix of the gripper root frame. + open_lengths: [B] opening length of the gripper fingers. + Returns: + gripper_pc: [B, P, 3] point cloud of the gripper in the world frame. + """ + + root_grasp_poses = grasp_poses.clone() + root_grasp_poses[:, :3, 3] -= ( + root_grasp_poses[:, :3, 2] + * 0.5 + * (self.cfg.finger_length + self.cfg.root_z_width) + ) + open_lengths_repeat = ( + open_lengths[:, None] + self.cfg.open_check_margin + ).repeat(1, 3) + left_finger_poses = grasp_poses.clone() + left_finger_poses[:, :3, 3] -= left_finger_poses[:, :3, 0] * open_lengths_repeat + + right_finger_poses = grasp_poses.clone() + right_finger_poses[:, :3, 3] += ( + right_finger_poses[:, :3, 0] * open_lengths_repeat + ) + + root_pc = transform_points_batch(self.root_template, root_grasp_poses) + left_pc = transform_points_batch(self.left_template, left_finger_poses) + right_pc = transform_points_batch(self.right_template, right_finger_poses) + gripper_pc = torch.cat([root_pc, left_pc, right_pc], dim=1) + return gripper_pc + + def query( + self, + obj_pose: torch.Tensor, + grasp_poses: torch.Tensor, + open_lengths: torch.Tensor, + collision_threshold: float = 0.0, + is_visual: bool = False, + ) -> torch.Tensor: + inv_obj_pose = obj_pose.clone() + inv_obj_pose[:3, :3] = obj_pose[:3, :3].T + inv_obj_pose[:3, 3] = -obj_pose[:3, 3] @ obj_pose[:3, :3] + inv_obj_poses = inv_obj_pose[None, :, :].repeat(grasp_poses.shape[0], 1, 1) + grasp_relative_pose = torch.bmm(inv_obj_poses, grasp_poses) + gripper_pc = self._get_gripper_pc(grasp_relative_pose, open_lengths) + return self._checker.query_batch_points( + gripper_pc, collision_threshold=collision_threshold, is_visual=is_visual + ) + + +def transform_points_batch( + points: torch.Tensor, poses: torch.Tensor # [P, 3] # [B, 4, 4] +) -> torch.Tensor: + """ + Apply a batch of rigid transforms to a point cloud. + + Args: + points: [P, 3] source point cloud. + poses: [B, 4, 4] batch of homogeneous transformation matrices. + + Returns: + transformed: [B, P, 3] transformed point cloud for each pose. + """ + R = poses[:, :3, :3] # [B, 3, 3] + t = poses[:, :3, 3] # [B, 3] + transformed = torch.einsum("bij, pj -> bpi", R, points) + t.unsqueeze(1) + return transformed + + +def box_surface_grid( + size: Sequence[float] | torch.Tensor, + dense: float, + device: torch.device | str = "cpu", +) -> torch.Tensor: + """Generate grid-sampled points on the surface of an axis-aligned box. + + Six faces of the box are each sampled independently on a regular 2-D grid. + Grid resolution per face is derived automatically from ``dense``: + the number of sample points along an edge of length *L* is + ``max(2, round(L * dense) + 1)``, so ``dense`` behaves as + *approximate samples per unit length*. + + Edge and corner points are shared across adjacent faces and are included + exactly once (no duplicates). + + Args: + size: Box dimensions ``(sx, sy, sz)``. Accepts a sequence of three + floats or a 1-D :class:`torch.Tensor` of length 3. + dense: Approximate number of grid sample points per unit length along + each edge. Higher values yield denser point clouds. + device: Target PyTorch device for the returned tensor. + + Returns: + Float tensor of shape ``(N, 3)`` containing surface points expressed + in the box's local frame (origin at the box centre). + + Example: + >>> pts = box_surface_grid((0.1, 0.06, 0.03), dense=200.0) + >>> pts.shape + torch.Size([..., 3]) + """ + if isinstance(size, torch.Tensor): + sx, sy, sz = size[0].item(), size[1].item(), size[2].item() + else: + sx, sy, sz = float(size[0]), float(size[1]), float(size[2]) + + hx, hy, hz = sx / 2.0, sy / 2.0, sz / 2.0 + + # ── grid resolution per axis (at least 2 points to span the full edge) ── + nx = max(2, round(sx / dense) + 1) + ny = max(2, round(sy / dense) + 1) + nz = max(2, round(sz / dense) + 1) + + xs = torch.linspace(-hx, hx, nx, device=device) + ys = torch.linspace(-hy, hy, ny, device=device) + zs = torch.linspace(-hz, hz, nz, device=device) + + # Interior slices (exclude first and last to avoid duplicate edges) + xs_inner = xs[1:-1] # length nx-2 + ys_inner = ys[1:-1] # length ny-2 + + def _grid( + u: torch.Tensor, v: torch.Tensor, axis: int, offset: float + ) -> torch.Tensor: + """Build a flat (M, 3) tensor for one face grid. + + Args: + u: 1-D tensor of coordinates along the first in-plane axis. + v: 1-D tensor of coordinates along the second in-plane axis. + axis: Normal axis of the face — 0 (±X), 1 (±Y), or 2 (±Z). + offset: Signed half-extent along ``axis``. + + Returns: + Tensor of shape ``(len(u) * len(v), 3)``. + """ + uu, vv = torch.meshgrid(u, v, indexing="ij") + uu = uu.reshape(-1) + vv = vv.reshape(-1) + cc = torch.full_like(uu, offset) + if axis == 0: + return torch.stack([cc, uu, vv], dim=-1) + elif axis == 1: + return torch.stack([uu, cc, vv], dim=-1) + else: + return torch.stack([uu, vv, cc], dim=-1) + + # ───────────────────────────────────────────────────────────────────────── + # Build 6 faces. To avoid duplicate points on shared edges/corners: + # ±X faces → full NY × NZ grids + # ±Y faces → (NX-2) × NZ grids (x-edges owned by ±X faces) + # ±Z faces → (NX-2) × (NY-2) grids (x- and y-edges owned above) + # ───────────────────────────────────────────────────────────────────────── + faces: list[torch.Tensor] = [ + _grid(ys, zs, axis=0, offset=-hx), # −X face (NY × NZ) + _grid(ys, zs, axis=0, offset=+hx), # +X face (NY × NZ) + _grid(xs_inner, zs, axis=1, offset=-hy), # −Y face ((NX-2) × NZ) + _grid(xs_inner, zs, axis=1, offset=+hy), # +Y face ((NX-2) × NZ) + _grid(xs_inner, ys_inner, axis=2, offset=-hz), # −Z face + _grid(xs_inner, ys_inner, axis=2, offset=+hz), # +Z face + ] + + return torch.cat(faces, dim=0) diff --git a/embodichain/utils/warp/__init__.py b/embodichain/utils/warp/__init__.py index 905bc9e7..e0fac57a 100644 --- a/embodichain/utils/warp/__init__.py +++ b/embodichain/utils/warp/__init__.py @@ -30,3 +30,5 @@ repeat_first_point, interpolate_along_distance, ) + +from .collision_checker.convex_query import convex_signed_distance_kernel diff --git a/embodichain/utils/warp/collision_checker/__init__.py b/embodichain/utils/warp/collision_checker/__init__.py new file mode 100644 index 00000000..d7e19801 --- /dev/null +++ b/embodichain/utils/warp/collision_checker/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import convex_query diff --git a/embodichain/utils/warp/collision_checker/convex_query.py b/embodichain/utils/warp/collision_checker/convex_query.py new file mode 100644 index 00000000..f321e462 --- /dev/null +++ b/embodichain/utils/warp/collision_checker/convex_query.py @@ -0,0 +1,55 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import warp as wp +from typing import Any + + +@wp.kernel(enable_backward=False) +def convex_signed_distance_kernel( + query_points: wp.array(dtype=wp.float32, ndim=3), + plane_equations: wp.array(dtype=wp.float32, ndim=3), + plane_equation_counts: wp.array(dtype=wp.int32, ndim=1), + signed_distances: wp.array(dtype=wp.float32, ndim=3), +): + """ + Compute the signed distance from query points to convex hulls defined by plane equations. + + Args: + query_points: [n_pose, n_point, 3] coordinates of query points. + plane_equations: [n_convex, n_max_equation, 4] plane equations of convex hulls, where each plane equation is represented as (normal_x, normal_y, normal_z, offset). + plane_equation_counts: [n_convex, ] number of valid plane equations for each convex hull. + + Returns: + signed_distances: [n_pose, n_point, n_convex] output signed distances from query points to convex hulls. Should be initialized as +inf before calling this kernel. + """ + pose_id, point_id, convex_id = wp.tid() + n_equation = plane_equation_counts[convex_id] + for i in range(n_equation): + normal_x = plane_equations[convex_id, i, 0] + normal_y = plane_equations[convex_id, i, 1] + normal_z = plane_equations[convex_id, i, 2] + offset = plane_equations[convex_id, i, 3] + signed_distance = ( + query_points[pose_id, point_id, 0] * normal_x + + query_points[pose_id, point_id, 1] * normal_y + + query_points[pose_id, point_id, 2] * normal_z + + offset + ) + # should initialize as -inf + signed_distances[pose_id, point_id, convex_id] = max( + signed_distance, signed_distances[pose_id, point_id, convex_id] + ) diff --git a/pyproject.toml b/pyproject.toml index 60a12496..25b15290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", - "tensordict" + "tensordict", + "viser==1.0.21" ] [project.optional-dependencies] diff --git a/scripts/tutorials/grasp/grasp_generator.py b/scripts/tutorials/grasp/grasp_generator.py new file mode 100644 index 00000000..9f4450d0 --- /dev/null +++ b/scripts/tutorials/grasp/grasp_generator.py @@ -0,0 +1,268 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with a soft object, +and performs a pressing task in a simulated environment. +""" + +import argparse +import numpy as np +import time +import torch + +from dexsim.utility.path import get_resources_data_path + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + LightCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + URDFCfg, +) +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import ( + GraspAnnotatorCfg, + AntipodalSamplerCfg, +) + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument("--headless", action="store_true", help="Enable headless mode") + parser.add_argument( + "--device", + type=str, + default="cuda", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + return parser.parse_args() + + +def initialize_simulation(args) -> SimulationManager: + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + num_envs=args.num_envs, + arena_space=2.5, + ) + sim = SimulationManager(config) + + if args.enable_rt: + light = sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0, 3.0), + ) + ) + + return sim + + +def create_robot(sim: SimulationManager, position=[0.0, 0.0, 0.0]) -> Robot: + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + gripper_urdf_path = get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf") + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="UR10", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "FINGER[1-2]": 1e3}, + damping={"JOINT[0-9]": 1e3, "FINGER[1-2]": 1e2}, + max_effort={"JOINT[0-9]": 1e5, "FINGER[1-2]": 1e4}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": ["FINGER[1-2]"], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.12], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_mug(sim: SimulationManager): + mug_cfg = RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("CoffeeCup/cup.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[0.55, 0.0, 0.01], + init_rot=[0.0, 0.0, -90], + body_scale=(4, 4, 4), + ) + mug = sim.add_rigid_object(cfg=mug_cfg) + return mug + + +def get_grasp_traj(sim: SimulationManager, robot: Robot, grasp_xpos: torch.Tensor): + n_envs = sim.num_envs + rest_arm_qpos = robot.get_qpos("arm") + + approach_xpos = grasp_xpos.clone() + approach_xpos[:, 2, 3] += 0.1 + + _, qpos_approach = robot.compute_ik( + pose=approach_xpos, joint_seed=rest_arm_qpos, name="arm" + ) + _, qpos_grasp = robot.compute_ik( + pose=grasp_xpos, joint_seed=qpos_approach, name="arm" + ) + hand_open_qpos = torch.tensor([0.00, 0.00], dtype=torch.float32, device=sim.device) + hand_close_qpos = torch.tensor( + [0.025, 0.025], dtype=torch.float32, device=sim.device + ) + + arm_trajectory = torch.cat( + [ + rest_arm_qpos[:, None, :], + qpos_approach[:, None, :], + qpos_grasp[:, None, :], + qpos_grasp[:, None, :], + qpos_approach[:, None, :], + rest_arm_qpos[:, None, :], + ], + dim=1, + ) + hand_trajectory = torch.cat( + [ + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + ], + dim=1, + ) + all_trajectory = torch.cat([arm_trajectory, hand_trajectory], dim=-1) + interp_trajectory = interpolate_with_distance( + trajectory=all_trajectory, interp_num=200, device=sim.device + ) + return interp_trajectory + + +if __name__ == "__main__": + import time + + args = parse_arguments() + sim = initialize_simulation(args) + robot = create_robot(sim, position=[0.0, 0.0, 0.0]) + mug = create_mug(sim) + + # get mug grasp pose + grasp_cfg = GraspAnnotatorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=20000, max_length=0.088, min_length=0.003 + ), + force_regenerate=True, # force user to annotate grasp region each time + ) + sim.open_window() + + # Annotate part of the mug to be grasped by following the instructions in the visualization window: + # 1. View grasp object in browser (e.g http://localhost:11801) + # 2. press 'Rect Select Region', select grasp region + # 3. press 'Confirm Selection' to finish grasp region selection. + + start_time = time.time() + grasp_xpos = mug.get_grasp_pose( + approach_direction=torch.tensor( + [0, 0, -1], dtype=torch.float32, device=sim.device + ), # gripper approach direction in the world frame + cfg=grasp_cfg, + is_visual=True, # visualize selected grasp pose finally + ) + cost_time = time.time() - start_time + logger.log_info(f"Get grasp pose cost time: {cost_time:.2f} seconds") + + grab_traj = get_grasp_traj(sim, robot, grasp_xpos) + input("Press Enter to start the grab mug demo...") + n_waypoint = grab_traj.shape[1] + for i in range(n_waypoint): + robot.set_qpos(grab_traj[:, i, :]) + sim.update(step=4) + time.sleep(1e-2) + input("Press Enter to exit the simulation...")