Skip to content

Commit 660bdf7

Browse files
authored
add unthunk in rrules of eig! eigh! leftorth! rightorth! (#207)
* add in s * add unthunk in factorization rrules * Apply suggestions from code review
1 parent 0f352ce commit 660bdf7

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ end
5656
function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
5757
D, V = eig(t; kwargs...)
5858

59-
function eig!_pullback((ΔD, ΔV))
59+
function eig!_pullback((_ΔD, _ΔV))
60+
ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV)
6061
Δt = similar(t)
6162
for (c, b) in blocks(Δt)
6263
Dc, Vc = block(D, c), block(V, c)
@@ -77,7 +78,8 @@ end
7778
function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...)
7879
D, V = eigh(t; kwargs...)
7980

80-
function eigh!_pullback((ΔD, ΔV))
81+
function eigh!_pullback((_ΔD, _ΔV))
82+
ΔD, ΔV = unthunk(_ΔD), unthunk(_ΔV)
8183
Δt = similar(t)
8284
for (c, b) in blocks(Δt)
8385
Dc, Vc = block(D, c), block(V, c)
@@ -114,7 +116,8 @@ function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRp
114116
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
115117
error("only `alg=QR()` and `alg=QRpos()` are supported")
116118
Q, R = leftorth(t; alg)
117-
function leftorth!_pullback((ΔQ, ΔR))
119+
function leftorth!_pullback((_ΔQ, _ΔR))
120+
ΔQ, ΔR = unthunk(_ΔQ), unthunk(_ΔR)
118121
Δt = similar(t)
119122
for (c, b) in blocks(Δt)
120123
qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c))
@@ -129,7 +132,8 @@ function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQ
129132
alg isa TensorKit.LQ || alg isa TensorKit.LQpos ||
130133
error("only `alg=LQ()` and `alg=LQpos()` are supported")
131134
L, Q = rightorth(t; alg)
132-
function rightorth!_pullback((ΔL, ΔQ))
135+
function rightorth!_pullback((_ΔL, _ΔQ))
136+
ΔL, ΔQ = unthunk(_ΔL), unthunk(_ΔQ)
133137
Δt = similar(t)
134138
for (c, b) in blocks(Δt)
135139
lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c))

0 commit comments

Comments
 (0)