Skip to content

Commit 9fdce80

Browse files
authored
Handle case when no classification model exists (#20257)
1 parent 12f8c3f commit 9fdce80

File tree

1 file changed

+50
-16
lines changed

1 file changed

+50
-16
lines changed

frigate/data_processing/real_time/custom_classification.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def __init__(
4848
self.requestor = requestor
4949
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
5050
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
51-
self.interpreter: Interpreter = None
52-
self.tensor_input_details: dict[str, Any] = None
53-
self.tensor_output_details: dict[str, Any] = None
51+
self.interpreter: Interpreter | None = None
52+
self.tensor_input_details: dict[str, Any] | None = None
53+
self.tensor_output_details: dict[str, Any] | None = None
5454
self.labelmap: dict[int, str] = {}
5555
self.classifications_per_second = EventsPerSecond()
5656
self.inference_speed = InferenceSpeed(
@@ -61,17 +61,24 @@ def __init__(
6161

6262
@redirect_output_to_logger(logger, logging.DEBUG)
6363
def __build_detector(self) -> None:
64+
model_path = os.path.join(self.model_dir, "model.tflite")
65+
labelmap_path = os.path.join(self.model_dir, "labelmap.txt")
66+
67+
if not os.path.exists(model_path) or not os.path.exists(labelmap_path):
68+
self.interpreter = None
69+
self.tensor_input_details = None
70+
self.tensor_output_details = None
71+
self.labelmap = {}
72+
return
73+
6474
self.interpreter = Interpreter(
65-
model_path=os.path.join(self.model_dir, "model.tflite"),
75+
model_path=model_path,
6676
num_threads=2,
6777
)
6878
self.interpreter.allocate_tensors()
6979
self.tensor_input_details = self.interpreter.get_input_details()
7080
self.tensor_output_details = self.interpreter.get_output_details()
71-
self.labelmap = load_labels(
72-
os.path.join(self.model_dir, "labelmap.txt"),
73-
prefill=0,
74-
)
81+
self.labelmap = load_labels(labelmap_path, prefill=0)
7582
self.classifications_per_second.start()
7683

7784
def __update_metrics(self, duration: float) -> None:
@@ -140,6 +147,16 @@ def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
140147
logger.warning("Failed to resize image for state classification")
141148
return
142149

150+
if self.interpreter is None:
151+
write_classification_attempt(
152+
self.train_dir,
153+
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
154+
now,
155+
"unknown",
156+
0.0,
157+
)
158+
return
159+
143160
input = np.expand_dims(frame, axis=0)
144161
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
145162
self.interpreter.invoke()
@@ -197,10 +214,10 @@ def __init__(
197214
self.model_config = model_config
198215
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
199216
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
200-
self.interpreter: Interpreter = None
217+
self.interpreter: Interpreter | None = None
201218
self.sub_label_publisher = sub_label_publisher
202-
self.tensor_input_details: dict[str, Any] = None
203-
self.tensor_output_details: dict[str, Any] = None
219+
self.tensor_input_details: dict[str, Any] | None = None
220+
self.tensor_output_details: dict[str, Any] | None = None
204221
self.detected_objects: dict[str, float] = {}
205222
self.labelmap: dict[int, str] = {}
206223
self.classifications_per_second = EventsPerSecond()
@@ -211,17 +228,24 @@ def __init__(
211228

212229
@redirect_output_to_logger(logger, logging.DEBUG)
213230
def __build_detector(self) -> None:
231+
model_path = os.path.join(self.model_dir, "model.tflite")
232+
labelmap_path = os.path.join(self.model_dir, "labelmap.txt")
233+
234+
if not os.path.exists(model_path) or not os.path.exists(labelmap_path):
235+
self.interpreter = None
236+
self.tensor_input_details = None
237+
self.tensor_output_details = None
238+
self.labelmap = {}
239+
return
240+
214241
self.interpreter = Interpreter(
215-
model_path=os.path.join(self.model_dir, "model.tflite"),
242+
model_path=model_path,
216243
num_threads=2,
217244
)
218245
self.interpreter.allocate_tensors()
219246
self.tensor_input_details = self.interpreter.get_input_details()
220247
self.tensor_output_details = self.interpreter.get_output_details()
221-
self.labelmap = load_labels(
222-
os.path.join(self.model_dir, "labelmap.txt"),
223-
prefill=0,
224-
)
248+
self.labelmap = load_labels(labelmap_path, prefill=0)
225249

226250
def __update_metrics(self, duration: float) -> None:
227251
self.classifications_per_second.update()
@@ -265,6 +289,16 @@ def process_frame(self, obj_data, frame):
265289
logger.warning("Failed to resize image for state classification")
266290
return
267291

292+
if self.interpreter is None:
293+
write_classification_attempt(
294+
self.train_dir,
295+
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
296+
now,
297+
"unknown",
298+
0.0,
299+
)
300+
return
301+
268302
input = np.expand_dims(crop, axis=0)
269303
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
270304
self.interpreter.invoke()

0 commit comments

Comments
 (0)