Skip to content

Commit 28d8c08

Browse files
committed
Refactor
1 parent c8ccd84 commit 28d8c08

File tree

3 files changed

+80
-115
lines changed

3 files changed

+80
-115
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@
382382
merge_operations_to_circuit_op as merge_operations_to_circuit_op,
383383
merge_single_qubit_gates_to_phased_x_and_z as merge_single_qubit_gates_to_phased_x_and_z,
384384
merge_single_qubit_gates_to_phxz as merge_single_qubit_gates_to_phxz,
385+
merge_single_qubit_gates_to_phxz_symbolized as merge_single_qubit_gates_to_phxz_symbolized,
385386
merge_single_qubit_moments_to_phxz as merge_single_qubit_moments_to_phxz,
386387
optimize_for_target_gateset as optimize_for_target_gateset,
387388
parameterized_2q_op_to_sqrt_iswap_operations as parameterized_2q_op_to_sqrt_iswap_operations,

cirq-core/cirq/transformers/merge_single_qubit_gates.py

Lines changed: 65 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,24 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Callable, cast, Hashable, List, Tuple, TYPE_CHECKING
20-
21-
import sympy
19+
from typing import Callable, cast, Hashable, TYPE_CHECKING
2220

2321
from cirq import circuits, ops, protocols
2422
from cirq.study.resolver import ParamResolver
2523
from cirq.study.sweeps import dict_to_zip_sweep, ListSweep, ProductOrZipSweepLike, Sweep, Zip
2624
from cirq.transformers import (
2725
align,
2826
merge_k_qubit_gates,
29-
transformer_api,
30-
transformer_primitives,
3127
symbolize,
3228
tag_transformers,
29+
transformer_api,
30+
transformer_primitives,
3331
)
3432
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
3533

3634
if TYPE_CHECKING:
35+
import sympy
36+
3737
import cirq
3838

3939

@@ -78,9 +78,9 @@ def merge_single_qubit_gates_to_phxz(
7878
circuit: cirq.AbstractCircuit,
7979
*,
8080
context: cirq.TransformerContext | None = None,
81-
merge_tags_fn: Callable[[cirq.CircuitOperation], List[Hashable]] | None = None,
81+
merge_tags_fn: Callable[[cirq.CircuitOperation], list[Hashable]] | None = None,
8282
atol: float = 1e-8,
83-
) -> 'cirq.Circuit':
83+
) -> cirq.Circuit:
8484
"""Replaces runs of single qubit rotations with a single optional `cirq.PhasedXZGate`.
8585
8686
Specifically, any run of non-parameterized single-qubit unitaries will be replaced by an
@@ -97,7 +97,7 @@ def merge_single_qubit_gates_to_phxz(
9797
Copy of the transformed input circuit.
9898
"""
9999

100-
def rewriter(circuit_op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
100+
def rewriter(circuit_op: cirq.CircuitOperation) -> cirq.OP_TREE:
101101
u = protocols.unitary(circuit_op)
102102
if protocols.num_qubits(circuit_op) == 0:
103103
return ops.GlobalPhaseGate(u[0, 0]).on()
@@ -170,58 +170,42 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None:
170170
).unfreeze(copy=False)
171171

172172

173-
def _all_tags_startswith(circuit: cirq.AbstractCircuit, startswith: str):
174-
tag_set: set[Hashable] = set()
175-
for op in circuit.all_operations():
176-
for tag in op.tags:
177-
if str(tag).startswith(startswith):
178-
tag_set.add(tag)
179-
return tag_set
180-
181-
182173
def _sweep_on_symbols(sweep: Sweep, symbols: set[sympy.Symbol]) -> Sweep:
183-
new_resolvers: List[cirq.ParamResolver] = []
174+
new_resolvers: list[cirq.ParamResolver] = []
184175
for resolver in sweep:
185-
param_dict: 'cirq.ParamMappingType' = {s: resolver.value_of(s) for s in symbols}
176+
param_dict: cirq.ParamMappingType = {s: resolver.value_of(s) for s in symbols}
186177
new_resolvers.append(ParamResolver(param_dict))
187178
return ListSweep(new_resolvers)
188179

189180

190-
def _parameterize_phxz_in_circuits(
191-
circuit_list: List['cirq.Circuit'],
192-
merge_tag_prefix: str,
193-
phxz_symbols: set[sympy.Symbol],
194-
remaining_symbols: set[sympy.Symbol],
195-
sweep: Sweep,
181+
def _calc_phxz_sweeps(
182+
symbolized_circuit: cirq.Circuit, resolved_circuits: list[cirq.Circuit]
196183
) -> Sweep:
197-
"""Parameterizes the circuits and returns a new sweep."""
198-
values_by_params: dict[str, List[float]] = {**{str(s): [] for s in phxz_symbols}}
199-
200-
for circuit in circuit_list:
201-
for op in circuit.all_operations():
202-
the_merge_tag: str | None = None
203-
for tag in op.tags:
204-
if str(tag).startswith(merge_tag_prefix):
205-
the_merge_tag = str(tag)
206-
if not the_merge_tag:
207-
continue
208-
sid = the_merge_tag.rsplit("_", maxsplit=-1)[-1]
209-
x, z, a = 0.0, 0.0, 0.0 # Identity gate's parameters
210-
if isinstance(op.gate, ops.PhasedXZGate):
211-
x, z, a = op.gate.x_exponent, op.gate.z_exponent, op.gate.axis_phase_exponent
212-
elif op.gate is not ops.I:
213-
raise RuntimeError(
214-
f"Expected the merged gate to be a PhasedXZGate or IdentityGate,"
215-
f" but got {op.gate}."
184+
"""Return the phxz sweep of the symbolized_circuit on resolved_circuits.
185+
186+
Raises:
187+
ValueError: Structural mismatch: A `resolved_circuit` contains an unexpected gate type.
188+
Expected a `PhasedXZGate` or `IdentityGate` at a position corresponding to a
189+
symbolic `PhasedXZGate` in the `symbolized_circuit`.
190+
"""
191+
192+
def _extract_axz(op: ops.Operation | None) -> tuple[float, float, float]:
193+
if not op or not op.gate or not isinstance(op.gate, ops.IdentityGate | ops.PhasedXZGate):
194+
raise ValueError(f"Expect a PhasedXZGate or IdentityGate on op {op}.")
195+
if isinstance(op.gate, ops.IdentityGate):
196+
return 0.0, 0.0, 0.0 # Identity gate's a, x, z in PhasedXZ
197+
return op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent
198+
199+
values_by_params: dict[sympy.Symbol, tuple[float, ...]] = {}
200+
for mid, moment in enumerate(symbolized_circuit):
201+
for op in moment.operations:
202+
if op.gate and isinstance(op.gate, ops.PhasedXZGate) and protocols.is_parameterized(op):
203+
sa, sx, sz = op.gate.axis_phase_exponent, op.gate.x_exponent, op.gate.z_exponent
204+
values_by_params[sa], values_by_params[sx], values_by_params[sz] = zip(
205+
*[_extract_axz(c[mid].operation_at(op.qubits[0])) for c in resolved_circuits]
216206
)
217-
values_by_params[f"x{sid}"].append(x)
218-
values_by_params[f"z{sid}"].append(z)
219-
values_by_params[f"a{sid}"].append(a)
220207

221-
return Zip(
222-
dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params)),
223-
_sweep_on_symbols(sweep, remaining_symbols),
224-
)
208+
return dict_to_zip_sweep(cast(ProductOrZipSweepLike, values_by_params))
225209

226210

227211
def merge_single_qubit_gates_to_phxz_symbolized(
@@ -230,7 +214,7 @@ def merge_single_qubit_gates_to_phxz_symbolized(
230214
context: cirq.TransformerContext | None = None,
231215
sweep: Sweep,
232216
atol: float = 1e-8,
233-
) -> Tuple[cirq.Circuit, Sweep]:
217+
) -> tuple[cirq.Circuit, Sweep]:
234218
"""Merges consecutive single qubit gates as PhasedXZ Gates. Symbolizes if any of
235219
the consecutive gates is symbolized.
236220
@@ -288,6 +272,10 @@ def merge_single_qubit_gates_to_phxz_symbolized(
288272
for op in circuit_tagged.all_operations()
289273
]
290274
)
275+
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
276+
remaining_symbols: set[sympy.Symbol] = set(
277+
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
278+
)
291279
# If all single qubit gates are not parameterized, call the nonparamerized version of
292280
# the transformer.
293281
if not single_qubit_gate_symbols:
@@ -299,61 +287,43 @@ def merge_single_qubit_gates_to_phxz_symbolized(
299287
]
300288

301289
# Step 1, merge single qubit gates per resolved circuit, preserving
302-
# the symbolized_single_tag with indexes.
303-
merged_circuits: List['cirq.Circuit'] = []
304-
for resolved_circuit in resolved_circuits:
305-
merged_circuit = tag_transformers.index_tags(
306-
merge_single_qubit_gates_to_phxz(
307-
resolved_circuit,
308-
context=context,
309-
merge_tags_fn=lambda circuit_op: (
310-
[symbolized_single_tag]
311-
if any(
312-
symbolized_single_tag in set(op.tags)
313-
for op in circuit_op.circuit.all_operations()
314-
)
315-
else []
316-
),
317-
atol=atol,
290+
# the symbolized_single_tag to indicate the operator is a merged one.
291+
merged_circuits: list[cirq.Circuit] = [
292+
merge_single_qubit_gates_to_phxz(
293+
c,
294+
context=context,
295+
merge_tags_fn=lambda circuit_op: (
296+
[symbolized_single_tag]
297+
if any(
298+
symbolized_single_tag in set(op.tags)
299+
for op in circuit_op.circuit.all_operations()
300+
)
301+
else []
318302
),
319-
context=transformer_api.TransformerContext(deep=deep),
320-
target_tags={symbolized_single_tag},
303+
atol=atol,
321304
)
322-
merged_circuits.append(merged_circuit)
323-
324-
if not all(
325-
_all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
326-
== _all_tags_startswith(merged_circuit, startswith=symbolized_single_tag)
327-
for merged_circuit in merged_circuits
328-
):
329-
raise RuntimeError("Different resolvers in sweep resulted in different merged structures.")
305+
for c in resolved_circuits
306+
]
330307

331-
# Step 2, get the new symbolized circuit by symbolization on indexed symbolized_single_tag.
308+
# Step 2, get the new symbolized circuit by symbolizing on indexed symbolized_single_tag.
332309
new_circuit = align.align_right(
333-
tag_transformers.remove_tags(
310+
tag_transformers.remove_tags( # remove the temp tags used to track merges
334311
symbolize.symbolize_single_qubit_gates_by_indexed_tags(
335-
merged_circuits[0],
312+
tag_transformers.index_tags( # index all 1-qubit-ops merged from ops with symbols
313+
merged_circuits[0],
314+
context=transformer_api.TransformerContext(deep=deep),
315+
target_tags={symbolized_single_tag},
316+
),
336317
symbolize_tag=symbolize.SymbolizeTag(prefix=symbolized_single_tag),
337318
),
338319
remove_if=lambda tag: str(tag).startswith(symbolized_single_tag),
339320
)
340321
)
341322

342323
# Step 3, get N sets of parameterizations as new_sweep.
343-
phxz_symbols: set[sympy.Symbol] = set().union(
344-
*[
345-
set(
346-
[sympy.Symbol(tag.replace(f"{symbolized_single_tag}_", s)) for s in ["x", "z", "a"]]
347-
)
348-
for tag in _all_tags_startswith(merged_circuits[0], startswith=symbolized_single_tag)
349-
]
350-
)
351-
# Remaining symbols, e.g., 2 qubit gates' symbols. Sweep of those symbols keeps unchanged.
352-
remaining_symbols: set[sympy.Symbol] = set(
353-
protocols.parameter_symbols(circuit) - single_qubit_gate_symbols
354-
)
355-
new_sweep = _parameterize_phxz_in_circuits(
356-
merged_circuits, symbolized_single_tag, phxz_symbols, remaining_symbols, sweep
324+
new_sweep = Zip(
325+
_calc_phxz_sweeps(new_circuit, merged_circuits), # phxz sweeps
326+
_sweep_on_symbols(sweep, remaining_symbols), # remaining sweeps
357327
)
358328

359329
return new_circuit, new_sweep

cirq-core/cirq/transformers/merge_single_qubit_gates_test.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List
16-
from unittest.mock import Mock, patch
1715
from unittest import TestCase
16+
from unittest.mock import Mock, patch
1817

1918
import pytest
2019
import sympy
@@ -24,7 +23,7 @@
2423

2524
def assert_optimizes(optimized: cirq.AbstractCircuit, expected: cirq.AbstractCircuit):
2625
# Ignore differences that would be caught by follow-up optimizations.
27-
followup_transformers: List[cirq.TRANSFORMER] = [
26+
followup_transformers: list[cirq.TRANSFORMER] = [
2827
cirq.drop_negligible_operations,
2928
cirq.drop_empty_moments,
3029
]
@@ -241,7 +240,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z_global_phase():
241240
class TestMergeSingleQubitGatesSymbolized(TestCase):
242241
"""Test suite for merge_single_qubit_gates_to_phxz_symbolized."""
243242

244-
def case1(self):
243+
def test_case1(self):
245244
"""Test case diagram.
246245
Input circuit:
247246
# pylint: disable=line-too-long
@@ -298,7 +297,7 @@ def case1(self):
298297
{q: q for q in input_circuit.all_qubits()},
299298
)
300299

301-
def case_non_parameterized_singles(self):
300+
def test_case_non_parameterized_singles(self):
302301
"""Test merge_single_qubit_gates_to_phxz_symbolized when all single qubit gates are not
303302
parameterized."""
304303

@@ -310,27 +309,24 @@ def case_non_parameterized_singles(self):
310309
)
311310
assert_optimizes(output_circuit, expected_circuit)
312311

313-
def fail_different_structures_error(self):
314-
"""Tests that the function raises a RuntimeError if merged structures of the circuit differ
312+
def test_fail_different_structures_error(self):
313+
"""Tests that the function raises a ValueError if merged structures of the circuit differ
315314
for different parameterizations."""
316-
a = cirq.NamedQubit("a")
317-
circuit = cirq.Circuit(cirq.H(a) ** sympy.Symbol("exp"))
315+
q0, q1 = cirq.LineQubit.range(2)
316+
circuit = cirq.Circuit(cirq.H(q0) ** sympy.Symbol("exp"))
318317
sweep = cirq.Points(key="exp", points=[0.1, 0.2])
319318

320319
with patch(
321320
"cirq.protocols.resolve_parameters",
322-
side_effect=[
323-
cirq.Circuit(cirq.H(a).with_tags("_temp_symbolize_tag")),
324-
cirq.Circuit(cirq.H(a)),
321+
side_effect=[ # Mock the return values of resolve_parameters
322+
cirq.Circuit(cirq.I(q0).with_tags("_tmp_symbolize_tag")),
323+
cirq.Circuit(cirq.CZ(q0, q1)),
325324
],
326325
):
327-
with pytest.raises(
328-
RuntimeError,
329-
match="Different resolvers in sweep resulted in different merged structures.",
330-
):
326+
with pytest.raises(ValueError, match="Expect a PhasedXZGate or IdentityGate.*"):
331327
cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep)
332328

333-
def fail_unexpected_gate_error(self):
329+
def test_fail_unexpected_gate_error(self):
334330
"""Tests that the function raises a RuntimeError of unexpected gate."""
335331
a, b = cirq.LineQubit.range(2)
336332
circuit = cirq.Circuit(
@@ -352,7 +348,5 @@ def fail_unexpected_gate_error(self):
352348
".single_qubit_decompositions.single_qubit_matrix_to_phxz",
353349
return_value=cirq.H,
354350
):
355-
with pytest.raises(
356-
RuntimeError, match="Expected the merged gate to be a PhasedXZGate or IdentityGate."
357-
):
351+
with pytest.raises(ValueError, match="Expect a PhasedXZGate or IdentityGate.*"):
358352
cirq.merge_single_qubit_gates_to_phxz_symbolized(circuit, sweep=sweep)

0 commit comments

Comments
 (0)