Skip to content

Commit a4df799

Browse files
committed
Correct behavior
1 parent 21ef967 commit a4df799

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/compiler/chainrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case
294294
z2d(dx, ::Any) = dx
295295
z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx
296296
z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx
297-
z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? NoTangent() : map(Zygote.z2d, dx, primal)
297+
z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? dx : map(Zygote.z2d, dx, primal)
298298

299299
#=
300300
# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers

test/chainrules_tests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,9 @@ end
415415
@test z2d_compiled.c.a === z2d_fallback.c.a
416416
@test z2d_compiled.c.b === z2d_fallback.c.b
417417

418-
# empty arrays => NoTangent()
419-
@test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) === NoTangent()
418+
# empty dx => returns the dx
419+
@test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) === ones(1, 0)
420+
@test @inferred(Zygote.z2d(Union{Nothing, Float64}[], ones(16, 0))) === Union{Nothing, Float64}[]
420421
end
421422

422423
@testset "ChainRules translation" begin

0 commit comments

Comments
 (0)