47
47
import numpy as np
48
48
import opt_einsum as oe
49
49
from numpy import linalg as nla
50
- from numpy .core import ndarray
51
50
from scipy import linalg as sla
52
51
from sparse import COO
53
52
56
55
__all__ = ['Basis' , 'expand' , 'ggm_expand' , 'normalize' ]
57
56
58
57
59
- class Basis (ndarray ):
58
+ class Basis (np . ndarray ):
60
59
r"""
61
60
Class for operator bases. There are several ways to instantiate a
62
61
Basis object:
@@ -217,22 +216,26 @@ def __eq__(self, other: object) -> bool:
217
216
# Not ndarray
218
217
return np .equal (self , other )
219
218
220
- return np .allclose (self .view (ndarray ), other .view (ndarray ),
219
+ return np .allclose (self .view (np . ndarray ), other .view (np . ndarray ),
221
220
atol = self ._atol , rtol = self ._rtol )
222
221
223
- def __contains__ (self , item : ndarray ) -> bool :
222
+ def __contains__ (self , item : np . ndarray ) -> bool :
224
223
"""Implement 'in' operator."""
225
- return any (np .isclose (item .view (ndarray ), self .view (ndarray ),
224
+ return any (np .isclose (item .view (np . ndarray ), self .view (np . ndarray ),
226
225
rtol = self ._rtol , atol = self ._atol ).all (axis = (1 , 2 )))
227
226
228
- def __array_wrap__ (self , out_arr , context = None ):
227
+ def __array_wrap__ (self , arr , context = None , return_scalar = False ):
229
228
"""
230
229
Fixes problem that ufuncs return 0-d arrays instead of scalars.
231
230
232
231
https://github.com/numpy/numpy/issues/5819#issue-72454838
233
232
"""
234
- if out_arr .ndim :
235
- return ndarray .__array_wrap__ (self , out_arr , context )
233
+ try :
234
+ return super ().__array_wrap__ (arr , context , return_scalar = True )
235
+ except TypeError :
236
+ if arr .ndim :
237
+ # Numpy < 2
238
+ return np .ndarray .__array_wrap__ (self , arr , context )
236
239
237
240
def _print_checks (self ) -> None :
238
241
"""Print checks for debug purposes."""
@@ -265,7 +268,7 @@ def isorthonorm(self) -> bool:
265
268
actual = U .conj () @ U .T
266
269
target = np .identity (dim )
267
270
atol = self ._eps * (self .d ** 2 )** 3
268
- self ._isorthonorm = np .allclose (actual .view (ndarray ), target ,
271
+ self ._isorthonorm = np .allclose (actual .view (np . ndarray ), target ,
269
272
atol = atol , rtol = self ._rtol )
270
273
271
274
return self ._isorthonorm
@@ -278,13 +281,16 @@ def istraceless(self) -> bool:
278
281
if self ._istraceless is None :
279
282
trace = np .einsum ('...jj' , self )
280
283
trace = util .remove_float_errors (trace , self .d ** 2 )
281
- nonzero = trace .nonzero ()
284
+ nonzero = np . atleast_1d ( trace ) .nonzero ()
282
285
if nonzero [0 ].size == 0 :
283
286
self ._istraceless = True
284
287
elif nonzero [0 ].size == 1 :
285
288
# Single element has nonzero trace, check if (proportional to)
286
289
# identity
287
- elem = self [nonzero ][0 ].view (ndarray ) if self .ndim == 3 else self .view (ndarray )
290
+ if self .ndim == 3 :
291
+ elem = self [nonzero ][0 ].view (np .ndarray )
292
+ else :
293
+ elem = self .view (np .ndarray )
288
294
offdiag_nonzero = elem [~ np .eye (self .d , dtype = bool )].nonzero ()
289
295
diag_equal = np .diag (elem ) == elem [0 , 0 ]
290
296
if diag_equal .all () and not offdiag_nonzero [0 ].any ():
@@ -597,7 +603,7 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
597
603
# sort Identity label to the front, default to first if not found
598
604
# (should not happen since traceless checks that it is present)
599
605
id_idx = next ((i for i , elem in enumerate (elems )
600
- if np .allclose (Id .view (ndarray ), elem .view (ndarray ),
606
+ if np .allclose (Id .view (np . ndarray ), elem .view (np . ndarray ),
601
607
rtol = elems ._rtol , atol = elems ._atol )), 0 )
602
608
labels .insert (0 , labels .pop (id_idx ))
603
609
@@ -606,7 +612,7 @@ def _full_from_partial(elems: Sequence, traceless: bool, labels: Sequence[str])
606
612
return basis , labels
607
613
608
614
609
- def _norm (b : Sequence ) -> ndarray :
615
+ def _norm (b : Sequence ) -> np . ndarray :
610
616
"""Frobenius norm with two singleton dimensions inserted at the end."""
611
617
b = np .asanyarray (b )
612
618
norm = nla .norm (b , axis = (- 1 , - 2 ))
@@ -633,8 +639,8 @@ def normalize(b: Basis) -> Basis:
633
639
return (b / _norm (b )).squeeze ().view (Basis )
634
640
635
641
636
- def expand (M : Union [ndarray , Basis ], basis : Union [ndarray , Basis ],
637
- normalized : bool = True , hermitian : bool = False , tidyup : bool = False ) -> ndarray :
642
+ def expand (M : Union [np . ndarray , Basis ], basis : Union [np . ndarray , Basis ],
643
+ normalized : bool = True , hermitian : bool = False , tidyup : bool = False ) -> np . ndarray :
638
644
r"""
639
645
Expand the array *M* in the basis given by *basis*.
640
646
@@ -684,8 +690,8 @@ def cast(arr):
684
690
return util .remove_float_errors (coefficients ) if tidyup else coefficients
685
691
686
692
687
- def ggm_expand (M : Union [ndarray , Basis ], traceless : bool = False ,
688
- hermitian : bool = False ) -> ndarray :
693
+ def ggm_expand (M : Union [np . ndarray , Basis ], traceless : bool = False ,
694
+ hermitian : bool = False ) -> np . ndarray :
689
695
r"""
690
696
Expand the matrix *M* in a Generalized Gell-Mann basis [Bert08]_.
691
697
This function makes use of the explicit construction prescription of
@@ -767,7 +773,7 @@ def cast(arr):
767
773
return coeffs .squeeze () if square else coeffs
768
774
769
775
770
- def equivalent_pauli_basis_elements (idx : Union [Sequence [int ], int ], N : int ) -> ndarray :
776
+ def equivalent_pauli_basis_elements (idx : Union [Sequence [int ], int ], N : int ) -> np . ndarray :
771
777
"""
772
778
Get the indices of the equivalent (up to identities tensored to it)
773
779
basis elements of Pauli bases of qubits at position idx in the total
@@ -780,7 +786,7 @@ def equivalent_pauli_basis_elements(idx: Union[Sequence[int], int], N: int) -> n
780
786
return elem_idx
781
787
782
788
783
- def remap_pauli_basis_elements (order : Sequence [int ], N : int ) -> ndarray :
789
+ def remap_pauli_basis_elements (order : Sequence [int ], N : int ) -> np . ndarray :
784
790
"""
785
791
For a N-qubit Pauli basis, transpose the order of the subsystems and
786
792
return the indices that permute the old basis to the new.
0 commit comments