56
56
function ChainRulesCore. rrule (:: typeof (TensorKit. eig!), t:: AbstractTensorMap ; kwargs... )
57
57
D, V = eig (t; kwargs... )
58
58
59
- function eig!_pullback ((ΔD, ΔV))
59
+ function eig!_pullback ((_ΔD, _ΔV))
60
+ ΔD, ΔV = unthunk (_ΔD), unthunk (_ΔV)
60
61
Δt = similar (t)
61
62
for (c, b) in blocks (Δt)
62
63
Dc, Vc = block (D, c), block (V, c)
77
78
function ChainRulesCore. rrule (:: typeof (TensorKit. eigh!), t:: AbstractTensorMap ; kwargs... )
78
79
D, V = eigh (t; kwargs... )
79
80
80
- function eigh!_pullback ((ΔD, ΔV))
81
+ function eigh!_pullback ((_ΔD, _ΔV))
82
+ ΔD, ΔV = unthunk (_ΔD), unthunk (_ΔV)
81
83
Δt = similar (t)
82
84
for (c, b) in blocks (Δt)
83
85
Dc, Vc = block (D, c), block (V, c)
@@ -114,7 +116,8 @@ function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRp
114
116
alg isa TensorKit. QR || alg isa TensorKit. QRpos ||
115
117
error (" only `alg=QR()` and `alg=QRpos()` are supported" )
116
118
Q, R = leftorth (t; alg)
117
- function leftorth!_pullback ((ΔQ, ΔR))
119
+ function leftorth!_pullback ((_ΔQ, _ΔR))
120
+ ΔQ, ΔR = unthunk (_ΔQ), unthunk (_ΔR)
118
121
Δt = similar (t)
119
122
for (c, b) in blocks (Δt)
120
123
qr_pullback! (b, block (Q, c), block (R, c), block (ΔQ, c), block (ΔR, c))
@@ -129,7 +132,8 @@ function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQ
129
132
alg isa TensorKit. LQ || alg isa TensorKit. LQpos ||
130
133
error (" only `alg=LQ()` and `alg=LQpos()` are supported" )
131
134
L, Q = rightorth (t; alg)
132
- function rightorth!_pullback ((ΔL, ΔQ))
135
+ function rightorth!_pullback ((_ΔL, _ΔQ))
136
+ ΔL, ΔQ = unthunk (_ΔL), unthunk (_ΔQ)
133
137
Δt = similar (t)
134
138
for (c, b) in blocks (Δt)
135
139
lq_pullback! (b, block (L, c), block (Q, c), block (ΔL, c), block (ΔQ, c))
0 commit comments