From d37c2e459028c8724dc4d8006a3a50207f94e703 Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Wed, 27 May 2026 06:37:41 -0100 Subject: [PATCH 1/5] Type UAV test import stubs --- .../src/uav/test/test_auto_launch.py | 45 ++++++++++++------- .../sae_2025_ws/src/uav/test/test_flake8.py | 3 +- .../src/uav/test/test_fleet_launch.py | 40 +++++++++++------ .../src/uav/test/test_launch_helpers.py | 41 +++++++++++------ .../src/uav/test/test_mission_spec.py | 7 +-- .../test_payload_dlz_convex_hull_masking.py | 19 ++++---- .../uav/test/test_peer_runtime_contract.py | 25 ++++++----- .../src/uav/test/test_runtime_behavior.py | 39 ++++++++-------- .../src/uav/test/test_schema_validation.py | 3 +- 9 files changed, 133 insertions(+), 89 deletions(-) diff --git a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py index fecdfec8..9f0aefa5 100644 --- a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py +++ b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py @@ -6,6 +6,7 @@ import sys from types import SimpleNamespace import types +from typing import Any import pytest @@ -17,6 +18,8 @@ def _import_module_if_available(name: str): return None +std_msgs: Any +std_msgs_msg: Any if "std_msgs" not in sys.modules: std_msgs = _import_module_if_available("std_msgs") else: @@ -28,14 +31,14 @@ def _import_module_if_available(name: str): std_msgs.msg = std_msgs_msg sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) -ament_index_python = sys.modules.get("ament_index_python") +ament_index_python: Any = sys.modules.get("ament_index_python") if ament_index_python is None: ament_index_python = _import_module_if_available("ament_index_python") if ament_index_python is None: ament_index_python = types.ModuleType("ament_index_python") sys.modules["ament_index_python"] = ament_index_python -ament_index_packages = sys.modules.get("ament_index_python.packages") +ament_index_packages: Any = sys.modules.get("ament_index_python.packages") if ament_index_packages is None: ament_index_packages = _import_module_if_available("ament_index_python.packages") if ament_index_packages is None: @@ -51,8 +54,10 @@ class PackageNotFoundError(Exception): ament_index_packages.get_package_share_directory = lambda _name: str( Path(__file__).resolve().parents[1] ) -sys.modules["ament_index_python"].packages = ament_index_packages +setattr(sys.modules["ament_index_python"], "packages", ament_index_packages) +std_srvs: Any +std_srvs_srv: Any if "std_srvs" not in sys.modules: std_srvs = _import_module_if_available("std_srvs") else: @@ -69,7 +74,7 @@ class Trigger: std_srvs.srv = std_srvs_srv sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) -launch_module = sys.modules.get("launch") +launch_module: Any = sys.modules.get("launch") if launch_module is None: launch_module = _import_module_if_available("launch") if launch_module is None: @@ -78,7 +83,7 @@ class Trigger: if not hasattr(launch_module, "LaunchDescription"): launch_module.LaunchDescription = type("LaunchDescription", (), {}) -launch_actions = sys.modules.get("launch.actions") +launch_actions: Any = sys.modules.get("launch.actions") if launch_actions is None: launch_actions = _import_module_if_available("launch.actions") if launch_actions is None: @@ -93,7 +98,7 @@ class Trigger: if not hasattr(launch_actions, name): setattr(launch_actions, name, type(name, (), {})) -launch_sources = sys.modules.get("launch.launch_description_sources") +launch_sources: Any = sys.modules.get("launch.launch_description_sources") if launch_sources is None: launch_sources = _import_module_if_available("launch.launch_description_sources") if launch_sources is None: @@ -104,7 +109,7 @@ class Trigger: "PythonLaunchDescriptionSource", (), {} ) -launch_logging = sys.modules.get("launch.logging") +launch_logging: Any = sys.modules.get("launch.logging") if launch_logging is None: launch_logging = _import_module_if_available("launch.logging") if launch_logging is None: @@ -117,7 +122,7 @@ class Trigger: info=lambda *_a, **_k: None, ) -launch_substitutions = sys.modules.get("launch.substitutions") +launch_substitutions: Any = sys.modules.get("launch.substitutions") if launch_substitutions is None: launch_substitutions = _import_module_if_available("launch.substitutions") if launch_substitutions is None: @@ -126,7 +131,7 @@ class Trigger: if not hasattr(launch_substitutions, "LaunchConfiguration"): launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) -rclpy = sys.modules.get("rclpy") +rclpy: Any = sys.modules.get("rclpy") if rclpy is None: rclpy = _import_module_if_available("rclpy") if rclpy is None: @@ -139,7 +144,7 @@ class Trigger: if not hasattr(rclpy, "ok"): rclpy.ok = lambda: True -node_mod = sys.modules.get("rclpy.node") +node_mod: Any = sys.modules.get("rclpy.node") if node_mod is None: node_mod = _import_module_if_available("rclpy.node") if node_mod is None: @@ -153,7 +158,7 @@ def __init__(self, *_args, **_kwargs) -> None: node_mod.Node = Node -executors_mod = sys.modules.get("rclpy.executors") +executors_mod: Any = sys.modules.get("rclpy.executors") if executors_mod is None: executors_mod = _import_module_if_available("rclpy.executors") if executors_mod is None: @@ -166,7 +171,7 @@ class ExternalShutdownException(Exception): executors_mod.ExternalShutdownException = ExternalShutdownException -clock_mod = sys.modules.get("rclpy.clock") +clock_mod: Any = sys.modules.get("rclpy.clock") if clock_mod is None: clock_mod = _import_module_if_available("rclpy.clock") if clock_mod is None: @@ -175,7 +180,7 @@ class ExternalShutdownException(Exception): if not hasattr(clock_mod, "Clock"): clock_mod.Clock = type("Clock", (), {}) -parameter_mod = sys.modules.get("rclpy.parameter") +parameter_mod: Any = sys.modules.get("rclpy.parameter") if parameter_mod is None: parameter_mod = _import_module_if_available("rclpy.parameter") if parameter_mod is None: @@ -184,7 +189,7 @@ class ExternalShutdownException(Exception): if not hasattr(parameter_mod, "Parameter"): parameter_mod.Parameter = type("Parameter", (), {}) -validate_namespace_mod = sys.modules.get("rclpy.validate_namespace") +validate_namespace_mod: Any = sys.modules.get("rclpy.validate_namespace") if validate_namespace_mod is None: validate_namespace_mod = _import_module_if_available("rclpy.validate_namespace") if validate_namespace_mod is None: @@ -193,7 +198,7 @@ class ExternalShutdownException(Exception): if not hasattr(validate_namespace_mod, "validate_namespace"): validate_namespace_mod.validate_namespace = lambda namespace: None -validate_node_name_mod = sys.modules.get("rclpy.validate_node_name") +validate_node_name_mod: Any = sys.modules.get("rclpy.validate_node_name") if validate_node_name_mod is None: validate_node_name_mod = _import_module_if_available("rclpy.validate_node_name") if validate_node_name_mod is None: @@ -202,7 +207,7 @@ class ExternalShutdownException(Exception): if not hasattr(validate_node_name_mod, "validate_node_name"): validate_node_name_mod.validate_node_name = lambda node_name: None -qos_mod = sys.modules.get("rclpy.qos") +qos_mod: Any = sys.modules.get("rclpy.qos") if qos_mod is None: qos_mod = _import_module_if_available("rclpy.qos") if qos_mod is None: @@ -228,6 +233,9 @@ class ExternalShutdownException(Exception): from uav.runtime.ModeManager import ModeManager # noqa: E402 try: + uav_mission_module: Any + uav_manager_module: Any + UAVModeManager: Any import uav.runtime.uav_mission as uav_mission_module import uav.runtime.UAVModeManager as uav_manager_module from uav.runtime.UAVModeManager import UAVModeManager @@ -239,6 +247,8 @@ class ExternalShutdownException(Exception): UAVModeManager = None try: + payload_mission_module: Any + payload_manager_module: Any import uav.runtime.payload_mission as payload_mission_module import uav.runtime.PayloadModeManager as payload_manager_module except ModuleNotFoundError as exc: @@ -396,8 +406,9 @@ def _stub_mode_manager_init( def _load_main_launch_module(): launch_path = Path(__file__).resolve().parents[1] / "launch" / "main.launch.py" spec = importlib.util.spec_from_file_location("uav_main_launch", launch_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module diff --git a/controls/sae_2025_ws/src/uav/test/test_flake8.py b/controls/sae_2025_ws/src/uav/test/test_flake8.py index ee79f31a..9b88fa2f 100644 --- a/controls/sae_2025_ws/src/uav/test/test_flake8.py +++ b/controls/sae_2025_ws/src/uav/test/test_flake8.py @@ -14,12 +14,13 @@ from ament_flake8.main import main_with_errors import pytest +from typing import cast @pytest.mark.flake8 @pytest.mark.linter def test_flake8(): - rc, errors = main_with_errors(argv=[]) + rc, errors = cast(tuple[int, list[str]], main_with_errors(argv=[])) assert rc == 0, "Found %d code style errors / warnings:\n" % len( errors ) + "\n".join(errors) diff --git a/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py b/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py index 3a221ebf..d8505a0f 100644 --- a/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py +++ b/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py @@ -46,10 +46,12 @@ def _ensure_launch_import_stubs() -> None: ament_index_packages = types.ModuleType("ament_index_python.packages") sys.modules["ament_index_python.packages"] = ament_index_packages if not hasattr(ament_index_packages, "get_package_share_directory"): - ament_index_packages.get_package_share_directory = lambda _name: str( - Path(__file__).resolve().parents[1] + setattr( + ament_index_packages, + "get_package_share_directory", + lambda _name: str(Path(__file__).resolve().parents[1]), ) - ament_index_python.packages = ament_index_packages + setattr(ament_index_python, "packages", ament_index_packages) launch_module = sys.modules.get("launch") if launch_module is None: @@ -58,7 +60,7 @@ def _ensure_launch_import_stubs() -> None: launch_module = types.ModuleType("launch") sys.modules["launch"] = launch_module if not hasattr(launch_module, "LaunchDescription"): - launch_module.LaunchDescription = type("LaunchDescription", (), {}) + setattr(launch_module, "LaunchDescription", type("LaunchDescription", (), {})) launch_actions = sys.modules.get("launch.actions") if launch_actions is None: @@ -84,8 +86,10 @@ def _ensure_launch_import_stubs() -> None: launch_sources = types.ModuleType("launch.launch_description_sources") sys.modules["launch.launch_description_sources"] = launch_sources if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - launch_sources.PythonLaunchDescriptionSource = type( - "PythonLaunchDescriptionSource", (), {} + setattr( + launch_sources, + "PythonLaunchDescriptionSource", + type("PythonLaunchDescriptionSource", (), {}), ) launch_logging = sys.modules.get("launch.logging") @@ -95,10 +99,14 @@ def _ensure_launch_import_stubs() -> None: launch_logging = types.ModuleType("launch.logging") sys.modules["launch.logging"] = launch_logging if not hasattr(launch_logging, "get_logger"): - launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, + setattr( + launch_logging, + "get_logger", + lambda *_args, **_kwargs: SimpleNamespace( + warning=lambda *_a, **_k: None, + warn=lambda *_a, **_k: None, + info=lambda *_a, **_k: None, + ), ) launch_substitutions = sys.modules.get("launch.substitutions") @@ -108,7 +116,11 @@ def _ensure_launch_import_stubs() -> None: launch_substitutions = types.ModuleType("launch.substitutions") sys.modules["launch.substitutions"] = launch_substitutions if not hasattr(launch_substitutions, "LaunchConfiguration"): - launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) + setattr( + launch_substitutions, + "LaunchConfiguration", + type("LaunchConfiguration", (), {}), + ) def _load_fleet_module(): @@ -123,8 +135,9 @@ def _load_fleet_module(): spec = importlib.util.spec_from_file_location( "uav_fleet_launch_helpers", launch_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module @@ -255,8 +268,9 @@ def test_real_backend_does_not_import_sim(monkeypatch): spec = importlib.util.spec_from_file_location( "uav_fleet_launch_hardware_only", launch_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None real_import = builtins.__import__ diff --git a/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py b/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py index 6384fa50..a86b7a35 100644 --- a/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py +++ b/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py @@ -42,10 +42,12 @@ def _ensure_launch_import_stubs() -> None: ament_index_packages = types.ModuleType("ament_index_python.packages") sys.modules["ament_index_python.packages"] = ament_index_packages if not hasattr(ament_index_packages, "get_package_share_directory"): - ament_index_packages.get_package_share_directory = lambda _name: str( - Path(__file__).resolve().parents[1] + setattr( + ament_index_packages, + "get_package_share_directory", + lambda _name: str(Path(__file__).resolve().parents[1]), ) - ament_index_python.packages = ament_index_packages + setattr(ament_index_python, "packages", ament_index_packages) launch_module = sys.modules.get("launch") if launch_module is None: @@ -54,7 +56,7 @@ def _ensure_launch_import_stubs() -> None: launch_module = types.ModuleType("launch") sys.modules["launch"] = launch_module if not hasattr(launch_module, "LaunchDescription"): - launch_module.LaunchDescription = type("LaunchDescription", (), {}) + setattr(launch_module, "LaunchDescription", type("LaunchDescription", (), {})) launch_actions = sys.modules.get("launch.actions") if launch_actions is None: @@ -80,8 +82,10 @@ def _ensure_launch_import_stubs() -> None: launch_sources = types.ModuleType("launch.launch_description_sources") sys.modules["launch.launch_description_sources"] = launch_sources if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - launch_sources.PythonLaunchDescriptionSource = type( - "PythonLaunchDescriptionSource", (), {} + setattr( + launch_sources, + "PythonLaunchDescriptionSource", + type("PythonLaunchDescriptionSource", (), {}), ) launch_logging = sys.modules.get("launch.logging") @@ -91,10 +95,14 @@ def _ensure_launch_import_stubs() -> None: launch_logging = types.ModuleType("launch.logging") sys.modules["launch.logging"] = launch_logging if not hasattr(launch_logging, "get_logger"): - launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, + setattr( + launch_logging, + "get_logger", + lambda *_args, **_kwargs: SimpleNamespace( + warning=lambda *_a, **_k: None, + warn=lambda *_a, **_k: None, + info=lambda *_a, **_k: None, + ), ) launch_substitutions = sys.modules.get("launch.substitutions") @@ -104,7 +112,11 @@ def _ensure_launch_import_stubs() -> None: launch_substitutions = types.ModuleType("launch.substitutions") sys.modules["launch.substitutions"] = launch_substitutions if not hasattr(launch_substitutions, "LaunchConfiguration"): - launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) + setattr( + launch_substitutions, + "LaunchConfiguration", + type("LaunchConfiguration", (), {}), + ) launch_ros = sys.modules.get("launch_ros") if launch_ros is None: @@ -120,8 +132,8 @@ def _ensure_launch_import_stubs() -> None: launch_ros_actions = types.ModuleType("launch_ros.actions") sys.modules["launch_ros.actions"] = launch_ros_actions if not hasattr(launch_ros_actions, "Node"): - launch_ros_actions.Node = type("Node", (), {}) - launch_ros.actions = launch_ros_actions + setattr(launch_ros_actions, "Node", type("Node", (), {})) + setattr(launch_ros, "actions", launch_ros_actions) def _load_launch_module(filename: str, module_name: str): @@ -137,8 +149,9 @@ def _load_launch_module(filename: str, module_name: str): sys.path.insert(0, str(sim_package_root)) launch_path = package_root / "launch" / filename spec = importlib.util.spec_from_file_location(module_name, launch_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module diff --git a/controls/sae_2025_ws/src/uav/test/test_mission_spec.py b/controls/sae_2025_ws/src/uav/test/test_mission_spec.py index 340788b4..935c98be 100644 --- a/controls/sae_2025_ws/src/uav/test/test_mission_spec.py +++ b/controls/sae_2025_ws/src/uav/test/test_mission_spec.py @@ -5,12 +5,13 @@ from types import SimpleNamespace import sys import types +from typing import Any import pytest if "rclpy" not in sys.modules: - rclpy = types.ModuleType("rclpy") - node_mod = types.ModuleType("rclpy.node") + rclpy: Any = types.ModuleType("rclpy") + node_mod: Any = types.ModuleType("rclpy.node") class Node: def __init__(self, *_args, **_kwargs) -> None: @@ -343,7 +344,7 @@ def test_invalid_mode_target_is_rejected(monkeypatch, mission_target): def test_load_mode_class_accepts_module_path(monkeypatch): module_name = "uav.modes.payload.PayloadAprilTagApproachMode" - fake_module = types.ModuleType(module_name) + fake_module: Any = types.ModuleType(module_name) class PayloadAprilTagApproachMode(Mode): mission_target = "payload" diff --git a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py index d9a3e9a3..f3952269 100644 --- a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py +++ b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py @@ -5,6 +5,7 @@ from pathlib import Path from types import SimpleNamespace import types +from typing import Any import numpy as np import cv2 @@ -38,8 +39,8 @@ def _roi_ratio(mask: np.ndarray, x0: int, y0: int, x1: int, y1: int) -> float: def _install_import_stubs() -> None: if "rclpy" not in sys.modules: - rclpy = types.ModuleType("rclpy") - node_module = types.ModuleType("rclpy.node") + rclpy: Any = types.ModuleType("rclpy") + node_module: Any = types.ModuleType("rclpy.node") class Node: pass @@ -49,7 +50,7 @@ class Node: sys.modules.update({"rclpy": rclpy, "rclpy.node": node_module}) if "cv_bridge" not in sys.modules: - cv_bridge = types.ModuleType("cv_bridge") + cv_bridge: Any = types.ModuleType("cv_bridge") class CvBridge: def imgmsg_to_cv2(self, *_args, **_kwargs): @@ -62,8 +63,8 @@ def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): sys.modules["cv_bridge"] = cv_bridge if "sensor_msgs" not in sys.modules: - sensor_msgs = types.ModuleType("sensor_msgs") - sensor_msgs_msg = types.ModuleType("sensor_msgs.msg") + sensor_msgs: Any = types.ModuleType("sensor_msgs") + sensor_msgs_msg: Any = types.ModuleType("sensor_msgs.msg") sensor_msgs_msg.CompressedImage = type("CompressedImage", (), {}) sensor_msgs_msg.Image = type("Image", (), {}) sensor_msgs.msg = sensor_msgs_msg @@ -72,17 +73,17 @@ def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): ) if "uav.vehicles.Payload" not in sys.modules: - payload_module = types.ModuleType("uav.vehicles.Payload") + payload_module: Any = types.ModuleType("uav.vehicles.Payload") payload_module.Payload = type("Payload", (), {}) sys.modules["uav.vehicles.Payload"] = payload_module if "uav.vision_nodes" not in sys.modules: - vision_nodes = types.ModuleType("uav.vision_nodes") + vision_nodes: Any = types.ModuleType("uav.vision_nodes") vision_nodes.PayloadAprilTagNode = type("PayloadAprilTagNode", (), {}) sys.modules["uav.vision_nodes"] = vision_nodes if "uav.vision_nodes.payload_perception_common" not in sys.modules: - common = types.ModuleType("uav.vision_nodes.payload_perception_common") + common: Any = types.ModuleType("uav.vision_nodes.payload_perception_common") common.DEFAULT_TAG_FAMILY = "tag36h11" sys.modules["uav.vision_nodes.payload_perception_common"] = common @@ -90,7 +91,7 @@ def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): sys.modules["uav_interfaces"] = types.ModuleType("uav_interfaces") if "uav_interfaces.srv" not in sys.modules: - srv_module = types.ModuleType("uav_interfaces.srv") + srv_module: Any = types.ModuleType("uav_interfaces.srv") class PayloadAprilTagState: class Request: diff --git a/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py b/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py index a4890266..42eb5751 100644 --- a/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py +++ b/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py @@ -5,6 +5,7 @@ import textwrap from types import SimpleNamespace import types +from typing import Any import pytest @@ -15,17 +16,17 @@ def _placeholder(name: str): def _install_ros_test_doubles() -> None: if "rclpy" not in sys.modules: - rclpy = types.ModuleType("rclpy") + rclpy: Any = types.ModuleType("rclpy") rclpy.init = lambda *args, **kwargs: None rclpy.shutdown = lambda: None rclpy.ok = lambda: True - node_mod = types.ModuleType("rclpy.node") - executors_mod = types.ModuleType("rclpy.executors") - clock_mod = types.ModuleType("rclpy.clock") - parameter_mod = types.ModuleType("rclpy.parameter") - validate_namespace_mod = types.ModuleType("rclpy.validate_namespace") - validate_node_name_mod = types.ModuleType("rclpy.validate_node_name") - qos_mod = types.ModuleType("rclpy.qos") + node_mod: Any = types.ModuleType("rclpy.node") + executors_mod: Any = types.ModuleType("rclpy.executors") + clock_mod: Any = types.ModuleType("rclpy.clock") + parameter_mod: Any = types.ModuleType("rclpy.parameter") + validate_namespace_mod: Any = types.ModuleType("rclpy.validate_namespace") + validate_node_name_mod: Any = types.ModuleType("rclpy.validate_node_name") + qos_mod: Any = types.ModuleType("rclpy.qos") class Node: def __init__(self, *_args, **_kwargs) -> None: @@ -65,8 +66,8 @@ class ExternalShutdownException(Exception): ) if "std_srvs" not in sys.modules: - std_srvs = types.ModuleType("std_srvs") - std_srvs_srv = types.ModuleType("std_srvs.srv") + std_srvs: Any = types.ModuleType("std_srvs") + std_srvs_srv: Any = types.ModuleType("std_srvs.srv") class Trigger: Request = _placeholder("Request") @@ -77,8 +78,8 @@ class Trigger: sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) if "std_msgs" not in sys.modules: - std_msgs = types.ModuleType("std_msgs") - std_msgs_msg = types.ModuleType("std_msgs.msg") + std_msgs: Any = types.ModuleType("std_msgs") + std_msgs_msg: Any = types.ModuleType("std_msgs.msg") std_msgs_msg.Empty = _placeholder("Empty") std_msgs.msg = std_msgs_msg sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) diff --git a/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py b/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py index a1260623..c52ca3bf 100644 --- a/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py +++ b/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py @@ -4,6 +4,7 @@ import sys import types from types import SimpleNamespace +from typing import Any import pytest @@ -20,14 +21,14 @@ def _import_module_if_available(name: str): def _install_ros_test_doubles() -> None: - ament_index_python = sys.modules.get("ament_index_python") + ament_index_python: Any = sys.modules.get("ament_index_python") if ament_index_python is None: ament_index_python = _import_module_if_available("ament_index_python") if ament_index_python is None: ament_index_python = types.ModuleType("ament_index_python") sys.modules["ament_index_python"] = ament_index_python - ament_index_packages = sys.modules.get("ament_index_python.packages") + ament_index_packages: Any = sys.modules.get("ament_index_python.packages") if ament_index_packages is None: ament_index_packages = _import_module_if_available( "ament_index_python.packages" @@ -45,7 +46,7 @@ class PackageNotFoundError(Exception): ament_index_packages.get_package_share_directory = lambda _name: "" ament_index_python.packages = ament_index_packages - rclpy = sys.modules.get("rclpy") + rclpy: Any = sys.modules.get("rclpy") if rclpy is None: rclpy = _import_module_if_available("rclpy") if rclpy is None: @@ -54,7 +55,7 @@ class PackageNotFoundError(Exception): if not hasattr(rclpy, "ok"): rclpy.ok = lambda: True - node_mod = sys.modules.get("rclpy.node") + node_mod: Any = sys.modules.get("rclpy.node") if node_mod is None: node_mod = _import_module_if_available("rclpy.node") if node_mod is None: @@ -67,7 +68,7 @@ class Node: node_mod.Node = Node - executors_mod = sys.modules.get("rclpy.executors") + executors_mod: Any = sys.modules.get("rclpy.executors") if executors_mod is None: executors_mod = _import_module_if_available("rclpy.executors") if executors_mod is None: @@ -80,7 +81,7 @@ class ExternalShutdownException(Exception): executors_mod.ExternalShutdownException = ExternalShutdownException - clock_mod = sys.modules.get("rclpy.clock") + clock_mod: Any = sys.modules.get("rclpy.clock") if clock_mod is None: clock_mod = _import_module_if_available("rclpy.clock") if clock_mod is None: @@ -89,7 +90,7 @@ class ExternalShutdownException(Exception): if not hasattr(clock_mod, "Clock"): clock_mod.Clock = _placeholder("Clock") - parameter_mod = sys.modules.get("rclpy.parameter") + parameter_mod: Any = sys.modules.get("rclpy.parameter") if parameter_mod is None: parameter_mod = _import_module_if_available("rclpy.parameter") if parameter_mod is None: @@ -98,7 +99,7 @@ class ExternalShutdownException(Exception): if not hasattr(parameter_mod, "Parameter"): parameter_mod.Parameter = _placeholder("Parameter") - validate_namespace_mod = sys.modules.get("rclpy.validate_namespace") + validate_namespace_mod: Any = sys.modules.get("rclpy.validate_namespace") if validate_namespace_mod is None: validate_namespace_mod = _import_module_if_available("rclpy.validate_namespace") if validate_namespace_mod is None: @@ -107,7 +108,7 @@ class ExternalShutdownException(Exception): if not hasattr(validate_namespace_mod, "validate_namespace"): validate_namespace_mod.validate_namespace = lambda namespace: None - validate_node_name_mod = sys.modules.get("rclpy.validate_node_name") + validate_node_name_mod: Any = sys.modules.get("rclpy.validate_node_name") if validate_node_name_mod is None: validate_node_name_mod = _import_module_if_available("rclpy.validate_node_name") if validate_node_name_mod is None: @@ -116,7 +117,7 @@ class ExternalShutdownException(Exception): if not hasattr(validate_node_name_mod, "validate_node_name"): validate_node_name_mod.validate_node_name = lambda node_name: None - qos_mod = sys.modules.get("rclpy.qos") + qos_mod: Any = sys.modules.get("rclpy.qos") if qos_mod is None: qos_mod = _import_module_if_available("rclpy.qos") if qos_mod is None: @@ -140,8 +141,8 @@ class ExternalShutdownException(Exception): rclpy.qos = qos_mod if "std_srvs" not in sys.modules: - std_srvs = types.ModuleType("std_srvs") - std_srvs_srv = types.ModuleType("std_srvs.srv") + std_srvs: Any = types.ModuleType("std_srvs") + std_srvs_srv: Any = types.ModuleType("std_srvs.srv") class Trigger: Request = _placeholder("Request") @@ -152,15 +153,15 @@ class Trigger: sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) if "std_msgs" not in sys.modules: - std_msgs = types.ModuleType("std_msgs") - std_msgs_msg = types.ModuleType("std_msgs.msg") + std_msgs: Any = types.ModuleType("std_msgs") + std_msgs_msg: Any = types.ModuleType("std_msgs.msg") std_msgs_msg.Empty = _placeholder("Empty") std_msgs.msg = std_msgs_msg sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) if "px4_msgs" not in sys.modules: - px4_msgs = types.ModuleType("px4_msgs") - px4_msgs_msg = types.ModuleType("px4_msgs.msg") + px4_msgs: Any = types.ModuleType("px4_msgs") + px4_msgs_msg: Any = types.ModuleType("px4_msgs.msg") class VehicleStatus: NAVIGATION_STATE_AUTO_LOITER = 1 @@ -186,9 +187,9 @@ class VtolVehicleStatus: sys.modules.update({"px4_msgs": px4_msgs, "px4_msgs.msg": px4_msgs_msg}) if "payload_interfaces" not in sys.modules: - payload_interfaces = types.ModuleType("payload_interfaces") - payload_interfaces_msg = types.ModuleType("payload_interfaces.msg") - payload_interfaces_srv = types.ModuleType("payload_interfaces.srv") + payload_interfaces: Any = types.ModuleType("payload_interfaces") + payload_interfaces_msg: Any = types.ModuleType("payload_interfaces.msg") + payload_interfaces_srv: Any = types.ModuleType("payload_interfaces.srv") payload_interfaces_msg.DriveCommand = _placeholder("DriveCommand") payload_interfaces_msg.ServoCommand = _placeholder("ServoCommand") diff --git a/controls/sae_2025_ws/src/uav/test/test_schema_validation.py b/controls/sae_2025_ws/src/uav/test/test_schema_validation.py index bd9ff4e2..e5f286d1 100644 --- a/controls/sae_2025_ws/src/uav/test/test_schema_validation.py +++ b/controls/sae_2025_ws/src/uav/test/test_schema_validation.py @@ -24,8 +24,9 @@ def _load_fleet_module(): spec = importlib.util.spec_from_file_location( "uav_fleet_launch_schema_tests", launch_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module From e63b6398777cd71791ff468db66bb6104e523366 Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Wed, 27 May 2026 06:49:59 -0100 Subject: [PATCH 2/5] Type UAV test fakes --- .../src/uav/test/test_auto_launch.py | 6 ++-- .../src/uav/test/test_runtime_behavior.py | 30 ++++++++++--------- .../src/uav/test/test_schema_validation.py | 5 ++-- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py index 9f0aefa5..1103a542 100644 --- a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py +++ b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py @@ -6,7 +6,7 @@ import sys from types import SimpleNamespace import types -from typing import Any +from typing import Any, cast import pytest @@ -281,8 +281,8 @@ def cancel(self) -> None: self.cancelled = True -def _make_mode_manager(*, ready: bool) -> ModeManager: - manager = object.__new__(ModeManager) +def _make_mode_manager(*, ready: bool) -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = None manager.modes = {} manager.transitions = {} diff --git a/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py b/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py index c52ca3bf..613a4c2f 100644 --- a/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py +++ b/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py @@ -4,7 +4,7 @@ import sys import types from types import SimpleNamespace -from typing import Any +from typing import Any, cast import pytest @@ -290,8 +290,8 @@ def check_status(self) -> str: return self.status -def _make_mode_manager(*, vehicle=None, auto_launch: bool = False) -> ModeManager: - manager = object.__new__(ModeManager) +def _make_mode_manager(*, vehicle=None, auto_launch: bool = False) -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = vehicle manager.modes = {} manager.transitions = {} @@ -409,7 +409,9 @@ def _make_bootstrap(module_cls, params: dict[str, object]): return bootstrap -def _fake_mission_spec(*, target: str, is_uav: bool, is_payload: bool, vision_nodes=()): +def _fake_mission_spec( + *, target: str, is_uav: bool, is_payload: bool, vision_nodes=() +) -> Any: return SimpleNamespace( target=target, is_uav=is_uav, @@ -438,7 +440,7 @@ def __init__( required: int, optional: str = "default", ) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required self.optional = optional @@ -478,7 +480,7 @@ class FakeMode(Mode): mission_target = "uav" def __init__(self, node, vehicle: _ExpectedVehicle, required: int) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required def on_update(self, time_delta: float) -> None: @@ -510,7 +512,7 @@ class FakeMode(Mode): mission_target = "uav" def __init__(self, node, vehicle: _ExpectedVehicle, required: int) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required def on_update(self, time_delta: float) -> None: @@ -544,7 +546,7 @@ class FakeMode(Mode): mission_target = "uav" def __init__(self, node, vehicle: _ExpectedVehicle, required: int) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required def on_update(self, time_delta: float) -> None: @@ -615,7 +617,7 @@ def test_setup_vision_deduplicates_clients(monkeypatch): assert list(manager.vision_clients) == [canonical_name] assert ModeManager.get_vision_client(manager, FakeVisionNode) is client - assert created_clients == [(FakeVisionNode.srv, "vision/FakeVisionNode")] + assert created_clients == [(cast(Any, FakeVisionNode).srv, "vision/FakeVisionNode")] def test_setup_vision_rejects_vehicle_without_camera(): @@ -738,7 +740,7 @@ def test_create_entity_falls_back_to_raw_node_during_node_init(monkeypatch): raising=False, ) - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) publisher = ModeManager.create_publisher(manager, object, "/parameter_events", 10) @@ -914,7 +916,7 @@ def test_mode_manager_stop_vehicle_without_rclpy_guard(): _require_runtime_support() stop_calls: list[str] = [] - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: stop_calls.append("stop")) manager.get_logger = lambda: _FakeLogger() @@ -937,7 +939,7 @@ def deactivate(self) -> None: events.append("deactivate") self.active = False - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: events.append("stop")) manager.modes = {"start": _FakeMode()} manager.transitions = {} @@ -972,7 +974,7 @@ def deactivate(self) -> None: events.append("deactivate") self.active = False - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: events.append("stop")) manager.modes = {"start": _FakeMode()} manager.transitions = {} @@ -1012,7 +1014,7 @@ def deactivate(self) -> None: events.append("deactivate") self.active = False - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: events.append("stop")) manager.modes = {"start": _FakeMode()} manager.transitions = {} diff --git a/controls/sae_2025_ws/src/uav/test/test_schema_validation.py b/controls/sae_2025_ws/src/uav/test/test_schema_validation.py index e5f286d1..7c39c48c 100644 --- a/controls/sae_2025_ws/src/uav/test/test_schema_validation.py +++ b/controls/sae_2025_ws/src/uav/test/test_schema_validation.py @@ -4,6 +4,7 @@ import sys from pathlib import Path from types import SimpleNamespace +from typing import Any, cast import pytest import yaml @@ -51,8 +52,8 @@ def check_status(self) -> str: return self.status -def _make_mode_manager(): - manager = object.__new__(ModeManager) +def _make_mode_manager() -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = None manager.modes = {} manager.transitions = {} From b6c8c60f46cb57c2fb4679bdfb561f7485a1554e Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Wed, 27 May 2026 07:25:10 -0100 Subject: [PATCH 3/5] Type payload DLZ fixtures --- .../src/uav/test/test_payload_dlz_convex_hull_masking.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py index f3952269..db3d41d2 100644 --- a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py +++ b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py @@ -5,7 +5,7 @@ from pathlib import Path from types import SimpleNamespace import types -from typing import Any +from typing import Any, cast import numpy as np import cv2 @@ -19,8 +19,9 @@ def _bgr_from_hsv(h: int, s: int, v: int) -> tuple[int, int, int]: + hsv_pixel = np.array([[[h, s, v]]], dtype=np.uint8) pixel = cv2.cvtColor( - np.uint8([[[h, s, v]]]), + cast(Any, hsv_pixel), cv2.COLOR_HSV2BGR, )[0, 0] return int(pixel[0]), int(pixel[1]), int(pixel[2]) @@ -184,7 +185,8 @@ def _make_turn_to_center_mode(**kwargs): def _contour_bbox(mask: np.ndarray) -> tuple[int, int, int, int]: contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) assert len(contours) == 1 - return cv2.boundingRect(contours[0]) + x, y, width, height = cv2.boundingRect(contours[0]) + return int(x), int(y), int(width), int(height) def test_build_dlz_hull_mask_keeps_orange_rectangle_and_excludes_outside(): From b8e09be144c60ba7358cd29d9b37dfd18d2c086e Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Wed, 27 May 2026 07:00:25 -0100 Subject: [PATCH 4/5] Type UAV peer contract tests --- .../uav/test/test_peer_runtime_contract.py | 50 ++++++++++++------- .../src/uav/test/test_peer_stack_reconnect.py | 30 +++++++---- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py b/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py index 42eb5751..1a2e7df8 100644 --- a/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py +++ b/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py @@ -5,7 +5,7 @@ import textwrap from types import SimpleNamespace import types -from typing import Any +from typing import Any, Mapping, cast import pytest @@ -236,8 +236,8 @@ def create_service(self, _srv_type, service_name: str, *_args, **_kwargs): return SimpleNamespace(kind="service", name=service_name) -def _make_mode_manager() -> ModeManager: - manager = object.__new__(ModeManager) +def _make_mode_manager() -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = _FakeVehicle() manager.modes = {} manager.transitions = {} @@ -409,7 +409,9 @@ class PeerAwareMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -530,7 +532,9 @@ def __init__(self, node, vehicle) -> None: ) self.shared_pub = self.node.create_publisher(object, "/shared/debug", 1) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -575,7 +579,9 @@ def check_status(self) -> str: ) manager = _make_mode_manager() - mode = ModeManager.initialize_mode(manager, "fake.module.PeerAwareMode", {}) + mode = cast( + Any, ModeManager.initialize_mode(manager, "fake.module.PeerAwareMode", {}) + ) _configure_manager_for_mode(manager, mode) with manager._use_comm_builder( @@ -634,7 +640,7 @@ def __init__(self, node, vehicle) -> None: self.connection_checks: list[dict[str, bool]] = [] self.status_checks = 0 - def connection_ready(self, connection_status: dict[str, bool]) -> bool: + def connection_ready(self, connection_status: Mapping[str, bool]) -> bool: self.connection_checks.append(dict(connection_status)) return True @@ -642,7 +648,7 @@ def on_update(self, time_delta: float) -> None: self.update_calls.append(time_delta) def on_disconnect( - self, time_delta: float, connection_status: dict[str, bool] + self, time_delta: float, connection_status: Mapping[str, bool] ) -> None: self.disconnect_calls.append((time_delta, dict(connection_status))) @@ -688,12 +694,12 @@ def __init__(self, node, vehicle) -> None: self.connection_checks: list[dict[str, bool]] = [] self.status_checks = 0 - def connection_ready(self, connection_status: dict[str, bool]) -> bool: + def connection_ready(self, connection_status: Mapping[str, bool]) -> bool: self.connection_checks.append(dict(connection_status)) return False def on_disconnect( - self, time_delta: float, connection_status: dict[str, bool] + self, time_delta: float, connection_status: Mapping[str, bool] ) -> None: self.disconnect_calls.append((time_delta, dict(connection_status))) @@ -732,7 +738,9 @@ class PeerAwareMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -741,7 +749,7 @@ def on_update(self, time_delta: float) -> None: def check_status(self) -> str: return "continue" - mode = PeerAwareMode(_RecordingNode(), _FakeVehicle()) + mode = PeerAwareMode(cast(Any, _RecordingNode()), cast(Any, _FakeVehicle())) assert mode.connection_ready({"uav_1": True, "uav_2": True}) is True assert mode.connection_ready({"uav_1": True, "uav_2": False}) is False @@ -761,7 +769,9 @@ class PeerAwareMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -870,7 +880,9 @@ def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) self.node.create_publisher(object, "/shared/debug", 1) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_enter(self) -> None: @@ -930,7 +942,9 @@ class SharedStateMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -966,7 +980,9 @@ def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) self.node.create_subscription(object, "/uav_2/status", lambda _msg: None, 1) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_enter(self) -> None: @@ -980,7 +996,7 @@ def check_status(self) -> str: return "continue" node = _RecordingNode() - mode = InstrumentedPeerMode(node, _FakeVehicle()) + mode = InstrumentedPeerMode(cast(Any, node), cast(Any, _FakeVehicle())) mode.activate() assert _observed_peer_vehicle_names(node.calls, mode.peer_vehicle_names) == { diff --git a/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py b/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py index 6d6122dc..ac7dd177 100644 --- a/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py +++ b/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py @@ -9,6 +9,7 @@ import signal import subprocess import time +from typing import cast import pytest @@ -104,7 +105,7 @@ def __init__(self) -> None: String, f"/{vehicle_name}/peer_test/state", lambda message, vehicle_name=vehicle_name: self._on_status( - vehicle_name, message + vehicle_name, cast(String, message) ), 10, ) @@ -388,6 +389,12 @@ def _launch_vehicle_stack( ) +def _status_int(status: dict[str, object], key: str) -> int: + value = status[key] + assert isinstance(value, int) + return value + + @pytest.fixture def live_ros_environment(monkeypatch): _require_uav_package() @@ -483,9 +490,9 @@ def test_vehicle_stack_peer_reconnect_recovers_state_and_traffic( assert observer.shared_counts_by_sender.get("payload_0", 0) > 0 assert observer.shared_counts_by_sender.get("payload_1", 0) > 0 - payload_0_peer_before_disconnect = int(status_0["peer_received_total"]) - payload_0_shared_before_disconnect = int( - status_0["shared_remote_received_total"] + payload_0_peer_before_disconnect = _status_int(status_0, "peer_received_total") + payload_0_shared_before_disconnect = _status_int( + status_0, "shared_remote_received_total" ) payload_1_shared_events_before_disconnect = ( observer.shared_counts_by_sender.get("payload_1", 0) @@ -502,9 +509,12 @@ def test_vehicle_stack_peer_reconnect_recovers_state_and_traffic( state="waiting", disconnected_peers=["payload_1"], ) - assert int(waiting_0["peer_received_total"]) >= payload_0_peer_before_disconnect assert ( - int(waiting_0["shared_remote_received_total"]) + _status_int(waiting_0, "peer_received_total") + >= payload_0_peer_before_disconnect + ) + assert ( + _status_int(waiting_0, "shared_remote_received_total") >= payload_0_shared_before_disconnect ) @@ -563,15 +573,15 @@ def test_vehicle_stack_peer_reconnect_recovers_state_and_traffic( shared_remote_received_total=1, ) assert ( - int(reconnect_status_0["peer_received_total"]) + _status_int(reconnect_status_0, "peer_received_total") > payload_0_peer_before_disconnect ) assert ( - int(reconnect_status_0["shared_remote_received_total"]) + _status_int(reconnect_status_0, "shared_remote_received_total") > payload_0_shared_before_disconnect ) - assert int(reconnect_status_1["peer_received_total"]) > 0 - assert int(reconnect_status_1["shared_remote_received_total"]) > 0 + assert _status_int(reconnect_status_1, "peer_received_total") > 0 + assert _status_int(reconnect_status_1, "shared_remote_received_total") > 0 assert ( observer.shared_counts_by_sender.get("payload_1", 0) > payload_1_shared_events_before_disconnect From f6990c5956081c4c817565093c0d8e11d49b037c Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Wed, 27 May 2026 22:17:43 -0100 Subject: [PATCH 5/5] Consolidate UAV ROS test stubs --- .../src/uav/test/test_auto_launch.py | 222 +----------- .../src/uav/test/test_launch_helpers.py | 135 +------ .../src/uav/test/test_mission_spec.py | 21 +- .../test_payload_dlz_convex_hull_masking.py | 71 +--- .../src/uav/uav/test_support/__init__.py | 1 + .../src/uav/uav/test_support/ros_stubs.py | 340 ++++++++++++++++++ 6 files changed, 358 insertions(+), 432 deletions(-) create mode 100644 controls/sae_2025_ws/src/uav/uav/test_support/__init__.py create mode 100644 controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py diff --git a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py index 1103a542..ddb65303 100644 --- a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py +++ b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py @@ -1,234 +1,16 @@ from __future__ import annotations -import importlib import importlib.util from pathlib import Path -import sys from types import SimpleNamespace -import types from typing import Any, cast import pytest +from uav.test_support.ros_stubs import install_auto_launch_import_stubs -def _import_module_if_available(name: str): - try: - return importlib.import_module(name) - except ModuleNotFoundError: - return None - - -std_msgs: Any -std_msgs_msg: Any -if "std_msgs" not in sys.modules: - std_msgs = _import_module_if_available("std_msgs") -else: - std_msgs = sys.modules["std_msgs"] -if std_msgs is None: - std_msgs = types.ModuleType("std_msgs") - std_msgs_msg = types.ModuleType("std_msgs.msg") - std_msgs_msg.Empty = type("Empty", (), {}) - std_msgs.msg = std_msgs_msg - sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) - -ament_index_python: Any = sys.modules.get("ament_index_python") -if ament_index_python is None: - ament_index_python = _import_module_if_available("ament_index_python") -if ament_index_python is None: - ament_index_python = types.ModuleType("ament_index_python") - sys.modules["ament_index_python"] = ament_index_python - -ament_index_packages: Any = sys.modules.get("ament_index_python.packages") -if ament_index_packages is None: - ament_index_packages = _import_module_if_available("ament_index_python.packages") -if ament_index_packages is None: - ament_index_packages = types.ModuleType("ament_index_python.packages") - sys.modules["ament_index_python.packages"] = ament_index_packages -if not hasattr(ament_index_packages, "PackageNotFoundError"): - - class PackageNotFoundError(Exception): - pass - - ament_index_packages.PackageNotFoundError = PackageNotFoundError -if not hasattr(ament_index_packages, "get_package_share_directory"): - ament_index_packages.get_package_share_directory = lambda _name: str( - Path(__file__).resolve().parents[1] - ) -setattr(sys.modules["ament_index_python"], "packages", ament_index_packages) - -std_srvs: Any -std_srvs_srv: Any -if "std_srvs" not in sys.modules: - std_srvs = _import_module_if_available("std_srvs") -else: - std_srvs = sys.modules["std_srvs"] -if std_srvs is None: - std_srvs = types.ModuleType("std_srvs") - std_srvs_srv = types.ModuleType("std_srvs.srv") - - class Trigger: - Request = type("Request", (), {}) - Response = type("Response", (), {}) - - std_srvs_srv.Trigger = Trigger - std_srvs.srv = std_srvs_srv - sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) - -launch_module: Any = sys.modules.get("launch") -if launch_module is None: - launch_module = _import_module_if_available("launch") -if launch_module is None: - launch_module = types.ModuleType("launch") - sys.modules["launch"] = launch_module -if not hasattr(launch_module, "LaunchDescription"): - launch_module.LaunchDescription = type("LaunchDescription", (), {}) - -launch_actions: Any = sys.modules.get("launch.actions") -if launch_actions is None: - launch_actions = _import_module_if_available("launch.actions") -if launch_actions is None: - launch_actions = types.ModuleType("launch.actions") - sys.modules["launch.actions"] = launch_actions -for name in ( - "DeclareLaunchArgument", - "ExecuteProcess", - "IncludeLaunchDescription", - "OpaqueFunction", -): - if not hasattr(launch_actions, name): - setattr(launch_actions, name, type(name, (), {})) - -launch_sources: Any = sys.modules.get("launch.launch_description_sources") -if launch_sources is None: - launch_sources = _import_module_if_available("launch.launch_description_sources") -if launch_sources is None: - launch_sources = types.ModuleType("launch.launch_description_sources") - sys.modules["launch.launch_description_sources"] = launch_sources -if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - launch_sources.PythonLaunchDescriptionSource = type( - "PythonLaunchDescriptionSource", (), {} - ) - -launch_logging: Any = sys.modules.get("launch.logging") -if launch_logging is None: - launch_logging = _import_module_if_available("launch.logging") -if launch_logging is None: - launch_logging = types.ModuleType("launch.logging") - sys.modules["launch.logging"] = launch_logging -if not hasattr(launch_logging, "get_logger"): - launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, - ) -launch_substitutions: Any = sys.modules.get("launch.substitutions") -if launch_substitutions is None: - launch_substitutions = _import_module_if_available("launch.substitutions") -if launch_substitutions is None: - launch_substitutions = types.ModuleType("launch.substitutions") - sys.modules["launch.substitutions"] = launch_substitutions -if not hasattr(launch_substitutions, "LaunchConfiguration"): - launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) - -rclpy: Any = sys.modules.get("rclpy") -if rclpy is None: - rclpy = _import_module_if_available("rclpy") -if rclpy is None: - rclpy = types.ModuleType("rclpy") - sys.modules["rclpy"] = rclpy -if not hasattr(rclpy, "init"): - rclpy.init = lambda *args, **kwargs: None -if not hasattr(rclpy, "shutdown"): - rclpy.shutdown = lambda: None -if not hasattr(rclpy, "ok"): - rclpy.ok = lambda: True - -node_mod: Any = sys.modules.get("rclpy.node") -if node_mod is None: - node_mod = _import_module_if_available("rclpy.node") -if node_mod is None: - node_mod = types.ModuleType("rclpy.node") - sys.modules["rclpy.node"] = node_mod -if not hasattr(node_mod, "Node"): - - class Node: - def __init__(self, *_args, **_kwargs) -> None: - pass - - node_mod.Node = Node - -executors_mod: Any = sys.modules.get("rclpy.executors") -if executors_mod is None: - executors_mod = _import_module_if_available("rclpy.executors") -if executors_mod is None: - executors_mod = types.ModuleType("rclpy.executors") - sys.modules["rclpy.executors"] = executors_mod -if not hasattr(executors_mod, "ExternalShutdownException"): - - class ExternalShutdownException(Exception): - pass - - executors_mod.ExternalShutdownException = ExternalShutdownException - -clock_mod: Any = sys.modules.get("rclpy.clock") -if clock_mod is None: - clock_mod = _import_module_if_available("rclpy.clock") -if clock_mod is None: - clock_mod = types.ModuleType("rclpy.clock") - sys.modules["rclpy.clock"] = clock_mod -if not hasattr(clock_mod, "Clock"): - clock_mod.Clock = type("Clock", (), {}) - -parameter_mod: Any = sys.modules.get("rclpy.parameter") -if parameter_mod is None: - parameter_mod = _import_module_if_available("rclpy.parameter") -if parameter_mod is None: - parameter_mod = types.ModuleType("rclpy.parameter") - sys.modules["rclpy.parameter"] = parameter_mod -if not hasattr(parameter_mod, "Parameter"): - parameter_mod.Parameter = type("Parameter", (), {}) - -validate_namespace_mod: Any = sys.modules.get("rclpy.validate_namespace") -if validate_namespace_mod is None: - validate_namespace_mod = _import_module_if_available("rclpy.validate_namespace") -if validate_namespace_mod is None: - validate_namespace_mod = types.ModuleType("rclpy.validate_namespace") - sys.modules["rclpy.validate_namespace"] = validate_namespace_mod -if not hasattr(validate_namespace_mod, "validate_namespace"): - validate_namespace_mod.validate_namespace = lambda namespace: None - -validate_node_name_mod: Any = sys.modules.get("rclpy.validate_node_name") -if validate_node_name_mod is None: - validate_node_name_mod = _import_module_if_available("rclpy.validate_node_name") -if validate_node_name_mod is None: - validate_node_name_mod = types.ModuleType("rclpy.validate_node_name") - sys.modules["rclpy.validate_node_name"] = validate_node_name_mod -if not hasattr(validate_node_name_mod, "validate_node_name"): - validate_node_name_mod.validate_node_name = lambda node_name: None - -qos_mod: Any = sys.modules.get("rclpy.qos") -if qos_mod is None: - qos_mod = _import_module_if_available("rclpy.qos") -if qos_mod is None: - qos_mod = types.ModuleType("rclpy.qos") - sys.modules["rclpy.qos"] = qos_mod -if not hasattr(qos_mod, "QoSProfile"): - qos_mod.QoSProfile = type("QoSProfile", (), {}) -if not hasattr(qos_mod, "QoSReliabilityPolicy"): - qos_mod.QoSReliabilityPolicy = type("QoSReliabilityPolicy", (), {}) -if not hasattr(qos_mod, "QoSHistoryPolicy"): - qos_mod.QoSHistoryPolicy = type("QoSHistoryPolicy", (), {}) -if not hasattr(qos_mod, "QoSDurabilityPolicy"): - qos_mod.QoSDurabilityPolicy = type("QoSDurabilityPolicy", (), {}) - -rclpy.node = node_mod -rclpy.executors = executors_mod -rclpy.clock = clock_mod -rclpy.parameter = parameter_mod -rclpy.validate_namespace = validate_namespace_mod -rclpy.validate_node_name = validate_node_name_mod -rclpy.qos = qos_mod +install_auto_launch_import_stubs(Path(__file__).resolve().parents[1]) from uav.runtime.ModeManager import ModeManager # noqa: E402 diff --git a/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py b/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py index a86b7a35..d15dbe75 100644 --- a/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py +++ b/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py @@ -1,147 +1,24 @@ from __future__ import annotations -import importlib import importlib.util import sys from pathlib import Path from types import SimpleNamespace -import types import pytest - -def _import_module_if_available(name: str): - try: - return importlib.import_module(name) - except ModuleNotFoundError: - return None - - -def _purge_fake_rclpy_modules() -> None: - fake_rclpy = sys.modules.get("rclpy") - if fake_rclpy is not None and not hasattr(fake_rclpy, "__path__"): - for module_name in list(sys.modules): - if module_name == "rclpy" or module_name.startswith("rclpy."): - del sys.modules[module_name] - - -def _ensure_launch_import_stubs() -> None: - ament_index_python = sys.modules.get("ament_index_python") - if ament_index_python is None: - ament_index_python = _import_module_if_available("ament_index_python") - if ament_index_python is None: - ament_index_python = types.ModuleType("ament_index_python") - sys.modules["ament_index_python"] = ament_index_python - - ament_index_packages = sys.modules.get("ament_index_python.packages") - if ament_index_packages is None: - ament_index_packages = _import_module_if_available( - "ament_index_python.packages" - ) - if ament_index_packages is None: - ament_index_packages = types.ModuleType("ament_index_python.packages") - sys.modules["ament_index_python.packages"] = ament_index_packages - if not hasattr(ament_index_packages, "get_package_share_directory"): - setattr( - ament_index_packages, - "get_package_share_directory", - lambda _name: str(Path(__file__).resolve().parents[1]), - ) - setattr(ament_index_python, "packages", ament_index_packages) - - launch_module = sys.modules.get("launch") - if launch_module is None: - launch_module = _import_module_if_available("launch") - if launch_module is None: - launch_module = types.ModuleType("launch") - sys.modules["launch"] = launch_module - if not hasattr(launch_module, "LaunchDescription"): - setattr(launch_module, "LaunchDescription", type("LaunchDescription", (), {})) - - launch_actions = sys.modules.get("launch.actions") - if launch_actions is None: - launch_actions = _import_module_if_available("launch.actions") - if launch_actions is None: - launch_actions = types.ModuleType("launch.actions") - sys.modules["launch.actions"] = launch_actions - for name in ( - "DeclareLaunchArgument", - "ExecuteProcess", - "IncludeLaunchDescription", - "OpaqueFunction", - ): - if not hasattr(launch_actions, name): - setattr(launch_actions, name, type(name, (), {})) - - launch_sources = sys.modules.get("launch.launch_description_sources") - if launch_sources is None: - launch_sources = _import_module_if_available( - "launch.launch_description_sources" - ) - if launch_sources is None: - launch_sources = types.ModuleType("launch.launch_description_sources") - sys.modules["launch.launch_description_sources"] = launch_sources - if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - setattr( - launch_sources, - "PythonLaunchDescriptionSource", - type("PythonLaunchDescriptionSource", (), {}), - ) - - launch_logging = sys.modules.get("launch.logging") - if launch_logging is None: - launch_logging = _import_module_if_available("launch.logging") - if launch_logging is None: - launch_logging = types.ModuleType("launch.logging") - sys.modules["launch.logging"] = launch_logging - if not hasattr(launch_logging, "get_logger"): - setattr( - launch_logging, - "get_logger", - lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, - ), - ) - - launch_substitutions = sys.modules.get("launch.substitutions") - if launch_substitutions is None: - launch_substitutions = _import_module_if_available("launch.substitutions") - if launch_substitutions is None: - launch_substitutions = types.ModuleType("launch.substitutions") - sys.modules["launch.substitutions"] = launch_substitutions - if not hasattr(launch_substitutions, "LaunchConfiguration"): - setattr( - launch_substitutions, - "LaunchConfiguration", - type("LaunchConfiguration", (), {}), - ) - - launch_ros = sys.modules.get("launch_ros") - if launch_ros is None: - launch_ros = _import_module_if_available("launch_ros") - if launch_ros is None: - launch_ros = types.ModuleType("launch_ros") - sys.modules["launch_ros"] = launch_ros - - launch_ros_actions = sys.modules.get("launch_ros.actions") - if launch_ros_actions is None: - launch_ros_actions = _import_module_if_available("launch_ros.actions") - if launch_ros_actions is None: - launch_ros_actions = types.ModuleType("launch_ros.actions") - sys.modules["launch_ros.actions"] = launch_ros_actions - if not hasattr(launch_ros_actions, "Node"): - setattr(launch_ros_actions, "Node", type("Node", (), {})) - setattr(launch_ros, "actions", launch_ros_actions) +from uav.test_support.ros_stubs import ( + ensure_launch_import_stubs, + purge_fake_rclpy_modules, +) def _load_launch_module(filename: str, module_name: str): # Runtime-behavior tests install lightweight rclpy doubles into sys.modules. # launch_ros must see the real ROS Python packages if they are available. - _purge_fake_rclpy_modules() - _ensure_launch_import_stubs() + purge_fake_rclpy_modules() package_root = Path(__file__).resolve().parents[1] + ensure_launch_import_stubs(package_root, include_launch_ros=True) if str(package_root) not in sys.path: sys.path.insert(0, str(package_root)) sim_package_root = Path(__file__).resolve().parents[2] / "sim" diff --git a/controls/sae_2025_ws/src/uav/test/test_mission_spec.py b/controls/sae_2025_ws/src/uav/test/test_mission_spec.py index 935c98be..b2919be1 100644 --- a/controls/sae_2025_ws/src/uav/test/test_mission_spec.py +++ b/controls/sae_2025_ws/src/uav/test/test_mission_spec.py @@ -9,28 +9,21 @@ import pytest -if "rclpy" not in sys.modules: - rclpy: Any = types.ModuleType("rclpy") - node_mod: Any = types.ModuleType("rclpy.node") +from uav.test_support.ros_stubs import ensure_basic_rclpy_stubs - class Node: - def __init__(self, *_args, **_kwargs) -> None: - pass - node_mod.Node = Node - rclpy.node = node_mod - sys.modules.update({"rclpy": rclpy, "rclpy.node": node_mod}) +ensure_basic_rclpy_stubs() -from uav.modes.Mode import Mode -import uav.runtime.mission_spec as mission_spec_module -import uav.runtime.schema as schema_module -from uav.runtime.mission_spec import ( +from uav.modes.Mode import Mode # noqa: E402 +import uav.runtime.mission_spec as mission_spec_module # noqa: E402 +import uav.runtime.schema as schema_module # noqa: E402 +from uav.runtime.mission_spec import ( # noqa: E402 MissionSpec, load_mode_class, load_mission_spec, mission_path_for_name, mission_root, -) +) # noqa: E402 def _write_mission(tmp_path: Path, contents: str) -> Path: diff --git a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py index db3d41d2..1f302c4e 100644 --- a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py +++ b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py @@ -4,7 +4,6 @@ import sys from pathlib import Path from types import SimpleNamespace -import types from typing import Any, cast import numpy as np @@ -16,6 +15,7 @@ sys.path.insert(0, str(PACKAGE_ROOT)) from uav.cv.dlz_convex_hull import build_dlz_hull_mask # noqa: E402 +from uav.test_support.ros_stubs import install_payload_mode_import_stubs # noqa: E402 def _bgr_from_hsv(h: int, s: int, v: int) -> tuple[int, int, int]: @@ -38,74 +38,7 @@ def _roi_ratio(mask: np.ndarray, x0: int, y0: int, x1: int, y1: int) -> float: return float(np.count_nonzero(region)) / float(region.size) -def _install_import_stubs() -> None: - if "rclpy" not in sys.modules: - rclpy: Any = types.ModuleType("rclpy") - node_module: Any = types.ModuleType("rclpy.node") - - class Node: - pass - - node_module.Node = Node - rclpy.node = node_module - sys.modules.update({"rclpy": rclpy, "rclpy.node": node_module}) - - if "cv_bridge" not in sys.modules: - cv_bridge: Any = types.ModuleType("cv_bridge") - - class CvBridge: - def imgmsg_to_cv2(self, *_args, **_kwargs): - raise NotImplementedError - - def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): - return SimpleNamespace(header=SimpleNamespace(stamp=None)) - - cv_bridge.CvBridge = CvBridge - sys.modules["cv_bridge"] = cv_bridge - - if "sensor_msgs" not in sys.modules: - sensor_msgs: Any = types.ModuleType("sensor_msgs") - sensor_msgs_msg: Any = types.ModuleType("sensor_msgs.msg") - sensor_msgs_msg.CompressedImage = type("CompressedImage", (), {}) - sensor_msgs_msg.Image = type("Image", (), {}) - sensor_msgs.msg = sensor_msgs_msg - sys.modules.update( - {"sensor_msgs": sensor_msgs, "sensor_msgs.msg": sensor_msgs_msg} - ) - - if "uav.vehicles.Payload" not in sys.modules: - payload_module: Any = types.ModuleType("uav.vehicles.Payload") - payload_module.Payload = type("Payload", (), {}) - sys.modules["uav.vehicles.Payload"] = payload_module - - if "uav.vision_nodes" not in sys.modules: - vision_nodes: Any = types.ModuleType("uav.vision_nodes") - vision_nodes.PayloadAprilTagNode = type("PayloadAprilTagNode", (), {}) - sys.modules["uav.vision_nodes"] = vision_nodes - - if "uav.vision_nodes.payload_perception_common" not in sys.modules: - common: Any = types.ModuleType("uav.vision_nodes.payload_perception_common") - common.DEFAULT_TAG_FAMILY = "tag36h11" - sys.modules["uav.vision_nodes.payload_perception_common"] = common - - if "uav_interfaces" not in sys.modules: - sys.modules["uav_interfaces"] = types.ModuleType("uav_interfaces") - - if "uav_interfaces.srv" not in sys.modules: - srv_module: Any = types.ModuleType("uav_interfaces.srv") - - class PayloadAprilTagState: - class Request: - pass - - class Response: - pass - - srv_module.PayloadAprilTagState = PayloadAprilTagState - sys.modules["uav_interfaces.srv"] = srv_module - - -_install_import_stubs() +install_payload_mode_import_stubs() PayloadCornerNavigateMode = importlib.import_module( "uav.modes.payload.PayloadCornerNavigateMode" diff --git a/controls/sae_2025_ws/src/uav/uav/test_support/__init__.py b/controls/sae_2025_ws/src/uav/uav/test_support/__init__.py new file mode 100644 index 00000000..92a9d1d7 --- /dev/null +++ b/controls/sae_2025_ws/src/uav/uav/test_support/__init__.py @@ -0,0 +1 @@ +"""Test-only support helpers for UAV package tests.""" diff --git a/controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py b/controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py new file mode 100644 index 00000000..57199bee --- /dev/null +++ b/controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import importlib +from pathlib import Path +import sys +from types import SimpleNamespace +import types +from typing import Any + + +def import_module_if_available(name: str) -> Any | None: + try: + return importlib.import_module(name) + except ModuleNotFoundError: + return None + + +def ensure_basic_rclpy_stubs() -> None: + if "rclpy" in sys.modules: + return + + rclpy: Any = types.ModuleType("rclpy") + node_module: Any = types.ModuleType("rclpy.node") + + class Node: + def __init__(self, *_args, **_kwargs) -> None: + pass + + node_module.Node = Node + rclpy.node = node_module + sys.modules.update({"rclpy": rclpy, "rclpy.node": node_module}) + + +def ensure_runtime_rclpy_stubs() -> None: + rclpy: Any = sys.modules.get("rclpy") or import_module_if_available("rclpy") + if rclpy is None: + rclpy = types.ModuleType("rclpy") + sys.modules["rclpy"] = rclpy + if not hasattr(rclpy, "init"): + rclpy.init = lambda *args, **kwargs: None + if not hasattr(rclpy, "shutdown"): + rclpy.shutdown = lambda: None + if not hasattr(rclpy, "ok"): + rclpy.ok = lambda: True + + node_mod: Any = sys.modules.get("rclpy.node") or import_module_if_available( + "rclpy.node" + ) + if node_mod is None: + node_mod = types.ModuleType("rclpy.node") + sys.modules["rclpy.node"] = node_mod + if not hasattr(node_mod, "Node"): + + class Node: + def __init__(self, *_args, **_kwargs) -> None: + pass + + node_mod.Node = Node + + executors_mod: Any = sys.modules.get( + "rclpy.executors" + ) or import_module_if_available("rclpy.executors") + if executors_mod is None: + executors_mod = types.ModuleType("rclpy.executors") + sys.modules["rclpy.executors"] = executors_mod + if not hasattr(executors_mod, "ExternalShutdownException"): + + class ExternalShutdownException(Exception): + pass + + executors_mod.ExternalShutdownException = ExternalShutdownException + + for module_name, attr_name in ( + ("rclpy.clock", "Clock"), + ("rclpy.parameter", "Parameter"), + ): + module: Any = sys.modules.get(module_name) or import_module_if_available( + module_name + ) + if module is None: + module = types.ModuleType(module_name) + sys.modules[module_name] = module + if not hasattr(module, attr_name): + setattr(module, attr_name, type(attr_name, (), {})) + + validate_namespace_mod: Any = sys.modules.get( + "rclpy.validate_namespace" + ) or import_module_if_available("rclpy.validate_namespace") + if validate_namespace_mod is None: + validate_namespace_mod = types.ModuleType("rclpy.validate_namespace") + sys.modules["rclpy.validate_namespace"] = validate_namespace_mod + if not hasattr(validate_namespace_mod, "validate_namespace"): + validate_namespace_mod.validate_namespace = lambda namespace: None + + validate_node_name_mod: Any = sys.modules.get( + "rclpy.validate_node_name" + ) or import_module_if_available("rclpy.validate_node_name") + if validate_node_name_mod is None: + validate_node_name_mod = types.ModuleType("rclpy.validate_node_name") + sys.modules["rclpy.validate_node_name"] = validate_node_name_mod + if not hasattr(validate_node_name_mod, "validate_node_name"): + validate_node_name_mod.validate_node_name = lambda node_name: None + + qos_mod: Any = sys.modules.get("rclpy.qos") or import_module_if_available( + "rclpy.qos" + ) + if qos_mod is None: + qos_mod = types.ModuleType("rclpy.qos") + sys.modules["rclpy.qos"] = qos_mod + for name in ( + "QoSProfile", + "QoSReliabilityPolicy", + "QoSHistoryPolicy", + "QoSDurabilityPolicy", + ): + if not hasattr(qos_mod, name): + setattr(qos_mod, name, type(name, (), {})) + + rclpy.node = node_mod + rclpy.executors = executors_mod + rclpy.clock = sys.modules["rclpy.clock"] + rclpy.parameter = sys.modules["rclpy.parameter"] + rclpy.validate_namespace = validate_namespace_mod + rclpy.validate_node_name = validate_node_name_mod + rclpy.qos = qos_mod + + +def purge_fake_rclpy_modules() -> None: + fake_rclpy = sys.modules.get("rclpy") + if fake_rclpy is not None and not hasattr(fake_rclpy, "__path__"): + for module_name in list(sys.modules): + if module_name == "rclpy" or module_name.startswith("rclpy."): + del sys.modules[module_name] + + +def ensure_std_msgs_stub() -> None: + std_msgs: Any = sys.modules.get("std_msgs") or import_module_if_available( + "std_msgs" + ) + if std_msgs is not None: + return + std_msgs = types.ModuleType("std_msgs") + std_msgs_msg: Any = types.ModuleType("std_msgs.msg") + std_msgs_msg.Empty = type("Empty", (), {}) + std_msgs.msg = std_msgs_msg + sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) + + +def ensure_std_srvs_stub() -> None: + std_srvs: Any = sys.modules.get("std_srvs") or import_module_if_available( + "std_srvs" + ) + if std_srvs is not None: + return + std_srvs = types.ModuleType("std_srvs") + std_srvs_srv: Any = types.ModuleType("std_srvs.srv") + + class Trigger: + Request = type("Request", (), {}) + Response = type("Response", (), {}) + + std_srvs_srv.Trigger = Trigger + std_srvs.srv = std_srvs_srv + sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) + + +def ensure_ament_index_stub(package_root: Path) -> None: + ament_index_python: Any = sys.modules.get( + "ament_index_python" + ) or import_module_if_available("ament_index_python") + if ament_index_python is None: + ament_index_python = types.ModuleType("ament_index_python") + sys.modules["ament_index_python"] = ament_index_python + + ament_index_packages: Any = sys.modules.get( + "ament_index_python.packages" + ) or import_module_if_available("ament_index_python.packages") + if ament_index_packages is None: + ament_index_packages = types.ModuleType("ament_index_python.packages") + sys.modules["ament_index_python.packages"] = ament_index_packages + + if not hasattr(ament_index_packages, "PackageNotFoundError"): + + class PackageNotFoundError(Exception): + pass + + ament_index_packages.PackageNotFoundError = PackageNotFoundError + if not hasattr(ament_index_packages, "get_package_share_directory"): + ament_index_packages.get_package_share_directory = lambda _name: str( + package_root + ) + ament_index_python.packages = ament_index_packages + + +def ensure_launch_import_stubs( + package_root: Path, *, include_launch_ros: bool = False +) -> None: + ensure_ament_index_stub(package_root) + + launch_module: Any = sys.modules.get("launch") or import_module_if_available( + "launch" + ) + if launch_module is None: + launch_module = types.ModuleType("launch") + sys.modules["launch"] = launch_module + if not hasattr(launch_module, "LaunchDescription"): + launch_module.LaunchDescription = type("LaunchDescription", (), {}) + + launch_actions: Any = sys.modules.get( + "launch.actions" + ) or import_module_if_available("launch.actions") + if launch_actions is None: + launch_actions = types.ModuleType("launch.actions") + sys.modules["launch.actions"] = launch_actions + for name in ( + "DeclareLaunchArgument", + "ExecuteProcess", + "IncludeLaunchDescription", + "OpaqueFunction", + ): + if not hasattr(launch_actions, name): + setattr(launch_actions, name, type(name, (), {})) + + launch_sources: Any = sys.modules.get( + "launch.launch_description_sources" + ) or import_module_if_available("launch.launch_description_sources") + if launch_sources is None: + launch_sources = types.ModuleType("launch.launch_description_sources") + sys.modules["launch.launch_description_sources"] = launch_sources + if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): + launch_sources.PythonLaunchDescriptionSource = type( + "PythonLaunchDescriptionSource", (), {} + ) + + launch_logging: Any = sys.modules.get( + "launch.logging" + ) or import_module_if_available("launch.logging") + if launch_logging is None: + launch_logging = types.ModuleType("launch.logging") + sys.modules["launch.logging"] = launch_logging + if not hasattr(launch_logging, "get_logger"): + launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( + warning=lambda *_a, **_k: None, + warn=lambda *_a, **_k: None, + info=lambda *_a, **_k: None, + ) + + launch_substitutions: Any = sys.modules.get( + "launch.substitutions" + ) or import_module_if_available("launch.substitutions") + if launch_substitutions is None: + launch_substitutions = types.ModuleType("launch.substitutions") + sys.modules["launch.substitutions"] = launch_substitutions + if not hasattr(launch_substitutions, "LaunchConfiguration"): + launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) + + if not include_launch_ros: + return + + launch_ros: Any = sys.modules.get("launch_ros") or import_module_if_available( + "launch_ros" + ) + if launch_ros is None: + launch_ros = types.ModuleType("launch_ros") + sys.modules["launch_ros"] = launch_ros + + launch_ros_actions: Any = sys.modules.get( + "launch_ros.actions" + ) or import_module_if_available("launch_ros.actions") + if launch_ros_actions is None: + launch_ros_actions = types.ModuleType("launch_ros.actions") + sys.modules["launch_ros.actions"] = launch_ros_actions + if not hasattr(launch_ros_actions, "Node"): + launch_ros_actions.Node = type("Node", (), {}) + launch_ros.actions = launch_ros_actions + + +def install_auto_launch_import_stubs(package_root: Path) -> None: + ensure_std_msgs_stub() + ensure_std_srvs_stub() + ensure_launch_import_stubs(package_root) + ensure_runtime_rclpy_stubs() + + +def install_payload_mode_import_stubs() -> None: + ensure_basic_rclpy_stubs() + + if "cv_bridge" not in sys.modules: + cv_bridge: Any = types.ModuleType("cv_bridge") + + class CvBridge: + def imgmsg_to_cv2(self, *_args, **_kwargs): + raise NotImplementedError + + def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): + return SimpleNamespace(header=SimpleNamespace(stamp=None)) + + cv_bridge.CvBridge = CvBridge + sys.modules["cv_bridge"] = cv_bridge + + if "sensor_msgs" not in sys.modules: + sensor_msgs: Any = types.ModuleType("sensor_msgs") + sensor_msgs_msg: Any = types.ModuleType("sensor_msgs.msg") + sensor_msgs_msg.CompressedImage = type("CompressedImage", (), {}) + sensor_msgs_msg.Image = type("Image", (), {}) + sensor_msgs.msg = sensor_msgs_msg + sys.modules.update( + {"sensor_msgs": sensor_msgs, "sensor_msgs.msg": sensor_msgs_msg} + ) + + if "uav.vehicles.Payload" not in sys.modules: + payload_module: Any = types.ModuleType("uav.vehicles.Payload") + payload_module.Payload = type("Payload", (), {}) + sys.modules["uav.vehicles.Payload"] = payload_module + + if "uav.vision_nodes" not in sys.modules: + vision_nodes: Any = types.ModuleType("uav.vision_nodes") + vision_nodes.PayloadAprilTagNode = type("PayloadAprilTagNode", (), {}) + sys.modules["uav.vision_nodes"] = vision_nodes + + if "uav.vision_nodes.payload_perception_common" not in sys.modules: + common: Any = types.ModuleType("uav.vision_nodes.payload_perception_common") + common.DEFAULT_TAG_FAMILY = "tag36h11" + sys.modules["uav.vision_nodes.payload_perception_common"] = common + + if "uav_interfaces" not in sys.modules: + sys.modules["uav_interfaces"] = types.ModuleType("uav_interfaces") + + if "uav_interfaces.srv" not in sys.modules: + srv_module: Any = types.ModuleType("uav_interfaces.srv") + + class PayloadAprilTagState: + class Request: + pass + + class Response: + pass + + srv_module.PayloadAprilTagState = PayloadAprilTagState + sys.modules["uav_interfaces.srv"] = srv_module