diff --git a/detectron2/layers/nms.py b/detectron2/layers/nms.py index 37ba18b2af..5873a481f5 100644 --- a/detectron2/layers/nms.py +++ b/detectron2/layers/nms.py @@ -19,7 +19,8 @@ def batched_nms( # to decide whether to use coordinate trick or for loop to implement batched_nms. So we # just call it directly. # Fp16 does not have enough range for batched NMS, so adding float(). - return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold) + boxes_value = boxes.double() if scores.dtype == torch.float64 else boxes.float() + return box_ops.batched_nms(boxes_value, scores, idxs, iou_threshold) # Note: this function (nms_rotated) might be moved into