Skip to content

Commit e1790ef

Browse files
authored
Merge pull request #318 from JdeRobot/GUI_issue243
Issue 243 - Streamlit UI
2 parents 6ed318f + 8aa4631 commit e1790ef

15 files changed

+1734
-100
lines changed

app.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
import streamlit as st
2+
import os
3+
import sys
4+
import subprocess
5+
from tabs.dataset_viewer import dataset_viewer_tab
6+
from tabs.inference import inference_tab
7+
from tabs.evaluator import evaluator_tab
8+
9+
10+
def browse_folder():
11+
"""
12+
Opens a native folder selection dialog and returns the selected folder path.
13+
Works on Windows, macOS, and Linux (with zenity or kdialog).
14+
Returns None if cancelled or error.
15+
"""
16+
try:
17+
if sys.platform.startswith("win"):
18+
script = (
19+
"Add-Type -AssemblyName System.windows.forms;"
20+
"$f=New-Object System.Windows.Forms.FolderBrowserDialog;"
21+
'if($f.ShowDialog() -eq "OK"){Write-Output $f.SelectedPath}'
22+
)
23+
result = subprocess.run(
24+
["powershell", "-NoProfile", "-Command", script],
25+
capture_output=True,
26+
text=True,
27+
timeout=30,
28+
)
29+
folder = result.stdout.strip()
30+
return folder if folder else None
31+
elif sys.platform == "darwin":
32+
script = (
33+
'POSIX path of (choose folder with prompt "Select dataset folder:")'
34+
)
35+
result = subprocess.run(
36+
["osascript", "-e", script], capture_output=True, text=True, timeout=30
37+
)
38+
folder = result.stdout.strip()
39+
return folder if folder else None
40+
else:
41+
# Linux: try zenity, then kdialog
42+
for cmd in [
43+
[
44+
"zenity",
45+
"--file-selection",
46+
"--directory",
47+
"--title=Select dataset folder",
48+
],
49+
[
50+
"kdialog",
51+
"--getexistingdirectory",
52+
"--title",
53+
"Select dataset folder",
54+
],
55+
]:
56+
try:
57+
result = subprocess.run(
58+
cmd, capture_output=True, text=True, timeout=30
59+
)
60+
folder = result.stdout.strip()
61+
if folder:
62+
return folder
63+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
64+
continue
65+
return None
66+
except Exception:
67+
return None
68+
69+
70+
st.set_page_config(page_title="DetectionMetrics", layout="wide")
71+
72+
# st.title("DetectionMetrics")
73+
74+
PAGES = {
75+
"Dataset Viewer": dataset_viewer_tab,
76+
"Inference": inference_tab,
77+
"Evaluator": evaluator_tab,
78+
}
79+
80+
# Initialize commonly used session state keys
81+
st.session_state.setdefault("dataset_path", "")
82+
st.session_state.setdefault("dataset_type_selectbox", "Coco")
83+
st.session_state.setdefault("split_selectbox", "val")
84+
st.session_state.setdefault("config_option", "Manual Configuration")
85+
st.session_state.setdefault("confidence_threshold", 0.5)
86+
st.session_state.setdefault("nms_threshold", 0.5)
87+
st.session_state.setdefault("max_detections", 100)
88+
st.session_state.setdefault("device", "cpu")
89+
st.session_state.setdefault("batch_size", 1)
90+
st.session_state.setdefault("evaluation_step", 5)
91+
st.session_state.setdefault("detection_model", None)
92+
st.session_state.setdefault("detection_model_loaded", False)
93+
94+
# Sidebar: Dataset Inputs
95+
with st.sidebar:
96+
with st.expander("Dataset Inputs", expanded=True):
97+
# First row: Type and Split
98+
col1, col2 = st.columns(2)
99+
with col1:
100+
st.selectbox(
101+
"Type",
102+
["Coco", "Custom"],
103+
key="dataset_type_selectbox",
104+
)
105+
with col2:
106+
st.selectbox(
107+
"Split",
108+
["train", "val"],
109+
key="split_selectbox",
110+
)
111+
112+
# Second row: Path and Browse button
113+
col1, col2 = st.columns([3, 1])
114+
with col1:
115+
dataset_path_input = st.text_input(
116+
"Dataset Folder Path",
117+
value=st.session_state.get("dataset_path", ""),
118+
key="dataset_path_input",
119+
)
120+
with col2:
121+
st.markdown(
122+
"<div style='margin-bottom: 1.75rem;'></div>", unsafe_allow_html=True
123+
)
124+
if st.button("Browse", key="browse_button"):
125+
folder = browse_folder()
126+
if folder and os.path.isdir(folder):
127+
st.session_state["dataset_path"] = folder
128+
st.rerun()
129+
elif folder is not None:
130+
st.warning("Selected path is not a valid folder.")
131+
else:
132+
st.warning("Could not open folder browser. Please enter the path manually")
133+
134+
if dataset_path_input != st.session_state.get("dataset_path", ""):
135+
st.session_state["dataset_path"] = dataset_path_input
136+
137+
with st.expander("Model Inputs", expanded=False):
138+
st.file_uploader(
139+
"Model File (.pt, .onnx, .h5, .pb, .pth)",
140+
type=["pt", "onnx", "h5", "pb", "pth"],
141+
key="model_file",
142+
help="Upload your trained model file.",
143+
)
144+
st.file_uploader(
145+
"Ontology File (.json)",
146+
type=["json"],
147+
key="ontology_file",
148+
help="Upload a JSON file with class labels.",
149+
)
150+
st.radio(
151+
"Configuration Method:",
152+
["Manual Configuration", "Upload Config File"],
153+
key="config_option",
154+
horizontal=True,
155+
)
156+
if (
157+
st.session_state.get("config_option", "Manual Configuration")
158+
== "Upload Config File"
159+
):
160+
st.file_uploader(
161+
"Configuration File (.json)",
162+
type=["json"],
163+
key="config_file",
164+
help="Upload a JSON configuration file.",
165+
)
166+
else:
167+
col1, col2 = st.columns(2)
168+
with col1:
169+
st.slider(
170+
"Confidence Threshold",
171+
min_value=0.0,
172+
max_value=1.0,
173+
value=st.session_state.get("confidence_threshold", 0.5),
174+
step=0.01,
175+
key="confidence_threshold",
176+
help="Minimum confidence score for detections",
177+
)
178+
st.slider(
179+
"NMS Threshold",
180+
min_value=0.0,
181+
max_value=1.0,
182+
value=st.session_state.get("nms_threshold", 0.5),
183+
step=0.01,
184+
key="nms_threshold",
185+
help="Non-maximum suppression threshold",
186+
)
187+
st.number_input(
188+
"Max Detections/Image",
189+
min_value=1,
190+
max_value=1000,
191+
value=st.session_state.get("max_detections", 100),
192+
step=1,
193+
key="max_detections",
194+
)
195+
with col2:
196+
st.selectbox(
197+
"Device",
198+
["cpu", "cuda", "mps"],
199+
index=0 if st.session_state.get("device", "cpu") == "cpu" else 1,
200+
key="device",
201+
)
202+
st.number_input(
203+
"Batch Size",
204+
min_value=1,
205+
max_value=256,
206+
value=st.session_state.get("batch_size", 1),
207+
step=1,
208+
key="batch_size",
209+
)
210+
st.number_input(
211+
"Evaluation Step",
212+
min_value=0,
213+
max_value=1000,
214+
value=st.session_state.get("evaluation_step", 10),
215+
step=1,
216+
key="evaluation_step",
217+
help="Update UI with intermediate metrics every N images (0 = disable intermediate updates)",
218+
)
219+
220+
# Load model action in sidebar
221+
from detectionmetrics.models.torch_detection import TorchImageDetectionModel
222+
import json, tempfile
223+
224+
load_model_btn = st.button(
225+
"Load Model",
226+
type="primary",
227+
use_container_width=True,
228+
help="Load and save the model for use in the Inference tab",
229+
key="sidebar_load_model_btn",
230+
)
231+
232+
if load_model_btn:
233+
model_file = st.session_state.get("model_file")
234+
ontology_file = st.session_state.get("ontology_file")
235+
config_option = st.session_state.get(
236+
"config_option", "Manual Configuration"
237+
)
238+
config_file = (
239+
st.session_state.get("config_file")
240+
if config_option == "Upload Config File"
241+
else None
242+
)
243+
244+
# Prepare configuration
245+
config_data = None
246+
config_path = None
247+
try:
248+
if config_option == "Upload Config File":
249+
if config_file is not None:
250+
config_data = json.load(config_file)
251+
with tempfile.NamedTemporaryFile(
252+
delete=False, suffix=".json", mode="w"
253+
) as tmp_cfg:
254+
json.dump(config_data, tmp_cfg)
255+
config_path = tmp_cfg.name
256+
else:
257+
st.error("Please upload a configuration file")
258+
else:
259+
confidence_threshold = float(
260+
st.session_state.get("confidence_threshold", 0.5)
261+
)
262+
nms_threshold = float(st.session_state.get("nms_threshold", 0.5))
263+
max_detections = int(st.session_state.get("max_detections", 100))
264+
device = st.session_state.get("device", "cpu")
265+
batch_size = int(st.session_state.get("batch_size", 1))
266+
evaluation_step = int(st.session_state.get("evaluation_step", 5))
267+
config_data = {
268+
"confidence_threshold": confidence_threshold,
269+
"nms_threshold": nms_threshold,
270+
"max_detections_per_image": max_detections,
271+
"device": device,
272+
"batch_size": batch_size,
273+
"evaluation_step": evaluation_step,
274+
}
275+
with tempfile.NamedTemporaryFile(
276+
delete=False, suffix=".json", mode="w"
277+
) as tmp_cfg:
278+
json.dump(config_data, tmp_cfg)
279+
config_path = tmp_cfg.name
280+
except Exception as e:
281+
st.error(f"Failed to prepare configuration: {e}")
282+
config_path = None
283+
284+
if model_file is None:
285+
st.error("Please upload a model file")
286+
elif config_path is None:
287+
st.error("Please provide a valid model configuration")
288+
elif ontology_file is None:
289+
st.error("Please upload an ontology file")
290+
else:
291+
with st.spinner("Loading model..."):
292+
# Persist ontology to temp file
293+
try:
294+
ontology_data = json.load(ontology_file)
295+
with tempfile.NamedTemporaryFile(
296+
delete=False, suffix=".json", mode="w"
297+
) as tmp_ont:
298+
json.dump(ontology_data, tmp_ont)
299+
ontology_path = tmp_ont.name
300+
except Exception as e:
301+
st.error(f"Failed to load ontology: {e}")
302+
ontology_path = None
303+
304+
# Persist model to temp file
305+
try:
306+
with tempfile.NamedTemporaryFile(
307+
delete=False, suffix=".pt", mode="wb"
308+
) as tmp_model:
309+
tmp_model.write(model_file.read())
310+
model_temp_path = tmp_model.name
311+
except Exception as e:
312+
st.error(f"Failed to save model file: {e}")
313+
model_temp_path = None
314+
315+
if ontology_path and model_temp_path:
316+
try:
317+
model = TorchImageDetectionModel(
318+
model=model_temp_path,
319+
model_cfg=config_path,
320+
ontology_fname=ontology_path,
321+
device=st.session_state.get("device", "cpu"),
322+
)
323+
st.session_state.detection_model = model
324+
st.session_state.detection_model_loaded = True
325+
st.success("Model loaded and saved for inference")
326+
except Exception as e:
327+
st.session_state.detection_model = None
328+
st.session_state.detection_model_loaded = False
329+
st.error(f"Failed to load model: {e}")
330+
331+
# Main content area with horizontal tabs
332+
tab1, tab2, tab3 = st.tabs(["Dataset Viewer", "Inference", "Evaluator"])
333+
334+
with tab1:
335+
dataset_viewer_tab()
336+
with tab2:
337+
inference_tab()
338+
with tab3:
339+
evaluator_tab()

detectionmetrics/datasets/coco.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class CocoDataset(ImageDetectionDataset):
7777
"""
7878

7979
def __init__(self, annotation_file: str, image_dir: str, split: str = "train"):
80-
# Load COCO object once
80+
# Load COCO object once - this loads all annotations into memory with efficient indexing
8181
self.coco = COCO(annotation_file)
8282
self.image_dir = image_dir
8383
self.split = split
@@ -94,26 +94,26 @@ def read_annotation(
9494
) -> Tuple[List[List[float]], List[int], List[int]]:
9595
"""Return bounding boxes, labels, and category_ids for a given image ID.
9696
97+
This method uses COCO's efficient indexing to load annotations on-demand.
98+
The COCO object maintains an internal index that allows for very fast
99+
annotation retrieval without needing a separate cache.
100+
97101
:param fname: str (image_id in string form)
98102
:return: Tuple of (boxes, labels, category_ids)
99103
"""
100104
# Extract image ID (fname might be a path or ID string)
101105
try:
102-
image_id = int(
103-
os.path.basename(fname)
104-
) # handles both '123' and '/path/to/123'
106+
image_id = int(os.path.basename(fname))
105107
except ValueError:
106108
raise ValueError(f"Invalid annotation ID: {fname}")
107109

110+
# Use COCO's efficient indexing to get annotations for this image
111+
# getAnnIds() and loadAnns() are very fast due to COCO's internal indexing
108112
ann_ids = self.coco.getAnnIds(imgIds=image_id)
109113
anns = self.coco.loadAnns(ann_ids)
110114

111-
boxes = []
112-
labels = []
113-
category_ids = []
114-
115+
boxes, labels, category_ids = [], [], []
115116
for ann in anns:
116-
# Convert [x, y, width, height] to [x1, y1, x2, y2]
117117
x, y, w, h = ann["bbox"]
118118
boxes.append([x, y, x + w, y + h])
119119
labels.append(ann["category_id"])

0 commit comments

Comments
 (0)