diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py
index 309990ea03a..f9780c39a2d 100644
--- a/torchvision/ops/boxes.py
+++ b/torchvision/ops/boxes.py
@@ -69,10 +69,11 @@ def batched_nms(
         _log_api_usage_once(batched_nms)
     # Benchmarks that drove the following thresholds are at
     # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
-    if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
-        return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
-    else:
-        return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
+    return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
+    #if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
+    #    return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
+    #else:
+    #    return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
 
 
 @torch.jit._script_if_tracing
@@ -104,7 +105,8 @@ def _batched_nms_vanilla(
 ) -> Tensor:
     # Based on Detectron2 implementation, just manually call nms() on each class independently
     keep_mask = torch.zeros_like(scores, dtype=torch.bool)
-    for class_id in torch.unique(idxs):
+    #for class_id in torch.unique(idxs):
+    for class_id in idxs:
         curr_indices = torch.where(idxs == class_id)[0]
         curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
         keep_mask[curr_indices[curr_keep_indices]] = True