diff --git a/src/codomyrmex/events/core/event_bus.py b/src/codomyrmex/events/core/event_bus.py index 171d509f6..50d82b9d9 100644 --- a/src/codomyrmex/events/core/event_bus.py +++ b/src/codomyrmex/events/core/event_bus.py @@ -8,10 +8,11 @@ import asyncio import fnmatch import inspect +import re import threading from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any # Import logging @@ -40,6 +41,20 @@ class Subscription: is_async: bool = False filter_func: Callable[[Event], bool] | None = None priority: int = 0 # Higher numbers = higher priority + _literal_patterns: set[str] = field(default_factory=set, init=False, repr=False) + _regex_patterns: list[re.Pattern] = field( + default_factory=list, init=False, repr=False + ) + + def __post_init__(self): + self._literal_patterns = set() + self._regex_patterns = [] + for pattern in self.event_patterns: + p_str = pattern.value if hasattr(pattern, "value") else str(pattern) + if any(c in p_str for c in "*?[]"): + self._regex_patterns.append(re.compile(fnmatch.translate(p_str))) + else: + self._literal_patterns.add(p_str) def matches_event(self, event: Event) -> bool: """Check if this subscription matches an event.""" @@ -51,12 +66,13 @@ def matches_event(self, event: Event) -> bool: ) match_found = False - for pattern in self.event_patterns: - # Ensure pattern is a string for fnmatch - p_str = pattern.value if hasattr(pattern, "value") else str(pattern) - if fnmatch.fnmatch(event_type_str, p_str): - match_found = True - break + if event_type_str in self._literal_patterns: + match_found = True + else: + for regex in self._regex_patterns: + if regex.match(event_type_str): + match_found = True + break if not match_found: return False diff --git a/src/codomyrmex/events/integration_bus.py b/src/codomyrmex/events/integration_bus.py index f37bdb635..3ef3b4738 100644 --- a/src/codomyrmex/events/integration_bus.py +++ b/src/codomyrmex/events/integration_bus.py @@ -6,6 +6,7 @@ from __future__ import annotations import fnmatch +import re import time from collections import defaultdict from dataclasses import dataclass, field @@ -59,6 +60,7 @@ def __init__(self) -> None: str, list[tuple[Callable[[IntegrationEvent], None], int]] ] = defaultdict(list) self._history: list[IntegrationEvent] = [] + self._compiled_patterns: dict[str, re.Pattern] = {} def subscribe( self, @@ -76,6 +78,9 @@ def subscribe( self._handlers[topic].append((handler, priority)) # Keep handlers sorted by priority self._handlers[topic].sort(key=lambda x: x[1], reverse=True) + # Pre-compile the regex pattern if it has wildcards + if topic not in self._compiled_patterns and any(c in topic for c in "*?[]"): + self._compiled_patterns[topic] = re.compile(fnmatch.translate(topic)) def unsubscribe( self, topic: str, handler: Callable[[IntegrationEvent], None] @@ -89,9 +94,7 @@ def unsubscribe( return False original_len = len(self._handlers[topic]) - self._handlers[topic] = [ - h for h in self._handlers[topic] if h[0] != handler - ] + self._handlers[topic] = [h for h in self._handlers[topic] if h[0] != handler] return len(self._handlers[topic]) < original_len def emit( @@ -105,7 +108,12 @@ def emit( matching_handlers: list[tuple[Callable[[IntegrationEvent], None], int]] = [] for pattern, handlers in self._handlers.items(): - if pattern == topic or fnmatch.fnmatch(topic, pattern): + if pattern == topic: + matching_handlers.extend(handlers) + elif pattern in self._compiled_patterns: + if self._compiled_patterns[pattern].match(topic): + matching_handlers.extend(handlers) + elif fnmatch.fnmatch(topic, pattern): matching_handlers.extend(handlers) # Sort all matching handlers by priority