diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py index 96e4b410..786795ca 100644 --- a/pytorch3d/ops/points_alignment.py +++ b/pytorch3d/ops/points_alignment.py @@ -39,6 +39,7 @@ def iterative_closest_point( X: Union[torch.Tensor, "Pointclouds"], Y: Union[torch.Tensor, "Pointclouds"], init_transform: Optional[SimilarityTransform] = None, + trim_fraction: Union[float, torch.Tensor] = 0., max_iterations: int = 100, relative_rmse_thr: float = 1e-6, estimate_scale: bool = False, @@ -67,6 +68,11 @@ def iterative_closest_point( shape `(minibatch, d, d)`, `T` is a batch of translations of shape `(minibatch, d)` and `s` is a batch of scaling factors of shape `(minibatch,)`. + **trim_fraction**: A float or 1d `Tensor` of shape `(minibatch,)` in [0, 1] + specifying the ratio of outliers in each point cloud. If float, assume + the same outliers ratio for all point clouds in the batch. Outliers will + be detected by taking the `trim_fraction * num_points_X` highest values of + `s[i] X[i] R[i] + T[i] = Y[NN[i]]`. **max_iterations**: The maximum number of ICP iterations. **relative_rmse_thr**: A threshold on the relative root mean squared error used to terminate the algorithm. @@ -152,6 +158,17 @@ def iterative_closest_point( T = Xt.new_zeros((b, dim)) s = Xt.new_ones(b) + # initialize trim fraction parameter + if isinstance(trim_fraction, float): + trim_fraction = torch.as_tensor(trim_fraction) + trim_fraction = trim_fraction.to(Xt.device) # type: ignore + if trim_fraction.ndim == 0: + trim_fraction = trim_fraction.repeat(b) + trim = trim_fraction.min() > 0.0 + + # initial mask: no trim considered, only padding + mask = mask_X.bool().clone() + prev_rmse = None rmse = None iteration = -1 @@ -170,7 +187,7 @@ def iterative_closest_point( R, T, s = corresponding_points_alignment( Xt_init, Xt_nn_points, - weights=mask_X, + weights=mask, estimate_scale=estimate_scale, allow_reflection=allow_reflection, ) @@ -184,7 +201,15 @@ def iterative_closest_point( # compute the root mean squared error # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2) - rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0] + + # trimming: select `1 - trim_fraction` lowest distances. + if trim: + diff_thresholds = Xt_sq_diff[mask_X.bool()].quantile(1 - trim_fraction) + mask_trim = Xt_sq_diff < diff_thresholds[:, None] + # final mask is (trim_mask AND pad_mask) + mask = torch.logical_and(mask_trim, mask_X) + + rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask).sqrt()[:, 0, 0] # compute the relative rmse if prev_rmse is None: