Skip to content

Commit 4e590de

Browse files
authored
Merge pull request #30 from midas-research/consensus-api-fix
Consensus api added in quality report
2 parents f9a399e + 1c5598f commit 4e590de

File tree

2 files changed

+226
-1
lines changed

2 files changed

+226
-1
lines changed

cvat/apps/quality_control/serializers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class QualityReportCreateSerializer(serializers.Serializer):
7878
class ImmediateQualityReportCreateSerializer(serializers.Serializer):
7979
job_id = serializers.IntegerField(write_only=True)
8080

81+
class ConsensusCreateSerializer(serializers.Serializer):
82+
files = serializers.ListField()
83+
8184
class QualitySettingsSerializer(serializers.ModelSerializer):
8285
class Meta:
8386
model = models.QualitySettings

cvat/apps/quality_control/views.py

Lines changed: 223 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
import os
56
import textwrap
7+
import requests
8+
import json
9+
import tempfile
10+
import shutil
11+
import zipfile
12+
import io
613
import random
714

15+
from collections import namedtuple, Counter, defaultdict
16+
from django.http import FileResponse
817
from django.db import transaction
918
from django.db.models import Q, QuerySet
1019
from django.http import HttpResponse
@@ -16,7 +25,8 @@
1625
extend_schema_view,
1726
)
1827
from rest_framework import mixins, status, viewsets
19-
from rest_framework.decorators import action
28+
from rest_framework.permissions import AllowAny
29+
from rest_framework.decorators import action, api_view, permission_classes
2030
from rest_framework.exceptions import NotFound, ValidationError
2131
from rest_framework.response import Response
2232

@@ -42,6 +52,7 @@
4252
QualityReportCreateSerializer,
4353
QualityReportSerializer,
4454
QualitySettingsSerializer,
55+
ConsensusCreateSerializer,
4556
)
4657
from rest_framework.permissions import IsAuthenticated
4758

@@ -584,6 +595,217 @@ def calculate_score(gt_samples, ds_samples, start_time=0):
584595
except Exception as e:
585596
raise ValidationError(f"An internal server error occurred: {str(e)}")
586597

598+
@extend_schema(
599+
summary='Get the consensus result from multiple ZIP files containing JSONs.',
600+
description=textwrap.dedent(
601+
"""
602+
Accepts multiple ZIP files containing JSON annotations via URLs, processes them to build a consensus result,
603+
and returns the consensus JSON file. Each ZIP file is expected to contain a single JSON file with annotations.
604+
The URLs should be provided in the request body as a list under the 'files' key.
605+
Each file object in the list must contain 'bucket_name', 'chain_id', 'escrow_address', and 'file_name' fields.
606+
"""
607+
),
608+
request=ConsensusCreateSerializer,
609+
responses={
610+
"200": OpenApiResponse(
611+
response={"type": "object"},
612+
description="Consensus JSON file containing combined annotations from the provided files."
613+
),
614+
"400": OpenApiResponse(description="Bad Request: Missing or invalid data."),
615+
},
616+
)
617+
@action(detail=False, methods=['POST'], url_path='consensus-reports', permission_classes=[IsAuthenticated])
618+
def process_json_from_url(self, request, *args, **kwargs):
619+
"""Process multiple ZIP files containing JSONs and return consensus result"""
620+
files = request.data.get('files', [])
621+
if not files or not isinstance(files, list):
622+
return Response({"error": "Invalid request"}, status=status.HTTP_400_BAD_REQUEST)
623+
624+
temp_dir = None
625+
try:
626+
temp_dir = tempfile.mkdtemp()
627+
datasets = []
628+
629+
def download_and_extract_json(url, temp_dir, file_prefix):
630+
"""Downloads a ZIP file from URL, extracts it, and returns the JSON content."""
631+
response = requests.get(url, stream=True)
632+
response.raise_for_status()
633+
634+
zip_path = os.path.join(temp_dir, f"{file_prefix}.zip")
635+
with open(zip_path, 'wb') as f:
636+
for chunk in response.iter_content(chunk_size=8192):
637+
if chunk:
638+
f.write(chunk)
639+
640+
extract_dir = os.path.join(temp_dir, f"extract_{file_prefix}")
641+
os.makedirs(extract_dir, exist_ok=True)
642+
643+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
644+
zip_ref.extractall(extract_dir)
645+
646+
json_files = [f for f in os.listdir(extract_dir) if f.endswith('.json')]
647+
if not json_files:
648+
raise ValueError(f"No JSON file found in ZIP from {url}")
649+
650+
json_path = os.path.join(extract_dir, json_files[0])
651+
with open(json_path, 'r', encoding='utf-8') as f:
652+
return json.load(f)
653+
654+
def build_consensus(raw_labels):
655+
"""
656+
Builds consensus from a list of raw labels
657+
"""
658+
emotion_counts = Counter()
659+
intensity_counts = Counter()
660+
total = 0
661+
662+
def split_label(label):
663+
"""
664+
Splits a label like 'Positive_Mild' into emotion and intensity
665+
"""
666+
if label == "Can't Predict":
667+
return None, None
668+
parts = label.split("_", 1)
669+
if len(parts) == 2:
670+
return parts[0], parts[1]
671+
return label, None
672+
673+
for lbl in raw_labels:
674+
emotion, intensity = split_label(lbl)
675+
if emotion is None: # Can't Predict
676+
continue
677+
emotion_counts[emotion] += 1
678+
if intensity:
679+
intensity_counts[intensity] += 1
680+
total += 1
681+
682+
result = {
683+
"raw_labels": raw_labels,
684+
"combined_label_emotion": [],
685+
"combined_label_intensity": [],
686+
"consensus_label_emotion": [],
687+
"consensus_label_intensity": []
688+
}
689+
690+
if total == 0: # all were Can't Predict
691+
result["combined_label_emotion"] = [{"label": "Can't Predict", "confidence": 1.0}]
692+
result["combined_label_intensity"] = []
693+
result["consensus_label_emotion"] = ["Can't Predict"]
694+
result["consensus_label_intensity"] = []
695+
return result
696+
697+
# Combined distributions
698+
for emo, count in emotion_counts.items():
699+
result["combined_label_emotion"].append({
700+
"label": emo,
701+
"confidence": count / total
702+
})
703+
704+
for inten, count in intensity_counts.items():
705+
result["combined_label_intensity"].append({
706+
"label": inten,
707+
"confidence": count / total
708+
})
709+
710+
# Consensus = majority or tie
711+
if emotion_counts:
712+
max_emotion = max(emotion_counts.values())
713+
result["consensus_label_emotion"] = [emo for emo, c in emotion_counts.items() if c == max_emotion]
714+
715+
if intensity_counts:
716+
max_intensity = max(intensity_counts.values())
717+
result["consensus_label_intensity"] = [inten for inten, c in intensity_counts.items() if c == max_intensity]
718+
719+
return result
720+
721+
def apply_consensus_logic(datasets):
722+
"""
723+
Applies consensus logic to multiple datasets
724+
"""
725+
if not datasets:
726+
return []
727+
728+
# Build consensus for each annotation across datasets
729+
consensus = []
730+
for items in zip(*datasets):
731+
raw_labels = [item["label"] for item in items]
732+
733+
consensus.append({
734+
"audio_file": items[0]["audio_file"],
735+
"start": items[0]["start"],
736+
"end": items[0]["end"],
737+
**build_consensus(raw_labels)
738+
})
739+
740+
return consensus
741+
742+
def apply_consensus_logic_audio(consensus):
743+
audio_groups = defaultdict(lambda: {"emotions": [], "intensities": []})
744+
745+
for ann in consensus:
746+
audio_file = ann["audio_file"]
747+
audio_groups[audio_file]["emotions"].extend(ann["consensus_label_emotion"])
748+
audio_groups[audio_file]["intensities"].extend(ann["consensus_label_intensity"])
749+
750+
audio_consensus = []
751+
for audio_file, vals in audio_groups.items():
752+
emotion_result = build_consensus([f"{e}_Mild" for e in vals["emotions"]]) # dummy intensity to reuse logic
753+
intensity_result = build_consensus([f"Positive_{i}" for i in vals["intensities"]]) # dummy emotion to reuse logic
754+
755+
audio_consensus.append({
756+
"audio_file": audio_file,
757+
"combined_label_emotion": emotion_result["combined_label_emotion"],
758+
"combined_label_intensity": intensity_result["combined_label_intensity"],
759+
"consensus_label_emotion": emotion_result["consensus_label_emotion"],
760+
"consensus_label_intensity": intensity_result["consensus_label_intensity"]
761+
})
762+
return audio_consensus
763+
764+
for index, file_obj in enumerate(files):
765+
# Extract required fields
766+
required_fields = ['bucket_name', 'chain_id', 'escrow_address', 'file_name']
767+
if not all(file_obj.get(field) for field in required_fields):
768+
return Response({"error": "Missing required fields"}, status=status.HTTP_400_BAD_REQUEST)
769+
770+
# Build URL and download JSON
771+
url = f"https://{file_obj['bucket_name']}.s3.ap-south-1.amazonaws.com/{file_obj['escrow_address']}%40{file_obj['chain_id']}/{file_obj['file_name']}"
772+
json_data = download_and_extract_json(url, temp_dir, f"file{index+1}")
773+
datasets.append(json_data)
774+
775+
if not datasets:
776+
return Response({"error": "No data found"}, status=status.HTTP_400_BAD_REQUEST)
777+
778+
# Process and return result
779+
consensus_data = apply_consensus_logic(datasets)
780+
audio_consensus_data = apply_consensus_logic_audio(consensus_data)
781+
output_file_path = os.path.join(temp_dir, "consensus.json")
782+
output_audio_file_path = os.path.join(temp_dir, "audio_consensus.json")
783+
784+
with open(output_file_path, 'w', encoding='utf-8') as f:
785+
json.dump(consensus_data, f)
786+
787+
with open(output_audio_file_path, 'w', encoding='utf-8') as f:
788+
json.dump(audio_consensus_data, f)
789+
790+
# Create in-memory ZIP file
791+
zip_buffer = io.BytesIO()
792+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
793+
zipf.writestr("consensus.json", json.dumps(consensus_data, ensure_ascii=False))
794+
zipf.writestr("audio_consensus.json", json.dumps(audio_consensus_data, ensure_ascii=False))
795+
796+
zip_buffer.seek(0) # Go to the start of the BytesIO buffer
797+
798+
# Create response
799+
response = HttpResponse(zip_buffer, content_type='application/zip')
800+
response['Content-Disposition'] = 'attachment; filename=consensus_bundle.zip'
801+
return response
802+
803+
except Exception as e:
804+
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
805+
finally:
806+
if temp_dir and os.path.exists(temp_dir):
807+
shutil.rmtree(temp_dir)
808+
587809
@extend_schema(
588810
operation_id="quality_retrieve_report_data",
589811
summary="Get quality report contents",

0 commit comments

Comments
 (0)