diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates.py b/cirq-core/cirq/transformers/merge_single_qubit_gates.py index 342c6322643..795856f2b12 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates.py @@ -118,7 +118,7 @@ def merge_single_qubit_moments_to_phxz( def can_merge_moment(m: cirq.Moment): return all( - protocols.num_qubits(op) == 1 + protocols.num_qubits(op) <= 1 and protocols.has_unitary(op) and tags_to_ignore.isdisjoint(op.tags) for op in m @@ -146,6 +146,10 @@ def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> cirq.Moment | None: ) if gate: ret_ops.append(gate(q)) + # Transfer global phase + for op in m1.operations + m2.operations: + if protocols.num_qubits(op) == 0: + ret_ops.append(op) return circuits.Moment(ret_ops) return transformer_primitives.merge_moments( diff --git a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py index d09af63be69..b0a8fe47801 100644 --- a/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py +++ b/cirq-core/cirq/transformers/merge_single_qubit_gates_test.py @@ -231,3 +231,35 @@ def test_merge_single_qubit_moments_to_phased_x_and_z_global_phase(): c = cirq.Circuit(cirq.GlobalPhaseGate(1j).on()) c2 = cirq.merge_single_qubit_gates_to_phased_x_and_z(c) assert c == c2 + + +def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_first_moment(): + q0 = cirq.LineQubit(0) + c_orig = cirq.Circuit( + cirq.Moment(cirq.Y(q0) ** 0.5, cirq.GlobalPhaseGate(1j**0.5).on()), cirq.Moment(cirq.X(q0)) + ) + c_expected = cirq.Circuit( + cirq.Moment( + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0), + cirq.GlobalPhaseGate(1j**0.5).on(), + ) + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"]) + c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context) + assert c_new == c_expected + + +def test_merge_single_qubit_moments_to_phxz_with_global_phase_in_second_moment(): + q0 = cirq.LineQubit(0) + c_orig = cirq.Circuit( + cirq.Moment(cirq.Y(q0) ** 0.5), cirq.Moment(cirq.X(q0), cirq.GlobalPhaseGate(1j**0.5).on()) + ) + c_expected = cirq.Circuit( + cirq.Moment( + cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=-1.0).on(q0), + cirq.GlobalPhaseGate(1j**0.5).on(), + ) + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"]) + c_new = cirq.merge_single_qubit_moments_to_phxz(c_orig, context=context) + assert c_new == c_expected