diff --git a/merlin/utils/__init__.py b/merlin/utils/__init__.py index a52327a8..d59099e1 100644 --- a/merlin/utils/__init__.py +++ b/merlin/utils/__init__.py @@ -36,6 +36,7 @@ normalize_probabilities_and_amplitudes, probabilities_from_amplitudes, ) +from .post_selection import post_select_probs __all__ = [ "LexGrouping", @@ -47,4 +48,5 @@ "sanitize_parameters", "resolve_float_complex", "to_torch_dtype", + "post_select_probs" ] diff --git a/merlin/utils/post_selection.py b/merlin/utils/post_selection.py new file mode 100644 index 00000000..b42724a3 --- /dev/null +++ b/merlin/utils/post_selection.py @@ -0,0 +1,60 @@ +from perceval import PostSelect, BasicState +import torch + + +def post_select_probs( + post_select: PostSelect | str, + keys: list[tuple], + probs: torch.Tensor, + same_keys: bool = True +) -> tuple[list[tuple], torch.Tensor]: + """ + Given a batch of probabilities and corresponding keys, perform + post-selection in the style of Perceval. + + Args: + post_select: Post-selection object + keys: List of states representing the basis the probabilities + are written in. + probs: Batch of probabilities to be post-selected on. + same_keys: Determines whether to write the probability vectors + in the new post-selected basis or the original basis. + """ + if len(keys) != probs.shape[-1]: + raise ValueError("Probabilities do not match keys shape.") + + if isinstance(post_select, str): + post_select = PostSelect(post_select) + + was_1d = probs.ndim == 1 + if was_1d: + probs = probs.unsqueeze(0) + + new_keys = [] if not same_keys else keys + + keep = [] + for key in keys: + kept = post_select(BasicState(key)) + keep.append(kept) + + if not same_keys and kept: + new_keys.append(key) + + mask = torch.tensor(keep, dtype=torch.bool, device=probs.device) + + if same_keys: + new_probs = mask * probs + else: + new_probs = probs[:, mask] + + # Normalize vectors that are not zero vectors. + norm = new_probs.sum(dim=-1, keepdim=True) + new_probs = torch.where(norm > 0, new_probs / norm, new_probs) + + if was_1d: + new_probs = new_probs.squeeze(0) + + assert len(new_keys) == new_probs.shape[-1], ( + "Output probabilities do not match output keys shape." + ) + return new_keys, new_probs diff --git a/tests/utils/test_post_selection.py b/tests/utils/test_post_selection.py new file mode 100644 index 00000000..22a6ff85 --- /dev/null +++ b/tests/utils/test_post_selection.py @@ -0,0 +1,104 @@ +import pytest +import torch +from perceval import PostSelect +from merlin.utils import post_select_probs + + +@pytest.fixture +def uniform_probs(): + return torch.tensor([0.25, 0.25, 0.25, 0.25]) + +@pytest.fixture +def keys(): + return [(0, 2), (1, 1), (2, 0), (0, 0)] + + + +def test_keep_all(keys, uniform_probs): + """Post-selection that accepts everything leaves probs unchanged.""" + ps = "[0,1]>=0" + new_keys, new_probs = post_select_probs(ps, keys, uniform_probs) + + assert new_keys == keys + torch.testing.assert_close(new_probs, uniform_probs) + + +def test_keep_none_same_keys(keys, uniform_probs): + """Rejecting everything with same_keys yields a zero vector.""" + ps = PostSelect("[0]==99") # never true + _, new_probs = post_select_probs(ps, keys, uniform_probs, same_keys=True) + torch.testing.assert_close(new_probs, torch.zeros_like(uniform_probs)) + + +def test_normalization(keys): + """Surviving probabilities are renormalized to sum to 1.""" + probs = torch.tensor([0.1, 0.4, 0.4, 0.1]) + ps = PostSelect("[0]==1") # keeps only (1,1) + _, new_probs = post_select_probs(ps, keys, probs, same_keys=False) + + assert pytest.approx(new_probs.sum().item(), abs=1e-6) == 1.0 + + +def test_same_keys_true_zeros_rejected(keys, uniform_probs): + """same_keys=True: rejected entries become 0, key list is unchanged.""" + ps = PostSelect("[0]==1") + new_keys, new_probs = post_select_probs(ps, keys, uniform_probs, same_keys=True) + + assert new_keys == keys + assert new_probs.shape[-1] == len(keys) + + # Only the (1,1) slot survives + kept_idx = keys.index((1, 1)) + for i, p in enumerate(new_probs): + if i != kept_idx: + assert p.item() == pytest.approx(0.0) + + +def test_same_keys_false_shrinks_keys(keys, uniform_probs): + """same_keys=False: new_keys contains only accepted states.""" + ps = PostSelect("[0]==1") # accepts (1,1) only + new_keys, new_probs = post_select_probs(ps, keys, uniform_probs, same_keys=False) + + assert new_keys == [(1, 1)] + assert new_probs.shape[-1] == 1 + + +def test_batch_probs(keys): + """2-D input (batch × states) is handled correctly.""" + probs = torch.tensor([[0.5, 0.5, 0.0, 0.0], + [0.0, 0.0, 0.5, 0.5]]) + ps = PostSelect("[0]==1") # keeps index 1: (1,1) + _, new_probs = post_select_probs(ps, keys, probs, same_keys=True) + + assert new_probs.shape == probs.shape + + # First batch row should sum to 1 (had prob on kept state) + assert new_probs[0].sum().item() == pytest.approx(1.0) + + # Second batch row is all-zero (no prob on kept state) + torch.testing.assert_close(new_probs[1], torch.zeros(len(keys))) + + +def test_1d_input_returns_1d(keys, uniform_probs): + """1-D input is returned as 1-D.""" + _, new_probs = post_select_probs("[0]==1", keys, uniform_probs) + + assert new_probs.ndim == 1 + + +def test_single_state_kept(): + """Works correctly with a single-element key list.""" + keys = [(1,)] + probs = torch.tensor([1.0]) + ps = PostSelect("[0]==1") + new_keys, new_probs = post_select_probs(ps, keys, probs, same_keys=False) + + assert new_keys == [(1,)] + assert new_probs.item() == pytest.approx(1.0) + + +def test_device_preserved(keys, uniform_probs): + """Output tensor lives on the same device as the input.""" + _, new_probs = post_select_probs("[0]==1", keys, uniform_probs) + + assert new_probs.device == uniform_probs.device