Skip to content

Commit 4a128dc

Browse files
authored
Cirq x exponent symbol (#7400)
Raise TypeError when cirq.X or other EigenGate-derived gate is raised to a string exponent, e.g., in `cirq.X ** "text"`. Raise TypeError also for other unsupported exponent types. Fixes #2936
1 parent 15345b4 commit 4a128dc

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

cirq-core/cirq/ops/eigen_gate.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,15 @@ def __init__(
103103
`cirq.unitary(cirq.rx(pi))` equals -iX instead of X.
104104
105105
Raises:
106+
TypeError: If the supplied exponent is a string.
106107
ValueError: If the supplied exponent is a complex number with an
107108
imaginary component.
108109
"""
110+
if not isinstance(exponent, (numbers.Number, sympy.Expr)):
111+
raise TypeError(
112+
"Gate exponent must be a number or sympy expression. "
113+
f"Invalid type: {type(exponent).__name__!r}"
114+
)
109115
if isinstance(exponent, complex):
110116
if exponent.imag:
111117
raise ValueError(f"Gate exponent must be real. Invalid Value: {exponent}")
@@ -286,7 +292,12 @@ def _period(self) -> float | None:
286292
real_periods = [abs(2 / e) for e in exponents if e != 0]
287293
return _approximate_common_period(real_periods)
288294

289-
def __pow__(self, exponent: float | sympy.Symbol) -> EigenGate:
295+
def __pow__(self, exponent: value.TParamVal) -> EigenGate:
296+
if not isinstance(exponent, (numbers.Number, sympy.Expr)):
297+
raise TypeError(
298+
"Gate exponent must be a number or sympy expression. "
299+
f"Invalid type: {type(exponent).__name__!r}"
300+
)
290301
new_exponent = protocols.mul(self._exponent, exponent, NotImplemented)
291302
if new_exponent is NotImplemented:
292303
return NotImplemented # pragma: no cover

cirq-core/cirq/ops/eigen_gate_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,19 @@ def test_pow() -> None:
248248
assert ZGateDef(exponent=0.25, global_shift=0.5) ** 2 == ZGateDef(
249249
exponent=0.5, global_shift=0.5
250250
)
251-
with pytest.raises(ValueError, match="real"):
251+
with pytest.raises(ValueError, match="Gate exponent must be real."):
252252
assert ZGateDef(exponent=0.5) ** 0.5j
253253
assert ZGateDef(exponent=0.5) ** (1 + 0j) == ZGateDef(exponent=0.5)
254254

255+
with pytest.raises(TypeError, match="Gate exponent must be a number or sympy expression."):
256+
assert ZGateDef(exponent=0.5) ** "text"
257+
258+
with pytest.raises(TypeError, match="Gate exponent must be a number or sympy expression."):
259+
assert ZGateDef(exponent="text")
260+
261+
with pytest.raises(TypeError, match="Gate exponent must be a number or sympy expression."):
262+
assert ZGateDef(exponent=sympy.Symbol('a')) ** "text"
263+
255264

256265
def test_inverse() -> None:
257266
assert cirq.inverse(CExpZinGate(0.25)) == CExpZinGate(-0.25)

cirq-core/cirq/ops/pauli_gates_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,10 @@ def test_powers() -> None:
225225
assert isinstance(cirq.X**1, cirq.Pauli)
226226
assert isinstance(cirq.Y**1, cirq.Pauli)
227227
assert isinstance(cirq.Z**1, cirq.Pauli)
228+
229+
with pytest.raises(TypeError, match="Gate exponent must be a number or sympy expression."):
230+
assert cirq.X ** 'text'
231+
with pytest.raises(TypeError, match="Gate exponent must be a number or sympy expression."):
232+
assert cirq.Y ** 'text'
233+
with pytest.raises(TypeError, match="Gate exponent must be a number or sympy expression."):
234+
assert cirq.Z ** 'text'

cirq-core/cirq/sim/simulator_base_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def test_noise_applied_measurement_gate():
222222
def test_parameterized_copies_all_but_last():
223223
sim = CountingSimulator()
224224
n = 4
225-
rs = sim.simulate_sweep(cirq.Circuit(cirq.X(q0) ** 'a'), [{'a': i} for i in range(n)])
225+
rs = sim.simulate_sweep(
226+
cirq.Circuit(cirq.X(q0) ** sympy.Symbol('a')), [{'a': i} for i in range(n)]
227+
)
226228
for i in range(n):
227229
r = rs[i]
228230
assert r._final_simulator_state.gate_count == 1

0 commit comments

Comments
 (0)