Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions merlin/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
normalize_probabilities_and_amplitudes,
probabilities_from_amplitudes,
)
from .post_selection import post_select_probs

__all__ = [
"LexGrouping",
Expand All @@ -47,4 +48,5 @@
"sanitize_parameters",
"resolve_float_complex",
"to_torch_dtype",
"post_select_probs"
]
60 changes: 60 additions & 0 deletions merlin/utils/post_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from perceval import PostSelect, BasicState
import torch


def post_select_probs(
post_select: PostSelect | str,
keys: list[tuple],
probs: torch.Tensor,
same_keys: bool = True
) -> tuple[list[tuple], torch.Tensor]:
"""
Given a batch of probabilities and corresponding keys, perform
post-selection in the style of Perceval.

Args:
post_select: Post-selection object
keys: List of states representing the basis the probabilities
are written in.
probs: Batch of probabilities to be post-selected on.
same_keys: Determines whether to write the probability vectors
in the new post-selected basis or the original basis.
"""
if len(keys) != probs.shape[-1]:
raise ValueError("Probabilities do not match keys shape.")

if isinstance(post_select, str):
post_select = PostSelect(post_select)

was_1d = probs.ndim == 1
if was_1d:
probs = probs.unsqueeze(0)

new_keys = [] if not same_keys else keys

keep = []
for key in keys:
kept = post_select(BasicState(key))
keep.append(kept)

if not same_keys and kept:
new_keys.append(key)

mask = torch.tensor(keep, dtype=torch.bool, device=probs.device)

if same_keys:
new_probs = mask * probs
else:
new_probs = probs[:, mask]

# Normalize vectors that are not zero vectors.
norm = new_probs.sum(dim=-1, keepdim=True)
new_probs = torch.where(norm > 0, new_probs / norm, new_probs)

if was_1d:
new_probs = new_probs.squeeze(0)

assert len(new_keys) == new_probs.shape[-1], (
"Output probabilities do not match output keys shape."
)
return new_keys, new_probs
104 changes: 104 additions & 0 deletions tests/utils/test_post_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
import torch
from perceval import PostSelect
from merlin.utils import post_select_probs


@pytest.fixture
def uniform_probs():
return torch.tensor([0.25, 0.25, 0.25, 0.25])

@pytest.fixture
def keys():
return [(0, 2), (1, 1), (2, 0), (0, 0)]



def test_keep_all(keys, uniform_probs):
"""Post-selection that accepts everything leaves probs unchanged."""
ps = "[0,1]>=0"
new_keys, new_probs = post_select_probs(ps, keys, uniform_probs)

assert new_keys == keys
torch.testing.assert_close(new_probs, uniform_probs)


def test_keep_none_same_keys(keys, uniform_probs):
"""Rejecting everything with same_keys yields a zero vector."""
ps = PostSelect("[0]==99") # never true
_, new_probs = post_select_probs(ps, keys, uniform_probs, same_keys=True)
torch.testing.assert_close(new_probs, torch.zeros_like(uniform_probs))


def test_normalization(keys):
"""Surviving probabilities are renormalized to sum to 1."""
probs = torch.tensor([0.1, 0.4, 0.4, 0.1])
ps = PostSelect("[0]==1") # keeps only (1,1)
_, new_probs = post_select_probs(ps, keys, probs, same_keys=False)

assert pytest.approx(new_probs.sum().item(), abs=1e-6) == 1.0


def test_same_keys_true_zeros_rejected(keys, uniform_probs):
"""same_keys=True: rejected entries become 0, key list is unchanged."""
ps = PostSelect("[0]==1")
new_keys, new_probs = post_select_probs(ps, keys, uniform_probs, same_keys=True)

assert new_keys == keys
assert new_probs.shape[-1] == len(keys)

# Only the (1,1) slot survives
kept_idx = keys.index((1, 1))
for i, p in enumerate(new_probs):
if i != kept_idx:
assert p.item() == pytest.approx(0.0)


def test_same_keys_false_shrinks_keys(keys, uniform_probs):
"""same_keys=False: new_keys contains only accepted states."""
ps = PostSelect("[0]==1") # accepts (1,1) only
new_keys, new_probs = post_select_probs(ps, keys, uniform_probs, same_keys=False)

assert new_keys == [(1, 1)]
assert new_probs.shape[-1] == 1


def test_batch_probs(keys):
"""2-D input (batch × states) is handled correctly."""
probs = torch.tensor([[0.5, 0.5, 0.0, 0.0],
[0.0, 0.0, 0.5, 0.5]])
ps = PostSelect("[0]==1") # keeps index 1: (1,1)
_, new_probs = post_select_probs(ps, keys, probs, same_keys=True)

assert new_probs.shape == probs.shape

# First batch row should sum to 1 (had prob on kept state)
assert new_probs[0].sum().item() == pytest.approx(1.0)

# Second batch row is all-zero (no prob on kept state)
torch.testing.assert_close(new_probs[1], torch.zeros(len(keys)))


def test_1d_input_returns_1d(keys, uniform_probs):
"""1-D input is returned as 1-D."""
_, new_probs = post_select_probs("[0]==1", keys, uniform_probs)

assert new_probs.ndim == 1


def test_single_state_kept():
"""Works correctly with a single-element key list."""
keys = [(1,)]
probs = torch.tensor([1.0])
ps = PostSelect("[0]==1")
new_keys, new_probs = post_select_probs(ps, keys, probs, same_keys=False)

assert new_keys == [(1,)]
assert new_probs.item() == pytest.approx(1.0)


def test_device_preserved(keys, uniform_probs):
"""Output tensor lives on the same device as the input."""
_, new_probs = post_select_probs("[0]==1", keys, uniform_probs)

assert new_probs.device == uniform_probs.device
Loading