@@ -299,7 +299,7 @@ def extract_boxes(self, predictions):
299
299
300
300
301
301
class DocLayoutPostProcess :
302
- def __init__ (self , labels : List [str ], conf_thres = 0.7 , iou_thres = 0.5 ):
302
+ def __init__ (self , labels : List [str ], conf_thres = 0.2 , iou_thres = 0.5 ):
303
303
self .labels = labels
304
304
self .conf_threshold = conf_thres
305
305
self .iou_threshold = iou_thres
@@ -308,31 +308,18 @@ def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
308
308
309
309
def __call__ (
310
310
self ,
311
- output ,
311
+ preds ,
312
312
ori_img_shape : Tuple [int , int ],
313
313
img_shape : Tuple [int , int ] = (1024 , 1024 ),
314
314
):
315
- self .img_height , self .img_width = ori_img_shape
316
- self .input_height , self .input_width = img_shape
317
-
318
- output = output [0 ].squeeze ()
319
- boxes = output [:, :- 2 ]
320
- confidences = output [:, - 2 ]
321
- class_ids = output [:, - 1 ].astype (int )
322
-
323
- mask = confidences > self .conf_threshold
324
- boxes = boxes [mask , :]
325
- confidences = confidences [mask ]
326
- class_ids = class_ids [mask ]
327
-
328
- # Rescale boxes to original image dimensions
329
- boxes = rescale_boxes (
330
- boxes ,
331
- self .input_width ,
332
- self .input_height ,
333
- self .img_width ,
334
- self .img_height ,
335
- )
315
+ preds = preds [0 ]
316
+ mask = preds [..., 4 ] > self .conf_threshold
317
+ preds = [p [mask [idx ]] for idx , p in enumerate (preds )][0 ]
318
+ preds [:, :4 ] = scale_boxes (list (img_shape ), preds [:, :4 ], list (ori_img_shape ))
319
+
320
+ boxes = preds [:, :4 ]
321
+ confidences = preds [:, 4 ]
322
+ class_ids = preds [:, 5 ].astype (int )
336
323
labels = [self .labels [i ] for i in class_ids ]
337
324
return boxes , confidences , labels
338
325
@@ -345,6 +332,54 @@ def rescale_boxes(boxes, input_width, input_height, img_width, img_height):
345
332
return boxes
346
333
347
334
335
+ def scale_boxes (
336
+ img1_shape , boxes , img0_shape , ratio_pad = None , padding = True , xywh = False
337
+ ):
338
+ """
339
+ Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
340
+ specified in (img1_shape) to the shape of a different image (img0_shape).
341
+
342
+ Args:
343
+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
344
+ boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
345
+ img0_shape (tuple): the shape of the target image, in the format of (height, width).
346
+ ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
347
+ calculated based on the size difference between the two images.
348
+ padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
349
+ rescaling.
350
+ xywh (bool): The box format is xywh or not, default=False.
351
+
352
+ Returns:
353
+ boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
354
+ """
355
+ if ratio_pad is None : # calculate from img0_shape
356
+ gain = min (
357
+ img1_shape [0 ] / img0_shape [0 ], img1_shape [1 ] / img0_shape [1 ]
358
+ ) # gain = old / new
359
+ pad = (
360
+ round ((img1_shape [1 ] - img0_shape [1 ] * gain ) / 2 - 0.1 ),
361
+ round ((img1_shape [0 ] - img0_shape [0 ] * gain ) / 2 - 0.1 ),
362
+ ) # wh padding
363
+ else :
364
+ gain = ratio_pad [0 ][0 ]
365
+ pad = ratio_pad [1 ]
366
+
367
+ if padding :
368
+ boxes [..., 0 ] -= pad [0 ] # x padding
369
+ boxes [..., 1 ] -= pad [1 ] # y padding
370
+ if not xywh :
371
+ boxes [..., 2 ] -= pad [0 ] # x padding
372
+ boxes [..., 3 ] -= pad [1 ] # y padding
373
+ boxes [..., :4 ] /= gain
374
+ return clip_boxes (boxes , img0_shape )
375
+
376
+
377
+ def clip_boxes (boxes , shape ):
378
+ boxes [..., [0 , 2 ]] = boxes [..., [0 , 2 ]].clip (0 , shape [1 ]) # x1, x2
379
+ boxes [..., [1 , 3 ]] = boxes [..., [1 , 3 ]].clip (0 , shape [0 ]) # y1, y2
380
+ return boxes
381
+
382
+
348
383
def nms (boxes , scores , iou_threshold ):
349
384
# Sort by score
350
385
sorted_indices = np .argsort (scores )[::- 1 ]
0 commit comments