From 12005ce5c37b8da24bb33d958acc1639da729441 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Thu, 28 May 2026 16:08:15 -0400 Subject: [PATCH 01/17] Implmented the ValidatedLayerConfig class, and plugged it in the processor --- merlin/core/merlin_processor.py | 99 ++++++++++++++++++++++++++++----- 1 file changed, 84 insertions(+), 15 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index da95fde8..ed561ba4 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -4,10 +4,11 @@ import uuid import warnings import zlib -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from contextlib import suppress from dataclasses import dataclass from typing import Any, cast +from numbers import Integral import numpy as np import perceval as pcvl @@ -40,6 +41,74 @@ class BackendCapabilities: available_commands: tuple[str] +class ValidatedLayerConfig: + def __init__(self, config_to_verify: dict): + # circuit + try: + self.circuit: pcvl.ACircuit = config_to_verify["circuit"] + except KeyError: + raise KeyError(f"There must be a key 'circuit' in the configs dictionary") + if not isinstance(self.circuit, pcvl.ACircuit): + raise ValueError( + f"The 'circuit' key of the config dictionary must be a Perceval ACircuit, got {type(self.circuit)}" + ) + + # input_state + try: + self.input_state: Sequence[Integral] | None = config_to_verify[ + "input_state" + ] + except KeyError: + raise KeyError( + "There must be a key 'input_state' in the configs dictionary" + ) + + if self.input_state is not None: + if not isinstance(self.input_state, Sequence): + raise ValueError( + f"'input_state' must be a sequence of integers or None, " + f"got {type(self.input_state).__name__}" + ) + + bad_types = { + type(x).__name__ + for x in self.input_state + if not isinstance(x, Integral) + } + + if bad_types: + raise ValueError( + f"'input_state' must contain only integers. " + f"Got sequence type {type(self.input_state).__name__} " + f"with non-integer element types: {sorted(bad_types)}" + ) + + # input_param_order + try: + self.input_param_order: Sequence[str] = config_to_verify[ + "input_param_order" + ] + except KeyError: + raise KeyError( + f"There must be a key 'input_param_order' in the configs dictionary" + ) + if not isinstance(self.input_param_order, Sequence): + raise ValueError( + f"'input_state' must be a sequence of strings, got {type(self.input_param_order).__name__}" + ) + + bad_types = { + type(x).__name__ for x in self.input_param_order if not isinstance(x, str) + } + + if bad_types: + raise ValueError( + f"'input_param_order' must contain only integers. " + f"Got sequence type {type(self.input_param_order).__name__} " + f"with non-integer element types: {sorted(bad_types)}" + ) + + class MerlinProcessor: """RPC-style processor for quantum execution. @@ -566,7 +635,7 @@ def _offload_quantum_layer_with_chunking( cache = self._layer_cache.get(id(layer)) if cache is None: - config = cast(Any, layer).export_config() + config = ValidatedLayerConfig(cast(Any, layer).export_config()) self._layer_cache[id(layer)] = {"config": config} else: config = cache["config"] @@ -586,7 +655,7 @@ def _offload_quantum_layer_with_chunking( def _run_chunks_pooled( self, layer: MerlinModule, - config: dict, + config: ValidatedLayerConfig, input_tensor: torch.Tensor, chunks: list[tuple[int, int]], nsample: int | None, @@ -655,7 +724,7 @@ def _call(s: int, e: int, idx: int): def _run_chunk( self, layer: MerlinModule, - config: dict, + config: ValidatedLayerConfig, input_chunk: torch.Tensor, nsample: int | None, state: dict, @@ -706,11 +775,11 @@ def _capped_name(base: str, cmd: str) -> str: # Build a fresh RemoteProcessor and Sampler on each attempt so that # a corrupted RP doesn't poison retries. rp = self._create_fresh_rp() - rp.set_circuit(config["circuit"]) - if config.get("input_state"): - input_state = pcvl.BasicState(config["input_state"]) + rp.set_circuit(config.circuit) + if config.input_state: + input_state = pcvl.BasicState(config.input_state) rp.with_input(input_state) - n_photons = sum(config["input_state"]) + n_photons = sum(config.input_state) rp.min_detected_photons_filter(n_photons) max_shots_arg = ( @@ -1055,9 +1124,9 @@ def _iter_layers_in_order(self, module: nn.Module) -> Iterable[nn.Module]: for child in children: yield from self._iter_layers_in_order(child) - def _extract_input_params(self, config: dict) -> list[str]: + def _extract_input_params(self, config: ValidatedLayerConfig) -> list[str]: """Extract circuit parameter names that correspond to model inputs.""" - return list(config["input_param_order"]) + return list(config.input_param_order) def _process_batch_results( self, @@ -1252,14 +1321,14 @@ def estimate_required_shots_per_input( else: raise ValueError("input must be 1D or 2D tensor") - config = cast(Any, layer).export_config() + config = ValidatedLayerConfig(cast(Any, layer).export_config()) child_rp = self._create_fresh_rp() - child_rp.set_circuit(config["circuit"]) + child_rp.set_circuit(config.circuit) - if config.get("input_state"): - input_state = pcvl.BasicState(config["input_state"]) + if config.input_state: + input_state = pcvl.BasicState(config.input_state) child_rp.with_input(input_state) - n_photons = sum(config["input_state"]) + n_photons = sum(config.input_state) child_rp.min_detected_photons_filter(n_photons) input_param_names = self._extract_input_params(config) From 164231f11ea2555b72982676ae18fc2be6144466 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Thu, 28 May 2026 17:06:05 -0400 Subject: [PATCH 02/17] Finshed implementation? Only correcting tests and implementing the new ones remaining? --- merlin/algorithms/module.py | 3 +- merlin/core/merlin_processor.py | 67 ++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/merlin/algorithms/module.py b/merlin/algorithms/module.py index 024236ab..fa9cbe52 100644 --- a/merlin/algorithms/module.py +++ b/merlin/algorithms/module.py @@ -23,7 +23,7 @@ from __future__ import annotations from contextlib import contextmanager - +import uuid import torch import torch.nn as nn @@ -53,6 +53,7 @@ class MerlinModule(nn.Module): """ # -------------------- Execution policy & helpers -------------------- + uid = uuid.uuid4() @property def force_local(self) -> bool: diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index ed561ba4..ea072afc 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -7,7 +7,7 @@ from collections.abc import Iterable, Sequence from contextlib import suppress from dataclasses import dataclass -from typing import Any, cast +from typing import Any, cast, Protocol, runtime_checkable from numbers import Integral import numpy as np @@ -85,28 +85,34 @@ def __init__(self, config_to_verify: dict): # input_param_order try: - self.input_param_order: Sequence[str] = config_to_verify[ + self.input_param_order: Sequence[str] | None = config_to_verify[ "input_param_order" ] except KeyError: - raise KeyError( - f"There must be a key 'input_param_order' in the configs dictionary" - ) - if not isinstance(self.input_param_order, Sequence): - raise ValueError( - f"'input_state' must be a sequence of strings, got {type(self.input_param_order).__name__}" - ) + self.input_param_order = None + if self.input_param_order is not None: + if not isinstance(self.input_param_order, Sequence): + raise ValueError( + f"'input_state' must be a sequence of strings or None, got {type(self.input_param_order).__name__}" + ) - bad_types = { - type(x).__name__ for x in self.input_param_order if not isinstance(x, str) - } + bad_types = { + type(x).__name__ + for x in self.input_param_order + if not isinstance(x, str) + } + + if bad_types: + raise ValueError( + f"'input_param_order' must contain only strings. " + f"Got sequence type {type(self.input_param_order).__name__} " + f"with non-integer element types: {sorted(bad_types)}" + ) - if bad_types: - raise ValueError( - f"'input_param_order' must contain only integers. " - f"Got sequence type {type(self.input_param_order).__name__} " - f"with non-integer element types: {sorted(bad_types)}" - ) + +@runtime_checkable +class SupportsExportConfig(Protocol): + def export_config(self) -> dict: ... class MerlinProcessor: @@ -633,10 +639,14 @@ def _offload_quantum_layer_with_chunking( if input_tensor.is_cuda: input_tensor = input_tensor.cpu() - cache = self._layer_cache.get(id(layer)) + cache = self._layer_cache.get(layer.uid) if cache is None: - config = ValidatedLayerConfig(cast(Any, layer).export_config()) - self._layer_cache[id(layer)] = {"config": config} + if not isinstance(layer, SupportsExportConfig): + raise TypeError( + "The layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|None, 'input_param_order': Sequence[str]|None}." + ) + config = ValidatedLayerConfig(layer.export_config()) + self._layer_cache[layer.uid] = {"config": config} else: config = cache["config"] @@ -1309,10 +1319,11 @@ def estimate_required_shots_per_input( ValueError If ``input`` is not one- or two-dimensional. """ - if not hasattr(layer, "export_config") or not callable( - cast(Any, layer).export_config - ): - raise TypeError("layer must provide export_config() for shot estimation") + if not isinstance(layer, SupportsExportConfig): + raise TypeError( + "For shot estimation, the layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|None, 'input_param_order': Sequence[str]|None}." + ) + config = ValidatedLayerConfig(layer.export_config()) if input.dim() == 1: x = input.unsqueeze(0) @@ -1321,7 +1332,11 @@ def estimate_required_shots_per_input( else: raise ValueError("input must be 1D or 2D tensor") - config = ValidatedLayerConfig(cast(Any, layer).export_config()) + if not isinstance(layer, SupportsExportConfig): + raise TypeError( + "The layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|None, 'input_param_order': Sequence[str]|None}." + ) + config = ValidatedLayerConfig(layer.export_config()) child_rp = self._create_fresh_rp() child_rp.set_circuit(config.circuit) From 1b5d503c13df5c0ca64f09344b4deab6a0f5db75 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Thu, 28 May 2026 17:14:10 -0400 Subject: [PATCH 03/17] Specified the errors even more --- merlin/core/merlin_processor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index ea072afc..591d630f 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -47,10 +47,12 @@ def __init__(self, config_to_verify: dict): try: self.circuit: pcvl.ACircuit = config_to_verify["circuit"] except KeyError: - raise KeyError(f"There must be a key 'circuit' in the configs dictionary") + raise KeyError( + f"There must be a key 'circuit' in the configs dictionary that is associated with a perceval.ACircuit" + ) if not isinstance(self.circuit, pcvl.ACircuit): raise ValueError( - f"The 'circuit' key of the config dictionary must be a Perceval ACircuit, got {type(self.circuit)}" + f"The 'circuit' key of the config dictionary must be a perceval.ACircuit, got {type(self.circuit)}" ) # input_state @@ -60,7 +62,7 @@ def __init__(self, config_to_verify: dict): ] except KeyError: raise KeyError( - "There must be a key 'input_state' in the configs dictionary" + "There must be a key 'input_state' in the configs dictionary that is associated to None or a Sequence[Integral]." ) if self.input_state is not None: From cff6c34e0530e49b17738eec210454a632936f0f Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Thu, 28 May 2026 17:39:07 -0400 Subject: [PATCH 04/17] Fixed the test to modify with the new validated config class --- merlin/core/merlin_processor.py | 59 +++++++++++++++------------ tests/core/cloud/test_parammapping.py | 57 +++++++++++++------------- 2 files changed, 62 insertions(+), 54 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 591d630f..2b8b2e45 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -41,6 +41,15 @@ class BackendCapabilities: available_commands: tuple[str] +_ALLOWED_STATE_TYPES = ( + pcvl.StateVector, + pcvl.FockState, + pcvl.NoisyFockState, + pcvl.BasicState, + pcvl.LogicalState, +) + + class ValidatedLayerConfig: def __init__(self, config_to_verify: dict): # circuit @@ -57,33 +66,33 @@ def __init__(self, config_to_verify: dict): # input_state try: - self.input_state: Sequence[Integral] | None = config_to_verify[ - "input_state" - ] + self.input_state: pcvl.ACircuit = config_to_verify["input_state"] except KeyError: raise KeyError( - "There must be a key 'input_state' in the configs dictionary that is associated to None or a Sequence[Integral]." + f"There must be a key 'input_state' in the configs dictionary that is associated with a Sequence[Integral], a Perceval State object or None." ) - if self.input_state is not None: - if not isinstance(self.input_state, Sequence): + if isinstance(self.input_state, _ALLOWED_STATE_TYPES): + pass + elif not isinstance(self.input_state, Sequence): raise ValueError( - f"'input_state' must be a sequence of integers or None, " - f"got {type(self.input_state).__name__}" - ) - - bad_types = { - type(x).__name__ - for x in self.input_state - if not isinstance(x, Integral) - } - - if bad_types: - raise ValueError( - f"'input_state' must contain only integers. " - f"Got sequence type {type(self.input_state).__name__} " - f"with non-integer element types: {sorted(bad_types)}" + "'input_state' must be None, a sequence of integers, " + "or an Perceval state object " + f"(got {type(self.input_state).__name__})" ) + else: + bad_types = { + type(x).__name__ + for x in self.input_state + if not isinstance(x, Integral) + } + + if bad_types: + raise ValueError( + f"'input_state' must contain only integers when it is a sequence. " + f"Got sequence type {type(self.input_state).__name__} " + f"with non-integer element types: {sorted(bad_types)}" + ) # input_param_order try: @@ -95,7 +104,7 @@ def __init__(self, config_to_verify: dict): if self.input_param_order is not None: if not isinstance(self.input_param_order, Sequence): raise ValueError( - f"'input_state' must be a sequence of strings or None, got {type(self.input_param_order).__name__}" + f"'input_param_order' must be a sequence of strings or None, got {type(self.input_param_order).__name__}" ) bad_types = { @@ -645,7 +654,7 @@ def _offload_quantum_layer_with_chunking( if cache is None: if not isinstance(layer, SupportsExportConfig): raise TypeError( - "The layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|None, 'input_param_order': Sequence[str]|None}." + "The layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|'perceval state object'|None, 'input_param_order': Sequence[str]|None}." ) config = ValidatedLayerConfig(layer.export_config()) self._layer_cache[layer.uid] = {"config": config} @@ -1323,7 +1332,7 @@ def estimate_required_shots_per_input( """ if not isinstance(layer, SupportsExportConfig): raise TypeError( - "For shot estimation, the layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|None, 'input_param_order': Sequence[str]|None}." + "For shot estimation, the layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|'perceval state object'|None, 'input_param_order': Sequence[str]|None}." ) config = ValidatedLayerConfig(layer.export_config()) @@ -1336,7 +1345,7 @@ def estimate_required_shots_per_input( if not isinstance(layer, SupportsExportConfig): raise TypeError( - "The layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, 'input_state': Sequence[Integral]|None, 'input_param_order': Sequence[str]|None}." + "The layer must have a export_config() method returning a dictionary of this type: {'circuit':perceval.ACircuit, Sequence[Integral]|'perceval state object'|None, 'input_param_order': Sequence[str]|None}." ) config = ValidatedLayerConfig(layer.export_config()) child_rp = self._create_fresh_rp() diff --git a/tests/core/cloud/test_parammapping.py b/tests/core/cloud/test_parammapping.py index 5f7af02c..20b305f9 100644 --- a/tests/core/cloud/test_parammapping.py +++ b/tests/core/cloud/test_parammapping.py @@ -34,7 +34,7 @@ from merlin.algorithms.layer import QuantumLayer from merlin.builder.circuit_builder import CircuitBuilder from merlin.core.computation_space import ComputationSpace -from merlin.core.merlin_processor import MerlinProcessor +from merlin.core.merlin_processor import MerlinProcessor, ValidatedLayerConfig from merlin.measurement import MeasurementStrategy # --------------------------------------------------------------------------- @@ -177,10 +177,9 @@ def _expected_from_converter(layer: QuantumLayer) -> list[str]: return out -def _assert_config_contract(cfg: dict): - assert "input_param_order" in cfg, f"export_config keys={sorted(cfg.keys())}" - assert isinstance(cfg["input_param_order"], list) - assert all(isinstance(x, str) for x in cfg["input_param_order"]) +def _assert_config_contract(cfg: ValidatedLayerConfig): + assert isinstance(cfg.input_param_order, list) + assert all(isinstance(x, str) for x in cfg.input_param_order) # --------------------------------------------------------------------------- @@ -192,42 +191,42 @@ class TestPercevalUserBuilt: @pytest.mark.parametrize("n", [3, 5, 10, 12]) def test_export_matches_converter_and_is_numeric(self, n): layer = _make_perceval_layer(n, prefix="px") - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) _assert_config_contract(cfg) - assert cfg["input_param_order"] == _expected_from_converter(layer) - assert cfg["input_param_order"] == [f"px{i + 1}" for i in range(n)] + assert cfg.input_param_order == _expected_from_converter(layer) + assert cfg.input_param_order == [f"px{i + 1}" for i in range(n)] def test_two_prefixes_12_each(self): layer = _make_perceval_layer_two_prefixes( prefixes=["a", "b"], counts=[12, 12], m=24 ) - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) _assert_config_contract(cfg) expected = [f"a{i + 1}" for i in range(12)] + [f"b{i + 1}" for i in range(12)] - assert cfg["input_param_order"] == expected + assert cfg.input_param_order == expected def test_reversed_prefix_order(self): layer = _make_perceval_layer_two_prefixes( prefixes=["beta", "alpha"], counts=[4, 4], m=8 ) - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) _assert_config_contract(cfg) expected = [f"beta{i + 1}" for i in range(4)] + [ f"alpha{i + 1}" for i in range(4) ] - assert cfg["input_param_order"] == expected + assert cfg.input_param_order == expected @pytest.mark.parametrize("n", [10, 12]) def test_merlinprocessor_extract_and_route(self, n): layer = _make_perceval_layer(n, prefix="px") - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) proc = _mock_processor() names = proc._extract_input_params(cfg) - assert names == cfg["input_param_order"] + assert names == cfg.input_param_order row = np.array([(j + 1) * 0.1 for j in range(n)], dtype=float) params = {name: float(row[j]) for j, name in enumerate(names)} @@ -239,7 +238,7 @@ def test_merlinprocessor_extract_and_route(self, n): def test_user_scenario_2ph_12logical_24modes(self): layer = _make_perceval_layer(12, n_physical=24, n_photons=2, prefix="px") - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) proc = _mock_processor() names = proc._extract_input_params(cfg) @@ -261,15 +260,15 @@ class TestBuilderDeclarative: @pytest.mark.parametrize("n", [5, 10, 12]) def test_builder_export_matches_converter_and_routes(self, n): layer, _b = _make_builder_layer(n, include_trainable=True, scale=1.0) - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) _assert_config_contract(cfg) - assert cfg["input_param_order"] == _expected_from_converter(layer) - assert len(cfg["input_param_order"]) == n + assert cfg.input_param_order == _expected_from_converter(layer) + assert len(cfg.input_param_order) == n proc = _mock_processor() names = proc._extract_input_params(cfg) - assert names == cfg["input_param_order"] + assert names == cfg.input_param_order row = np.array([(j + 1) * 0.1 for j in range(n)], dtype=float) params = {name: float(row[j]) for j, name in enumerate(names)} @@ -278,21 +277,21 @@ def test_builder_export_matches_converter_and_routes(self, n): def test_builder_no_trainable_leakage(self): layer, _b = _make_builder_layer(12, include_trainable=True, scale=1.0) - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) _assert_config_contract(cfg) - for name in cfg["input_param_order"]: - assert not name.startswith("W"), ( - f"trainable leaked into input_param_order: {name}" - ) + for name in cfg.input_param_order: + assert not name.startswith( + "W" + ), f"trainable leaked into input_param_order: {name}" def test_builder_input_only_no_trainables(self): layer, _b = _make_builder_layer(8, include_trainable=False, scale=1.0) - cfg = layer.export_config() + cfg = ValidatedLayerConfig(layer.export_config()) _assert_config_contract(cfg) - assert cfg["input_param_order"] == _expected_from_converter(layer) - assert len(cfg["input_param_order"]) == 8 + assert cfg.input_param_order == _expected_from_converter(layer) + assert len(cfg.input_param_order) == 8 def test_builder_scale_does_not_change_order(self): layer1, b1 = _make_builder_layer(10, include_trainable=True, scale=1.0) @@ -300,8 +299,8 @@ def test_builder_scale_does_not_change_order(self): cfg1 = layer1.export_config() cfg2 = layer2.export_config() - _assert_config_contract(cfg1) - _assert_config_contract(cfg2) + _assert_config_contract(ValidatedLayerConfig(cfg1)) + _assert_config_contract(ValidatedLayerConfig(cfg2)) assert cfg1["input_param_order"] == cfg2["input_param_order"] assert cfg1["input_param_order"] == _expected_from_converter(layer1) From f46d29bdaec31e5d0974794b6c7fad202a061b53 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:26:35 -0400 Subject: [PATCH 05/17] every tests done --- merlin/algorithms/module.py | 1 + merlin/core/merlin_processor.py | 233 +++++++++++++-- tests/core/test_merlin_processor_unit.py | 357 ++++++++++++++++++++++- 3 files changed, 554 insertions(+), 37 deletions(-) diff --git a/merlin/algorithms/module.py b/merlin/algorithms/module.py index fa9cbe52..fa41cc66 100644 --- a/merlin/algorithms/module.py +++ b/merlin/algorithms/module.py @@ -24,6 +24,7 @@ from contextlib import contextmanager import uuid + import torch import torch.nn as nn diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 2b8b2e45..9387473c 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -7,8 +7,8 @@ from collections.abc import Iterable, Sequence from contextlib import suppress from dataclasses import dataclass -from typing import Any, cast, Protocol, runtime_checkable from numbers import Integral +from typing import Any, Protocol, cast, runtime_checkable import numpy as np import perceval as pcvl @@ -50,49 +50,198 @@ class BackendCapabilities: ) +def check_sequence(input: Any) -> bool | Sequence: + """ + Check whether an object can be treated as a sequence. + + Parameters + ---------- + input : Any + Object to validate. + + Returns + ------- + Sequence | bool + The original object if it is an instance of + ``collections.abc.Sequence``. + + Otherwise, if the object is iterable, a tuple containing its + elements. + + Returns ``False`` if the object is not iterable. + + Notes + ----- + This helper accepts objects that are not instances of + ``collections.abc.Sequence`` but can be iterated over, such as + NumPy arrays and PyTorch tensors. Such objects are converted to + tuples before being returned. + + Examples + -------- + >>> check_sequence([1, 2, 3]) + [1, 2, 3] + + >>> check_sequence((1, 2, 3)) + (1, 2, 3) + + >>> check_sequence(np.array([1, 2, 3])) + (1, 2, 3) + + >>> check_sequence(42) + False + """ + + if isinstance(input, Sequence): + return input + try: + values = tuple(input) + except TypeError: + return False + + return values + + class ValidatedLayerConfig: + """ + Validate and normalize the configuration dictionary returned by + ``export_config()``. + + Parameters + ---------- + config_to_verify : dict + Configuration dictionary containing the layer definition. + + Attributes + ---------- + circuit : pcvl.ACircuit + Perceval circuit associated with the layer. + + input_state : Sequence[Integral] | pcvl.BasicState | pcvl.StateVector | pcvl.BSDistribution | pcvl.SVDistribution | None + Input state for the circuit. May be ``None``, a sequence of integers, + or one of the supported Perceval state objects. Sequence-like inputs + are normalized through ``check_sequence()``. + + input_param_order : Sequence[str] | None + Ordered names of the circuit parameters expected by the layer. + Sequence-like inputs are normalized through ``check_sequence()``. + + Raises + ------ + KeyError + If one of the required configuration keys is missing: + + - ``"circuit"`` + - ``"input_state"`` + - ``"input_param_order"`` + + ValueError + If: + + - ``circuit`` is not a ``pcvl.ACircuit``. + - ``input_state`` is neither ``None``, a supported Perceval state + object, nor a sequence. + - ``input_state`` is a sequence containing non-integer elements. + - ``input_param_order`` is neither ``None`` nor a sequence. + - ``input_param_order`` contains non-string elements. + + Notes + ----- + Sequence validation relies on ``check_sequence()``. Accepted sequence + implementations may include Python sequences as well as array-like objects + supported by that helper. + """ + def __init__(self, config_to_verify: dict): + """ + Validate and normalize a layer configuration dictionary. + + Parameters + ---------- + config_to_verify : dict + Configuration dictionary containing the following required keys: + + - ``"circuit"``: a ``pcvl.ACircuit`` instance. + - ``"input_state"``: ``None``, a sequence of integers, or a supported + Perceval state object. + - ``"input_param_order"``: ``None`` or a sequence of strings. + + Raises + ------ + KeyError + If one of the required keys is missing from ``config_to_verify``. + + ValueError + If: + + - ``config_to_verify["circuit"]`` is not a ``pcvl.ACircuit``. + - ``config_to_verify["input_state"]`` is neither ``None``, a valid + Perceval state object, nor a sequence. + - ``config_to_verify["input_state"]`` contains non-integer elements. + - ``config_to_verify["input_param_order"]`` is neither ``None`` nor a + sequence. + - ``config_to_verify["input_param_order"]`` contains non-string + elements. + + Notes + ----- + Sequence-like inputs are normalized using ``check_sequence()``. Objects + that are iterable but not instances of ``collections.abc.Sequence`` + (e.g. NumPy arrays or PyTorch tensors) may therefore be accepted and + converted to tuples. + """ # circuit try: self.circuit: pcvl.ACircuit = config_to_verify["circuit"] except KeyError: raise KeyError( - f"There must be a key 'circuit' in the configs dictionary that is associated with a perceval.ACircuit" + "There must be a key 'circuit' in the configs dictionary that is associated with a perceval.ACircuit." ) if not isinstance(self.circuit, pcvl.ACircuit): raise ValueError( - f"The 'circuit' key of the config dictionary must be a perceval.ACircuit, got {type(self.circuit)}" + f"The 'circuit' key of the config dictionary must be a perceval.ACircuit, got {type(self.circuit)}." ) # input_state try: - self.input_state: pcvl.ACircuit = config_to_verify["input_state"] + self.input_state: ( + Sequence[Integral] + | pcvl.BasicState + | pcvl.StateVector + | pcvl.BSDistribution + | pcvl.SVDistribution + | None + ) = config_to_verify["input_state"] except KeyError: raise KeyError( - f"There must be a key 'input_state' in the configs dictionary that is associated with a Sequence[Integral], a Perceval State object or None." + "There must be a key 'input_state' in the configs dictionary that is associated with a Sequence[Integral], a Perceval State object or None." ) if self.input_state is not None: if isinstance(self.input_state, _ALLOWED_STATE_TYPES): pass - elif not isinstance(self.input_state, Sequence): - raise ValueError( - "'input_state' must be None, a sequence of integers, " - "or an Perceval state object " - f"(got {type(self.input_state).__name__})" - ) - else: - bad_types = { - type(x).__name__ - for x in self.input_state - if not isinstance(x, Integral) - } - if bad_types: + else: + input_state_sequence = check_sequence(self.input_state) + if not input_state_sequence: raise ValueError( - f"'input_state' must contain only integers when it is a sequence. " - f"Got sequence type {type(self.input_state).__name__} " - f"with non-integer element types: {sorted(bad_types)}" + "'input_state' must be None, a sequence of integers, " + "or an Perceval state object " + f"(got {type(self.input_state).__name__})." ) + else: + self.input_state = input_state_sequence + bad_types = { + type(x).__name__ + for x in self.input_state + if not isinstance(x, Integral) + } + + if bad_types: + raise ValueError( + f"'input_state' must contain only integers when it is a sequence. " + f"Got sequence type {type(self.input_state).__name__} " + f"with non-integer element types: {sorted(bad_types)}." + ) # input_param_order try: @@ -100,13 +249,16 @@ def __init__(self, config_to_verify: dict): "input_param_order" ] except KeyError: - self.input_param_order = None + raise KeyError( + f"There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." + ) if self.input_param_order is not None: - if not isinstance(self.input_param_order, Sequence): + input_param_order_sequence = check_sequence(self.input_param_order) + if not input_param_order_sequence: raise ValueError( - f"'input_param_order' must be a sequence of strings or None, got {type(self.input_param_order).__name__}" + f"'input_param_order' must be a sequence of strings or None, got {type(self.input_param_order).__name__}." ) - + self.input_param_order = input_param_order_sequence bad_types = { type(x).__name__ for x in self.input_param_order @@ -117,13 +269,40 @@ def __init__(self, config_to_verify: dict): raise ValueError( f"'input_param_order' must contain only strings. " f"Got sequence type {type(self.input_param_order).__name__} " - f"with non-integer element types: {sorted(bad_types)}" + f"with non-integer element types: {sorted(bad_types)}." ) @runtime_checkable class SupportsExportConfig(Protocol): - def export_config(self) -> dict: ... + """ + Protocol for objects that can export their configuration as a dictionary. + + Implementations must provide an ``export_config()`` method returning a + dictionary containing the information required to reconstruct or validate + the object's configuration. + + Notes + ----- + This protocol is marked as ``@runtime_checkable``, allowing runtime checks + with ``isinstance()`` and ``issubclass()``. + + Examples + -------- + >>> isinstance(obj, SupportsExportConfig) + True + """ + + def export_config(self) -> dict: + """ + Export the object's configuration. + + Returns + ------- + dict + Dictionary containing the configuration of the object. + """ + ... class MerlinProcessor: diff --git a/tests/core/test_merlin_processor_unit.py b/tests/core/test_merlin_processor_unit.py index a62556d6..1d755cea 100644 --- a/tests/core/test_merlin_processor_unit.py +++ b/tests/core/test_merlin_processor_unit.py @@ -15,7 +15,19 @@ from perceval.runtime.session import ISession import merlin.core.merlin_processor as merlin_processor_module -from merlin.core.merlin_processor import BackendCapabilities, MerlinProcessor +from merlin.core.merlin_processor import ( + BackendCapabilities, + MerlinProcessor, + ValidatedLayerConfig, + SupportsExportConfig, +) +from collections.abc import Sequence +from numbers import Integral +import perceval as pcvl +import numpy as np +import re +from merlin.core.circuit import Circuit +from merlin.core.state_vector import StateVector class FakeCommand: @@ -122,6 +134,8 @@ def make_processor(available_commands: list[str]) -> MerlinProcessor: ) proc._lock = threading.Lock() proc._active_jobs = set() + proc._layer_cache = {} + proc.microbatch_size = 32 return proc @@ -281,7 +295,7 @@ def test_session_path_with_empty_commands_and_sampling_only(): session.build_remote_processor.return_value = remote_processor with patch.object(MerlinProcessor, "_extract_rp_token", return_value="token"): with pytest.warns( - UserWarning, match="Remote processor has no available commands" + UserWarning, match=r"Remote processor has no available commands" ): proc = MerlinProcessor(session=session) sampler = FakeSampler() @@ -693,7 +707,7 @@ def test_poll_job_failed_status_raises_with_stop_message_and_job_id(): ) proc._active_jobs.add(job) - with pytest.raises(RuntimeError, match="hardware rejected job.*job-failed"): + with pytest.raises(RuntimeError, match=r"hardware rejected job.*job-failed"): proc._poll_job(job, make_state(), None, 1, object(), None) assert job not in proc._active_jobs @@ -706,7 +720,7 @@ def test_poll_job_cancel_request_cancels_remote_job(): state = make_state() state["cancel_requested"] = True - with pytest.raises(CancelledError, match="Remote call was cancelled"): + with pytest.raises(CancelledError, match=r"Remote call was cancelled"): proc._poll_job(job, state, None, 1, object(), None) assert job.cancelled is True @@ -717,7 +731,7 @@ def test_poll_job_timeout_cancels_remote_job(): proc = make_poll_processor() job = FakeJob(is_complete=False) - with pytest.raises(TimeoutError, match="remote cancel issued"): + with pytest.raises(TimeoutError, match=r"remote cancel issued"): proc._poll_job(job, make_state(), time.time() - 1.0, 1, object(), None) assert job.cancelled is True @@ -733,7 +747,7 @@ def test_poll_job_cancel_requested_stop_message_raises_cancelled_error(): ) proc._active_jobs.add(job) - with pytest.raises(CancelledError, match="Remote call was cancelled"): + with pytest.raises(CancelledError, match=r"Remote call was cancelled"): proc._poll_job(job, make_state(), None, 1, object(), None) assert job not in proc._active_jobs @@ -745,7 +759,7 @@ def test_poll_job_cancel_requested_get_results_exception_raises_cancelled_error( job = FakeJob(result_events=[RuntimeError("Cancel requested on backend")]) proc._active_jobs.add(job) - with pytest.raises(CancelledError, match="Remote call was cancelled"): + with pytest.raises(CancelledError, match=r"Remote call was cancelled"): proc._poll_job(job, make_state(), None, 1, object(), None) assert job not in proc._active_jobs @@ -777,7 +791,7 @@ def test_poll_job_retries_complete_non_dict_payloads_then_fails(monkeypatch): job = FakeJob(job_id="job-nondict", result_events=[["not", "a", "dict"]]) proc._active_jobs.add(job) - with pytest.raises(RuntimeError, match="not a dict after 60 re-polls"): + with pytest.raises(RuntimeError, match=r"not a dict after 60 re-polls"): proc._poll_job(job, make_state(), None, 1, object(), None) assert job.get_results_calls == 60 @@ -788,7 +802,7 @@ def test_process_batch_results_rejects_missing_payload(): """A missing backend payload currently raises a runtime failure.""" proc = make_processor(["probs"]) - with pytest.raises(RuntimeError, match="returned no results"): + with pytest.raises(RuntimeError, match=r"returned no results"): proc._process_batch_results(None, 1, FakeLayer()) @@ -796,7 +810,7 @@ def test_process_batch_results_rejects_non_dict_payload(): """A non-dict backend payload currently raises a runtime failure.""" proc = make_processor(["probs"]) - with pytest.raises(RuntimeError, match="Unexpected remote results type"): + with pytest.raises(RuntimeError, match=r"Unexpected remote results type"): proc._process_batch_results(["not", "a", "dict"], 1, FakeLayer()) @@ -1041,3 +1055,326 @@ def test_create_fresh_rp_remote_processor_path_with_cloning_disabled(): mock_clone.assert_called_once_with(original_rp) assert fresh_rp is mock_clone.return_value + + +def test_different_valid_configs(): + # BasicState input + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.BasicState("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, pcvl.BasicState) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # StateVector input + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.StateVector("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, pcvl.StateVector) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # FockState input + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.FockState("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, pcvl.FockState) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # NoisyFockState input + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.NoisyFockState(pcvl.FockState([1, 0])), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, pcvl.NoisyFockState) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # LogicalState input + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.LogicalState([1, 0]), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, pcvl.LogicalState) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # Sequence input list + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": [1, 0], + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, Sequence) + for i in v_config.input_state: + assert isinstance(i, Integral) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # Sequence input torch ints as tuple + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": tuple([torch.tensor(1).item(), torch.tensor(0).item()]), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, Sequence) + for i in v_config.input_state: + assert isinstance(i, Integral) + assert v_config.input_state == config["input_state"] + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + # Sequence input array + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": np.array([1, 0]), + "input_param_order": ["px", "el", "s"], + } + v_config = ValidatedLayerConfig(config) + assert isinstance(v_config.circuit, pcvl.Circuit) + assert v_config.circuit == config["circuit"] + assert isinstance(v_config.input_state, Sequence) + for i in v_config.input_state: + assert isinstance(i, Integral) + assert v_config.input_state == tuple(config["input_state"]) + assert isinstance(v_config.input_param_order, Sequence) + for i in v_config.input_param_order: + assert isinstance(i, str) + assert v_config.input_param_order == config["input_param_order"] + + +def test_missing_required_fiels_in_configs(): + # Missing Circuit + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.BasicState("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + + # Missing Circuit + config = { + "input_state": pcvl.BasicState("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + with pytest.raises( + KeyError, + match=r"There must be a key 'circuit' in the configs dictionary that is associated with a perceval.ACircuit.", + ): + v_config = ValidatedLayerConfig(config) + + # Missing State + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_param_order": ["px", "el", "s"], + } + with pytest.raises( + KeyError, + match=r".*There must be a key 'input_state' in the configs dictionary.*", + ): + v_config = ValidatedLayerConfig(config) + + # Missing input param order + config = { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": pcvl.BasicState("|1,0>"), + } + with pytest.raises( + KeyError, + match=r".*There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence\[str\] or None\..*", + ): + v_config = ValidatedLayerConfig(config) + + +def test_wrong_types_config(): + # Bad Circuit + config = { + "circuit": None, + "input_state": pcvl.BasicState("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + with pytest.raises( + ValueError, + match=r"The 'circuit' key of the config dictionary must be a perceval.ACircuit", + ): + v_config = ValidatedLayerConfig(config) + + config = { + "circuit": Circuit(n_modes=2, components=[pcvl.components.BS()]), + "input_state": pcvl.BasicState("|1,0>"), + "input_param_order": ["px", "el", "s"], + } + with pytest.raises( + ValueError, + match=r"The 'circuit' key of the config dictionary must be a perceval.ACircuit", + ): + v_config = ValidatedLayerConfig(config) + + # input_state + config = { + "circuit": pcvl.Circuit(m=2), + "input_state": [1.1, 2.0], + "input_param_order": ["px", "el", "s"], + } + with pytest.raises( + ValueError, + match=r"'input_state' must contain only integers when it is a sequence.", + ): + v_config = ValidatedLayerConfig(config) + + config = { + "circuit": pcvl.Circuit(m=2), + "input_state": StateVector(torch.tensor([1, 0]), n_modes=2, n_photons=1), + "input_param_order": ["px", "el", "s"], + } + with pytest.raises( + ValueError, + match=r"'input_state' must be None, a sequence of integers, or an Perceval state object.", + ): + v_config = ValidatedLayerConfig(config) + + # input_param_order + config = { + "circuit": pcvl.Circuit(m=2), + "input_state": [2, 0], + "input_param_order": 3, + } + with pytest.raises( + ValueError, + match=r"'input_param_order' must be a sequence of strings or None, got int.", + ): + v_config = ValidatedLayerConfig(config) + + config = { + "circuit": pcvl.Circuit(m=2), + "input_state": [1, 2], + "input_param_order": [11, 2, 1], + } + with pytest.raises( + ValueError, + match=r"'input_param_order' must contain only strings.", + ): + v_config = ValidatedLayerConfig(config) + + +def test_has_export_config(): + class GoodLayer: + def __init__(self): + pass + + def export_config(): + return { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": [1, 0], + "input_param_order": ["px", "el", "s"], + } + + class BadLayer: + def __init__(self): + pass + + assert isinstance(GoodLayer(), SupportsExportConfig) + assert not isinstance(BadLayer(), SupportsExportConfig) + + +def test_offload_quantum_layer_with_chunking_validates_and_caches_export_config(): + proc = make_processor(["probs", "sample_count"]) + + class LayerWithExportConfig: + uid = 42 + + def __init__(self) -> None: + self.export_config_calls = 0 + + def export_config(self): + self.export_config_calls += 1 + return { + "circuit": pcvl.Circuit(m=2, name="Circuit"), + "input_state": [1, 0], + "input_param_order": ["px", "el", "s"], + } + + layer = LayerWithExportConfig() + + def fake_run_chunks_pooled( + layer_arg, config, input_tensor, chunks, nsample, state, deadline + ): + assert layer_arg is layer + assert isinstance(config, ValidatedLayerConfig) + assert isinstance(config.circuit, pcvl.Circuit) + assert config.input_param_order == ["px", "el", "s"] + return torch.tensor([[1.0]]) + + proc._run_chunks_pooled = fake_run_chunks_pooled + + result = proc._offload_quantum_layer_with_chunking( + layer, + torch.zeros(1, 2), + None, + {}, + None, + ) + + assert torch.equal(result, torch.tensor([[1.0]])) + assert layer.export_config_calls == 1 + assert layer.uid in proc._layer_cache + assert isinstance(proc._layer_cache[layer.uid]["config"], ValidatedLayerConfig) + + # Calling again should reuse the cached config and not call export_config again. + proc._offload_quantum_layer_with_chunking( + layer, + torch.zeros(1, 2), + None, + {}, + None, + ) + assert layer.export_config_calls == 1 From 449835fac59fb1044f1f9d30257ea1aa88ac8ee6 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:29:57 -0400 Subject: [PATCH 06/17] ruff check --- merlin/algorithms/module.py | 3 ++- merlin/core/merlin_processor.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/merlin/algorithms/module.py b/merlin/algorithms/module.py index fa41cc66..2452b4f3 100644 --- a/merlin/algorithms/module.py +++ b/merlin/algorithms/module.py @@ -22,9 +22,10 @@ from __future__ import annotations -from contextlib import contextmanager import uuid +from contextlib import contextmanager + import torch import torch.nn as nn diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 9387473c..ffad3bc8 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -250,7 +250,7 @@ def __init__(self, config_to_verify: dict): ] except KeyError: raise KeyError( - f"There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." + "There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." ) if self.input_param_order is not None: input_param_order_sequence = check_sequence(self.input_param_order) From 6cf135bc5d7ddb42b691d01accbc61294444cc40 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:32:36 -0400 Subject: [PATCH 07/17] ruff --- merlin/algorithms/module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/merlin/algorithms/module.py b/merlin/algorithms/module.py index 2452b4f3..63d55ddd 100644 --- a/merlin/algorithms/module.py +++ b/merlin/algorithms/module.py @@ -23,7 +23,6 @@ from __future__ import annotations import uuid - from contextlib import contextmanager import torch From 8ae242792237f1bcfc408eecb3a46c6bc4ac3306 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:44:04 -0400 Subject: [PATCH 08/17] mypy --- merlin/core/merlin_processor.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index ffad3bc8..16b729fd 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -221,20 +221,21 @@ def __init__(self, config_to_verify: dict): pass else: - input_state_sequence = check_sequence(self.input_state) - if not input_state_sequence: + input_state_sequence: Sequence[Integral] | False = check_sequence( + self.input_state + ) + if input_state_sequence is False: raise ValueError( "'input_state' must be None, a sequence of integers, " "or an Perceval state object " f"(got {type(self.input_state).__name__})." ) - else: - self.input_state = input_state_sequence - bad_types = { - type(x).__name__ - for x in self.input_state - if not isinstance(x, Integral) - } + self.input_state = input_state_sequence + bad_types = { + type(x).__name__ + for x in self.input_state + if not isinstance(x, Integral) + } if bad_types: raise ValueError( @@ -253,8 +254,10 @@ def __init__(self, config_to_verify: dict): "There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." ) if self.input_param_order is not None: - input_param_order_sequence = check_sequence(self.input_param_order) - if not input_param_order_sequence: + input_param_order_sequence: Sequence[str] | False = check_sequence( + self.input_param_order + ) + if input_param_order_sequence is False: raise ValueError( f"'input_param_order' must be a sequence of strings or None, got {type(self.input_param_order).__name__}." ) @@ -509,7 +512,7 @@ def __init__( self.chunk_concurrency = max(1, int(chunk_concurrency)) # Caches & global tracking - self._layer_cache: dict[int, dict] = {} + self._layer_cache: dict[uuid.UUID, dict[str, Any]] = {} self._job_history: list[RemoteJob] = [] # Lifecycle/cancellation From 225aa682a646361beec9a01791bcb5c45cbe4954 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:44:52 -0400 Subject: [PATCH 09/17] Bad update --- merlin/core/merlin_processor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 16b729fd..9b4a6419 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -237,12 +237,12 @@ def __init__(self, config_to_verify: dict): if not isinstance(x, Integral) } - if bad_types: - raise ValueError( - f"'input_state' must contain only integers when it is a sequence. " - f"Got sequence type {type(self.input_state).__name__} " - f"with non-integer element types: {sorted(bad_types)}." - ) + if bad_types: + raise ValueError( + f"'input_state' must contain only integers when it is a sequence. " + f"Got sequence type {type(self.input_state).__name__} " + f"with non-integer element types: {sorted(bad_types)}." + ) # input_param_order try: From 11df51d9a04e4fee6979032581c12107bcd98867 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:52:33 -0400 Subject: [PATCH 10/17] mypy --- merlin/core/merlin_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 9b4a6419..dbad02b3 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -8,7 +8,7 @@ from contextlib import suppress from dataclasses import dataclass from numbers import Integral -from typing import Any, Protocol, cast, runtime_checkable +from typing import Any, Literal, Protocol, cast, runtime_checkable import numpy as np import perceval as pcvl @@ -221,7 +221,7 @@ def __init__(self, config_to_verify: dict): pass else: - input_state_sequence: Sequence[Integral] | False = check_sequence( + input_state_sequence: Sequence[Integral] | Literal[False] = check_sequence( self.input_state ) if input_state_sequence is False: @@ -254,7 +254,7 @@ def __init__(self, config_to_verify: dict): "There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." ) if self.input_param_order is not None: - input_param_order_sequence: Sequence[str] | False = check_sequence( + input_param_order_sequence: Sequence[str] | Literal[False] = check_sequence( self.input_param_order ) if input_param_order_sequence is False: From 5bcc8b2a64c906538afb19fb21f91f4f60d6ada4 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 10:58:54 -0400 Subject: [PATCH 11/17] formatting --- merlin/core/merlin_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index dbad02b3..0b2f00c2 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -221,8 +221,8 @@ def __init__(self, config_to_verify: dict): pass else: - input_state_sequence: Sequence[Integral] | Literal[False] = check_sequence( - self.input_state + input_state_sequence: Sequence[Integral] | Literal[False] = ( + check_sequence(self.input_state) ) if input_state_sequence is False: raise ValueError( From bb0d0513bab8f3d977312724e9fdfc42c9dc6f05 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 11:05:10 -0400 Subject: [PATCH 12/17] mypy again --- merlin/core/merlin_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 0b2f00c2..719a27bd 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -221,8 +221,8 @@ def __init__(self, config_to_verify: dict): pass else: - input_state_sequence: Sequence[Integral] | Literal[False] = ( - check_sequence(self.input_state) + input_state_sequence: Sequence[Integral] | bool = check_sequence( + self.input_state ) if input_state_sequence is False: raise ValueError( @@ -254,7 +254,7 @@ def __init__(self, config_to_verify: dict): "There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." ) if self.input_param_order is not None: - input_param_order_sequence: Sequence[str] | Literal[False] = check_sequence( + input_param_order_sequence: Sequence[str] | bool = check_sequence( self.input_param_order ) if input_param_order_sequence is False: From e20d4e7a0975229a2f9a4ef2581c46321d2a6f8f Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 11:17:11 -0400 Subject: [PATCH 13/17] ruff --- merlin/core/merlin_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index 719a27bd..adcb8f83 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -8,7 +8,7 @@ from contextlib import suppress from dataclasses import dataclass from numbers import Integral -from typing import Any, Literal, Protocol, cast, runtime_checkable +from typing import Any, Protocol, cast, runtime_checkable import numpy as np import perceval as pcvl From a675f11cdf6c851ebd889ba1feb575eb133ec4e5 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 11:25:54 -0400 Subject: [PATCH 14/17] Changing the function for mypy --- merlin/core/merlin_processor.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index adcb8f83..a0228df9 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -50,7 +50,7 @@ class BackendCapabilities: ) -def check_sequence(input: Any) -> bool | Sequence: +def check_sequence(input: Any) -> Sequence[Any] | None: """ Check whether an object can be treated as a sequence. @@ -61,14 +61,14 @@ def check_sequence(input: Any) -> bool | Sequence: Returns ------- - Sequence | bool + Sequence | None The original object if it is an instance of ``collections.abc.Sequence``. Otherwise, if the object is iterable, a tuple containing its elements. - Returns ``False`` if the object is not iterable. + Returns None if the object is not iterable. Notes ----- @@ -89,17 +89,15 @@ def check_sequence(input: Any) -> bool | Sequence: (1, 2, 3) >>> check_sequence(42) - False + None """ - if isinstance(input, Sequence): + if isinstance(input, Sequence) and not isinstance(input, (str, bytes)): return input try: - values = tuple(input) + return tuple(input) except TypeError: - return False - - return values + return None class ValidatedLayerConfig: @@ -221,10 +219,10 @@ def __init__(self, config_to_verify: dict): pass else: - input_state_sequence: Sequence[Integral] | bool = check_sequence( + input_state_sequence: Sequence[Integral] | None = check_sequence( self.input_state ) - if input_state_sequence is False: + if input_state_sequence is None: raise ValueError( "'input_state' must be None, a sequence of integers, " "or an Perceval state object " @@ -254,10 +252,10 @@ def __init__(self, config_to_verify: dict): "There must be a key 'input_param_order' in the configs dictionary that is associated with a Sequence[str] or None." ) if self.input_param_order is not None: - input_param_order_sequence: Sequence[str] | bool = check_sequence( + input_param_order_sequence: Sequence[str] | None = check_sequence( self.input_param_order ) - if input_param_order_sequence is False: + if input_param_order_sequence is None: raise ValueError( f"'input_param_order' must be a sequence of strings or None, got {type(self.input_param_order).__name__}." ) From f943aa36e6d83495eb550c1c72c564a089fecc82 Mon Sep 17 00:00:00 2001 From: Benjamin STOTT Date: Fri, 29 May 2026 18:17:38 +0200 Subject: [PATCH 15/17] Fix Scaleway compatibility with Perceval 1.2 --- .../PML-302_scaleway_compatibility_notes.md | 243 ++++++++++++++++++ merlin/core/merlin_processor.py | 59 ++++- tests/core/test_merlin_processor_unit.py | 73 ++++++ 3 files changed, 364 insertions(+), 11 deletions(-) create mode 100644 .github/PML-302_scaleway_compatibility_notes.md diff --git a/.github/PML-302_scaleway_compatibility_notes.md b/.github/PML-302_scaleway_compatibility_notes.md new file mode 100644 index 00000000..c14d84f7 --- /dev/null +++ b/.github/PML-302_scaleway_compatibility_notes.md @@ -0,0 +1,243 @@ +# PML-302 Scaleway Compatibility Notes + +Branch under review: `PML-302-MerLinProcessor-export_config-and-typing` + +This note explains the local fixes applied while probing the Scaleway failures. +Nothing has been pushed. + +## Summary + +Two local fixes were applied: + +1. `MerlinProcessor(session=...)` no longer requires a token extracted from a + `RemoteProcessor`. +2. Perceval 1.2 sampler iteration payloads are normalized before remote + submission so Scaleway can JSON-serialize them. + +These address two different causes: + +- The Perceval iterator serialization issue affects `main` with + `perceval-quandela==1.2.1`. +- The session-token failure is caused by PML-302. + +## General Fix: Perceval 1.2 Iterator Compatibility + +### Problem + +Merlin batches remote inputs through Perceval `Sampler.add_iteration()`: + +```python +sampler.clear_iterations() +for params in iteration_params: + sampler.add_iteration(circuit_params=params) +``` + +With Perceval 1.1, sampler iterations were stored as a plain list, so the +Scaleway payload was JSON-serializable. + +With Perceval 1.2.1, the sampler stores iterations in a `ParameterIterator` +object. Perceval then places that object in the remote job payload, but the +Scaleway handler still calls: + +```python +json.dumps(payload.get("payload", {})) +``` + +That fails with: + +```text +TypeError: Object of type ParameterIterator is not JSON serializable +``` + +### Local Fix + +Added `_ensure_serializable_sampler_iterator()` in +`merlin/core/merlin_processor.py`. + +The helper checks for the Perceval 1.2 shape: + +```python +iterator = getattr(sampler, "_iterator", None) +iterations = getattr(iterator, "iterations", None) +``` + +If present, it replaces the private remote-job payload iterator object with a +plain list: + +```python +payload["iterator"] = list(iterations) +``` + +This is backwards compatible with Perceval 1.1 because `sampler._iterator` is +already a list there and has no `.iterations` attribute, so the helper returns +without changing anything. + +### Test Added + +Added: + +```text +tests/core/test_merlin_processor_unit.py::test_submit_job_serializes_perceval_12_parameter_iterator_payload +``` + +It uses a fake Perceval 1.2-style sampler payload and confirms `_submit_job()` +converts the iterator before `execute_async()`. + +## PML-302 Fix: Session Path Should Not Require Token Extraction + +### Problem + +PML-302 unified the `remote_processor=` and `session=` constructor paths and +then always attempted to extract a token: + +```python +if self._token is None: + self._token = self._extract_rp_token(remote_processor) + +if self._token is None: + raise ValueError(...) +``` + +That is valid for `remote_processor=...`, because Merlin later clones the +processor with: + +```python +RemoteProcessor(name=rp.name, token=self._token, ...) +``` + +It is not valid for `session=...`. In the session path, authentication is owned +by the `ISession`, and Merlin creates fresh processors with: + +```python +self.session.build_remote_processor() +``` + +The session path should not require Merlin to extract a token from the generated +remote processor. + +### Local Fix + +Token extraction is now only performed when `self.session is None`. + +### Test Added + +Added: + +```text +tests/core/test_merlin_processor_unit.py::test_session_path_does_not_require_remote_processor_token +``` + +It verifies `MerlinProcessor(session=session)` succeeds even when +`_extract_rp_token()` would return `None`, and that token extraction is not +called in the session path. + +## Remaining PML-302 Issues + +After the two fixes above, Scaleway tests on current Perceval no longer fail on +constructor auth or `ParameterIterator` serialization. The remaining failures are +PML-302 review/test-expectation issues. + +### 1. Context Manager Double-Starts The Session + +PML-302 changes `MerlinProcessor.__enter__()` to call: + +```python +self.session.__enter__() +``` + +The Scaleway pytest fixture already creates an active session with: + +```python +with scw.Session(...) as session: + yield session +``` + +So `with MerlinProcessor(session=scaleway_session)` tries to start an already +attached session and fails with: + +```text +Exception: A session is already attached to this RPC handler +``` + +This appears branch-caused. The safer ownership rule is that callers own the +session lifecycle when they pass an existing `ISession` into `MerlinProcessor`. + +### 2. Tests Assume `probs` Exists On The Real Backend + +The real Scaleway backend currently reports: + +```text +('sample_count', 'samples') +``` + +Several PML-302 tests assert `"probs" in proc.available_commands`, so they fail +before exercising forward execution. This is a test assumption unless the remote +backend is guaranteed to expose `probs`. + +### 3. Sample Cap Warning Is Treated As A Failure + +One test calls: + +```python +proc = MerlinProcessor(..., max_shots_per_call=100) +y = proc.forward(q, X, nsample=1000) +``` + +PML-302 emits a `UserWarning` and caps `nsample` to `100`. Since `pytest.ini` +treats warnings as errors, the test fails. The test should either request a value +within the cap or assert the warning explicitly. + +## Test Results + +### Before Local Fixes On PML-302 With Perceval 1.2.1 + +```text +tests/core/cloud/test_scaleway_session.py +15 failed +``` + +All failures occurred at construction: + +```text +ValueError: Could not extract auth token from RemoteProcessor. +``` + +### After Local Fixes On PML-302 With Perceval 1.2.1 + +```text +tests/core/test_merlin_processor_unit.py +53 passed in 4.87s +``` + +```text +tests/core/cloud/test_scaleway_session.py +10 passed, 5 failed in 200.58s +``` + +The remaining 5 failures are the context-manager/test-expectation issues listed +above, not the Perceval iterator compatibility issue. + +### Main Branch Comparison + +On `origin/main` with `perceval-quandela==1.2.1`: + +```text +tests/core/cloud/test_scaleway_session.py +8 failed, 3 passed, 1 skipped +``` + +The failures were all: + +```text +TypeError: Object of type ParameterIterator is not JSON serializable +``` + +On `origin/main` with Perceval 1.1.0 forced through `PYTHONPATH`: + +```text +tests/core/cloud/test_scaleway_session.py +12 passed in 232.70s +``` + +That confirms the iterator serialization failure is a Perceval 1.2 +compatibility issue, not a PML-302-only issue. diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index a0228df9..ca15725f 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -485,18 +485,19 @@ def __init__( stacklevel=2, ) - # Auto-extract the token from the RP's handler when not - # explicitly provided, so cloned RPs inherit it. - if self._token is None: - self._token = self._extract_rp_token(remote_processor) + if self.session is None: + # Auto-extract the token from the RP's handler when not + # explicitly provided, so cloned RPs inherit it. + if self._token is None: + self._token = self._extract_rp_token(remote_processor) - if self._token is None: - raise ValueError( - "Could not extract auth token from RemoteProcessor. " - "Either pass token= to MerlinProcessor or call " - "RemoteConfig.set_token() before constructing the " - "RemoteProcessor." - ) + if self._token is None: + raise ValueError( + "Could not extract auth token from RemoteProcessor. " + "Either pass token= to MerlinProcessor or call " + "RemoteConfig.set_token() before constructing the " + "RemoteProcessor." + ) self.microbatch_size = microbatch_size self.default_timeout = float(timeout) @@ -1070,6 +1071,7 @@ def _submit_job(self, sampler, nsample, job_base_label, _capped_name): cmd = "probs" if job_base_label: job.name = _capped_name(job_base_label, cmd) + self._ensure_serializable_sampler_iterator(job, sampler) return job.execute_async(), is_probability use_shots = self.DEFAULT_SHOTS_PER_CALL if nsample is None else int(nsample) @@ -1086,8 +1088,43 @@ def _submit_job(self, sampler, nsample, job_base_label, _capped_name): if job_base_label: job.name = _capped_name(job_base_label, cmd) + self._ensure_serializable_sampler_iterator(job, sampler) return job.execute_async(max_samples=use_shots), is_probability + @staticmethod + def _ensure_serializable_sampler_iterator(job: RemoteJob, sampler: Sampler) -> None: + """Replace Perceval 1.2 iterator objects with JSON-serializable data. + + Parameters + ---------- + job : RemoteJob + Prepared Perceval remote job whose private request payload may contain + a sampler iterator. + sampler : Sampler + Perceval sampler used to prepare the job. + + Notes + ----- + Perceval 1.1 stores sampler iterations as a plain list. Perceval 1.2 + stores them in a ``ParameterIterator`` object, but the Scaleway session + handler still serializes ``payload["payload"]`` with ``json.dumps``. + Until Perceval exposes a public serializer for that object, Merlin + normalizes the remote-job payload back to the list shape accepted by the + cloud side. + """ + iterator = getattr(sampler, "_iterator", None) + iterations = getattr(iterator, "iterations", None) + if not iterations: + return + + request_data = getattr(job, "_request_data", None) + if not isinstance(request_data, dict): + return + + payload = request_data.get("payload") + if isinstance(payload, dict) and payload.get("iterator") is iterator: + payload["iterator"] = list(iterations) + def _poll_job( self, job: RemoteJob, diff --git a/tests/core/test_merlin_processor_unit.py b/tests/core/test_merlin_processor_unit.py index 1d755cea..f2005525 100644 --- a/tests/core/test_merlin_processor_unit.py +++ b/tests/core/test_merlin_processor_unit.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import threading import time from concurrent.futures import CancelledError @@ -54,6 +55,39 @@ def __init__(self) -> None: self.samples = FakeCommand() +class FakePerceval12Command(FakeCommand): + """Sampler command with a Perceval 1.2-style private request payload.""" + + def __init__(self, iterator) -> None: + super().__init__() + self._request_data = {"payload": {"iterator": iterator}} + + def execute_async(self, **kwargs): + """Serialize the payload before recording the async execution call.""" + json.dumps(self._request_data["payload"]) + return super().execute_async(**kwargs) + + +class FakePerceval12Iterator: + """Small stand-in for Perceval 1.2 ParameterIterator.""" + + def __init__(self) -> None: + self.iterations = [{"circuit_params": {"px1": 0.25}}] + + def __bool__(self) -> bool: + return True + + +class FakePerceval12Sampler: + """Sampler fake whose commands expose a Perceval 1.2 iterator payload.""" + + def __init__(self) -> None: + self._iterator = FakePerceval12Iterator() + self.probs = FakePerceval12Command(self._iterator) + self.sample_count = FakePerceval12Command(self._iterator) + self.samples = FakePerceval12Command(self._iterator) + + @dataclass class FakeStatus: """Small job status object with the fields read by _poll_job.""" @@ -270,6 +304,24 @@ def test_session_path_stores_backend_capabilities(): assert proc.backend_capabilities.available_commands == ("probs", "sample_count") +def test_session_path_does_not_require_remote_processor_token(): + """ISession authentication should stay owned by the session object.""" + session = MagicMock(spec=ISession) + remote_processor = MagicMock(spec=RemoteProcessor) + remote_processor.name = "perceval-qpu:scaleway" + remote_processor.available_commands = ["probs", "sample_count"] + remote_processor.proxies = None + session.build_remote_processor.return_value = remote_processor + + with patch.object(MerlinProcessor, "_extract_rp_token", return_value=None) as extract: + proc = MerlinProcessor(session=session) + + extract.assert_not_called() + assert proc.session is session + assert proc.remote_processor is None + assert proc.backend_capabilities.available_commands == ("probs", "sample_count") + + def test_remote_processor_path_copies_available_commands(): """RemoteProcessor construction freezes the current command-detection path.""" remote_processor = MagicMock(spec=RemoteProcessor) @@ -629,6 +681,27 @@ def test_submit_job_uses_sample_count_when_sampling_requested(): assert sampler.probs.executed is False +def test_submit_job_serializes_perceval_12_parameter_iterator_payload(): + """Perceval 1.2 sampler iterations must be JSON-serializable for Scaleway.""" + proc = make_processor(["sample_count"]) + sampler = FakePerceval12Sampler() + + returned_job, is_probability = proc._submit_job( + sampler, + nsample=37, + job_base_label="job", + _capped_name=lambda base, command: f"{base}:{command}", + ) + + assert returned_job is sampler.sample_count + assert is_probability is False + assert sampler.sample_count.executed is True + assert sampler.sample_count.execute_kwargs == {"max_samples": 37} + assert sampler.sample_count._request_data["payload"]["iterator"] == [ + {"circuit_params": {"px1": 0.25}} + ] + + def test_submit_job_falls_back_to_samples_when_sample_count_is_unavailable(): """Backends without sample_count currently use samples for sampled jobs.""" proc = make_processor(["samples"]) From 26ae083fbdaf2a1df7d276a2f01e8065d9b41b8c Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Fri, 29 May 2026 18:34:02 -0400 Subject: [PATCH 16/17] Fixes --- merlin/core/merlin_processor.py | 5 - tests/core/cloud/conftest.py | 23 ++++ tests/core/cloud/test_scaleway_session.py | 144 +++++++++++----------- 3 files changed, 97 insertions(+), 75 deletions(-) diff --git a/merlin/core/merlin_processor.py b/merlin/core/merlin_processor.py index ca15725f..687ab7af 100644 --- a/merlin/core/merlin_processor.py +++ b/merlin/core/merlin_processor.py @@ -565,9 +565,6 @@ def __enter__(self): with self._lock: if self._closed: raise RuntimeError("MerlinProcessor is closed") - # Start session lifecycle if provided - if self.session is not None and hasattr(self.session, "__enter__"): - self.session.__enter__() return self def __exit__(self, exc_type, exc, tb): @@ -576,8 +573,6 @@ def __exit__(self, exc_type, exc, tb): self.cancel_all() finally: # End session lifecycle if provided - if self.session is not None and hasattr(self.session, "__exit__"): - suppress_exception = bool(self.session.__exit__(exc_type, exc, tb)) with self._lock: self._closed = True return suppress_exception diff --git a/tests/core/cloud/conftest.py b/tests/core/cloud/conftest.py index f47e85bc..e5f59f5f 100644 --- a/tests/core/cloud/conftest.py +++ b/tests/core/cloud/conftest.py @@ -143,3 +143,26 @@ def scaleway_session(scaleway_credentials): max_duration_s=600, ) as session: yield session + + +# Uncomment when there is a probs backend +# @pytest.fixture(scope="module") +# def scaleway_session_probs(scaleway_credentials): +# """ +# Provide a Scaleway Session for testing. + +# Uses EMU-ASCELLA-6PQ platform with reasonable timeouts for testing. +# The session is shared across all tests in a module to avoid +# repeatedly creating/destroying sessions. +# """ +# scw = pytest.importorskip("perceval.providers.scaleway") + +# with scw.Session( +# "SIM-SLOS", +# project_id=scaleway_credentials["project_id"], +# token=scaleway_credentials["token"], +# deduplication_id="merlin-test-session", +# max_idle_duration_s=300, +# max_duration_s=600, +# ) as session: +# yield session diff --git a/tests/core/cloud/test_scaleway_session.py b/tests/core/cloud/test_scaleway_session.py index 6b8bb271..e3468b4b 100644 --- a/tests/core/cloud/test_scaleway_session.py +++ b/tests/core/cloud/test_scaleway_session.py @@ -97,73 +97,76 @@ def test_processor_accepts_session(self, scaleway_session): # Verify backwards-compatible attributes are set assert proc.session is not None or proc.remote_processor is not None - def test_confirm_backend_has_probs(self, scaleway_session): - """Confirm backend has probs in available_commands.""" - proc = MerlinProcessor(session=scaleway_session) - - # Check available commands - print(f"\nBackend available_commands: {proc.available_commands}") - print(f"Backend capabilities: {proc.backend_capabilities.available_commands}") - - # Confirm probs is available - assert "probs" in proc.available_commands, \ - f"'probs' not in available commands: {proc.available_commands}" - assert "probs" in proc.backend_capabilities.available_commands, \ - f"'probs' not in backend capabilities: {proc.backend_capabilities.available_commands}" - - def test_simple_forward_probs_zero(self, scaleway_session): - """Basic synchronous forward pass with nsample=0 should use probs.""" - proc = MerlinProcessor( - session=scaleway_session, - microbatch_size=32, - timeout=300.0, - max_shots_per_call=100, - ) - - # Confirm probs is available - print(f"\nAvailable commands: {proc.available_commands}") - assert "probs" in proc.available_commands, "probs not available" - - q = _make_layer( - 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED - ) - X = torch.rand(4, 2) - - print("Running forward with nsample=0 - should use PROBS command") - y = proc.forward(q, X, nsample=0) - - expected_output_size = comb(6, 2) # 15 - assert y.shape == (4, expected_output_size) - # Output should be normalized probabilities - assert torch.all(y >= 0) - assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) - - def test_simple_forward_probs_None(self, scaleway_session): - """Basic synchronous forward pass with nsample=None should use probs.""" - proc = MerlinProcessor( - session=scaleway_session, - microbatch_size=32, - timeout=300.0, - max_shots_per_call=100, - ) - - # Confirm probs is available - print(f"\nAvailable commands: {proc.available_commands}") - assert "probs" in proc.available_commands, "probs not available" - - q = _make_layer( - 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED - ) - X = torch.rand(4, 2) - - print("Running forward with nsample=None - should use PROBS command") - y = proc.forward(q, X, nsample=None) - - expected_output_size = comb(6, 2) # 15 - assert y.shape == (4, expected_output_size) - # Output should be normalized probabilities - assert torch.all(y >= 0) - assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) + # Uncomment when there is a probs backend + # def test_confirm_backend_has_probs(self, scaleway_session_probs): + # """Confirm backend has probs in available_commands.""" + # proc = MerlinProcessor(session=scaleway_session_probs) + + # # Check available commands + # print(f"\nBackend available_commands: {proc.available_commands}") + # print(f"Backend capabilities: {proc.backend_capabilities.available_commands}") + + # # Confirm probs is available + # assert ( + # "probs" in proc.available_commands + # ), f"'probs' not in available commands: {proc.available_commands}" + # assert ( + # "probs" in proc.backend_capabilities.available_commands + # ), f"'probs' not in backend capabilities: {proc.backend_capabilities.available_commands}" + + # def test_simple_forward_probs_zero(self, scaleway_session_probs): + # """Basic synchronous forward pass with nsample=0 should use probs.""" + # proc = MerlinProcessor( + # session=scaleway_session_probs, + # microbatch_size=32, + # timeout=300.0, + # max_shots_per_call=100, + # ) + + # # Confirm probs is available + # print(f"\nAvailable commands: {proc.available_commands}") + # assert "probs" in proc.available_commands, "probs not available" + + # q = _make_layer( + # 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED + # ) + # X = torch.rand(4, 2) + + # print("Running forward with nsample=0 - should use PROBS command") + # y = proc.forward(q, X, nsample=0) + + # expected_output_size = comb(6, 2) # 15 + # assert y.shape == (4, expected_output_size) + # # Output should be normalized probabilities + # assert torch.all(y >= 0) + # assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) + + # def test_simple_forward_probs_None(self, scaleway_session_probs): + # """Basic synchronous forward pass with nsample=None should use probs.""" + # proc = MerlinProcessor( + # session=scaleway_session_probs, + # microbatch_size=32, + # timeout=300.0, + # max_shots_per_call=100, + # ) + + # # Confirm probs is available + # print(f"\nAvailable commands: {proc.available_commands}") + # assert "probs" in proc.available_commands, "probs not available" + + # q = _make_layer( + # 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED + # ) + # X = torch.rand(4, 2) + + # print("Running forward with nsample=None - should use PROBS command") + # y = proc.forward(q, X, nsample=None) + + # expected_output_size = comb(6, 2) # 15 + # assert y.shape == (4, expected_output_size) + # # Output should be normalized probabilities + # assert torch.all(y >= 0) + # assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) def test_simple_forward_sample(self, scaleway_session): """Basic synchronous forward pass with nsample=1000 should use sample_count.""" @@ -173,7 +176,7 @@ def test_simple_forward_sample(self, scaleway_session): timeout=300.0, max_shots_per_call=100, ) - + # Check available commands print(f"\nAvailable commands: {proc.available_commands}") @@ -181,9 +184,10 @@ def test_simple_forward_sample(self, scaleway_session): 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED ) X = torch.rand(4, 2) - + print("Running forward with nsample=1000 - should use SAMPLE_COUNT command") - y = proc.forward(q, X, nsample=1000) + with pytest.warns(UserWarning, match=r"Number of samples requested"): + y = proc.forward(q, X, nsample=1000) expected_output_size = comb(6, 2) # 15 assert y.shape == (4, expected_output_size) From 150c4afcac032d167d846571693478592a122a26 Mon Sep 17 00:00:00 2001 From: LF-Vigneux <156103879+LF-Vigneux@users.noreply.github.com> Date: Mon, 1 Jun 2026 08:04:01 -0400 Subject: [PATCH 17/17] CHeck if available --- tests/core/cloud/conftest.py | 78 ++++++++---- tests/core/cloud/test_scaleway_session.py | 138 +++++++++++----------- 2 files changed, 126 insertions(+), 90 deletions(-) diff --git a/tests/core/cloud/conftest.py b/tests/core/cloud/conftest.py index e5f59f5f..2e7eae88 100644 --- a/tests/core/cloud/conftest.py +++ b/tests/core/cloud/conftest.py @@ -51,7 +51,10 @@ def pytest_collection_modifyitems( item.add_marker(cloud_skip) # Skip scaleway tests unless --run-scaleway-tests - if not run_scaleway and "scaleway_session" in fixturenames: + if not run_scaleway and ( + "scaleway_session" in fixturenames + or "scaleway_session_probs" in fixturenames + ): item.add_marker(scaleway_skip) @@ -146,23 +149,56 @@ def scaleway_session(scaleway_credentials): # Uncomment when there is a probs backend -# @pytest.fixture(scope="module") -# def scaleway_session_probs(scaleway_credentials): -# """ -# Provide a Scaleway Session for testing. - -# Uses EMU-ASCELLA-6PQ platform with reasonable timeouts for testing. -# The session is shared across all tests in a module to avoid -# repeatedly creating/destroying sessions. -# """ -# scw = pytest.importorskip("perceval.providers.scaleway") - -# with scw.Session( -# "SIM-SLOS", -# project_id=scaleway_credentials["project_id"], -# token=scaleway_credentials["token"], -# deduplication_id="merlin-test-session", -# max_idle_duration_s=300, -# max_duration_s=600, -# ) as session: -# yield session +@pytest.fixture(scope="module") +def scaleway_session_probs(scaleway_credentials): + """ + Provide a Scaleway Session connected to a platform that supports the ``probs`` + command. The fixture queries the Scaleway platform list at runtime and picks + the first Quandela backend that exposes ``probs``; the test is skipped if none + is available. + """ + requests = pytest.importorskip("requests") + scw = pytest.importorskip("perceval.providers.scaleway") + + token = scaleway_credentials["token"] + project_id = scaleway_credentials["project_id"] + + r = requests.get( + "https://api.scaleway.com/qaas/v1alpha1/platforms", + headers={"X-Auth-Token": token, "Content-Type": "application/json"}, + ) + r.raise_for_status() + + probs_platform: str | None = None + for p in r.json().get("platforms", []): + if p.get("provider_name") != "quandela": + continue + try: + with scw.Session( + p["name"], + project_id=project_id, + token=token, + deduplication_id="merlin-probs-probe", + ) as probe_session: + proc = probe_session.build_remote_processor() + if "probs" in proc.available_commands: + probs_platform = p["name"] + break + except Exception: + continue + + if probs_platform is None: + pytest.skip( + "No Scaleway Quandela backend with 'probs' support found; " + "skipping probs-specific tests." + ) + + with scw.Session( + probs_platform, + project_id=project_id, + token=token, + deduplication_id="merlin-test-probs-session", + max_idle_duration_s=300, + max_duration_s=600, + ) as session: + yield session diff --git a/tests/core/cloud/test_scaleway_session.py b/tests/core/cloud/test_scaleway_session.py index e3468b4b..f7ee28a2 100644 --- a/tests/core/cloud/test_scaleway_session.py +++ b/tests/core/cloud/test_scaleway_session.py @@ -98,75 +98,75 @@ def test_processor_accepts_session(self, scaleway_session): assert proc.session is not None or proc.remote_processor is not None # Uncomment when there is a probs backend - # def test_confirm_backend_has_probs(self, scaleway_session_probs): - # """Confirm backend has probs in available_commands.""" - # proc = MerlinProcessor(session=scaleway_session_probs) - - # # Check available commands - # print(f"\nBackend available_commands: {proc.available_commands}") - # print(f"Backend capabilities: {proc.backend_capabilities.available_commands}") - - # # Confirm probs is available - # assert ( - # "probs" in proc.available_commands - # ), f"'probs' not in available commands: {proc.available_commands}" - # assert ( - # "probs" in proc.backend_capabilities.available_commands - # ), f"'probs' not in backend capabilities: {proc.backend_capabilities.available_commands}" - - # def test_simple_forward_probs_zero(self, scaleway_session_probs): - # """Basic synchronous forward pass with nsample=0 should use probs.""" - # proc = MerlinProcessor( - # session=scaleway_session_probs, - # microbatch_size=32, - # timeout=300.0, - # max_shots_per_call=100, - # ) - - # # Confirm probs is available - # print(f"\nAvailable commands: {proc.available_commands}") - # assert "probs" in proc.available_commands, "probs not available" - - # q = _make_layer( - # 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED - # ) - # X = torch.rand(4, 2) - - # print("Running forward with nsample=0 - should use PROBS command") - # y = proc.forward(q, X, nsample=0) - - # expected_output_size = comb(6, 2) # 15 - # assert y.shape == (4, expected_output_size) - # # Output should be normalized probabilities - # assert torch.all(y >= 0) - # assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) - - # def test_simple_forward_probs_None(self, scaleway_session_probs): - # """Basic synchronous forward pass with nsample=None should use probs.""" - # proc = MerlinProcessor( - # session=scaleway_session_probs, - # microbatch_size=32, - # timeout=300.0, - # max_shots_per_call=100, - # ) - - # # Confirm probs is available - # print(f"\nAvailable commands: {proc.available_commands}") - # assert "probs" in proc.available_commands, "probs not available" - - # q = _make_layer( - # 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED - # ) - # X = torch.rand(4, 2) - - # print("Running forward with nsample=None - should use PROBS command") - # y = proc.forward(q, X, nsample=None) - - # expected_output_size = comb(6, 2) # 15 - # assert y.shape == (4, expected_output_size) - # # Output should be normalized probabilities - # assert torch.all(y >= 0) - # assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) + def test_confirm_backend_has_probs(self, scaleway_session_probs): + """Confirm backend has probs in available_commands.""" + proc = MerlinProcessor(session=scaleway_session_probs) + + # Check available commands + print(f"\nBackend available_commands: {proc.available_commands}") + print(f"Backend capabilities: {proc.backend_capabilities.available_commands}") + + # Confirm probs is available + assert ( + "probs" in proc.available_commands + ), f"'probs' not in available commands: {proc.available_commands}" + assert ( + "probs" in proc.backend_capabilities.available_commands + ), f"'probs' not in backend capabilities: {proc.backend_capabilities.available_commands}" + + def test_simple_forward_probs_zero(self, scaleway_session_probs): + """Basic synchronous forward pass with nsample=0 should use probs.""" + proc = MerlinProcessor( + session=scaleway_session_probs, + microbatch_size=32, + timeout=300.0, + max_shots_per_call=100, + ) + + # Confirm probs is available + print(f"\nAvailable commands: {proc.available_commands}") + assert "probs" in proc.available_commands, "probs not available" + + q = _make_layer( + 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED + ) + X = torch.rand(4, 2) + + print("Running forward with nsample=0 - should use PROBS command") + y = proc.forward(q, X, nsample=0) + + expected_output_size = comb(6, 2) # 15 + assert y.shape == (4, expected_output_size) + # Output should be normalized probabilities + assert torch.all(y >= 0) + assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) + + def test_simple_forward_probs_None(self, scaleway_session_probs): + """Basic synchronous forward pass with nsample=None should use probs.""" + proc = MerlinProcessor( + session=scaleway_session_probs, + microbatch_size=32, + timeout=300.0, + max_shots_per_call=100, + ) + + # Confirm probs is available + print(f"\nAvailable commands: {proc.available_commands}") + assert "probs" in proc.available_commands, "probs not available" + + q = _make_layer( + 6, 2, input_size=2, computation_space=ComputationSpace.UNBUNCHED + ) + X = torch.rand(4, 2) + + print("Running forward with nsample=None - should use PROBS command") + y = proc.forward(q, X, nsample=None) + + expected_output_size = comb(6, 2) # 15 + assert y.shape == (4, expected_output_size) + # Output should be normalized probabilities + assert torch.all(y >= 0) + assert torch.allclose(y.sum(dim=1), torch.ones(4), atol=0.001) def test_simple_forward_sample(self, scaleway_session): """Basic synchronous forward pass with nsample=1000 should use sample_count."""