diff --git a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py index fecdfec8..ddb65303 100644 --- a/controls/sae_2025_ws/src/uav/test/test_auto_launch.py +++ b/controls/sae_2025_ws/src/uav/test/test_auto_launch.py @@ -1,233 +1,23 @@ from __future__ import annotations -import importlib import importlib.util from pathlib import Path -import sys from types import SimpleNamespace -import types +from typing import Any, cast import pytest +from uav.test_support.ros_stubs import install_auto_launch_import_stubs -def _import_module_if_available(name: str): - try: - return importlib.import_module(name) - except ModuleNotFoundError: - return None - - -if "std_msgs" not in sys.modules: - std_msgs = _import_module_if_available("std_msgs") -else: - std_msgs = sys.modules["std_msgs"] -if std_msgs is None: - std_msgs = types.ModuleType("std_msgs") - std_msgs_msg = types.ModuleType("std_msgs.msg") - std_msgs_msg.Empty = type("Empty", (), {}) - std_msgs.msg = std_msgs_msg - sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) - -ament_index_python = sys.modules.get("ament_index_python") -if ament_index_python is None: - ament_index_python = _import_module_if_available("ament_index_python") -if ament_index_python is None: - ament_index_python = types.ModuleType("ament_index_python") - sys.modules["ament_index_python"] = ament_index_python - -ament_index_packages = sys.modules.get("ament_index_python.packages") -if ament_index_packages is None: - ament_index_packages = _import_module_if_available("ament_index_python.packages") -if ament_index_packages is None: - ament_index_packages = types.ModuleType("ament_index_python.packages") - sys.modules["ament_index_python.packages"] = ament_index_packages -if not hasattr(ament_index_packages, "PackageNotFoundError"): - - class PackageNotFoundError(Exception): - pass - - ament_index_packages.PackageNotFoundError = PackageNotFoundError -if not hasattr(ament_index_packages, "get_package_share_directory"): - ament_index_packages.get_package_share_directory = lambda _name: str( - Path(__file__).resolve().parents[1] - ) -sys.modules["ament_index_python"].packages = ament_index_packages - -if "std_srvs" not in sys.modules: - std_srvs = _import_module_if_available("std_srvs") -else: - std_srvs = sys.modules["std_srvs"] -if std_srvs is None: - std_srvs = types.ModuleType("std_srvs") - std_srvs_srv = types.ModuleType("std_srvs.srv") - - class Trigger: - Request = type("Request", (), {}) - Response = type("Response", (), {}) - - std_srvs_srv.Trigger = Trigger - std_srvs.srv = std_srvs_srv - sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) - -launch_module = sys.modules.get("launch") -if launch_module is None: - launch_module = _import_module_if_available("launch") -if launch_module is None: - launch_module = types.ModuleType("launch") - sys.modules["launch"] = launch_module -if not hasattr(launch_module, "LaunchDescription"): - launch_module.LaunchDescription = type("LaunchDescription", (), {}) - -launch_actions = sys.modules.get("launch.actions") -if launch_actions is None: - launch_actions = _import_module_if_available("launch.actions") -if launch_actions is None: - launch_actions = types.ModuleType("launch.actions") - sys.modules["launch.actions"] = launch_actions -for name in ( - "DeclareLaunchArgument", - "ExecuteProcess", - "IncludeLaunchDescription", - "OpaqueFunction", -): - if not hasattr(launch_actions, name): - setattr(launch_actions, name, type(name, (), {})) - -launch_sources = sys.modules.get("launch.launch_description_sources") -if launch_sources is None: - launch_sources = _import_module_if_available("launch.launch_description_sources") -if launch_sources is None: - launch_sources = types.ModuleType("launch.launch_description_sources") - sys.modules["launch.launch_description_sources"] = launch_sources -if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - launch_sources.PythonLaunchDescriptionSource = type( - "PythonLaunchDescriptionSource", (), {} - ) - -launch_logging = sys.modules.get("launch.logging") -if launch_logging is None: - launch_logging = _import_module_if_available("launch.logging") -if launch_logging is None: - launch_logging = types.ModuleType("launch.logging") - sys.modules["launch.logging"] = launch_logging -if not hasattr(launch_logging, "get_logger"): - launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, - ) -launch_substitutions = sys.modules.get("launch.substitutions") -if launch_substitutions is None: - launch_substitutions = _import_module_if_available("launch.substitutions") -if launch_substitutions is None: - launch_substitutions = types.ModuleType("launch.substitutions") - sys.modules["launch.substitutions"] = launch_substitutions -if not hasattr(launch_substitutions, "LaunchConfiguration"): - launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) - -rclpy = sys.modules.get("rclpy") -if rclpy is None: - rclpy = _import_module_if_available("rclpy") -if rclpy is None: - rclpy = types.ModuleType("rclpy") - sys.modules["rclpy"] = rclpy -if not hasattr(rclpy, "init"): - rclpy.init = lambda *args, **kwargs: None -if not hasattr(rclpy, "shutdown"): - rclpy.shutdown = lambda: None -if not hasattr(rclpy, "ok"): - rclpy.ok = lambda: True - -node_mod = sys.modules.get("rclpy.node") -if node_mod is None: - node_mod = _import_module_if_available("rclpy.node") -if node_mod is None: - node_mod = types.ModuleType("rclpy.node") - sys.modules["rclpy.node"] = node_mod -if not hasattr(node_mod, "Node"): - - class Node: - def __init__(self, *_args, **_kwargs) -> None: - pass - - node_mod.Node = Node - -executors_mod = sys.modules.get("rclpy.executors") -if executors_mod is None: - executors_mod = _import_module_if_available("rclpy.executors") -if executors_mod is None: - executors_mod = types.ModuleType("rclpy.executors") - sys.modules["rclpy.executors"] = executors_mod -if not hasattr(executors_mod, "ExternalShutdownException"): - - class ExternalShutdownException(Exception): - pass - - executors_mod.ExternalShutdownException = ExternalShutdownException - -clock_mod = sys.modules.get("rclpy.clock") -if clock_mod is None: - clock_mod = _import_module_if_available("rclpy.clock") -if clock_mod is None: - clock_mod = types.ModuleType("rclpy.clock") - sys.modules["rclpy.clock"] = clock_mod -if not hasattr(clock_mod, "Clock"): - clock_mod.Clock = type("Clock", (), {}) - -parameter_mod = sys.modules.get("rclpy.parameter") -if parameter_mod is None: - parameter_mod = _import_module_if_available("rclpy.parameter") -if parameter_mod is None: - parameter_mod = types.ModuleType("rclpy.parameter") - sys.modules["rclpy.parameter"] = parameter_mod -if not hasattr(parameter_mod, "Parameter"): - parameter_mod.Parameter = type("Parameter", (), {}) - -validate_namespace_mod = sys.modules.get("rclpy.validate_namespace") -if validate_namespace_mod is None: - validate_namespace_mod = _import_module_if_available("rclpy.validate_namespace") -if validate_namespace_mod is None: - validate_namespace_mod = types.ModuleType("rclpy.validate_namespace") - sys.modules["rclpy.validate_namespace"] = validate_namespace_mod -if not hasattr(validate_namespace_mod, "validate_namespace"): - validate_namespace_mod.validate_namespace = lambda namespace: None - -validate_node_name_mod = sys.modules.get("rclpy.validate_node_name") -if validate_node_name_mod is None: - validate_node_name_mod = _import_module_if_available("rclpy.validate_node_name") -if validate_node_name_mod is None: - validate_node_name_mod = types.ModuleType("rclpy.validate_node_name") - sys.modules["rclpy.validate_node_name"] = validate_node_name_mod -if not hasattr(validate_node_name_mod, "validate_node_name"): - validate_node_name_mod.validate_node_name = lambda node_name: None - -qos_mod = sys.modules.get("rclpy.qos") -if qos_mod is None: - qos_mod = _import_module_if_available("rclpy.qos") -if qos_mod is None: - qos_mod = types.ModuleType("rclpy.qos") - sys.modules["rclpy.qos"] = qos_mod -if not hasattr(qos_mod, "QoSProfile"): - qos_mod.QoSProfile = type("QoSProfile", (), {}) -if not hasattr(qos_mod, "QoSReliabilityPolicy"): - qos_mod.QoSReliabilityPolicy = type("QoSReliabilityPolicy", (), {}) -if not hasattr(qos_mod, "QoSHistoryPolicy"): - qos_mod.QoSHistoryPolicy = type("QoSHistoryPolicy", (), {}) -if not hasattr(qos_mod, "QoSDurabilityPolicy"): - qos_mod.QoSDurabilityPolicy = type("QoSDurabilityPolicy", (), {}) - -rclpy.node = node_mod -rclpy.executors = executors_mod -rclpy.clock = clock_mod -rclpy.parameter = parameter_mod -rclpy.validate_namespace = validate_namespace_mod -rclpy.validate_node_name = validate_node_name_mod -rclpy.qos = qos_mod +install_auto_launch_import_stubs(Path(__file__).resolve().parents[1]) from uav.runtime.ModeManager import ModeManager # noqa: E402 try: + uav_mission_module: Any + uav_manager_module: Any + UAVModeManager: Any import uav.runtime.uav_mission as uav_mission_module import uav.runtime.UAVModeManager as uav_manager_module from uav.runtime.UAVModeManager import UAVModeManager @@ -239,6 +29,8 @@ class ExternalShutdownException(Exception): UAVModeManager = None try: + payload_mission_module: Any + payload_manager_module: Any import uav.runtime.payload_mission as payload_mission_module import uav.runtime.PayloadModeManager as payload_manager_module except ModuleNotFoundError as exc: @@ -271,8 +63,8 @@ def cancel(self) -> None: self.cancelled = True -def _make_mode_manager(*, ready: bool) -> ModeManager: - manager = object.__new__(ModeManager) +def _make_mode_manager(*, ready: bool) -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = None manager.modes = {} manager.transitions = {} @@ -396,8 +188,9 @@ def _stub_mode_manager_init( def _load_main_launch_module(): launch_path = Path(__file__).resolve().parents[1] / "launch" / "main.launch.py" spec = importlib.util.spec_from_file_location("uav_main_launch", launch_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module diff --git a/controls/sae_2025_ws/src/uav/test/test_flake8.py b/controls/sae_2025_ws/src/uav/test/test_flake8.py index ee79f31a..9b88fa2f 100644 --- a/controls/sae_2025_ws/src/uav/test/test_flake8.py +++ b/controls/sae_2025_ws/src/uav/test/test_flake8.py @@ -14,12 +14,13 @@ from ament_flake8.main import main_with_errors import pytest +from typing import cast @pytest.mark.flake8 @pytest.mark.linter def test_flake8(): - rc, errors = main_with_errors(argv=[]) + rc, errors = cast(tuple[int, list[str]], main_with_errors(argv=[])) assert rc == 0, "Found %d code style errors / warnings:\n" % len( errors ) + "\n".join(errors) diff --git a/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py b/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py index 3a221ebf..d8505a0f 100644 --- a/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py +++ b/controls/sae_2025_ws/src/uav/test/test_fleet_launch.py @@ -46,10 +46,12 @@ def _ensure_launch_import_stubs() -> None: ament_index_packages = types.ModuleType("ament_index_python.packages") sys.modules["ament_index_python.packages"] = ament_index_packages if not hasattr(ament_index_packages, "get_package_share_directory"): - ament_index_packages.get_package_share_directory = lambda _name: str( - Path(__file__).resolve().parents[1] + setattr( + ament_index_packages, + "get_package_share_directory", + lambda _name: str(Path(__file__).resolve().parents[1]), ) - ament_index_python.packages = ament_index_packages + setattr(ament_index_python, "packages", ament_index_packages) launch_module = sys.modules.get("launch") if launch_module is None: @@ -58,7 +60,7 @@ def _ensure_launch_import_stubs() -> None: launch_module = types.ModuleType("launch") sys.modules["launch"] = launch_module if not hasattr(launch_module, "LaunchDescription"): - launch_module.LaunchDescription = type("LaunchDescription", (), {}) + setattr(launch_module, "LaunchDescription", type("LaunchDescription", (), {})) launch_actions = sys.modules.get("launch.actions") if launch_actions is None: @@ -84,8 +86,10 @@ def _ensure_launch_import_stubs() -> None: launch_sources = types.ModuleType("launch.launch_description_sources") sys.modules["launch.launch_description_sources"] = launch_sources if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - launch_sources.PythonLaunchDescriptionSource = type( - "PythonLaunchDescriptionSource", (), {} + setattr( + launch_sources, + "PythonLaunchDescriptionSource", + type("PythonLaunchDescriptionSource", (), {}), ) launch_logging = sys.modules.get("launch.logging") @@ -95,10 +99,14 @@ def _ensure_launch_import_stubs() -> None: launch_logging = types.ModuleType("launch.logging") sys.modules["launch.logging"] = launch_logging if not hasattr(launch_logging, "get_logger"): - launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, + setattr( + launch_logging, + "get_logger", + lambda *_args, **_kwargs: SimpleNamespace( + warning=lambda *_a, **_k: None, + warn=lambda *_a, **_k: None, + info=lambda *_a, **_k: None, + ), ) launch_substitutions = sys.modules.get("launch.substitutions") @@ -108,7 +116,11 @@ def _ensure_launch_import_stubs() -> None: launch_substitutions = types.ModuleType("launch.substitutions") sys.modules["launch.substitutions"] = launch_substitutions if not hasattr(launch_substitutions, "LaunchConfiguration"): - launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) + setattr( + launch_substitutions, + "LaunchConfiguration", + type("LaunchConfiguration", (), {}), + ) def _load_fleet_module(): @@ -123,8 +135,9 @@ def _load_fleet_module(): spec = importlib.util.spec_from_file_location( "uav_fleet_launch_helpers", launch_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module @@ -255,8 +268,9 @@ def test_real_backend_does_not_import_sim(monkeypatch): spec = importlib.util.spec_from_file_location( "uav_fleet_launch_hardware_only", launch_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None real_import = builtins.__import__ diff --git a/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py b/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py index 6384fa50..d15dbe75 100644 --- a/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py +++ b/controls/sae_2025_ws/src/uav/test/test_launch_helpers.py @@ -1,135 +1,24 @@ from __future__ import annotations -import importlib import importlib.util import sys from pathlib import Path from types import SimpleNamespace -import types import pytest - -def _import_module_if_available(name: str): - try: - return importlib.import_module(name) - except ModuleNotFoundError: - return None - - -def _purge_fake_rclpy_modules() -> None: - fake_rclpy = sys.modules.get("rclpy") - if fake_rclpy is not None and not hasattr(fake_rclpy, "__path__"): - for module_name in list(sys.modules): - if module_name == "rclpy" or module_name.startswith("rclpy."): - del sys.modules[module_name] - - -def _ensure_launch_import_stubs() -> None: - ament_index_python = sys.modules.get("ament_index_python") - if ament_index_python is None: - ament_index_python = _import_module_if_available("ament_index_python") - if ament_index_python is None: - ament_index_python = types.ModuleType("ament_index_python") - sys.modules["ament_index_python"] = ament_index_python - - ament_index_packages = sys.modules.get("ament_index_python.packages") - if ament_index_packages is None: - ament_index_packages = _import_module_if_available( - "ament_index_python.packages" - ) - if ament_index_packages is None: - ament_index_packages = types.ModuleType("ament_index_python.packages") - sys.modules["ament_index_python.packages"] = ament_index_packages - if not hasattr(ament_index_packages, "get_package_share_directory"): - ament_index_packages.get_package_share_directory = lambda _name: str( - Path(__file__).resolve().parents[1] - ) - ament_index_python.packages = ament_index_packages - - launch_module = sys.modules.get("launch") - if launch_module is None: - launch_module = _import_module_if_available("launch") - if launch_module is None: - launch_module = types.ModuleType("launch") - sys.modules["launch"] = launch_module - if not hasattr(launch_module, "LaunchDescription"): - launch_module.LaunchDescription = type("LaunchDescription", (), {}) - - launch_actions = sys.modules.get("launch.actions") - if launch_actions is None: - launch_actions = _import_module_if_available("launch.actions") - if launch_actions is None: - launch_actions = types.ModuleType("launch.actions") - sys.modules["launch.actions"] = launch_actions - for name in ( - "DeclareLaunchArgument", - "ExecuteProcess", - "IncludeLaunchDescription", - "OpaqueFunction", - ): - if not hasattr(launch_actions, name): - setattr(launch_actions, name, type(name, (), {})) - - launch_sources = sys.modules.get("launch.launch_description_sources") - if launch_sources is None: - launch_sources = _import_module_if_available( - "launch.launch_description_sources" - ) - if launch_sources is None: - launch_sources = types.ModuleType("launch.launch_description_sources") - sys.modules["launch.launch_description_sources"] = launch_sources - if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): - launch_sources.PythonLaunchDescriptionSource = type( - "PythonLaunchDescriptionSource", (), {} - ) - - launch_logging = sys.modules.get("launch.logging") - if launch_logging is None: - launch_logging = _import_module_if_available("launch.logging") - if launch_logging is None: - launch_logging = types.ModuleType("launch.logging") - sys.modules["launch.logging"] = launch_logging - if not hasattr(launch_logging, "get_logger"): - launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( - warning=lambda *_a, **_k: None, - warn=lambda *_a, **_k: None, - info=lambda *_a, **_k: None, - ) - - launch_substitutions = sys.modules.get("launch.substitutions") - if launch_substitutions is None: - launch_substitutions = _import_module_if_available("launch.substitutions") - if launch_substitutions is None: - launch_substitutions = types.ModuleType("launch.substitutions") - sys.modules["launch.substitutions"] = launch_substitutions - if not hasattr(launch_substitutions, "LaunchConfiguration"): - launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) - - launch_ros = sys.modules.get("launch_ros") - if launch_ros is None: - launch_ros = _import_module_if_available("launch_ros") - if launch_ros is None: - launch_ros = types.ModuleType("launch_ros") - sys.modules["launch_ros"] = launch_ros - - launch_ros_actions = sys.modules.get("launch_ros.actions") - if launch_ros_actions is None: - launch_ros_actions = _import_module_if_available("launch_ros.actions") - if launch_ros_actions is None: - launch_ros_actions = types.ModuleType("launch_ros.actions") - sys.modules["launch_ros.actions"] = launch_ros_actions - if not hasattr(launch_ros_actions, "Node"): - launch_ros_actions.Node = type("Node", (), {}) - launch_ros.actions = launch_ros_actions +from uav.test_support.ros_stubs import ( + ensure_launch_import_stubs, + purge_fake_rclpy_modules, +) def _load_launch_module(filename: str, module_name: str): # Runtime-behavior tests install lightweight rclpy doubles into sys.modules. # launch_ros must see the real ROS Python packages if they are available. - _purge_fake_rclpy_modules() - _ensure_launch_import_stubs() + purge_fake_rclpy_modules() package_root = Path(__file__).resolve().parents[1] + ensure_launch_import_stubs(package_root, include_launch_ros=True) if str(package_root) not in sys.path: sys.path.insert(0, str(package_root)) sim_package_root = Path(__file__).resolve().parents[2] / "sim" @@ -137,8 +26,9 @@ def _load_launch_module(filename: str, module_name: str): sys.path.insert(0, str(sim_package_root)) launch_path = package_root / "launch" / filename spec = importlib.util.spec_from_file_location(module_name, launch_path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module diff --git a/controls/sae_2025_ws/src/uav/test/test_mission_spec.py b/controls/sae_2025_ws/src/uav/test/test_mission_spec.py index 340788b4..b2919be1 100644 --- a/controls/sae_2025_ws/src/uav/test/test_mission_spec.py +++ b/controls/sae_2025_ws/src/uav/test/test_mission_spec.py @@ -5,31 +5,25 @@ from types import SimpleNamespace import sys import types +from typing import Any import pytest -if "rclpy" not in sys.modules: - rclpy = types.ModuleType("rclpy") - node_mod = types.ModuleType("rclpy.node") +from uav.test_support.ros_stubs import ensure_basic_rclpy_stubs - class Node: - def __init__(self, *_args, **_kwargs) -> None: - pass - node_mod.Node = Node - rclpy.node = node_mod - sys.modules.update({"rclpy": rclpy, "rclpy.node": node_mod}) +ensure_basic_rclpy_stubs() -from uav.modes.Mode import Mode -import uav.runtime.mission_spec as mission_spec_module -import uav.runtime.schema as schema_module -from uav.runtime.mission_spec import ( +from uav.modes.Mode import Mode # noqa: E402 +import uav.runtime.mission_spec as mission_spec_module # noqa: E402 +import uav.runtime.schema as schema_module # noqa: E402 +from uav.runtime.mission_spec import ( # noqa: E402 MissionSpec, load_mode_class, load_mission_spec, mission_path_for_name, mission_root, -) +) # noqa: E402 def _write_mission(tmp_path: Path, contents: str) -> Path: @@ -343,7 +337,7 @@ def test_invalid_mode_target_is_rejected(monkeypatch, mission_target): def test_load_mode_class_accepts_module_path(monkeypatch): module_name = "uav.modes.payload.PayloadAprilTagApproachMode" - fake_module = types.ModuleType(module_name) + fake_module: Any = types.ModuleType(module_name) class PayloadAprilTagApproachMode(Mode): mission_target = "payload" diff --git a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py index d9a3e9a3..1f302c4e 100644 --- a/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py +++ b/controls/sae_2025_ws/src/uav/test/test_payload_dlz_convex_hull_masking.py @@ -4,7 +4,7 @@ import sys from pathlib import Path from types import SimpleNamespace -import types +from typing import Any, cast import numpy as np import cv2 @@ -15,11 +15,13 @@ sys.path.insert(0, str(PACKAGE_ROOT)) from uav.cv.dlz_convex_hull import build_dlz_hull_mask # noqa: E402 +from uav.test_support.ros_stubs import install_payload_mode_import_stubs # noqa: E402 def _bgr_from_hsv(h: int, s: int, v: int) -> tuple[int, int, int]: + hsv_pixel = np.array([[[h, s, v]]], dtype=np.uint8) pixel = cv2.cvtColor( - np.uint8([[[h, s, v]]]), + cast(Any, hsv_pixel), cv2.COLOR_HSV2BGR, )[0, 0] return int(pixel[0]), int(pixel[1]), int(pixel[2]) @@ -36,74 +38,7 @@ def _roi_ratio(mask: np.ndarray, x0: int, y0: int, x1: int, y1: int) -> float: return float(np.count_nonzero(region)) / float(region.size) -def _install_import_stubs() -> None: - if "rclpy" not in sys.modules: - rclpy = types.ModuleType("rclpy") - node_module = types.ModuleType("rclpy.node") - - class Node: - pass - - node_module.Node = Node - rclpy.node = node_module - sys.modules.update({"rclpy": rclpy, "rclpy.node": node_module}) - - if "cv_bridge" not in sys.modules: - cv_bridge = types.ModuleType("cv_bridge") - - class CvBridge: - def imgmsg_to_cv2(self, *_args, **_kwargs): - raise NotImplementedError - - def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): - return SimpleNamespace(header=SimpleNamespace(stamp=None)) - - cv_bridge.CvBridge = CvBridge - sys.modules["cv_bridge"] = cv_bridge - - if "sensor_msgs" not in sys.modules: - sensor_msgs = types.ModuleType("sensor_msgs") - sensor_msgs_msg = types.ModuleType("sensor_msgs.msg") - sensor_msgs_msg.CompressedImage = type("CompressedImage", (), {}) - sensor_msgs_msg.Image = type("Image", (), {}) - sensor_msgs.msg = sensor_msgs_msg - sys.modules.update( - {"sensor_msgs": sensor_msgs, "sensor_msgs.msg": sensor_msgs_msg} - ) - - if "uav.vehicles.Payload" not in sys.modules: - payload_module = types.ModuleType("uav.vehicles.Payload") - payload_module.Payload = type("Payload", (), {}) - sys.modules["uav.vehicles.Payload"] = payload_module - - if "uav.vision_nodes" not in sys.modules: - vision_nodes = types.ModuleType("uav.vision_nodes") - vision_nodes.PayloadAprilTagNode = type("PayloadAprilTagNode", (), {}) - sys.modules["uav.vision_nodes"] = vision_nodes - - if "uav.vision_nodes.payload_perception_common" not in sys.modules: - common = types.ModuleType("uav.vision_nodes.payload_perception_common") - common.DEFAULT_TAG_FAMILY = "tag36h11" - sys.modules["uav.vision_nodes.payload_perception_common"] = common - - if "uav_interfaces" not in sys.modules: - sys.modules["uav_interfaces"] = types.ModuleType("uav_interfaces") - - if "uav_interfaces.srv" not in sys.modules: - srv_module = types.ModuleType("uav_interfaces.srv") - - class PayloadAprilTagState: - class Request: - pass - - class Response: - pass - - srv_module.PayloadAprilTagState = PayloadAprilTagState - sys.modules["uav_interfaces.srv"] = srv_module - - -_install_import_stubs() +install_payload_mode_import_stubs() PayloadCornerNavigateMode = importlib.import_module( "uav.modes.payload.PayloadCornerNavigateMode" @@ -183,7 +118,8 @@ def _make_turn_to_center_mode(**kwargs): def _contour_bbox(mask: np.ndarray) -> tuple[int, int, int, int]: contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) assert len(contours) == 1 - return cv2.boundingRect(contours[0]) + x, y, width, height = cv2.boundingRect(contours[0]) + return int(x), int(y), int(width), int(height) def test_build_dlz_hull_mask_keeps_orange_rectangle_and_excludes_outside(): diff --git a/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py b/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py index a4890266..1a2e7df8 100644 --- a/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py +++ b/controls/sae_2025_ws/src/uav/test/test_peer_runtime_contract.py @@ -5,6 +5,7 @@ import textwrap from types import SimpleNamespace import types +from typing import Any, Mapping, cast import pytest @@ -15,17 +16,17 @@ def _placeholder(name: str): def _install_ros_test_doubles() -> None: if "rclpy" not in sys.modules: - rclpy = types.ModuleType("rclpy") + rclpy: Any = types.ModuleType("rclpy") rclpy.init = lambda *args, **kwargs: None rclpy.shutdown = lambda: None rclpy.ok = lambda: True - node_mod = types.ModuleType("rclpy.node") - executors_mod = types.ModuleType("rclpy.executors") - clock_mod = types.ModuleType("rclpy.clock") - parameter_mod = types.ModuleType("rclpy.parameter") - validate_namespace_mod = types.ModuleType("rclpy.validate_namespace") - validate_node_name_mod = types.ModuleType("rclpy.validate_node_name") - qos_mod = types.ModuleType("rclpy.qos") + node_mod: Any = types.ModuleType("rclpy.node") + executors_mod: Any = types.ModuleType("rclpy.executors") + clock_mod: Any = types.ModuleType("rclpy.clock") + parameter_mod: Any = types.ModuleType("rclpy.parameter") + validate_namespace_mod: Any = types.ModuleType("rclpy.validate_namespace") + validate_node_name_mod: Any = types.ModuleType("rclpy.validate_node_name") + qos_mod: Any = types.ModuleType("rclpy.qos") class Node: def __init__(self, *_args, **_kwargs) -> None: @@ -65,8 +66,8 @@ class ExternalShutdownException(Exception): ) if "std_srvs" not in sys.modules: - std_srvs = types.ModuleType("std_srvs") - std_srvs_srv = types.ModuleType("std_srvs.srv") + std_srvs: Any = types.ModuleType("std_srvs") + std_srvs_srv: Any = types.ModuleType("std_srvs.srv") class Trigger: Request = _placeholder("Request") @@ -77,8 +78,8 @@ class Trigger: sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) if "std_msgs" not in sys.modules: - std_msgs = types.ModuleType("std_msgs") - std_msgs_msg = types.ModuleType("std_msgs.msg") + std_msgs: Any = types.ModuleType("std_msgs") + std_msgs_msg: Any = types.ModuleType("std_msgs.msg") std_msgs_msg.Empty = _placeholder("Empty") std_msgs.msg = std_msgs_msg sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) @@ -235,8 +236,8 @@ def create_service(self, _srv_type, service_name: str, *_args, **_kwargs): return SimpleNamespace(kind="service", name=service_name) -def _make_mode_manager() -> ModeManager: - manager = object.__new__(ModeManager) +def _make_mode_manager() -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = _FakeVehicle() manager.modes = {} manager.transitions = {} @@ -408,7 +409,9 @@ class PeerAwareMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -529,7 +532,9 @@ def __init__(self, node, vehicle) -> None: ) self.shared_pub = self.node.create_publisher(object, "/shared/debug", 1) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -574,7 +579,9 @@ def check_status(self) -> str: ) manager = _make_mode_manager() - mode = ModeManager.initialize_mode(manager, "fake.module.PeerAwareMode", {}) + mode = cast( + Any, ModeManager.initialize_mode(manager, "fake.module.PeerAwareMode", {}) + ) _configure_manager_for_mode(manager, mode) with manager._use_comm_builder( @@ -633,7 +640,7 @@ def __init__(self, node, vehicle) -> None: self.connection_checks: list[dict[str, bool]] = [] self.status_checks = 0 - def connection_ready(self, connection_status: dict[str, bool]) -> bool: + def connection_ready(self, connection_status: Mapping[str, bool]) -> bool: self.connection_checks.append(dict(connection_status)) return True @@ -641,7 +648,7 @@ def on_update(self, time_delta: float) -> None: self.update_calls.append(time_delta) def on_disconnect( - self, time_delta: float, connection_status: dict[str, bool] + self, time_delta: float, connection_status: Mapping[str, bool] ) -> None: self.disconnect_calls.append((time_delta, dict(connection_status))) @@ -687,12 +694,12 @@ def __init__(self, node, vehicle) -> None: self.connection_checks: list[dict[str, bool]] = [] self.status_checks = 0 - def connection_ready(self, connection_status: dict[str, bool]) -> bool: + def connection_ready(self, connection_status: Mapping[str, bool]) -> bool: self.connection_checks.append(dict(connection_status)) return False def on_disconnect( - self, time_delta: float, connection_status: dict[str, bool] + self, time_delta: float, connection_status: Mapping[str, bool] ) -> None: self.disconnect_calls.append((time_delta, dict(connection_status))) @@ -731,7 +738,9 @@ class PeerAwareMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -740,7 +749,7 @@ def on_update(self, time_delta: float) -> None: def check_status(self) -> str: return "continue" - mode = PeerAwareMode(_RecordingNode(), _FakeVehicle()) + mode = PeerAwareMode(cast(Any, _RecordingNode()), cast(Any, _FakeVehicle())) assert mode.connection_ready({"uav_1": True, "uav_2": True}) is True assert mode.connection_ready({"uav_1": True, "uav_2": False}) is False @@ -760,7 +769,9 @@ class PeerAwareMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -869,7 +880,9 @@ def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) self.node.create_publisher(object, "/shared/debug", 1) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_enter(self) -> None: @@ -929,7 +942,9 @@ class SharedStateMode(Mode): def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_update(self, time_delta: float) -> None: @@ -965,7 +980,9 @@ def __init__(self, node, vehicle) -> None: super().__init__(node, vehicle) self.node.create_subscription(object, "/uav_2/status", lambda _msg: None, 1) - def on_disconnect(self, time_delta: float, peers: dict[str, bool]) -> None: + def on_disconnect( + self, time_delta: float, connection_status: Mapping[str, bool] + ) -> None: pass def on_enter(self) -> None: @@ -979,7 +996,7 @@ def check_status(self) -> str: return "continue" node = _RecordingNode() - mode = InstrumentedPeerMode(node, _FakeVehicle()) + mode = InstrumentedPeerMode(cast(Any, node), cast(Any, _FakeVehicle())) mode.activate() assert _observed_peer_vehicle_names(node.calls, mode.peer_vehicle_names) == { diff --git a/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py b/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py index 6d6122dc..ac7dd177 100644 --- a/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py +++ b/controls/sae_2025_ws/src/uav/test/test_peer_stack_reconnect.py @@ -9,6 +9,7 @@ import signal import subprocess import time +from typing import cast import pytest @@ -104,7 +105,7 @@ def __init__(self) -> None: String, f"/{vehicle_name}/peer_test/state", lambda message, vehicle_name=vehicle_name: self._on_status( - vehicle_name, message + vehicle_name, cast(String, message) ), 10, ) @@ -388,6 +389,12 @@ def _launch_vehicle_stack( ) +def _status_int(status: dict[str, object], key: str) -> int: + value = status[key] + assert isinstance(value, int) + return value + + @pytest.fixture def live_ros_environment(monkeypatch): _require_uav_package() @@ -483,9 +490,9 @@ def test_vehicle_stack_peer_reconnect_recovers_state_and_traffic( assert observer.shared_counts_by_sender.get("payload_0", 0) > 0 assert observer.shared_counts_by_sender.get("payload_1", 0) > 0 - payload_0_peer_before_disconnect = int(status_0["peer_received_total"]) - payload_0_shared_before_disconnect = int( - status_0["shared_remote_received_total"] + payload_0_peer_before_disconnect = _status_int(status_0, "peer_received_total") + payload_0_shared_before_disconnect = _status_int( + status_0, "shared_remote_received_total" ) payload_1_shared_events_before_disconnect = ( observer.shared_counts_by_sender.get("payload_1", 0) @@ -502,9 +509,12 @@ def test_vehicle_stack_peer_reconnect_recovers_state_and_traffic( state="waiting", disconnected_peers=["payload_1"], ) - assert int(waiting_0["peer_received_total"]) >= payload_0_peer_before_disconnect assert ( - int(waiting_0["shared_remote_received_total"]) + _status_int(waiting_0, "peer_received_total") + >= payload_0_peer_before_disconnect + ) + assert ( + _status_int(waiting_0, "shared_remote_received_total") >= payload_0_shared_before_disconnect ) @@ -563,15 +573,15 @@ def test_vehicle_stack_peer_reconnect_recovers_state_and_traffic( shared_remote_received_total=1, ) assert ( - int(reconnect_status_0["peer_received_total"]) + _status_int(reconnect_status_0, "peer_received_total") > payload_0_peer_before_disconnect ) assert ( - int(reconnect_status_0["shared_remote_received_total"]) + _status_int(reconnect_status_0, "shared_remote_received_total") > payload_0_shared_before_disconnect ) - assert int(reconnect_status_1["peer_received_total"]) > 0 - assert int(reconnect_status_1["shared_remote_received_total"]) > 0 + assert _status_int(reconnect_status_1, "peer_received_total") > 0 + assert _status_int(reconnect_status_1, "shared_remote_received_total") > 0 assert ( observer.shared_counts_by_sender.get("payload_1", 0) > payload_1_shared_events_before_disconnect diff --git a/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py b/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py index a1260623..613a4c2f 100644 --- a/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py +++ b/controls/sae_2025_ws/src/uav/test/test_runtime_behavior.py @@ -4,6 +4,7 @@ import sys import types from types import SimpleNamespace +from typing import Any, cast import pytest @@ -20,14 +21,14 @@ def _import_module_if_available(name: str): def _install_ros_test_doubles() -> None: - ament_index_python = sys.modules.get("ament_index_python") + ament_index_python: Any = sys.modules.get("ament_index_python") if ament_index_python is None: ament_index_python = _import_module_if_available("ament_index_python") if ament_index_python is None: ament_index_python = types.ModuleType("ament_index_python") sys.modules["ament_index_python"] = ament_index_python - ament_index_packages = sys.modules.get("ament_index_python.packages") + ament_index_packages: Any = sys.modules.get("ament_index_python.packages") if ament_index_packages is None: ament_index_packages = _import_module_if_available( "ament_index_python.packages" @@ -45,7 +46,7 @@ class PackageNotFoundError(Exception): ament_index_packages.get_package_share_directory = lambda _name: "" ament_index_python.packages = ament_index_packages - rclpy = sys.modules.get("rclpy") + rclpy: Any = sys.modules.get("rclpy") if rclpy is None: rclpy = _import_module_if_available("rclpy") if rclpy is None: @@ -54,7 +55,7 @@ class PackageNotFoundError(Exception): if not hasattr(rclpy, "ok"): rclpy.ok = lambda: True - node_mod = sys.modules.get("rclpy.node") + node_mod: Any = sys.modules.get("rclpy.node") if node_mod is None: node_mod = _import_module_if_available("rclpy.node") if node_mod is None: @@ -67,7 +68,7 @@ class Node: node_mod.Node = Node - executors_mod = sys.modules.get("rclpy.executors") + executors_mod: Any = sys.modules.get("rclpy.executors") if executors_mod is None: executors_mod = _import_module_if_available("rclpy.executors") if executors_mod is None: @@ -80,7 +81,7 @@ class ExternalShutdownException(Exception): executors_mod.ExternalShutdownException = ExternalShutdownException - clock_mod = sys.modules.get("rclpy.clock") + clock_mod: Any = sys.modules.get("rclpy.clock") if clock_mod is None: clock_mod = _import_module_if_available("rclpy.clock") if clock_mod is None: @@ -89,7 +90,7 @@ class ExternalShutdownException(Exception): if not hasattr(clock_mod, "Clock"): clock_mod.Clock = _placeholder("Clock") - parameter_mod = sys.modules.get("rclpy.parameter") + parameter_mod: Any = sys.modules.get("rclpy.parameter") if parameter_mod is None: parameter_mod = _import_module_if_available("rclpy.parameter") if parameter_mod is None: @@ -98,7 +99,7 @@ class ExternalShutdownException(Exception): if not hasattr(parameter_mod, "Parameter"): parameter_mod.Parameter = _placeholder("Parameter") - validate_namespace_mod = sys.modules.get("rclpy.validate_namespace") + validate_namespace_mod: Any = sys.modules.get("rclpy.validate_namespace") if validate_namespace_mod is None: validate_namespace_mod = _import_module_if_available("rclpy.validate_namespace") if validate_namespace_mod is None: @@ -107,7 +108,7 @@ class ExternalShutdownException(Exception): if not hasattr(validate_namespace_mod, "validate_namespace"): validate_namespace_mod.validate_namespace = lambda namespace: None - validate_node_name_mod = sys.modules.get("rclpy.validate_node_name") + validate_node_name_mod: Any = sys.modules.get("rclpy.validate_node_name") if validate_node_name_mod is None: validate_node_name_mod = _import_module_if_available("rclpy.validate_node_name") if validate_node_name_mod is None: @@ -116,7 +117,7 @@ class ExternalShutdownException(Exception): if not hasattr(validate_node_name_mod, "validate_node_name"): validate_node_name_mod.validate_node_name = lambda node_name: None - qos_mod = sys.modules.get("rclpy.qos") + qos_mod: Any = sys.modules.get("rclpy.qos") if qos_mod is None: qos_mod = _import_module_if_available("rclpy.qos") if qos_mod is None: @@ -140,8 +141,8 @@ class ExternalShutdownException(Exception): rclpy.qos = qos_mod if "std_srvs" not in sys.modules: - std_srvs = types.ModuleType("std_srvs") - std_srvs_srv = types.ModuleType("std_srvs.srv") + std_srvs: Any = types.ModuleType("std_srvs") + std_srvs_srv: Any = types.ModuleType("std_srvs.srv") class Trigger: Request = _placeholder("Request") @@ -152,15 +153,15 @@ class Trigger: sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) if "std_msgs" not in sys.modules: - std_msgs = types.ModuleType("std_msgs") - std_msgs_msg = types.ModuleType("std_msgs.msg") + std_msgs: Any = types.ModuleType("std_msgs") + std_msgs_msg: Any = types.ModuleType("std_msgs.msg") std_msgs_msg.Empty = _placeholder("Empty") std_msgs.msg = std_msgs_msg sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) if "px4_msgs" not in sys.modules: - px4_msgs = types.ModuleType("px4_msgs") - px4_msgs_msg = types.ModuleType("px4_msgs.msg") + px4_msgs: Any = types.ModuleType("px4_msgs") + px4_msgs_msg: Any = types.ModuleType("px4_msgs.msg") class VehicleStatus: NAVIGATION_STATE_AUTO_LOITER = 1 @@ -186,9 +187,9 @@ class VtolVehicleStatus: sys.modules.update({"px4_msgs": px4_msgs, "px4_msgs.msg": px4_msgs_msg}) if "payload_interfaces" not in sys.modules: - payload_interfaces = types.ModuleType("payload_interfaces") - payload_interfaces_msg = types.ModuleType("payload_interfaces.msg") - payload_interfaces_srv = types.ModuleType("payload_interfaces.srv") + payload_interfaces: Any = types.ModuleType("payload_interfaces") + payload_interfaces_msg: Any = types.ModuleType("payload_interfaces.msg") + payload_interfaces_srv: Any = types.ModuleType("payload_interfaces.srv") payload_interfaces_msg.DriveCommand = _placeholder("DriveCommand") payload_interfaces_msg.ServoCommand = _placeholder("ServoCommand") @@ -289,8 +290,8 @@ def check_status(self) -> str: return self.status -def _make_mode_manager(*, vehicle=None, auto_launch: bool = False) -> ModeManager: - manager = object.__new__(ModeManager) +def _make_mode_manager(*, vehicle=None, auto_launch: bool = False) -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = vehicle manager.modes = {} manager.transitions = {} @@ -408,7 +409,9 @@ def _make_bootstrap(module_cls, params: dict[str, object]): return bootstrap -def _fake_mission_spec(*, target: str, is_uav: bool, is_payload: bool, vision_nodes=()): +def _fake_mission_spec( + *, target: str, is_uav: bool, is_payload: bool, vision_nodes=() +) -> Any: return SimpleNamespace( target=target, is_uav=is_uav, @@ -437,7 +440,7 @@ def __init__( required: int, optional: str = "default", ) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required self.optional = optional @@ -477,7 +480,7 @@ class FakeMode(Mode): mission_target = "uav" def __init__(self, node, vehicle: _ExpectedVehicle, required: int) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required def on_update(self, time_delta: float) -> None: @@ -509,7 +512,7 @@ class FakeMode(Mode): mission_target = "uav" def __init__(self, node, vehicle: _ExpectedVehicle, required: int) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required def on_update(self, time_delta: float) -> None: @@ -543,7 +546,7 @@ class FakeMode(Mode): mission_target = "uav" def __init__(self, node, vehicle: _ExpectedVehicle, required: int) -> None: - super().__init__(node, vehicle) + super().__init__(node, cast(Any, vehicle)) self.required = required def on_update(self, time_delta: float) -> None: @@ -614,7 +617,7 @@ def test_setup_vision_deduplicates_clients(monkeypatch): assert list(manager.vision_clients) == [canonical_name] assert ModeManager.get_vision_client(manager, FakeVisionNode) is client - assert created_clients == [(FakeVisionNode.srv, "vision/FakeVisionNode")] + assert created_clients == [(cast(Any, FakeVisionNode).srv, "vision/FakeVisionNode")] def test_setup_vision_rejects_vehicle_without_camera(): @@ -737,7 +740,7 @@ def test_create_entity_falls_back_to_raw_node_during_node_init(monkeypatch): raising=False, ) - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) publisher = ModeManager.create_publisher(manager, object, "/parameter_events", 10) @@ -913,7 +916,7 @@ def test_mode_manager_stop_vehicle_without_rclpy_guard(): _require_runtime_support() stop_calls: list[str] = [] - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: stop_calls.append("stop")) manager.get_logger = lambda: _FakeLogger() @@ -936,7 +939,7 @@ def deactivate(self) -> None: events.append("deactivate") self.active = False - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: events.append("stop")) manager.modes = {"start": _FakeMode()} manager.transitions = {} @@ -971,7 +974,7 @@ def deactivate(self) -> None: events.append("deactivate") self.active = False - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: events.append("stop")) manager.modes = {"start": _FakeMode()} manager.transitions = {} @@ -1011,7 +1014,7 @@ def deactivate(self) -> None: events.append("deactivate") self.active = False - manager = object.__new__(ModeManager) + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = SimpleNamespace(stop=lambda: events.append("stop")) manager.modes = {"start": _FakeMode()} manager.transitions = {} diff --git a/controls/sae_2025_ws/src/uav/test/test_schema_validation.py b/controls/sae_2025_ws/src/uav/test/test_schema_validation.py index bd9ff4e2..7c39c48c 100644 --- a/controls/sae_2025_ws/src/uav/test/test_schema_validation.py +++ b/controls/sae_2025_ws/src/uav/test/test_schema_validation.py @@ -4,6 +4,7 @@ import sys from pathlib import Path from types import SimpleNamespace +from typing import Any, cast import pytest import yaml @@ -24,8 +25,9 @@ def _load_fleet_module(): spec = importlib.util.spec_from_file_location( "uav_fleet_launch_schema_tests", launch_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load launch module from {launch_path}") module = importlib.util.module_from_spec(spec) - assert spec.loader is not None spec.loader.exec_module(module) return module @@ -50,8 +52,8 @@ def check_status(self) -> str: return self.status -def _make_mode_manager(): - manager = object.__new__(ModeManager) +def _make_mode_manager() -> Any: + manager = cast(Any, object.__new__(ModeManager)) manager.vehicle = None manager.modes = {} manager.transitions = {} diff --git a/controls/sae_2025_ws/src/uav/uav/test_support/__init__.py b/controls/sae_2025_ws/src/uav/uav/test_support/__init__.py new file mode 100644 index 00000000..92a9d1d7 --- /dev/null +++ b/controls/sae_2025_ws/src/uav/uav/test_support/__init__.py @@ -0,0 +1 @@ +"""Test-only support helpers for UAV package tests.""" diff --git a/controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py b/controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py new file mode 100644 index 00000000..57199bee --- /dev/null +++ b/controls/sae_2025_ws/src/uav/uav/test_support/ros_stubs.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import importlib +from pathlib import Path +import sys +from types import SimpleNamespace +import types +from typing import Any + + +def import_module_if_available(name: str) -> Any | None: + try: + return importlib.import_module(name) + except ModuleNotFoundError: + return None + + +def ensure_basic_rclpy_stubs() -> None: + if "rclpy" in sys.modules: + return + + rclpy: Any = types.ModuleType("rclpy") + node_module: Any = types.ModuleType("rclpy.node") + + class Node: + def __init__(self, *_args, **_kwargs) -> None: + pass + + node_module.Node = Node + rclpy.node = node_module + sys.modules.update({"rclpy": rclpy, "rclpy.node": node_module}) + + +def ensure_runtime_rclpy_stubs() -> None: + rclpy: Any = sys.modules.get("rclpy") or import_module_if_available("rclpy") + if rclpy is None: + rclpy = types.ModuleType("rclpy") + sys.modules["rclpy"] = rclpy + if not hasattr(rclpy, "init"): + rclpy.init = lambda *args, **kwargs: None + if not hasattr(rclpy, "shutdown"): + rclpy.shutdown = lambda: None + if not hasattr(rclpy, "ok"): + rclpy.ok = lambda: True + + node_mod: Any = sys.modules.get("rclpy.node") or import_module_if_available( + "rclpy.node" + ) + if node_mod is None: + node_mod = types.ModuleType("rclpy.node") + sys.modules["rclpy.node"] = node_mod + if not hasattr(node_mod, "Node"): + + class Node: + def __init__(self, *_args, **_kwargs) -> None: + pass + + node_mod.Node = Node + + executors_mod: Any = sys.modules.get( + "rclpy.executors" + ) or import_module_if_available("rclpy.executors") + if executors_mod is None: + executors_mod = types.ModuleType("rclpy.executors") + sys.modules["rclpy.executors"] = executors_mod + if not hasattr(executors_mod, "ExternalShutdownException"): + + class ExternalShutdownException(Exception): + pass + + executors_mod.ExternalShutdownException = ExternalShutdownException + + for module_name, attr_name in ( + ("rclpy.clock", "Clock"), + ("rclpy.parameter", "Parameter"), + ): + module: Any = sys.modules.get(module_name) or import_module_if_available( + module_name + ) + if module is None: + module = types.ModuleType(module_name) + sys.modules[module_name] = module + if not hasattr(module, attr_name): + setattr(module, attr_name, type(attr_name, (), {})) + + validate_namespace_mod: Any = sys.modules.get( + "rclpy.validate_namespace" + ) or import_module_if_available("rclpy.validate_namespace") + if validate_namespace_mod is None: + validate_namespace_mod = types.ModuleType("rclpy.validate_namespace") + sys.modules["rclpy.validate_namespace"] = validate_namespace_mod + if not hasattr(validate_namespace_mod, "validate_namespace"): + validate_namespace_mod.validate_namespace = lambda namespace: None + + validate_node_name_mod: Any = sys.modules.get( + "rclpy.validate_node_name" + ) or import_module_if_available("rclpy.validate_node_name") + if validate_node_name_mod is None: + validate_node_name_mod = types.ModuleType("rclpy.validate_node_name") + sys.modules["rclpy.validate_node_name"] = validate_node_name_mod + if not hasattr(validate_node_name_mod, "validate_node_name"): + validate_node_name_mod.validate_node_name = lambda node_name: None + + qos_mod: Any = sys.modules.get("rclpy.qos") or import_module_if_available( + "rclpy.qos" + ) + if qos_mod is None: + qos_mod = types.ModuleType("rclpy.qos") + sys.modules["rclpy.qos"] = qos_mod + for name in ( + "QoSProfile", + "QoSReliabilityPolicy", + "QoSHistoryPolicy", + "QoSDurabilityPolicy", + ): + if not hasattr(qos_mod, name): + setattr(qos_mod, name, type(name, (), {})) + + rclpy.node = node_mod + rclpy.executors = executors_mod + rclpy.clock = sys.modules["rclpy.clock"] + rclpy.parameter = sys.modules["rclpy.parameter"] + rclpy.validate_namespace = validate_namespace_mod + rclpy.validate_node_name = validate_node_name_mod + rclpy.qos = qos_mod + + +def purge_fake_rclpy_modules() -> None: + fake_rclpy = sys.modules.get("rclpy") + if fake_rclpy is not None and not hasattr(fake_rclpy, "__path__"): + for module_name in list(sys.modules): + if module_name == "rclpy" or module_name.startswith("rclpy."): + del sys.modules[module_name] + + +def ensure_std_msgs_stub() -> None: + std_msgs: Any = sys.modules.get("std_msgs") or import_module_if_available( + "std_msgs" + ) + if std_msgs is not None: + return + std_msgs = types.ModuleType("std_msgs") + std_msgs_msg: Any = types.ModuleType("std_msgs.msg") + std_msgs_msg.Empty = type("Empty", (), {}) + std_msgs.msg = std_msgs_msg + sys.modules.update({"std_msgs": std_msgs, "std_msgs.msg": std_msgs_msg}) + + +def ensure_std_srvs_stub() -> None: + std_srvs: Any = sys.modules.get("std_srvs") or import_module_if_available( + "std_srvs" + ) + if std_srvs is not None: + return + std_srvs = types.ModuleType("std_srvs") + std_srvs_srv: Any = types.ModuleType("std_srvs.srv") + + class Trigger: + Request = type("Request", (), {}) + Response = type("Response", (), {}) + + std_srvs_srv.Trigger = Trigger + std_srvs.srv = std_srvs_srv + sys.modules.update({"std_srvs": std_srvs, "std_srvs.srv": std_srvs_srv}) + + +def ensure_ament_index_stub(package_root: Path) -> None: + ament_index_python: Any = sys.modules.get( + "ament_index_python" + ) or import_module_if_available("ament_index_python") + if ament_index_python is None: + ament_index_python = types.ModuleType("ament_index_python") + sys.modules["ament_index_python"] = ament_index_python + + ament_index_packages: Any = sys.modules.get( + "ament_index_python.packages" + ) or import_module_if_available("ament_index_python.packages") + if ament_index_packages is None: + ament_index_packages = types.ModuleType("ament_index_python.packages") + sys.modules["ament_index_python.packages"] = ament_index_packages + + if not hasattr(ament_index_packages, "PackageNotFoundError"): + + class PackageNotFoundError(Exception): + pass + + ament_index_packages.PackageNotFoundError = PackageNotFoundError + if not hasattr(ament_index_packages, "get_package_share_directory"): + ament_index_packages.get_package_share_directory = lambda _name: str( + package_root + ) + ament_index_python.packages = ament_index_packages + + +def ensure_launch_import_stubs( + package_root: Path, *, include_launch_ros: bool = False +) -> None: + ensure_ament_index_stub(package_root) + + launch_module: Any = sys.modules.get("launch") or import_module_if_available( + "launch" + ) + if launch_module is None: + launch_module = types.ModuleType("launch") + sys.modules["launch"] = launch_module + if not hasattr(launch_module, "LaunchDescription"): + launch_module.LaunchDescription = type("LaunchDescription", (), {}) + + launch_actions: Any = sys.modules.get( + "launch.actions" + ) or import_module_if_available("launch.actions") + if launch_actions is None: + launch_actions = types.ModuleType("launch.actions") + sys.modules["launch.actions"] = launch_actions + for name in ( + "DeclareLaunchArgument", + "ExecuteProcess", + "IncludeLaunchDescription", + "OpaqueFunction", + ): + if not hasattr(launch_actions, name): + setattr(launch_actions, name, type(name, (), {})) + + launch_sources: Any = sys.modules.get( + "launch.launch_description_sources" + ) or import_module_if_available("launch.launch_description_sources") + if launch_sources is None: + launch_sources = types.ModuleType("launch.launch_description_sources") + sys.modules["launch.launch_description_sources"] = launch_sources + if not hasattr(launch_sources, "PythonLaunchDescriptionSource"): + launch_sources.PythonLaunchDescriptionSource = type( + "PythonLaunchDescriptionSource", (), {} + ) + + launch_logging: Any = sys.modules.get( + "launch.logging" + ) or import_module_if_available("launch.logging") + if launch_logging is None: + launch_logging = types.ModuleType("launch.logging") + sys.modules["launch.logging"] = launch_logging + if not hasattr(launch_logging, "get_logger"): + launch_logging.get_logger = lambda *_args, **_kwargs: SimpleNamespace( + warning=lambda *_a, **_k: None, + warn=lambda *_a, **_k: None, + info=lambda *_a, **_k: None, + ) + + launch_substitutions: Any = sys.modules.get( + "launch.substitutions" + ) or import_module_if_available("launch.substitutions") + if launch_substitutions is None: + launch_substitutions = types.ModuleType("launch.substitutions") + sys.modules["launch.substitutions"] = launch_substitutions + if not hasattr(launch_substitutions, "LaunchConfiguration"): + launch_substitutions.LaunchConfiguration = type("LaunchConfiguration", (), {}) + + if not include_launch_ros: + return + + launch_ros: Any = sys.modules.get("launch_ros") or import_module_if_available( + "launch_ros" + ) + if launch_ros is None: + launch_ros = types.ModuleType("launch_ros") + sys.modules["launch_ros"] = launch_ros + + launch_ros_actions: Any = sys.modules.get( + "launch_ros.actions" + ) or import_module_if_available("launch_ros.actions") + if launch_ros_actions is None: + launch_ros_actions = types.ModuleType("launch_ros.actions") + sys.modules["launch_ros.actions"] = launch_ros_actions + if not hasattr(launch_ros_actions, "Node"): + launch_ros_actions.Node = type("Node", (), {}) + launch_ros.actions = launch_ros_actions + + +def install_auto_launch_import_stubs(package_root: Path) -> None: + ensure_std_msgs_stub() + ensure_std_srvs_stub() + ensure_launch_import_stubs(package_root) + ensure_runtime_rclpy_stubs() + + +def install_payload_mode_import_stubs() -> None: + ensure_basic_rclpy_stubs() + + if "cv_bridge" not in sys.modules: + cv_bridge: Any = types.ModuleType("cv_bridge") + + class CvBridge: + def imgmsg_to_cv2(self, *_args, **_kwargs): + raise NotImplementedError + + def cv2_to_compressed_imgmsg(self, *_args, **_kwargs): + return SimpleNamespace(header=SimpleNamespace(stamp=None)) + + cv_bridge.CvBridge = CvBridge + sys.modules["cv_bridge"] = cv_bridge + + if "sensor_msgs" not in sys.modules: + sensor_msgs: Any = types.ModuleType("sensor_msgs") + sensor_msgs_msg: Any = types.ModuleType("sensor_msgs.msg") + sensor_msgs_msg.CompressedImage = type("CompressedImage", (), {}) + sensor_msgs_msg.Image = type("Image", (), {}) + sensor_msgs.msg = sensor_msgs_msg + sys.modules.update( + {"sensor_msgs": sensor_msgs, "sensor_msgs.msg": sensor_msgs_msg} + ) + + if "uav.vehicles.Payload" not in sys.modules: + payload_module: Any = types.ModuleType("uav.vehicles.Payload") + payload_module.Payload = type("Payload", (), {}) + sys.modules["uav.vehicles.Payload"] = payload_module + + if "uav.vision_nodes" not in sys.modules: + vision_nodes: Any = types.ModuleType("uav.vision_nodes") + vision_nodes.PayloadAprilTagNode = type("PayloadAprilTagNode", (), {}) + sys.modules["uav.vision_nodes"] = vision_nodes + + if "uav.vision_nodes.payload_perception_common" not in sys.modules: + common: Any = types.ModuleType("uav.vision_nodes.payload_perception_common") + common.DEFAULT_TAG_FAMILY = "tag36h11" + sys.modules["uav.vision_nodes.payload_perception_common"] = common + + if "uav_interfaces" not in sys.modules: + sys.modules["uav_interfaces"] = types.ModuleType("uav_interfaces") + + if "uav_interfaces.srv" not in sys.modules: + srv_module: Any = types.ModuleType("uav_interfaces.srv") + + class PayloadAprilTagState: + class Request: + pass + + class Response: + pass + + srv_module.PayloadAprilTagState = PayloadAprilTagState + sys.modules["uav_interfaces.srv"] = srv_module