Skip to content

Commit 74742ff

Browse files
authored
Merge pull request #28 from midas-research/feat/gt-annotation
feat: enhance score calculation by adjusting ground truth sample poin…
2 parents 5caacc9 + bed5655 commit 74742ff

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

cvat/apps/quality_control/views.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,18 @@ def immediate_reports(self, request, *args, **kwargs):
446446
# handle label matching task
447447
else:
448448
penalty_factor = 0.1
449-
def calculate_score(gt_samples, ds_samples):
450-
# Group annotations by label
449+
def calculate_score(gt_samples, ds_samples, start_time=0):
450+
gt_samples_adjusted = []
451+
for ann in gt_samples:
452+
adjusted_ann = ann.copy()
453+
adjusted_points = adjusted_ann["points"].copy()
454+
adjusted_points[0] = round(adjusted_points[0] - start_time, 10) # Round to 10 decimal places
455+
adjusted_points[3] = round(adjusted_points[3] - start_time, 10)
456+
adjusted_ann["points"] = adjusted_points
457+
gt_samples_adjusted.append(adjusted_ann)
458+
451459
gt_by_label = {}
452-
for idx, ann in enumerate(gt_samples):
460+
for idx, ann in enumerate(gt_samples_adjusted):
453461
label = ann.get("label_id", "")
454462
if label not in gt_by_label:
455463
gt_by_label[label] = []
@@ -466,27 +474,22 @@ def calculate_score(gt_samples, ds_samples):
466474
total_gt_count = 0
467475
unused_predictions = []
468476

469-
# Process each label separately
470477
for label in set(list(gt_by_label.keys()) + list(ds_by_label.keys())):
471478
gt_list = gt_by_label.get(label, [])
472479
ds_list = ds_by_label.get(label, [])
473480
total_gt_count += len(gt_list)
474481

475-
# Track best coverage for each GT
476482
gt_coverage = [0.0] * len(gt_list)
477483
used_predictions = [False] * len(ds_list)
478484

479-
# Check each GT against all predictions
480485
for gt_idx, (orig_gt_idx, gt_ann) in enumerate(gt_list):
481486
gt_start, gt_end = gt_ann["points"][0], gt_ann["points"][3]
482487
gt_length = gt_end - gt_start
483488

484-
# Find all overlapping predictions
485489
overlapping_predictions = []
486490
for ds_idx, ds_ann in enumerate(ds_list):
487491
ds_start, ds_end = ds_ann["points"][0], ds_ann["points"][3]
488492

489-
# Calculate overlap
490493
overlap_start = max(gt_start, ds_start)
491494
overlap_end = min(gt_end, ds_end)
492495
overlap = max(0, overlap_end - overlap_start)
@@ -495,7 +498,6 @@ def calculate_score(gt_samples, ds_samples):
495498
overlapping_predictions.append((ds_idx, overlap))
496499
used_predictions[ds_idx] = True
497500

498-
# Merge overlaps to calculate total coverage
499501
overlapping_predictions.sort(key=lambda x: x[1], reverse=True)
500502
covered_intervals = []
501503

@@ -520,7 +522,7 @@ def calculate_score(gt_samples, ds_samples):
520522

521523
# Calculate total coverage
522524
total_covered = sum(end - start for start, end in covered_intervals)
523-
coverage_ratio = min(1.0, total_covered / gt_length)
525+
coverage_ratio = min(1.0, total_covered / gt_length) if gt_length > 0 else 0
524526
gt_coverage[gt_idx] = coverage_ratio
525527

526528
# Add unused predictions for this label to global list
@@ -547,7 +549,7 @@ def calculate_score(gt_samples, ds_samples):
547549

548550
return final_score
549551

550-
score = calculate_score(gt_samples_filtered, ds_samples_filtered)
552+
score = calculate_score(gt_samples_filtered, ds_samples_filtered, start_time)
551553

552554
response_data = {
553555
"score": score,

0 commit comments

Comments
 (0)