diff --git a/Project.toml b/Project.toml index 3f76e35b59..0d38419a99 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.123" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" @@ -65,6 +66,7 @@ ReactantYaoBlocksExt = "YaoBlocks" AbstractFFTs = "1.5" Adapt = "4.1" ArrayInterface = "7.17.1" +Bijections = "0.2.1" CEnum = "0.5" CUDA = "5.6" Downloads = "1.6" diff --git a/src/Compiler.jl b/src/Compiler.jl index a38f8069b2..8258a164c6 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -3,6 +3,7 @@ module Compiler using Reactant_jll using Libdl: dlsym using LinearAlgebra: BLAS +using Bijections import ..Reactant: Reactant, @@ -32,8 +33,8 @@ const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false) const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict() -@inline function traced_getfield(@nospecialize(obj::Dict), field) - return Base.getindex(obj, field) +@inline function traced_getfield(@nospecialize(obj::AbstractDict), idx) + return first(Iterators.drop(obj, idx - 1)) end @inline function traced_getfield(@nospecialize(obj), field) @@ -109,7 +110,7 @@ end return setfield_carray!(obj, field, val, path) end -@inline function traced_setfield!(@nospecialize(obj::Dict), field, val, path) +@inline function traced_setfield!(@nospecialize(obj::AbstractDict), field, val, path) return Base.setindex!(obj, field, val) end @@ -635,12 +636,15 @@ function create_result( result_cache[tocopy] = sym - for (k, v) in pairs(tocopy) - subexpr = create_result(v, append_path(path, k), args...) + for (i, (k, v)) in enumerate(pairs(tocopy)) + path_k = append_path(append_path(path, i), 1) + k_expr = create_result(k, path_k, args...) + path_v = append_path(append_path(path, i), 2) + v_expr = create_result(v, path_v, args...) push!( resultgen_code, quote - @inbounds $sym[$k] = $subexpr + @inbounds $sym[$k_expr] = $v_expr end, ) end @@ -3284,8 +3288,7 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs.. end # inspired by RuntimeGeneratedFunction.jl -const __thunk_fwd_body_cache = Dict{Symbol,Expr}() -const __thunk_rev_body_cache = Dict{Expr,Symbol}() +const __thunk_body_cache = Bijection{Symbol,Expr}() function compile(f, args; sync=false, kwargs...) _, exec, mlir_fn_res, device, client, str = compile_xla(f, args; kwargs...) @@ -3397,12 +3400,11 @@ function compile(f, args; sync=false, kwargs...) display(mlir_fn_res.donated_args_mask) end - fname = if body in keys(__thunk_rev_body_cache) - __thunk_rev_body_cache[body] + fname = if hasvalue(__thunk_body_cache, body) + __thunk_body_cache(body) else fname2 = gensym(Symbol(Symbol(f), :_reactant)) - __thunk_rev_body_cache[body] = fname2 - __thunk_fwd_body_cache[fname2] = body + __thunk_body_cache[fname2] = body fname2 end @@ -3484,7 +3486,7 @@ end ) end end - body = __thunk_fwd_body_cache[tag] + body = __thunk_body_cache[tag] if IsClosure return quote args = (thunk.f, args...) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 7c401db712..d9bd0e732d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -503,9 +503,13 @@ function prepare_mlir_fn_args( end aval = args[path[2]] for (cidx, idx) in enumerate(path[3:end]) - if aval isa Array || aval isa Dict + if aval isa Array #|| aval isa Dict aval = getindex(aval, idx) stridx = stridx * "[" * string(idx) * "]" + elseif aval isa AbstractDict + # TODO maybe we want a way to customize this behavior like we did with `traced_getfield`? or a more powerfull `traced_getfield`? + aval = Reactant.Compiler.traced_getfield(aval, idx) + stridx = stridx * "[" * string(idx) * "]" else fldname = if idx isa Integer string(fieldname(Core.Typeof(aval), idx)) diff --git a/src/Tracing.jl b/src/Tracing.jl index aa39f08141..be9c9f0b8c 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1,3 +1,5 @@ +using Bijections + @enum TraceMode begin ConcreteToTraced = 1 TracedTrack = 2 @@ -214,28 +216,99 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(sharding), @nospecialize(runtime) ) + K = dict_key(T) V = dict_value(T) - if V === nothing + + K_traced = if !isnothing(K) + traced_type_inner(K, seen, mode, track_numbers, sharding, runtime) + else + nothing + end + V_traced = if !isnothing(V) + traced_type_inner(V, seen, mode, track_numbers, sharding, runtime) + else + nothing + end + + if K == K_traced && V == V_traced return T + end + + dictty = if T isa UnionAll + T.body.name.wrapper else - K = dict_key(T) - V2 = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime) - if V == V2 - return T - end - dictty = if T isa UnionAll - T.body.name.wrapper - else - T.name.wrapper - end - if K !== nothing - return dictty{K,V2} - else - return (dictty{KT,V2} where {KT}) - end + T.name.wrapper + end + + if isnothing(K_traced) && isnothing(V_traced) + return (dictty{Kt,Vt} where {Kt,Vt}) + elseif isnothing(K_traced) + return (dictty{Kt,V_traced} where {Kt}) + elseif isnothing(V_traced) + return (dictty{K_traced,Vt} where {Vt}) + else + return dictty{K_traced,V_traced} end end +Base.@nospecializeinfer @inline bijection_fwd_type(::Type{<:Bijection}) = nothing +Base.@nospecializeinfer @inline function bijection_fwd_type( + ::Type{<:(Bijection{K,V,F} where {K,V})} +) where {F} + return F +end +Base.@nospecializeinfer @inline bijection_bwd_type(::Type{<:Bijection}) = nothing +Base.@nospecializeinfer @inline function bijection_bwd_type( + ::Type{<:(Bijection{K,V,F,Finv} where {K,V,F})} +) where {Finv} + return Finv +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Bijection}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) + B = Bijection + + K = dict_key(T) + if !isnothing(K) + K = traced_type_inner(K, seen, mode, track_numbers, sharding, runtime) + B = B{K} + else + B = (B{Kt} where {Kt}) + end + + V = dict_value(T) + if !isnothing(V) + V = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime) + B = B{V} + else + B = (B{Vt} where {Vt}) + end + + F = bijection_fwd_type(T) + if !isnothing(F) + F = traced_type_inner(F, seen, mode, track_numbers, sharding, runtime) + B = B{F} + else + B = (B{Ft} where {Ft}) + end + + Finv = bijection_bwd_type(T) + if !isnothing(Finv) + Finv = traced_type_inner(Finv, seen, mode, track_numbers, sharding, runtime) + B = B{Finv} + else + B = (B{Finvt} where {Finvt}) + end + + return B +end + Base.@nospecializeinfer function traced_type_inner( @nospecialize(T0::Type{<:ConcretePJRTNumber}), seen, @@ -1581,14 +1654,14 @@ end Base.@nospecializeinfer function make_tracer( seen, - @nospecialize(prev::Dict{Key,Value}), + prev::D, @nospecialize(path), mode; @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), kwargs..., -) where {Key,Value} +) where {D<:AbstractDict} RT = Core.Typeof(prev) if mode != NoStopTracedTrack && haskey(seen, prev) if mode == TracedToTypes @@ -1619,31 +1692,41 @@ Base.@nospecializeinfer function make_tracer( end return nothing end - Value2 = traced_type(Value, Val(mode), track_numbers, sharding, runtime) - newa = Dict{Key,Value2}() - seen[prev] = newa + Dt = traced_type(D, Val(mode), track_numbers, sharding, runtime) + dict_traced = Dt() + seen[prev] = dict_traced same = true - for (k, v) in prev - nv = make_tracer( + for (i, (k, v)) in enumerate(prev) + kt = make_tracer( + seen, + k, + append_path(append_path(path, i), 1), + mode; + track_numbers, + sharding=Base.getproperty(sharding, k), + runtime, + kwargs..., + ) + vt = make_tracer( seen, v, - append_path(path, k), + append_path(append_path(path, i), 2), mode; track_numbers, sharding=Base.getproperty(sharding, k), runtime, kwargs..., ) - if v !== nv + if k !== kt || v !== vt same = false end - newa[k] = nv + dict_traced[kt] = vt end if same seen[prev] = prev return prev end - return newa + return dict_traced end Base.@nospecializeinfer function make_tracer( diff --git a/test/Project.toml b/test/Project.toml index befb0973bd..695c09c77b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" diff --git a/test/basic.jl b/test/basic.jl index 4a108beea6..8c7dae34b1 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -4,61 +4,62 @@ using Enzyme using Statistics using Random Random.seed!(123) +using Bijections fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf)) using InteractiveUtils @testset "2D sum" begin - x = rand(2, 10) + x = rand(2, 10) - r_res = sum(x) + r_res = sum(x) - a = Reactant.to_rarray(x) + a = Reactant.to_rarray(x) - c_res = @allowscalar sum(a) - @test c_res ≈ r_res + c_res = @allowscalar sum(a) + @test c_res ≈ r_res - @test @jit(sum(a)) ≈ r_res + @test @jit(sum(a)) ≈ r_res end @testset "Julia Compilation cache" begin - x = @compile -(Reactant.to_rarray(ones(2))) - y = @compile -(Reactant.to_rarray(ones(2))) + x = @compile -(Reactant.to_rarray(ones(2))) + y = @compile -(Reactant.to_rarray(ones(2))) - @test typeof(x) == typeof(y) - # TODO, currently x and y are not equal as x.exec != y.exec - # as the executable we generate is itself not cached - # (which clearly we should do to improve jit time) + @test typeof(x) == typeof(y) + # TODO, currently x and y are not equal as x.exec != y.exec + # as the executable we generate is itself not cached + # (which clearly we should do to improve jit time) end @testset "Basic reduce max" begin - x = rand(2, 10) + x = rand(2, 10) - r_res = fastmax(x) + r_res = fastmax(x) - a = Reactant.to_rarray(x) + a = Reactant.to_rarray(x) - c_res = @allowscalar fastmax(a) - @test c_res ≈ r_res + c_res = @allowscalar fastmax(a) + @test c_res ≈ r_res - @test @jit(fastmax(a)) ≈ r_res + @test @jit(fastmax(a)) ≈ r_res end sinexp(x) = sin(exp(x)) sinexpbc(x) = sinexp.(x) @testset "Broadcast combined" begin - x = rand(2, 10) + x = rand(2, 10) - r_res = sinexpbc(x) + r_res = sinexpbc(x) - a = Reactant.to_rarray(x) + a = Reactant.to_rarray(x) - c_res = @allowscalar sinexpbc(a) - @test c_res isa ConcreteRArray - @test c_res ≈ r_res - @test @jit(sinexpbc(a)) ≈ r_res + c_res = @allowscalar sinexpbc(a) + @test c_res isa ConcreteRArray + @test c_res ≈ r_res + @test @jit(sinexpbc(a)) ≈ r_res end sumexp(x) = sum(exp, x) @@ -66,1274 +67,1294 @@ sumexp(x) = sum(exp, x) sum_compare(x) = sum(x) > 0 @testset "Basic mapreduce" begin - x = rand(Float32, 10) - a = Reactant.to_rarray(x) - r_res = sumexp(x) + x = rand(Float32, 10) + a = Reactant.to_rarray(x) + r_res = sumexp(x) - f_res = @jit sumexp(a) + f_res = @jit sumexp(a) - @test f_res ≈ r_res + @test f_res ≈ r_res - # Ensure we are tracing as scalars. Else this will fail due to > not being defined on - # arrays - @test @jit(sum_compare(a)) == sum_compare(x) + # Ensure we are tracing as scalars. Else this will fail due to > not being defined on + # arrays + @test @jit(sum_compare(a)) == sum_compare(x) end function mysoftmax!(x) - max_ = fastmax(x) - return x .- max_ + max_ = fastmax(x) + return x .- max_ end @testset "Basic softmax" begin - x = rand(2, 10) - r_res = mysoftmax!(x) + x = rand(2, 10) + r_res = mysoftmax!(x) - a = Reactant.to_rarray(x) + a = Reactant.to_rarray(x) - f_res = @jit mysoftmax!(a) - @test f_res ≈ r_res + f_res = @jit mysoftmax!(a) + @test f_res ≈ r_res end bcast_cos(x) = cos.(x) @testset "Basic cos" begin - x = rand(3, 2) - c = Reactant.to_rarray(x) + x = rand(3, 2) + c = Reactant.to_rarray(x) - @test @jit(bcast_cos(c)) ≈ cos.(x) + @test @jit(bcast_cos(c)) ≈ cos.(x) end f_var(args...) = sum(args) @testset "Vararg" begin - x = Reactant.to_rarray(ones(3)) - y = Reactant.to_rarray(3 * ones(3)) - z = Reactant.to_rarray(2.6 * ones(3)) + x = Reactant.to_rarray(ones(3)) + y = Reactant.to_rarray(3 * ones(3)) + z = Reactant.to_rarray(2.6 * ones(3)) - @test @jit(f_var(x, y, z)) ≈ [6.6, 6.6, 6.6] + @test @jit(f_var(x, y, z)) ≈ [6.6, 6.6, 6.6] end sumcos(x) = sum(cos.(x)) function grad_ip(x) - dx = Enzyme.make_zero(x) - Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx)) - return dx + dx = Enzyme.make_zero(x) + Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx)) + return dx end function resgrad_ip(x) - dx = Enzyme.make_zero(x) - res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx)) - return (res, dx) + dx = Enzyme.make_zero(x) + res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx)) + return (res, dx) end @testset "Basic grad cos" begin - c = Reactant.to_rarray(ones(3, 2)) + c = Reactant.to_rarray(ones(3, 2)) - @test @jit(grad_ip(c)) ≈ -sin.(ones(3, 2)) + @test @jit(grad_ip(c)) ≈ -sin.(ones(3, 2)) - orig, r = @jit(resgrad_ip(c)) + orig, r = @jit(resgrad_ip(c)) - @test orig[2] ≈ sum(cos.(ones(3, 2))) - @test r ≈ -sin.(ones(3, 2)) + @test orig[2] ≈ sum(cos.(ones(3, 2))) + @test r ≈ -sin.(ones(3, 2)) end @testset "matmul" begin - c = Reactant.to_rarray(ones(50, 70)) - d = Reactant.to_rarray(ones(70, 30)) + c = Reactant.to_rarray(ones(50, 70)) + d = Reactant.to_rarray(ones(70, 30)) - @test @jit(*(c, d)) ≈ *(ones(50, 70), ones(70, 30)) + @test @jit(*(c, d)) ≈ *(ones(50, 70), ones(70, 30)) end @testset "similar Reactant.to_rarray" begin - c = Reactant.to_rarray(ones(50, 70)) - sim_c = similar(c) - @test typeof(sim_c) == typeof(c) && size(sim_c) == size(sim_c) + c = Reactant.to_rarray(ones(50, 70)) + sim_c = similar(c) + @test typeof(sim_c) == typeof(c) && size(sim_c) == size(sim_c) end @testset "@code_hlo" begin - W = Reactant.to_rarray(randn(Float32, 10, 20)) - x = Reactant.to_rarray(randn(Float32, 20, 5)) - res = @code_hlo W * x - res_repr = sprint(show, res) + W = Reactant.to_rarray(randn(Float32, 10, 20)) + x = Reactant.to_rarray(randn(Float32, 20, 5)) + res = @code_hlo W * x + res_repr = sprint(show, res) - @test contains(res_repr, "stablehlo.dot_general") + @test contains(res_repr, "stablehlo.dot_general") end @testset "@code_hlo broadcasting" begin - x = Reactant.to_rarray(randn(Float32, 2, 2)) - y = Reactant.to_rarray(randn(Float32, 2, 2)) - res = @code_hlo (.+)(x, y) - res_repr = sprint(show, res) + x = Reactant.to_rarray(randn(Float32, 2, 2)) + y = Reactant.to_rarray(randn(Float32, 2, 2)) + res = @code_hlo (.+)(x, y) + res_repr = sprint(show, res) - @test contains(res_repr, "stablehlo.add") + @test contains(res_repr, "stablehlo.add") end @testset "Statistics: `mean` & `var`" begin - x = randn(2, 3, 4) - x_ca = Reactant.to_rarray(x) + x = randn(2, 3, 4) + x_ca = Reactant.to_rarray(x) - @test @jit(mean(x_ca)) ≈ mean(x) - @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1) - @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2)) - @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3)) + @test @jit(mean(x_ca)) ≈ mean(x) + @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1) + @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2)) + @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3)) - @test @jit(var(x_ca)) ≈ var(x) - @test @jit(var(x_ca, dims=1)) ≈ var(x; dims=1) - @test @jit(var(x_ca, dims=(1, 2); corrected=false)) ≈ + @test @jit(var(x_ca)) ≈ var(x) + @test @jit(var(x_ca, dims=1)) ≈ var(x; dims=1) + @test @jit(var(x_ca, dims=(1, 2); corrected=false)) ≈ var(x; dims=(1, 2), corrected=false) - @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈ + @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈ var(x; dims=(1, 3), corrected=false) end @testset "concatenation" begin - @testset "Number" begin - x = fill(true) - x_concrete = Reactant.to_rarray(x) - - # NOTE [,,,] is a call to `vect`, not `*cat` - # f = Reactant.compile((x_concrete,)) do x - # return [x, x, x] - # end - # @test f(x_concrete) ≈ ones(3) - - # vcat - test_vcat(x) = begin - x = x[] # unwrap scalar - [x; x; x] - end - y = @jit test_vcat(x_concrete) - @test y == test_vcat(x) - @test eltype(y) === Bool - - # hcat - test_hcat(x) = begin - x = x[] # unwrap scalar - [x x x] - end - y = @jit test_hcat(x_concrete) - @test y == test_hcat(x) - @test eltype(y) === Bool - - # hvcat - test_hvcat(x) = begin - x = x[] # unwrap scalar - [x x x; x x x] - end - y = @jit test_hvcat(x_concrete) - @test y == test_hvcat(x) - @test eltype(y) === Bool - - # hvncat - test_hvncat(x) = begin - x = x[] # unwrap scalar - [x x x; x x x;;; x x x; x x x] - end - y = @jit test_hvncat(x_concrete) - @test y == test_hvncat(x) - @test eltype(y) === Bool - - # typed_vcat - test_typed_vcat(x) = begin - x = x[] # unwrap scalar - Int[x; x; x] - end - y = @jit test_typed_vcat(x_concrete) - @test y == test_typed_vcat(x) - @test eltype(y) === Int - - # typed_hcat - test_typed_hcat(x) = begin - x = x[] # unwrap scalar - Int[x x x] - end - y = @jit test_typed_hcat(x_concrete) - @test y == test_typed_hcat(x) - @test eltype(y) === Int - - # typed_hvcat - test_typed_hvcat(x) = begin - x = x[] # unwrap scalar - Int[x x x; x x x] - end - y = @jit test_typed_hvcat(x_concrete) - @test y == test_typed_hvcat(x) - @test eltype(y) === Int - - # typed_hvncat - test_typed_hvncat(x) = begin - x = x[] # unwrap scalar - Int[x x x; x x x;;; x x x; x x x] - end - y = @jit test_typed_hvncat(x_concrete) - @test y == test_typed_hvncat(x) - @test eltype(y) === Int + @testset "Number" begin + x = fill(true) + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + test_vcat(x) = begin + x = x[] # unwrap scalar + [x; x; x] end - - @testset "$(ndims(x))-dim Array" for x in [ - fill(true), - [true, false], - [true false], - [true true; true false], - [ - true true true true; true true true false;;; - true true false true; true true false false;;; - true false true true; true false true false - ], - ] - x_concrete = Reactant.to_rarray(x) - - # NOTE [,,,] is a call to `vect`, not `*cat` - # f = Reactant.compile((x_concrete,)) do x - # return [x, x, x] - # end - # @test f(x_concrete) ≈ ones(3) - - # vcat - test_vcat(x) = [x; x; x] - y = @jit test_vcat(x_concrete) - @test y == test_vcat(x) - @test eltype(y) === Bool - - # hcat - test_hcat(x) = [x x x] - y = @jit test_hcat(x_concrete) - @test y == test_hcat(x) - @test eltype(y) === Bool - - # hvcat - test_hvcat(x) = [x x x; x x x] - y = @jit test_hvcat(x_concrete) - @test y == test_hvcat(x) - @test eltype(y) === Bool - - # hvncat - test_hvncat(x) = [x x x; x x x;;; x x x; x x x] - y = @jit test_hvncat(x_concrete) - @test y == test_hvncat(x) - @test eltype(y) === Bool - - # typed_vcat - test_typed_vcat(x) = Int[x; x; x] - y = @jit test_typed_vcat(x_concrete) - @test y == test_typed_vcat(x) - @test eltype(y) === Int - - # typed_hcat - test_typed_hcat(x) = Int[x x x] - y = @jit test_typed_hcat(x_concrete) - @test y == test_typed_hcat(x) - @test eltype(y) === Int - - # typed_hvcat - test_typed_hvcat(x) = Int[x x x; x x x] - y = @jit test_typed_hvcat(x_concrete) - @test y == test_typed_hvcat(x) - @test eltype(y) === Int - - # typed_hvncat - test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x] - y = @jit test_typed_hvncat(x_concrete) - @test y == test_typed_hvncat(x) - @test eltype(y) === Int + y = @jit test_vcat(x_concrete) + @test y == test_vcat(x) + @test eltype(y) === Bool + + # hcat + test_hcat(x) = begin + x = x[] # unwrap scalar + [x x x] end - - @testset "Number and RArray" for a in [1.0f0, 1.0e0] - typeof_a = typeof(a) - _b = typeof_a.([2.0, 3.0, 4.0]) - _c = typeof_a.([2.0 3.0 4.0]) - b = Reactant.to_rarray(_b) - c = Reactant.to_rarray(_c) - - # vcat test - y = @jit vcat(a, b) - @test y == vcat(a, _b) - @test y isa ConcreteRArray{typeof_a,1} - - ## vcat test - adjoint - y1 = @jit vcat(a, c') - @test y1 == vcat(a, _c') - @test y1 isa ConcreteRArray{typeof_a,2} - - # hcat test - z = @jit hcat(a, c) - @test z == hcat(a, _c) - @test z isa ConcreteRArray{typeof_a,2} - - ## hcat test - adjoint - z1 = @jit hcat(a, b') - @test z1 == hcat(a, _b') - @test z1 isa ConcreteRArray{typeof_a,2} + y = @jit test_hcat(x_concrete) + @test y == test_hcat(x) + @test eltype(y) === Bool + + # hvcat + test_hvcat(x) = begin + x = x[] # unwrap scalar + [x x x; x x x] + end + y = @jit test_hvcat(x_concrete) + @test y == test_hvcat(x) + @test eltype(y) === Bool + + # hvncat + test_hvncat(x) = begin + x = x[] # unwrap scalar + [x x x; x x x;;; x x x; x x x] + end + y = @jit test_hvncat(x_concrete) + @test y == test_hvncat(x) + @test eltype(y) === Bool + + # typed_vcat + test_typed_vcat(x) = begin + x = x[] # unwrap scalar + Int[x; x; x] + end + y = @jit test_typed_vcat(x_concrete) + @test y == test_typed_vcat(x) + @test eltype(y) === Int + + # typed_hcat + test_typed_hcat(x) = begin + x = x[] # unwrap scalar + Int[x x x] + end + y = @jit test_typed_hcat(x_concrete) + @test y == test_typed_hcat(x) + @test eltype(y) === Int + + # typed_hvcat + test_typed_hvcat(x) = begin + x = x[] # unwrap scalar + Int[x x x; x x x] end + y = @jit test_typed_hvcat(x_concrete) + @test y == test_typed_hvcat(x) + @test eltype(y) === Int + + # typed_hvncat + test_typed_hvncat(x) = begin + x = x[] # unwrap scalar + Int[x x x; x x x;;; x x x; x x x] + end + y = @jit test_typed_hvncat(x_concrete) + @test y == test_typed_hvncat(x) + @test eltype(y) === Int + end + + @testset "$(ndims(x))-dim Array" for x in [ + fill(true), + [true, false], + [true false], + [true true; true false], + [ + true true true true; true true true false;;; + true true false true; true true false false;;; + true false true true; true false true false + ], + ] + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + test_vcat(x) = [x; x; x] + y = @jit test_vcat(x_concrete) + @test y == test_vcat(x) + @test eltype(y) === Bool + + # hcat + test_hcat(x) = [x x x] + y = @jit test_hcat(x_concrete) + @test y == test_hcat(x) + @test eltype(y) === Bool + + # hvcat + test_hvcat(x) = [x x x; x x x] + y = @jit test_hvcat(x_concrete) + @test y == test_hvcat(x) + @test eltype(y) === Bool + + # hvncat + test_hvncat(x) = [x x x; x x x;;; x x x; x x x] + y = @jit test_hvncat(x_concrete) + @test y == test_hvncat(x) + @test eltype(y) === Bool + + # typed_vcat + test_typed_vcat(x) = Int[x; x; x] + y = @jit test_typed_vcat(x_concrete) + @test y == test_typed_vcat(x) + @test eltype(y) === Int + + # typed_hcat + test_typed_hcat(x) = Int[x x x] + y = @jit test_typed_hcat(x_concrete) + @test y == test_typed_hcat(x) + @test eltype(y) === Int + + # typed_hvcat + test_typed_hvcat(x) = Int[x x x; x x x] + y = @jit test_typed_hvcat(x_concrete) + @test y == test_typed_hvcat(x) + @test eltype(y) === Int + + # typed_hvncat + test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x] + y = @jit test_typed_hvncat(x_concrete) + @test y == test_typed_hvncat(x) + @test eltype(y) === Int + end + + @testset "Number and RArray" for a in [1.0f0, 1.0e0] + typeof_a = typeof(a) + _b = typeof_a.([2.0, 3.0, 4.0]) + _c = typeof_a.([2.0 3.0 4.0]) + b = Reactant.to_rarray(_b) + c = Reactant.to_rarray(_c) + + # vcat test + y = @jit vcat(a, b) + @test y == vcat(a, _b) + @test y isa ConcreteRArray{typeof_a,1} + + ## vcat test - adjoint + y1 = @jit vcat(a, c') + @test y1 == vcat(a, _c') + @test y1 isa ConcreteRArray{typeof_a,2} + + # hcat test + z = @jit hcat(a, c) + @test z == hcat(a, _c) + @test z isa ConcreteRArray{typeof_a,2} + + ## hcat test - adjoint + z1 = @jit hcat(a, b') + @test z1 == hcat(a, _b') + @test z1 isa ConcreteRArray{typeof_a,2} + end end @testset "repeat" begin - @testset for (size, counts) in Iterators.product( - [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)], - [(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)], - ) - x = rand(size...) - - @testset "outer repeat" begin - @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) - end + @testset for (size, counts) in Iterators.product( + [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)], + [(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)], + ) + x = rand(size...) + + @testset "outer repeat" begin + @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) + end - length(counts) < length(size) && continue + length(counts) < length(size) && continue - @testset "inner repeat" begin - @test (@jit repeat(Reactant.to_rarray(x); inner=counts)) == - repeat(x; inner=counts) - end + @testset "inner repeat" begin + @test (@jit repeat(Reactant.to_rarray(x); inner=counts)) == + repeat(x; inner=counts) end + end end tuple_byref(x) = (; a=(; b=x)) tuple_byref2(x) = abs2.(x), tuple_byref(x) @testset "Tuple byref" begin - x = Reactant.to_rarray([1.0 -2.0; -3.0 4.0]) - @test @jit(tuple_byref(x)).a.b.data === x.data + x = Reactant.to_rarray([1.0 -2.0; -3.0 4.0]) + @test @jit(tuple_byref(x)).a.b.data === x.data - f2 = @compile tuple_byref2(x) - r2 = f2(x) - @test r2[2].a.b.data === x.data - @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0]) + f2 = @compile tuple_byref2(x) + r2 = f2(x) + @test r2[2].a.b.data === x.data + @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0]) end sum_xxᵀ(x) = sum(x .* x') @testset "sum(x .* x')" begin - @testset "size(x): $(size(x))" for x in (rand(4, 4), rand(4)) - x_ca = Reactant.to_rarray(x) + @testset "size(x): $(size(x))" for x in (rand(4, 4), rand(4)) + x_ca = Reactant.to_rarray(x) - @test @jit(sum_xxᵀ(x_ca)) ≈ sum_xxᵀ(x) - end + @test @jit(sum_xxᵀ(x_ca)) ≈ sum_xxᵀ(x) + end end @testset "similar" begin - x = zeros(2, 3) - y = Reactant.to_rarray(x) - f = @compile similar(y) - @test size(f(y)) == size(x) - @test eltype(f(y)) == eltype(x) + x = zeros(2, 3) + y = Reactant.to_rarray(x) + f = @compile similar(y) + @test size(f(y)) == size(x) + @test eltype(f(y)) == eltype(x) end @testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64) - a = Reactant.to_rarray(ones(CT, 2)) - b = Reactant.to_rarray(ones(CT, 2)) - c = Reactant.compile(+, (a, b))(a, b) - @test c == ones(CT, 2) + ones(CT, 2) + a = Reactant.to_rarray(ones(CT, 2)) + b = Reactant.to_rarray(ones(CT, 2)) + c = Reactant.compile(+, (a, b))(a, b) + @test c == ones(CT, 2) + ones(CT, 2) end @testset "Scalars" begin - @testset "Only Scalars" begin - x = (3, 3.14) - - f1(x) = x[1] * x[2] - - x_ra = Reactant.to_rarray(x; track_numbers=Number) - f2 = @compile f1(x_ra) - @test f2(Reactant.to_rarray((5, 5.2); track_numbers=Number)) ≈ 5 * 5.2 - @test f2(Reactant.to_rarray((5, 5.2); track_numbers=Number)) isa ConcreteRNumber - - x_ra = Reactant.to_rarray(x) - f3 = @compile f1(x_ra) - @test f3(Reactant.to_rarray((5, 5.2))) ≈ f1(x) - @test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcreteRNumber) - @test f3(Reactant.to_rarray((5, 5.2))) isa Number - - x_ra = Reactant.to_rarray(x; track_numbers=Int) - f4 = @compile f1(x_ra) - @test f4(Reactant.to_rarray((5, 5.2); track_numbers=Int)) ≈ 5 * 3.14 - @test f4(Reactant.to_rarray((5, 5.2); track_numbers=Int)) isa ConcreteRNumber - end + @testset "Only Scalars" begin + x = (3, 3.14) + + f1(x) = x[1] * x[2] - @testset "Mixed" begin - x = (3, [3.14]) + x_ra = Reactant.to_rarray(x; track_numbers=Number) + f2 = @compile f1(x_ra) + @test f2(Reactant.to_rarray((5, 5.2); track_numbers=Number)) ≈ 5 * 5.2 + @test f2(Reactant.to_rarray((5, 5.2); track_numbers=Number)) isa ConcreteRNumber - f1(x) = x[1] * x[2] + x_ra = Reactant.to_rarray(x) + f3 = @compile f1(x_ra) + @test f3(Reactant.to_rarray((5, 5.2))) ≈ f1(x) + @test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcreteRNumber) + @test f3(Reactant.to_rarray((5, 5.2))) isa Number - x_ra = Reactant.to_rarray(x; track_numbers=Number) + x_ra = Reactant.to_rarray(x; track_numbers=Int) + f4 = @compile f1(x_ra) + @test f4(Reactant.to_rarray((5, 5.2); track_numbers=Int)) ≈ 5 * 3.14 + @test f4(Reactant.to_rarray((5, 5.2); track_numbers=Int)) isa ConcreteRNumber + end - f2 = @compile f1(x_ra) - res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=Number)) - @test @allowscalar(only(res2)) ≈ 5 * 3.14 - @test res2 isa ConcreteRArray + @testset "Mixed" begin + x = (3, [3.14]) - x_ra = Reactant.to_rarray(x) + f1(x) = x[1] * x[2] - f3 = @compile f1(x_ra) - res3 = f3(Reactant.to_rarray((5, [3.14]))) - @test @allowscalar(only(res3)) ≈ only(f1(x)) - @test res3 isa ConcreteRArray - end + x_ra = Reactant.to_rarray(x; track_numbers=Number) + + f2 = @compile f1(x_ra) + res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=Number)) + @test @allowscalar(only(res2)) ≈ 5 * 3.14 + @test res2 isa ConcreteRArray + + x_ra = Reactant.to_rarray(x) + + f3 = @compile f1(x_ra) + res3 = f3(Reactant.to_rarray((5, [3.14]))) + @test @allowscalar(only(res3)) ≈ only(f1(x)) + @test res3 isa ConcreteRArray + end end relu(x::T) where {T<:Number} = max(T(0), x) relu(x) = relu.(x) @testset "type casting" begin - x = randn(2, 10) - x_ra = Reactant.to_rarray(x) + x = randn(2, 10) + x_ra = Reactant.to_rarray(x) - @test @jit(relu(x_ra)) ≈ relu(x) + @test @jit(relu(x_ra)) ≈ relu(x) end @testset "concrete number to julia number" begin - x = Reactant.to_rarray(3.14; track_numbers=Number) - @test Float32(x) isa Float32 - @test Float64(x) isa Float64 - @test_throws InexactError Int(x) + x = Reactant.to_rarray(3.14; track_numbers=Number) + @test Float32(x) isa Float32 + @test Float64(x) isa Float64 + @test_throws InexactError Int(x) - x = Reactant.to_rarray(3; track_numbers=Number) - @test Float32(x) isa Float32 - @test Float64(x) isa Float64 - @test Int(x) isa Int - @test float(x) isa ConcreteRNumber{Float64} + x = Reactant.to_rarray(3; track_numbers=Number) + @test Float32(x) isa Float32 + @test Float64(x) isa Float64 + @test Int(x) isa Int + @test float(x) isa ConcreteRNumber{Float64} end @testset "concrete number with fill" begin - x = Reactant.to_rarray(10; track_numbers=Number) - x_ra = @jit fill(x, (10, 10)) - @test fill(x, (10, 10)) == Array(x_ra) + x = Reactant.to_rarray(10; track_numbers=Number) + x_ra = @jit fill(x, (10, 10)) + @test fill(x, (10, 10)) == Array(x_ra) end @testset "clamp" begin - x = randn(2, 3) - x_ra = Reactant.to_rarray(x) + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) - y = @jit(clamp!(x_ra, 0.0, 0.25)) - @allowscalar begin - @test maximum(y) ≤ 0.25 - @test minimum(y) ≥ 0.0 - @test maximum(x_ra) == maximum(y) - @test minimum(x_ra) == minimum(y) - end + y = @jit(clamp!(x_ra, 0.0, 0.25)) + @allowscalar begin + @test maximum(y) ≤ 0.25 + @test minimum(y) ≥ 0.0 + @test maximum(x_ra) == maximum(y) + @test minimum(x_ra) == minimum(y) + end - x = randn(2, 3) - x_ra = Reactant.to_rarray(x) + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) - y = @jit(clamp.(x_ra, 0.0, 0.25)) - @allowscalar begin - @test maximum(y) ≤ 0.25 - @test minimum(y) ≥ 0.0 - @test x_ra ≈ x - end + y = @jit(clamp.(x_ra, 0.0, 0.25)) + @allowscalar begin + @test maximum(y) ≤ 0.25 + @test minimum(y) ≥ 0.0 + @test x_ra ≈ x + end - x_ra = ConcreteRNumber(3.0) - y = @jit(clamp(x_ra, 0.0, 0.25)) - @test y isa ConcreteRNumber{Float64} + x_ra = ConcreteRNumber(3.0) + y = @jit(clamp(x_ra, 0.0, 0.25)) + @test y isa ConcreteRNumber{Float64} end @testset for op in [round, ceil, floor] - @testset "$(typeof(x)) : $(size(x))" for x in (rand(Float32, (3, 3)), rand(Float64)) - intop = Base.Fix1(op, Int) - x_ra = Reactant.to_rarray.(x; track_numbers=Number) + @testset "$(typeof(x)) : $(size(x))" for x in (rand(Float32, (3, 3)), rand(Float64)) + intop = Base.Fix1(op, Int) + x_ra = Reactant.to_rarray.(x; track_numbers=Number) - @test @jit(op.(x_ra)) ≈ op.(x) - @test @jit(intop.(x_ra)) ≈ intop.(x) - end + @test @jit(op.(x_ra)) ≈ op.(x) + @test @jit(intop.(x_ra)) ≈ intop.(x) + end end @testset "sign" begin - x = collect(Float64, 0:0.01:1) .- 0.5 - x_ra = Reactant.to_rarray(x) - @test Array(@jit(sign.(x_ra))) ≈ sign.(x) + x = collect(Float64, 0:0.01:1) .- 0.5 + x_ra = Reactant.to_rarray(x) + @test Array(@jit(sign.(x_ra))) ≈ sign.(x) end @testset "aos_to_soa" begin - using ArrayInterface + using ArrayInterface - x_res = collect(reshape(1.0:4.0, 2, 1, 2)) - x_ca = Reactant.to_rarray.(x_res; track_numbers=Number) + x_res = collect(reshape(1.0:4.0, 2, 1, 2)) + x_ca = Reactant.to_rarray.(x_res; track_numbers=Number) - y_ca1 = @allowscalar ArrayInterface.aos_to_soa(x_ca) - @test y_ca1 ≈ x_res - @test y_ca1 isa ConcreteRArray + y_ca1 = @allowscalar ArrayInterface.aos_to_soa(x_ca) + @test y_ca1 ≈ x_res + @test y_ca1 isa ConcreteRArray - y_ca2 = @jit(ArrayInterface.aos_to_soa(x_ca)) - @test y_ca2 ≈ x_res - @test y_ca2 isa ConcreteRArray + y_ca2 = @jit(ArrayInterface.aos_to_soa(x_ca)) + @test y_ca2 ≈ x_res + @test y_ca2 isa ConcreteRArray end @testset "collect" begin - x = randn(2, 3) - x_ra = Reactant.to_rarray(x) + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) - @testset "Reactant.to_rarray" begin - y = collect(x_ra) - @test y == x - @test y !== x_ra - end + @testset "Reactant.to_rarray" begin + y = collect(x_ra) + @test y == x + @test y !== x_ra + end - @testset "TracedRArray" begin - y = @jit(collect(x_ra)) - @test y == x - @test y !== x_ra - end + @testset "TracedRArray" begin + y = @jit(collect(x_ra)) + @test y == x + @test y !== x_ra + end - x = 5 - x_ra = ConcreteRNumber(x) + x = 5 + x_ra = ConcreteRNumber(x) - @testset "ConcreteRNumber" begin - y = collect(x_ra) - @test y isa Array{Int,0} - end + @testset "ConcreteRNumber" begin + y = collect(x_ra) + @test y isa Array{Int,0} + end - @testset "TracedRArray" begin - y = @jit(collect(x_ra)) - @test y isa ConcreteRArray{Int,0} - @test y == x - end + @testset "TracedRArray" begin + y = @jit(collect(x_ra)) + @test y isa ConcreteRArray{Int,0} + @test y == x + end end function f_row_major(x::AbstractArray{T}) where {T} - y = [1 2; 3 4; 5 6] - if x isa Reactant.TracedRArray - y = Reactant.TracedUtils.promote_to( - Reactant.TracedRArray{Reactant.unwrapped_eltype(T),2}, y - ) - end - return x .+ y + y = [1 2; 3 4; 5 6] + if x isa Reactant.TracedRArray + y = Reactant.TracedUtils.promote_to( + Reactant.TracedRArray{Reactant.unwrapped_eltype(T),2}, y + ) + end + return x .+ y end @testset "array attributes: row major" begin - x = zeros(Int, 3, 2) - x_ra = Reactant.to_rarray(x) + x = zeros(Int, 3, 2) + x_ra = Reactant.to_rarray(x) - @test @jit(f_row_major(x_ra)) ≈ f_row_major(x) + @test @jit(f_row_major(x_ra)) ≈ f_row_major(x) end @testset "ifelse" begin - @test 1.0 == @test_warn r"`ifelse` with different element-types" @jit( - ifelse(ConcreteRNumber(true), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) - ) - @test @jit( - ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) - ) isa ConcreteRNumber{Float64} - @test 0.0f0 == + @test 1.0 == @test_warn r"`ifelse` with different element-types" @jit( + ifelse(ConcreteRNumber(true), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) + ) + @test @jit( + ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) + ) isa ConcreteRNumber{Float64} + @test 0.0f0 == @jit ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0)) - @test @jit( - ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0)) - ) isa ConcreteRNumber{Float32} + @test @jit( + ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0)) + ) isa ConcreteRNumber{Float32} - cond = ConcreteRNumber(true) - x = ConcreteRNumber(1.0) - @test @jit(ifelse(cond, x, 0.0)) == ConcreteRNumber(1.0) - @test @jit(ifelse(cond, 0.0, x)) == ConcreteRNumber(0.0) - @test @jit(ifelse(cond, 1.0, 0.0)) == ConcreteRNumber(1.0) - @test @jit(ifelse(cond, 0.0, 1.0)) == ConcreteRNumber(0.0) + cond = ConcreteRNumber(true) + x = ConcreteRNumber(1.0) + @test @jit(ifelse(cond, x, 0.0)) == ConcreteRNumber(1.0) + @test @jit(ifelse(cond, 0.0, x)) == ConcreteRNumber(0.0) + @test @jit(ifelse(cond, 1.0, 0.0)) == ConcreteRNumber(1.0) + @test @jit(ifelse(cond, 0.0, 1.0)) == ConcreteRNumber(0.0) end @testset "fill! and zero on Reactant.to_rarray" begin - x_ra = Reactant.to_rarray(rand(3, 4)) + x_ra = Reactant.to_rarray(rand(3, 4)) - z = zero(x_ra) - @test z isa ConcreteRArray - @test size(z) == size(x_ra) - @test all(iszero, Array(z)) + z = zero(x_ra) + @test z isa ConcreteRArray + @test size(z) == size(x_ra) + @test all(iszero, Array(z)) - fill!(z, 1.0) - @test all(==(1.0), Array(z)) + fill!(z, 1.0) + @test all(==(1.0), Array(z)) end @testset "Preserve Aliasing" begin - x = Reactant.to_rarray([3]) - - if x isa ConcretePJRTArray - # For IFRT arrays we don't have unsafe_buffer_pointer implemented - T = Any[nothing] - - function ip(m, T) - @allowscalar m[1] = 2 - T[1] = m - return m - end - - res = @jit ip(x, T) - @test @allowscalar res[1] == 2 - @test @allowscalar x[1] == 2 - @test @allowscalar T[1][1] == 2 - - ptr_x = Base.unsafe_convert( - Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(x.data[1].buffer) - ) - ptr_res = Base.unsafe_convert( - Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(res.data[1].buffer) - ) - ptr_T1 = Base.unsafe_convert( - Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(T[1].data[1].buffer) - ) - - @test ptr_x == ptr_res == ptr_T1 - end -end + x = Reactant.to_rarray([3]) -@testset "eltype conversion inside interpreter" begin - function test_convert(x::AbstractArray{T}, eta) where {T} - eta = T(eta) - return x .* eta, eta + if x isa ConcretePJRTArray + # For IFRT arrays we don't have unsafe_buffer_pointer implemented + T = Any[nothing] + + function ip(m, T) + @allowscalar m[1] = 2 + T[1] = m + return m end - res = @jit test_convert( - Reactant.to_rarray(rand(4, 2)), Reactant.to_rarray(3.0f0; track_numbers=Number) + res = @jit ip(x, T) + @test @allowscalar res[1] == 2 + @test @allowscalar x[1] == 2 + @test @allowscalar T[1][1] == 2 + + ptr_x = Base.unsafe_convert( + Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(x.data[1].buffer) + ) + ptr_res = Base.unsafe_convert( + Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(res.data[1].buffer) + ) + ptr_T1 = Base.unsafe_convert( + Ptr{Float64}, Reactant.XLA.unsafe_buffer_pointer(T[1].data[1].buffer) ) - @test res[1] isa ConcreteRArray{Float64,2} - @test res[2] isa ConcreteRNumber{Float64} + @test ptr_x == ptr_res == ptr_T1 + end end -@testset "stack" begin - x = rand(4, 4) - y = rand(4, 4) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) +@testset "eltype conversion inside interpreter" begin + function test_convert(x::AbstractArray{T}, eta) where {T} + eta = T(eta) + return x .* eta, eta + end - @test @jit(stack((x_ra, x_ra))) ≈ stack((x, x)) - @test @jit(stack((x_ra, x_ra); dims=2)) ≈ stack((x, x); dims=2) - @test @jit(stack((x_ra, y_ra); dims=2)) ≈ stack((x, y); dims=2) - @test @jit(stack((x_ra, y_ra, x_ra); dims=1)) ≈ stack((x, y, x); dims=1) + res = @jit test_convert( + Reactant.to_rarray(rand(4, 2)), Reactant.to_rarray(3.0f0; track_numbers=Number) + ) - # Test that we don't hit illegal instruction; `x` is intentionally not a traced array - @test @jit(stack((x, x))) isa Any - @test @jit(stack((x, x); dims=2)) isa Any - @test @jit(stack((x, y); dims=2)) isa Any - @test @jit(stack((x, y, x); dims=1)) isa Any + @test res[1] isa ConcreteRArray{Float64,2} + @test res[2] isa ConcreteRNumber{Float64} end -@testset "unstable stack" begin - x = rand(4, 4) - y = rand(4, 4) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) +@testset "stack" begin + x = rand(4, 4) + y = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) - function s1(x) - xs = [] - push!(xs, x) - push!(xs, x) - return stack(xs) - end - function s2(x) - xs = [] - push!(xs, x) - push!(xs, x) - return stack(xs; dims=2) - end - function s3(x, y) - xs = [] - push!(xs, x) - push!(xs, y) - return stack(xs; dims=2) - end - function s4(x, y) - xs = [] - push!(xs, x) - push!(xs, y) - push!(xs, x) - return stack(xs; dims=2) - end + @test @jit(stack((x_ra, x_ra))) ≈ stack((x, x)) + @test @jit(stack((x_ra, x_ra); dims=2)) ≈ stack((x, x); dims=2) + @test @jit(stack((x_ra, y_ra); dims=2)) ≈ stack((x, y); dims=2) + @test @jit(stack((x_ra, y_ra, x_ra); dims=1)) ≈ stack((x, y, x); dims=1) - @test @jit(s1(x_ra)) ≈ s1(x) - @test @jit(s2(x_ra)) ≈ s2(x) - @test @jit(s3(x_ra, y_ra)) ≈ s3(x, y) - @test @jit(s4(x_ra, y_ra)) ≈ s4(x, y) + # Test that we don't hit illegal instruction; `x` is intentionally not a traced array + @test @jit(stack((x, x))) isa Any + @test @jit(stack((x, x); dims=2)) isa Any + @test @jit(stack((x, y); dims=2)) isa Any + @test @jit(stack((x, y, x); dims=1)) isa Any +end - # Test that we don't hit illegal instruction; `x` is intentionally not a traced array - @test @jit(s1(x)) isa Any - @test @jit(s2(x)) isa Any - @test @jit(s3(x, y)) isa Any - @test @jit(s4(x, y)) isa Any +@testset "unstable stack" begin + x = rand(4, 4) + y = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + function s1(x) + xs = [] + push!(xs, x) + push!(xs, x) + return stack(xs) + end + function s2(x) + xs = [] + push!(xs, x) + push!(xs, x) + return stack(xs; dims=2) + end + function s3(x, y) + xs = [] + push!(xs, x) + push!(xs, y) + return stack(xs; dims=2) + end + function s4(x, y) + xs = [] + push!(xs, x) + push!(xs, y) + push!(xs, x) + return stack(xs; dims=2) + end + + @test @jit(s1(x_ra)) ≈ s1(x) + @test @jit(s2(x_ra)) ≈ s2(x) + @test @jit(s3(x_ra, y_ra)) ≈ s3(x, y) + @test @jit(s4(x_ra, y_ra)) ≈ s4(x, y) + + # Test that we don't hit illegal instruction; `x` is intentionally not a traced array + @test @jit(s1(x)) isa Any + @test @jit(s2(x)) isa Any + @test @jit(s3(x, y)) isa Any + @test @jit(s4(x, y)) isa Any end @testset "duplicate args (#226)" begin - first_arg(x, y) = x - x_ra = Reactant.to_rarray(rand(2, 2)) - res = @jit first_arg(x_ra, x_ra) - @test res ≈ x_ra + first_arg(x, y) = x + x_ra = Reactant.to_rarray(rand(2, 2)) + res = @jit first_arg(x_ra, x_ra) + @test res ≈ x_ra end @testset "Common Trig Functions" begin - x = rand(Float32, 4, 16) - x_ra = Reactant.to_rarray(x) - - @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) - @test @jit(fn.(x_ra)) ≈ fn.(x) - @test @jit(fn.(x_ra)) isa ConcreteRArray{Float32,2} - end - - x = 0.235f0 - x_ra = Reactant.to_rarray(x; track_numbers=Number) - - @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) - @test @jit(fn.(x_ra)) ≈ fn.(x) - @test @jit(fn.(x_ra)) isa ConcreteRNumber{Float32} - end - @testset for fn in (sincospi, sincos) - res = @jit fn(x_ra) - @test res[1] ≈ fn(x)[1] - @test res[2] ≈ fn(x)[2] - @test res[1] isa ConcreteRNumber{Float32} - @test res[2] isa ConcreteRNumber{Float32} - end + x = rand(Float32, 4, 16) + x_ra = Reactant.to_rarray(x) + + @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) + @test @jit(fn.(x_ra)) ≈ fn.(x) + @test @jit(fn.(x_ra)) isa ConcreteRArray{Float32,2} + end + + x = 0.235f0 + x_ra = Reactant.to_rarray(x; track_numbers=Number) + + @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) + @test @jit(fn.(x_ra)) ≈ fn.(x) + @test @jit(fn.(x_ra)) isa ConcreteRNumber{Float32} + end + @testset for fn in (sincospi, sincos) + res = @jit fn(x_ra) + @test res[1] ≈ fn(x)[1] + @test res[2] ≈ fn(x)[2] + @test res[1] isa ConcreteRNumber{Float32} + @test res[2] isa ConcreteRNumber{Float32} + end end @testset "isfinite" begin - x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) - @test @jit(isfinite.(x)) == [true, false, false, false, false] + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test @jit(isfinite.(x)) == [true, false, false, false, false] - x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) - @test @jit(isfinite.(x)) == [true, false, false, false, false] + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test @jit(isfinite.(x)) == [true, false, false, false, false] end @testset "isnan" begin - x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) - @test @jit(isnan.(x)) == [false, true, false, false, true] + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) + @test @jit(isnan.(x)) == [false, true, false, false, true] - x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) - @test @jit(isnan.(x)) == [false, true, false, false, true] + x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) + @test @jit(isnan.(x)) == [false, true, false, false, true] end @testset "isnan/isfinite" begin - @test isnan(Reactant.to_rarray(NaN; track_numbers=Number)) - @test !isnan(Reactant.to_rarray(0.0; track_numbers=Number)) - @test isfinite(Reactant.to_rarray(0.0; track_numbers=Number)) - @test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number)) + @test isnan(Reactant.to_rarray(NaN; track_numbers=Number)) + @test !isnan(Reactant.to_rarray(0.0; track_numbers=Number)) + @test isfinite(Reactant.to_rarray(0.0; track_numbers=Number)) + @test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number)) end @testset "isinf" begin - @test Bool(@jit(isinf(ConcreteRNumber(Inf)))) - @test Bool(@jit(isinf(ConcreteRNumber(-Inf)))) - @test !Bool(@jit(isinf(ConcreteRNumber(2)))) - @test !Bool(@jit(isinf(ConcreteRNumber(2.0)))) - @test !Bool(@jit(isinf(ConcreteRNumber(true)))) + @test Bool(@jit(isinf(ConcreteRNumber(Inf)))) + @test Bool(@jit(isinf(ConcreteRNumber(-Inf)))) + @test !Bool(@jit(isinf(ConcreteRNumber(2)))) + @test !Bool(@jit(isinf(ConcreteRNumber(2.0)))) + @test !Bool(@jit(isinf(ConcreteRNumber(true)))) end @testset "mod and rem" begin - a = [-1.1, 7.7, -3.3, 9.9, -5.5] - b = [6.6, -2.2, -8.8, 4.4, -10.1] + a = [-1.1, 7.7, -3.3, 9.9, -5.5] + b = [6.6, -2.2, -8.8, 4.4, -10.1] - expected_mod = mod.(a, b) - @test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_mod - @test @jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod - @test @jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod + expected_mod = mod.(a, b) + @test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_mod + @test @jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod + @test @jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod - expected_rem = rem.(a, b) - @test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_rem - @test @jit(rem.(a, Reactant.to_rarray(b))) ≈ expected_rem - @test @jit(rem.(Reactant.to_rarray(a), b)) ≈ expected_rem + expected_rem = rem.(a, b) + @test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_rem + @test @jit(rem.(a, Reactant.to_rarray(b))) ≈ expected_rem + @test @jit(rem.(Reactant.to_rarray(a), b)) ≈ expected_rem end @testset "xor" begin - for a in (true, false), b in (true, false) - @test @jit(xor(ConcreteRNumber(a), ConcreteRNumber(b))) == xor(a, b) - end + for a in (true, false), b in (true, false) + @test @jit(xor(ConcreteRNumber(a), ConcreteRNumber(b))) == xor(a, b) + end end @testset "signbit" begin - for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0) - @test @jit(signbit(ConcreteRNumber(x))) == signbit(x) - end + for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0) + @test @jit(signbit(ConcreteRNumber(x))) == signbit(x) + end end @testset "copysign" begin - for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14) - # Make sure also the return type is correct - @test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) === - copysign(a, b) - end + for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14) + # Make sure also the return type is correct + @test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) === + copysign(a, b) + end end @testset "reduce integers" begin - x = rand(Bool, 100) - x_ra = Reactant.to_rarray(x) + x = rand(Bool, 100) + x_ra = Reactant.to_rarray(x) - @test @jit(sum(x_ra)) == sum(x) + @test @jit(sum(x_ra)) == sum(x) - x = rand(Int16, 100) - x_ra = Reactant.to_rarray(x) + x = rand(Int16, 100) + x_ra = Reactant.to_rarray(x) - @test @jit(sum(x_ra)) == sum(x) + @test @jit(sum(x_ra)) == sum(x) end @testset "/ on integers" begin - @test @jit(/(ConcreteRNumber(2), ConcreteRNumber(4))) ≈ 0.5 - @test @jit(/(ConcreteRNumber(2), 4)) ≈ 0.5 - @test @jit(/(2, ConcreteRNumber(4))) ≈ 0.5 - @test @jit(/(2, ConcreteRNumber(Int32(4)))) ≈ 0.5 + @test @jit(/(ConcreteRNumber(2), ConcreteRNumber(4))) ≈ 0.5 + @test @jit(/(ConcreteRNumber(2), 4)) ≈ 0.5 + @test @jit(/(2, ConcreteRNumber(4))) ≈ 0.5 + @test @jit(/(2, ConcreteRNumber(Int32(4)))) ≈ 0.5 end @testset "Broadcasting with Range" begin - x = Reactant.to_rarray(rand(10)) - fn(x) = x .+ (1:length(x)) + x = Reactant.to_rarray(rand(10)) + fn(x) = x .+ (1:length(x)) - @test @jit(fn(x)) ≈ fn(Array(x)) + @test @jit(fn(x)) ≈ fn(Array(x)) end function fntest1(x) - y = similar(x, 1, 1, 8) - sum!(y, x) - return y + y = similar(x, 1, 1, 8) + sum!(y, x) + return y end function fntest2(x) - y = similar(x, 2, 1, 8) - sum!(y, x) - return y + y = similar(x, 2, 1, 8) + sum!(y, x) + return y end function fntest3(x) - y = similar(x, 2, 1, 1) - sum!(abs2, y, x) - return y + y = similar(x, 2, 1, 1) + sum!(abs2, y, x) + return y end @testset "mapreducedim!" begin - x = reshape(collect(Float32, 1:64), 2, 4, 8) ./ 64 - x_ra = Reactant.to_rarray(x) + x = reshape(collect(Float32, 1:64), 2, 4, 8) ./ 64 + x_ra = Reactant.to_rarray(x) - @test Array(@jit(fntest1(x_ra))) ≈ fntest1(x) - @test Array(@jit(fntest2(x_ra))) ≈ fntest2(x) - @test Array(@jit(fntest3(x_ra))) ≈ fntest3(x) + @test Array(@jit(fntest1(x_ra))) ≈ fntest1(x) + @test Array(@jit(fntest2(x_ra))) ≈ fntest2(x) + @test Array(@jit(fntest3(x_ra))) ≈ fntest3(x) end @testset "don't expand ranges by default" begin - fn(x) = Reactant.TracedUtils.broadcast_to_size(x, (length(x),)) + fn(x) = Reactant.TracedUtils.broadcast_to_size(x, (length(x),)) - hlo = repr(@code_hlo(fn(1:10000))) - @test contains(hlo, "stablehlo.iota") - @test contains(hlo, "stablehlo.add") - @test Array(@jit(fn(1:10000))) ≈ collect(1:10000) + hlo = repr(@code_hlo(fn(1:10000))) + @test contains(hlo, "stablehlo.iota") + @test contains(hlo, "stablehlo.add") + @test Array(@jit(fn(1:10000))) ≈ collect(1:10000) - hlo = repr(@code_hlo(fn(32:10000))) - @test contains(hlo, "stablehlo.iota") - @test contains(hlo, "stablehlo.add") - @test Array(@jit(fn(32:10000))) ≈ collect(32:10000) + hlo = repr(@code_hlo(fn(32:10000))) + @test contains(hlo, "stablehlo.iota") + @test contains(hlo, "stablehlo.add") + @test Array(@jit(fn(32:10000))) ≈ collect(32:10000) - hlo = repr(@code_hlo(fn(0:10000))) - @test contains(hlo, "stablehlo.iota") - @test !contains(hlo, "stablehlo.add") - @test Array(@jit(fn(0:10000))) ≈ collect(0:10000) + hlo = repr(@code_hlo(fn(0:10000))) + @test contains(hlo, "stablehlo.iota") + @test !contains(hlo, "stablehlo.add") + @test Array(@jit(fn(0:10000))) ≈ collect(0:10000) - hlo = repr(@code_hlo(fn(Base.OneTo(10000)))) - @test contains(hlo, "stablehlo.iota") - @test contains(hlo, "stablehlo.add") - @test Array(@jit(fn(Base.OneTo(10000)))) ≈ collect(Base.OneTo(10000)) + hlo = repr(@code_hlo(fn(Base.OneTo(10000)))) + @test contains(hlo, "stablehlo.iota") + @test contains(hlo, "stablehlo.add") + @test Array(@jit(fn(Base.OneTo(10000)))) ≈ collect(Base.OneTo(10000)) end function dip!(x) - x[:a] = x[:a] .* x[:b] - return nothing + x[:a] = x[:a] .* x[:b] + return nothing end @testset "Dict" begin - x = Dict{Symbol,Vector{Float32}}() - x[:a] = 2.7 * ones(4) - x[:b] = 3.1 * ones(4) + x = Dict{Symbol,Vector{Float32}}() + x[:a] = 2.7 * ones(4) + x[:b] = 3.1 * ones(4) + + ra = Reactant.to_rarray(x) + @jit dip!(ra) + ra[:a] ≈ (2.7 * 2) * ones(4) +end - ra = Reactant.to_rarray(x) - @jit dip!(ra) - ra[:a] ≈ (2.7 * 2) * ones(4) +function combine_ab_to_c!(d) + d[:c] = d[:a] + d[:b] + return nothing +end + +@testset "Bijection" begin + d = Bijection{Symbol,ConcreteRArray{Float64,1},Dict{Symbol,ConcreteRArray{Float64,1}},IdDict{ConcreteRArray{Float64,1},Symbol}}() + a = Reactant.to_rarray([1.0]) + b = Reactant.to_rarray([1.0]) + d[:a] = a + d[:b] = b + + d2 = @jit identity(d) + @test d == d2 + + @jit combine_ab_to_c!(d) + @test haskey(d, :c) + @test d[:c] == [2.0] end @testset "@code_xla" begin - x_ra = Reactant.to_rarray(ones(4)) - hlo = repr(@code_xla(sin.(x_ra))) - @test contains(hlo, "HloModule") - @test contains(hlo, "sine") + x_ra = Reactant.to_rarray(ones(4)) + hlo = repr(@code_xla(sin.(x_ra))) + @test contains(hlo, "HloModule") + @test contains(hlo, "sine") end @testset "Raise keyword" begin - v = randn(Float32, 16) - rv = Reactant.to_rarray(v) - @test sin.(v) ≈ @jit raise = true sin.(rv) - @test cos.(v) ≈ @jit raise = false cos.(rv) - @test exp.(v) ≈ @jit raise = "canonicalize" exp.(rv) - @test_throws Reactant.MLIR.IR.AddPipelineException @jit raise = "this_pass-does_not_ExisT" exp.( - rv - ) + v = randn(Float32, 16) + rv = Reactant.to_rarray(v) + @test sin.(v) ≈ @jit raise = true sin.(rv) + @test cos.(v) ≈ @jit raise = false cos.(rv) + @test exp.(v) ≈ @jit raise = "canonicalize" exp.(rv) + @test_throws Reactant.MLIR.IR.AddPipelineException @jit raise = "this_pass-does_not_ExisT" exp.( + rv + ) end @testset "map!" begin - x = randn(Float32, 2, 3) - y = zeros(Float32, 2, 3) + x = randn(Float32, 2, 3) + y = zeros(Float32, 2, 3) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) - @test Array(@jit(map!(abs2, y_ra, x_ra))) ≈ map!(abs2, y, x) - @test Array(y_ra) ≈ y + @test Array(@jit(map!(abs2, y_ra, x_ra))) ≈ map!(abs2, y, x) + @test Array(y_ra) ≈ y end @testset "ConcreteRArray inplace broadcast" begin - x = Reactant.to_rarray(zeros(Float32, 2, 3)) - y = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3)) + x = Reactant.to_rarray(zeros(Float32, 2, 3)) + y = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3)) - x .= y ./ 2 + x .= y ./ 2 - @test Array(x) ≈ Array(y) ./ 2 + @test Array(x) ≈ Array(y) ./ 2 - x = zeros(Float32, 2, 3) - x .= y ./ 2 + x = zeros(Float32, 2, 3) + x .= y ./ 2 - @test Array(x) ≈ Array(y) ./ 2 + @test Array(x) ≈ Array(y) ./ 2 - x = view(zeros(Float32, 2, 5), :, 1:3) - x .= y ./ 2 + x = view(zeros(Float32, 2, 5), :, 1:3) + x .= y ./ 2 - @test Array(x) ≈ Array(y) ./ 2 + @test Array(x) ≈ Array(y) ./ 2 end @testset "Hlo Cost Analysis" begin - x_ra = Reactant.to_rarray(rand(4, 4)) - mul_comp = @compile x_ra * x_ra - cost = Reactant.XLA.cost_analysis(mul_comp) + x_ra = Reactant.to_rarray(rand(4, 4)) + mul_comp = @compile x_ra * x_ra + cost = Reactant.XLA.cost_analysis(mul_comp) - @test cost isa Reactant.XLA.HloCostAnalysisProperties + @test cost isa Reactant.XLA.HloCostAnalysisProperties end function fractional_idx(times, t) - n₂ = searchsortedfirst(times, t) - n₁ = max(1, n₂ - 1) - Nt = length(times) - n₂ = min(Nt, n₂) + n₂ = searchsortedfirst(times, t) + n₁ = max(1, n₂ - 1) + Nt = length(times) + n₂ = min(Nt, n₂) - t₁ = times[n₁] - t₂ = times[n₂] + t₁ = times[n₁] + t₂ = times[n₂] - ñ = (t - t₁) / (t₂ - t₁) + ñ = (t - t₁) / (t₂ - t₁) - return ñ, n₁, n₂ + return ñ, n₁, n₂ end @testset "Fractional index" begin - times = 0:0.01:4.5 - @test times isa Base.StepRangeLen - res = @jit fractional_idx(times, ConcreteRNumber(2.143)) - @test res[1] == 0.29999999999997334 - @test res[2] == 215 - @test res[3] == 216 + times = 0:0.01:4.5 + @test times isa Base.StepRangeLen + res = @jit fractional_idx(times, ConcreteRNumber(2.143)) + @test res[1] == 0.29999999999997334 + @test res[2] == 215 + @test res[3] == 216 end @testset "Traced fractional index" begin - times = Reactant.to_rarray(0:0.01:4.5; track_numbers=Number) - @test times isa Reactant.TracedRNumberOverrides.TracedStepRangeLen - res = @jit fractional_idx(times, ConcreteRNumber(2.143)) - @test res[1] == 0.29999999999997334 - @test res[2] == 215 - @test res[3] == 216 + times = Reactant.to_rarray(0:0.01:4.5; track_numbers=Number) + @test times isa Reactant.TracedRNumberOverrides.TracedStepRangeLen + res = @jit fractional_idx(times, ConcreteRNumber(2.143)) + @test res[1] == 0.29999999999997334 + @test res[2] == 215 + @test res[3] == 216 end function unitrange_test(r, i) - return r[i] + return r[i] end @testset "Unitrange" begin - x = 2:10 - @test (@jit unitrange_test(x, 3)) == 4 - @test (@jit unitrange_test(x, Reactant.ConcreteRNumber(4))) == 5 + x = 2:10 + @test (@jit unitrange_test(x, 3)) == 4 + @test (@jit unitrange_test(x, Reactant.ConcreteRNumber(4))) == 5 - x = Reactant.to_rarray(2:10; track_numbers=Number) - @test (@jit unitrange_test(x, 3)) == 4 - @test (@jit unitrange_test(x, Reactant.ConcreteRNumber(4))) == 5 + x = Reactant.to_rarray(2:10; track_numbers=Number) + @test (@jit unitrange_test(x, 3)) == 4 + @test (@jit unitrange_test(x, Reactant.ConcreteRNumber(4))) == 5 end mulpi(x) = π * x @testset "Irrational promotion" begin - x = Reactant.to_rarray(ones(2)) - y = @jit mulpi(x) - @test all(Array(y) .≈ π) + x = Reactant.to_rarray(ones(2)) + y = @jit mulpi(x) + @test all(Array(y) .≈ π) end @testset "copyto! ConcreteArray" begin - x_ra = Reactant.to_rarray(ones(4, 4)) - y_ra = Reactant.to_rarray(zeros(2, 2)) - copyto!(view(x_ra, 1:2, 1:2), y_ra) - @test Array(x_ra) == + x_ra = Reactant.to_rarray(ones(4, 4)) + y_ra = Reactant.to_rarray(zeros(2, 2)) + copyto!(view(x_ra, 1:2, 1:2), y_ra) + @test Array(x_ra) == [0.0 0.0 1.0 1.0; 0.0 0.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0] end @testset "copyto! ConcreteArray Array" begin - x_ra = Reactant.to_rarray(ones(4, 4)) - y_ra = view(zeros(4, 4), 1:2, 1:2) - copyto!(view(x_ra, 1:2, 1:2), y_ra) - @test Array(x_ra) == + x_ra = Reactant.to_rarray(ones(4, 4)) + y_ra = view(zeros(4, 4), 1:2, 1:2) + copyto!(view(x_ra, 1:2, 1:2), y_ra) + @test Array(x_ra) == [0.0 0.0 1.0 1.0; 0.0 0.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0] end @testset "copyto! TracedRArray" begin - x_ra = Reactant.to_rarray(ones(4, 4)) - y_ra = Reactant.to_rarray(zeros(2, 2)) - @jit copyto!(x_ra, 6, y_ra, 3, 2) + x_ra = Reactant.to_rarray(ones(4, 4)) + y_ra = Reactant.to_rarray(zeros(2, 2)) + @jit copyto!(x_ra, 6, y_ra, 3, 2) - x = ones(4, 4) - y = zeros(2, 2) - copyto!(x, 6, y, 3, 2) - @test Array(x_ra) == x + x = ones(4, 4) + y = zeros(2, 2) + copyto!(x, 6, y, 3, 2) + @test Array(x_ra) == x end function reshapecopy!(x, y) - Base.copyto!(x, reshape(y, size(x))) - return nothing + Base.copyto!(x, reshape(y, size(x))) + return nothing end @testset "copyto! Reshaped TracedRArray" begin - x = zeros(3, 4, 5) - y = collect(reshape(1:60, (3, 20))) + x = zeros(3, 4, 5) + y = collect(reshape(1:60, (3, 20))) - xr = Reactant.to_rarray(x) - yr = Reactant.to_rarray(y) + xr = Reactant.to_rarray(x) + yr = Reactant.to_rarray(y) - @jit reshapecopy!(xr, yr) + @jit reshapecopy!(xr, yr) - reshapecopy!(x, y) - @test Array(xr) == x + reshapecopy!(x, y) + @test Array(xr) == x end @testset "copy(::Broadcast.Broadcasted{ArrayStyle{ConcreteRArray}})" begin - x_ra = Reactant.to_rarray(ones(4, 4)) - res = copy(Broadcast.broadcasted(-, Broadcast.broadcasted(+, x_ra, 1))) - @test res ≈ -(Array(x_ra) .+ 1) + x_ra = Reactant.to_rarray(ones(4, 4)) + res = copy(Broadcast.broadcasted(-, Broadcast.broadcasted(+, x_ra, 1))) + @test res ≈ -(Array(x_ra) .+ 1) end @testset "typemin/typemax" begin - fn(x) = [typemin(eltype(x)), typemax(eltype(x))] + fn(x) = [typemin(eltype(x)), typemax(eltype(x))] - x_ra = Reactant.to_rarray(ones(4)) - @test @jit(fn(x_ra)) == fn(ones(4)) + x_ra = Reactant.to_rarray(ones(4)) + @test @jit(fn(x_ra)) == fn(ones(4)) - x_ra = Reactant.to_rarray(ones(Int, 4)) - @test @jit(fn(x_ra)) == fn(ones(Int, 4)) + x_ra = Reactant.to_rarray(ones(Int, 4)) + @test @jit(fn(x_ra)) == fn(ones(Int, 4)) end @testset "Module printing" begin - for opt in (true, false, :before_jit), debug in (true, false) - v = collect(Float32(1):Float32(64)) - vr = Reactant.to_rarray(v) - mod = @code_hlo optimize = opt log.(vr) - - # Store the module as a string with different debug options. - io = IOBuffer() - show(IOContext(io, :debug => debug), mod) - mod_string = String(take!(io)) - - # Test that we can parse back the string as an MLIR module, compile it - # and get correct results. - res = @jit(Reactant.Ops.hlo_call(mod_string, vr))[1] - @test res ≈ log.(v) - end + for opt in (true, false, :before_jit), debug in (true, false) + v = collect(Float32(1):Float32(64)) + vr = Reactant.to_rarray(v) + mod = @code_hlo optimize = opt log.(vr) + + # Store the module as a string with different debug options. + io = IOBuffer() + show(IOContext(io, :debug => debug), mod) + mod_string = String(take!(io)) + + # Test that we can parse back the string as an MLIR module, compile it + # and get correct results. + res = @jit(Reactant.Ops.hlo_call(mod_string, vr))[1] + @test res ≈ log.(v) + end end @testset "Dump MLIR modules" begin - always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] - dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[] - - mktempdir() do dir - Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true - Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir - @compile sin.(Reactant.to_rarray(Float32[1.0])) - for mod in readdir(dir; join=true) - @test contains(read(mod, String), "hlo.sine") - end + always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] + dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[] + + mktempdir() do dir + Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true + Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir + @compile sin.(Reactant.to_rarray(Float32[1.0])) + for mod in readdir(dir; join=true) + @test contains(read(mod, String), "hlo.sine") end + end - mktempdir() do dir - Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false - Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir - @compile exp.(Reactant.to_rarray(Float32[1.0])) - # Make sure we don't save anything to file when compilation is - # successful and `DUMP_MLIR_ALWAYS=false`. - @test isempty(readdir(dir; join=true)) - end + mktempdir() do dir + Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false + Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir + @compile exp.(Reactant.to_rarray(Float32[1.0])) + # Make sure we don't save anything to file when compilation is + # successful and `DUMP_MLIR_ALWAYS=false`. + @test isempty(readdir(dir; join=true)) + end - Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old - Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old + Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old + Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old end @testset "Allocator Stats" begin - platform_name = lowercase(Reactant.XLA.platform_name(Reactant.XLA.default_backend())) - if platform_name != "cpu" # not supported on CPU - @test Reactant.XLA.allocatorstats() isa Reactant.XLA.AllocatorStats - else - @test_throws Reactant.XLA.ReactantInternalError Reactant.XLA.allocatorstats() - end + platform_name = lowercase(Reactant.XLA.platform_name(Reactant.XLA.default_backend())) + if platform_name != "cpu" # not supported on CPU + @test Reactant.XLA.allocatorstats() isa Reactant.XLA.AllocatorStats + else + @test_throws Reactant.XLA.ReactantInternalError Reactant.XLA.allocatorstats() + end end @testset "copy/deepcopy" begin - for op in (copy, deepcopy) - x = Reactant.to_rarray(ones(4, 4)) - if x isa Reactant.ConcretePJRTArray - orig_ptr = only(x.data).buffer.buffer - y = op(x) - @test y isa Reactant.ConcretePJRTArray - @test only(y.data).buffer.buffer != orig_ptr - @test only(x.data).buffer.buffer == orig_ptr - else - orig_ptr = x.data.buffer.buffer - y = op(x) - @test y isa Reactant.ConcreteIFRTArray - @test y.data.buffer.buffer != orig_ptr - @test x.data.buffer.buffer == orig_ptr - end - - x = Reactant.to_rarray(4.0; track_numbers=Number) - if x isa Reactant.ConcretePJRTNumber - orig_ptr = only(x.data).buffer.buffer - y = op(x) - @test y isa Reactant.ConcretePJRTNumber - @test only(y.data).buffer.buffer != orig_ptr - @test only(x.data).buffer.buffer == orig_ptr - else - orig_ptr = x.data.buffer.buffer - y = op(x) - @test y isa Reactant.ConcreteIFRTNumber - @test y.data.buffer.buffer != orig_ptr - @test x.data.buffer.buffer == orig_ptr - end + for op in (copy, deepcopy) + x = Reactant.to_rarray(ones(4, 4)) + if x isa Reactant.ConcretePJRTArray + orig_ptr = only(x.data).buffer.buffer + y = op(x) + @test y isa Reactant.ConcretePJRTArray + @test only(y.data).buffer.buffer != orig_ptr + @test only(x.data).buffer.buffer == orig_ptr + else + orig_ptr = x.data.buffer.buffer + y = op(x) + @test y isa Reactant.ConcreteIFRTArray + @test y.data.buffer.buffer != orig_ptr + @test x.data.buffer.buffer == orig_ptr + end + + x = Reactant.to_rarray(4.0; track_numbers=Number) + if x isa Reactant.ConcretePJRTNumber + orig_ptr = only(x.data).buffer.buffer + y = op(x) + @test y isa Reactant.ConcretePJRTNumber + @test only(y.data).buffer.buffer != orig_ptr + @test only(x.data).buffer.buffer == orig_ptr + else + orig_ptr = x.data.buffer.buffer + y = op(x) + @test y isa Reactant.ConcreteIFRTNumber + @test y.data.buffer.buffer != orig_ptr + @test x.data.buffer.buffer == orig_ptr end + end end function test_aliased_numbers(ps, x) - return map(Returns(x), ps) + return map(Returns(x), ps) end @testset "Correct Aliasing" begin - ps = Reactant.to_rarray((a=rand(4), b=rand(2), c=rand(4))) - x = ConcreteRNumber(3.14) - res = @jit test_aliased_numbers(ps, x) + ps = Reactant.to_rarray((a=rand(4), b=rand(2), c=rand(4))) + x = ConcreteRNumber(3.14) + res = @jit test_aliased_numbers(ps, x) - @test res[1] === res[2] === res[3] + @test res[1] === res[2] === res[3] end accum_fn(x, y) = abs2(x) + abs2(y) @testset "accumulate" begin - a = collect(Float32, 1:10) ./ 10 - a_ra = Reactant.to_rarray(a) - - b = reshape(collect(Float32, 1:60), (3, 4, 5)) ./ 60 - b_ra = Reactant.to_rarray(b) - - @testset "cumsum" begin - @test @jit(cumsum(a_ra)) ≈ cumsum(a) - - @test @jit(cumsum(b_ra; dims=1)) ≈ cumsum(b; dims=1) - @test @jit(cumsum(b_ra; dims=2)) ≈ cumsum(b; dims=2) - @test @jit(cumsum(b_ra; dims=3)) ≈ cumsum(b; dims=3) - - @test begin - z = similar(a_ra) - @jit(cumsum!(z, a_ra)) - z - end ≈ cumsum(a) - - @test begin - z = similar(b_ra) - @jit(cumsum!(z, b_ra; dims=1)) - z - end ≈ cumsum(b; dims=1) - @test begin - z = similar(b_ra) - @jit(cumsum!(z, b_ra; dims=2)) - z - end ≈ cumsum(b; dims=2) - @test begin - z = similar(b_ra) - @jit(cumsum!(z, b_ra; dims=3)) - z - end ≈ cumsum(b; dims=3) - end - - @testset "cumprod" begin - @test @jit(cumprod(a_ra)) ≈ cumprod(a) - - @test @jit(cumprod(b_ra; dims=1)) ≈ cumprod(b; dims=1) - @test @jit(cumprod(b_ra; dims=2)) ≈ cumprod(b; dims=2) - @test @jit(cumprod(b_ra; dims=3)) ≈ cumprod(b; dims=3) - - @test begin - z = similar(a_ra) - @jit(cumprod!(z, a_ra)) - z - end ≈ cumprod(a) - @test begin - z = similar(b_ra) - @jit(cumprod!(z, b_ra; dims=1)) - z - end ≈ cumprod(b; dims=1) - @test begin - z = similar(b_ra) - @jit(cumprod!(z, b_ra; dims=2)) - z - end ≈ cumprod(b; dims=2) - @test begin - z = similar(b_ra) - @jit(cumprod!(z, b_ra; dims=3)) - z - end ≈ cumprod(b; dims=3) - end - - @testset "accumulate" begin - @test @jit(accumulate(accum_fn, a_ra; init=0.0f0)) ≈ - accumulate(accum_fn, a; init=0.0f0) - - @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1)) ≈ - accumulate(accum_fn, b; dims=1, init=0.0f0) - @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2)) ≈ - accumulate(accum_fn, b; dims=2, init=0.0f0) - @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3)) ≈ - accumulate(accum_fn, b; dims=3, init=0.0f0) - - @test begin - z = similar(a_ra) - @jit(accumulate!(accum_fn, z, a_ra; init=0.0f0)) - z - end ≈ accumulate(accum_fn, a; init=0.0f0) - - @test begin - z = similar(b_ra) - @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1)) - z - end ≈ accumulate(accum_fn, b; dims=1, init=0.0f0) - @test begin - z = similar(b_ra) - @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2)) - z - end ≈ accumulate(accum_fn, b; dims=2, init=0.0f0) - @test begin - z = similar(b_ra) - @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3)) - z - end ≈ accumulate(accum_fn, b; dims=3, init=0.0f0) - end + a = collect(Float32, 1:10) ./ 10 + a_ra = Reactant.to_rarray(a) + + b = reshape(collect(Float32, 1:60), (3, 4, 5)) ./ 60 + b_ra = Reactant.to_rarray(b) + + @testset "cumsum" begin + @test @jit(cumsum(a_ra)) ≈ cumsum(a) + + @test @jit(cumsum(b_ra; dims=1)) ≈ cumsum(b; dims=1) + @test @jit(cumsum(b_ra; dims=2)) ≈ cumsum(b; dims=2) + @test @jit(cumsum(b_ra; dims=3)) ≈ cumsum(b; dims=3) + + @test begin + z = similar(a_ra) + @jit(cumsum!(z, a_ra)) + z + end ≈ cumsum(a) + + @test begin + z = similar(b_ra) + @jit(cumsum!(z, b_ra; dims=1)) + z + end ≈ cumsum(b; dims=1) + @test begin + z = similar(b_ra) + @jit(cumsum!(z, b_ra; dims=2)) + z + end ≈ cumsum(b; dims=2) + @test begin + z = similar(b_ra) + @jit(cumsum!(z, b_ra; dims=3)) + z + end ≈ cumsum(b; dims=3) + end + + @testset "cumprod" begin + @test @jit(cumprod(a_ra)) ≈ cumprod(a) + + @test @jit(cumprod(b_ra; dims=1)) ≈ cumprod(b; dims=1) + @test @jit(cumprod(b_ra; dims=2)) ≈ cumprod(b; dims=2) + @test @jit(cumprod(b_ra; dims=3)) ≈ cumprod(b; dims=3) + + @test begin + z = similar(a_ra) + @jit(cumprod!(z, a_ra)) + z + end ≈ cumprod(a) + @test begin + z = similar(b_ra) + @jit(cumprod!(z, b_ra; dims=1)) + z + end ≈ cumprod(b; dims=1) + @test begin + z = similar(b_ra) + @jit(cumprod!(z, b_ra; dims=2)) + z + end ≈ cumprod(b; dims=2) + @test begin + z = similar(b_ra) + @jit(cumprod!(z, b_ra; dims=3)) + z + end ≈ cumprod(b; dims=3) + end + + @testset "accumulate" begin + @test @jit(accumulate(accum_fn, a_ra; init=0.0f0)) ≈ + accumulate(accum_fn, a; init=0.0f0) + + @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1)) ≈ + accumulate(accum_fn, b; dims=1, init=0.0f0) + @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2)) ≈ + accumulate(accum_fn, b; dims=2, init=0.0f0) + @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3)) ≈ + accumulate(accum_fn, b; dims=3, init=0.0f0) + + @test begin + z = similar(a_ra) + @jit(accumulate!(accum_fn, z, a_ra; init=0.0f0)) + z + end ≈ accumulate(accum_fn, a; init=0.0f0) + + @test begin + z = similar(b_ra) + @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1)) + z + end ≈ accumulate(accum_fn, b; dims=1, init=0.0f0) + @test begin + z = similar(b_ra) + @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2)) + z + end ≈ accumulate(accum_fn, b; dims=2, init=0.0f0) + @test begin + z = similar(b_ra) + @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3)) + z + end ≈ accumulate(accum_fn, b; dims=3, init=0.0f0) + end end sameunitrange(x, y) = first(x) == first(y) && last(x) == last(y) @testset "searchsorted" begin - x = [1, 2, 4, 5, 5, 7] - x_ra = Reactant.to_rarray(x) - - @testset "searchsortedfirst" begin - @testset for val in (4, 5, 3, 9, 0) - @test @jit(searchsortedfirst(x_ra, val)) == searchsortedfirst(x, val) - @test @jit(searchsortedfirst(x_ra, ConcreteRNumber(val))) == - searchsortedfirst(x, val) - end + x = [1, 2, 4, 5, 5, 7] + x_ra = Reactant.to_rarray(x) + + @testset "searchsortedfirst" begin + @testset for val in (4, 5, 3, 9, 0) + @test @jit(searchsortedfirst(x_ra, val)) == searchsortedfirst(x, val) + @test @jit(searchsortedfirst(x_ra, ConcreteRNumber(val))) == + searchsortedfirst(x, val) end + end - @testset "searchsortedlast" begin - @testset for val in (4, 5, 3, 9, 0) - @test @jit(searchsortedlast(x_ra, val)) == searchsortedlast(x, val) - @test @jit(searchsortedlast(x_ra, ConcreteRNumber(val))) == - searchsortedlast(x, val) - end + @testset "searchsortedlast" begin + @testset for val in (4, 5, 3, 9, 0) + @test @jit(searchsortedlast(x_ra, val)) == searchsortedlast(x, val) + @test @jit(searchsortedlast(x_ra, ConcreteRNumber(val))) == + searchsortedlast(x, val) end - - @testset "searchsorted" begin - @testset for val in (4, 5, 3, 9, 0) - @test sameunitrange(@jit(searchsorted(x_ra, val)), searchsorted(x, val)) - @test sameunitrange( - @jit(searchsorted(x_ra, ConcreteRNumber(val))), searchsorted(x, val) - ) - end + end + + @testset "searchsorted" begin + @testset for val in (4, 5, 3, 9, 0) + @test sameunitrange(@jit(searchsorted(x_ra, val)), searchsorted(x, val)) + @test sameunitrange( + @jit(searchsorted(x_ra, ConcreteRNumber(val))), searchsorted(x, val) + ) end + end end diff --git a/test/tracing.jl b/test/tracing.jl index 503f8139b1..4d595d84ce 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -1,289 +1,320 @@ using Reactant using Reactant: - traced_type, - TracedRArray, - TracedRNumber, - ConcreteToTraced, - ArrayToConcrete, - NoFieldMatchError, - TracedTypeError, - ReactantPrimitive + traced_type, + TracedRArray, + TracedRNumber, + ConcreteToTraced, + ArrayToConcrete, + NoFieldMatchError, + TracedTypeError, + ReactantPrimitive +using Bijections using Test struct Wrapper{A,B} - a::A - b::B + a::A + b::B end struct Descent{T} - eta::T + eta::T end struct RMSProp{Teta,Trho,Teps,C<:Bool} - eta::Teta - rho::Trho - epsilon::Teps - centred::C + eta::Teta + rho::Trho + epsilon::Teps + centred::C end @testset "Traced Type" begin - @test !(Vector{Union{}} <: Reactant.AnyTracedRArray) + @test !(Vector{Union{}} <: Reactant.AnyTracedRArray) end @testset "Tracing" begin - @testset "trace_type" begin - @testset "mode = ConcreteToTraced" begin - @testset "$origty" for (origty, targetty, targettynum) in [ - (Any, Any, Any), - (Real, Real, Real), - (Module, Module, Module), - (DataType, DataType, DataType), - # (Union{}, Union{}), # fails - (Nothing, Nothing, Nothing), - (Symbol, Symbol, Symbol), - (Char, Char, Char), - (AbstractString, AbstractString, AbstractString), - (String, String, String), - (VersionNumber, VersionNumber, VersionNumber), + @testset "trace_type" begin + @testset "mode = ConcreteToTraced" begin + @testset "$origty" for (origty, targetty, targettynum) in [ + (Any, Any, Any), + (Real, Real, Real), + (Module, Module, Module), + (DataType, DataType, DataType), + # (Union{}, Union{}), # fails + (Nothing, Nothing, Nothing), + (Symbol, Symbol, Symbol), + (Char, Char, Char), + (AbstractString, AbstractString, AbstractString), + (String, String, String), + (VersionNumber, VersionNumber, VersionNumber), - # Numeric types - (AbstractFloat, AbstractFloat, AbstractFloat), - (Float16, Float16, TracedRNumber{Float16}), - (Float32, Float32, TracedRNumber{Float32}), - (Float64, Float64, TracedRNumber{Float64}), - (Integer, Integer, Integer), - (Int8, Int8, TracedRNumber{Int8}), - (Int16, Int16, TracedRNumber{Int16}), - (Int32, Int32, TracedRNumber{Int32}), - (Int64, Int64, TracedRNumber{Int64}), - (UInt8, UInt8, TracedRNumber{UInt8}), - (UInt16, UInt16, TracedRNumber{UInt16}), - (UInt32, UInt32, TracedRNumber{UInt32}), - (UInt64, UInt64, TracedRNumber{UInt64}), - (Complex{Float32}, Complex{Float32}, TracedRNumber{Complex{Float32}}), - (Complex{Float64}, Complex{Float64}, TracedRNumber{Complex{Float64}}), - (Complex{Int8}, Complex{Int8}, TracedRNumber{Complex{Int8}}), - (Complex{Int16}, Complex{Int16}, TracedRNumber{Complex{Int16}}), - (Complex{Int32}, Complex{Int32}, TracedRNumber{Complex{Int32}}), - (Complex{Int64}, Complex{Int64}, TracedRNumber{Complex{Int64}}), - (Complex{UInt8}, Complex{UInt8}, TracedRNumber{Complex{UInt8}}), - (Complex{UInt16}, Complex{UInt16}, TracedRNumber{Complex{UInt16}}), - (Complex{UInt32}, Complex{UInt32}, TracedRNumber{Complex{UInt32}}), - (Complex{UInt64}, Complex{UInt64}, TracedRNumber{Complex{UInt64}}), + # Numeric types + (AbstractFloat, AbstractFloat, AbstractFloat), + (Float16, Float16, TracedRNumber{Float16}), + (Float32, Float32, TracedRNumber{Float32}), + (Float64, Float64, TracedRNumber{Float64}), + (Integer, Integer, Integer), + (Int8, Int8, TracedRNumber{Int8}), + (Int16, Int16, TracedRNumber{Int16}), + (Int32, Int32, TracedRNumber{Int32}), + (Int64, Int64, TracedRNumber{Int64}), + (UInt8, UInt8, TracedRNumber{UInt8}), + (UInt16, UInt16, TracedRNumber{UInt16}), + (UInt32, UInt32, TracedRNumber{UInt32}), + (UInt64, UInt64, TracedRNumber{UInt64}), + (Complex{Float32}, Complex{Float32}, TracedRNumber{Complex{Float32}}), + (Complex{Float64}, Complex{Float64}, TracedRNumber{Complex{Float64}}), + (Complex{Int8}, Complex{Int8}, TracedRNumber{Complex{Int8}}), + (Complex{Int16}, Complex{Int16}, TracedRNumber{Complex{Int16}}), + (Complex{Int32}, Complex{Int32}, TracedRNumber{Complex{Int32}}), + (Complex{Int64}, Complex{Int64}, TracedRNumber{Complex{Int64}}), + (Complex{UInt8}, Complex{UInt8}, TracedRNumber{Complex{UInt8}}), + (Complex{UInt16}, Complex{UInt16}, TracedRNumber{Complex{UInt16}}), + (Complex{UInt32}, Complex{UInt32}, TracedRNumber{Complex{UInt32}}), + (Complex{UInt64}, Complex{UInt64}, TracedRNumber{Complex{UInt64}}), - # RArray types - ( - ConcreteRArray{Float64,0}, - TracedRArray{Float64,0}, - TracedRArray{Float64,0}, - ), - ( - ConcreteRArray{Float64,1}, - TracedRArray{Float64,1}, - TracedRArray{Float64,1}, - ), - ( - ConcreteRArray{Float64,2}, - TracedRArray{Float64,2}, - TracedRArray{Float64,2}, - ), - ( - ConcreteRArray{Float64,3}, - TracedRArray{Float64,3}, - TracedRArray{Float64,3}, - ), + # RArray types + ( + ConcreteRArray{Float64,0}, + TracedRArray{Float64,0}, + TracedRArray{Float64,0}, + ), + ( + ConcreteRArray{Float64,1}, + TracedRArray{Float64,1}, + TracedRArray{Float64,1}, + ), + ( + ConcreteRArray{Float64,2}, + TracedRArray{Float64,2}, + TracedRArray{Float64,2}, + ), + ( + ConcreteRArray{Float64,3}, + TracedRArray{Float64,3}, + TracedRArray{Float64,3}, + ), - # Array types - (Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64},1}), - ( - Array{ConcreteRArray{Float64,2},1}, - Array{TracedRArray{Float64,2},1}, - Array{TracedRArray{Float64,2},1}, - ), + # Array types + (Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64},1}), + ( + Array{ConcreteRArray{Float64,2},1}, + Array{TracedRArray{Float64,2},1}, + Array{TracedRArray{Float64,2},1}, + ), - # Union types - (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}), - ( - Union{Nothing,ConcreteRArray{Float64,1}}, - Union{Nothing,TracedRArray{Float64,1}}, - Union{Nothing,TracedRArray{Float64,1}}, - ), + # Union types + (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}), + ( + Union{Nothing,ConcreteRArray{Float64,1}}, + Union{Nothing,TracedRArray{Float64,1}}, + Union{Nothing,TracedRArray{Float64,1}}, + ), - # Ptr types - (Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}), - ( - Ptr{ConcreteRArray{Float64,1}}, - Ptr{TracedRArray{Float64,1}}, - Ptr{TracedRArray{Float64,1}}, - ), - ( - Core.LLVMPtr{Float64}, - Core.LLVMPtr{Float64}, - Core.LLVMPtr{TracedRNumber{Float64}}, - ), - ( - Core.LLVMPtr{ConcreteRArray{Float64,1}}, - Core.LLVMPtr{TracedRArray{Float64,1}}, - Core.LLVMPtr{TracedRArray{Float64,1}}, - ), - ( - Base.RefValue{Float64}, - Base.RefValue{Float64}, - Base.RefValue{TracedRNumber{Float64}}, - ), - ( - Base.RefValue{ConcreteRArray{Float64,1}}, - Base.RefValue{TracedRArray{Float64,1}}, - Base.RefValue{TracedRArray{Float64,1}}, - ), + # Ptr types + (Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}), + ( + Ptr{ConcreteRArray{Float64,1}}, + Ptr{TracedRArray{Float64,1}}, + Ptr{TracedRArray{Float64,1}}, + ), + ( + Core.LLVMPtr{Float64}, + Core.LLVMPtr{Float64}, + Core.LLVMPtr{TracedRNumber{Float64}}, + ), + ( + Core.LLVMPtr{ConcreteRArray{Float64,1}}, + Core.LLVMPtr{TracedRArray{Float64,1}}, + Core.LLVMPtr{TracedRArray{Float64,1}}, + ), + ( + Base.RefValue{Float64}, + Base.RefValue{Float64}, + Base.RefValue{TracedRNumber{Float64}}, + ), + ( + Base.RefValue{ConcreteRArray{Float64,1}}, + Base.RefValue{TracedRArray{Float64,1}}, + Base.RefValue{TracedRArray{Float64,1}}, + ), - # Val types - (Val{0}, Val{0}, Val{0}), - (Val{0.5}, Val{0.5}, Val{0.5}), - (Val{:x}, Val{:x}, Val{:x}), - ( - Dict{Int,ConcreteRArray{Float64,0}}, - Dict{Int,TracedRArray{Float64,0}}, - Dict{Int,TracedRArray{Float64,0}}, - ), - (Dict{Int}, Dict{Int}, Dict{Int}), - (Dict, Dict, Dict), - ( - (Dict{A,ConcreteRArray{Float64,0}} where {A}), - (Dict{A,TracedRArray{Float64,0}} where {A}), - (Dict{A,TracedRArray{Float64,0}} where {A}), - ), - ( - ( - Dict{ - Symbol,NTuple{nsteps,SpectralVariable3D} - } where {nsteps} where {SpectralVariable3D} - ), - ( - Dict{ - Symbol,NTuple{nsteps,SpectralVariable3D} - } where {nsteps} where {SpectralVariable3D} - ), - ( - Dict{ - Symbol,NTuple{nsteps,SpectralVariable3D} - } where {nsteps} where {SpectralVariable3D} - ), - ), - ( - Base.Pairs{Symbol,Union{}}, - Base.Pairs{Symbol,Union{}}, - Base.Pairs{Symbol,Union{}}, - ), - ( - NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, - NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, - NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, - ), - ( - Base.RefValue{A} where {A}, - Base.RefValue{A} where {A}, - Base.RefValue{A} where {A}, - ), - (Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}), - ( - Wrapper{Float64,Vector{Float64}}, - Wrapper{Float64,Vector{Float64}}, - Wrapper{TracedRNumber{Float64},Vector{Float64}}, - ), - ( - Wrapper{Float64,ConcreteRArray{Float64,1}}, - Wrapper{Float64,TracedRArray{Float64,1}}, - Wrapper{TracedRNumber{Float64},TracedRArray{Float64,1}}, - ), - (Wrapper{Symbol}, Wrapper{Symbol}, Wrapper{Symbol}), - (Wrapper{Float64}, Wrapper{Float64}, Wrapper{TracedRNumber{Float64}}), - ( - Wrapper{ConcreteRArray{Float64,1}}, - Wrapper{TracedRArray{Float64,1}}, - Wrapper{TracedRArray{Float64,1}}, - ), - (Wrapper, Wrapper, Wrapper), - ] - tracedty = traced_type( - origty, - Val(ConcreteToTraced), - Union{}, - Sharding.NoSharding(), - Reactant.XLA.runtime(), - ) - @test tracedty == targetty + # Val types + (Val{0}, Val{0}, Val{0}), + (Val{0.5}, Val{0.5}, Val{0.5}), + (Val{:x}, Val{:x}, Val{:x}), + ( + Dict{Int,ConcreteRArray{Float64,0}}, + Dict{Int,TracedRArray{Float64,0}}, + Dict{Int,TracedRArray{Float64,0}}, + ), + (Dict{Int}, Dict{Int}, Dict{Int}), + (Dict, Dict, Dict), + ( + (Dict{A,ConcreteRArray{Float64,0}} where {A}), + (Dict{A,TracedRArray{Float64,0}} where {A}), + (Dict{A,TracedRArray{Float64,0}} where {A}), + ), + ( + ( + Dict{ + Symbol,NTuple{nsteps,SpectralVariable3D} + } where {nsteps} where {SpectralVariable3D} + ), + ( + Dict{ + Symbol,NTuple{nsteps,SpectralVariable3D} + } where {nsteps} where {SpectralVariable3D} + ), + ( + Dict{ + Symbol,NTuple{nsteps,SpectralVariable3D} + } where {nsteps} where {SpectralVariable3D} + ), + ), + ( + Bijection{Symbol,Int,Dict{Symbol,Int},Dict{Int,Symbol}}, + Bijection{Symbol,Int,Dict{Symbol,Int},Dict{Int,Symbol}}, + Bijection{Symbol,Int,Dict{Symbol,Int},Dict{Int,Symbol}} + ), + ( + Bijection{Symbol,Int,Dict{Symbol,Int},IdDict{Int,Symbol}}, + Bijection{Symbol,Int,Dict{Symbol,Int},IdDict{Int,Symbol}}, + Bijection{Symbol,Int,Dict{Symbol,Int},IdDict{Int,Symbol}} + ), + ( + Bijection{Symbol,Int}, + Bijection{Symbol,Int}, + Bijection{Symbol,Int} + ), + ( + Bijection{Int,ConcreteRArray{Float64,0},Dict{Int,ConcreteRArray{Float64,0}},Dict{ConcreteRArray{Float64,0},Int}}, + Bijection{Int,TracedRArray{Float64,0},Dict{Int,TracedRArray{Float64,0}},Dict{TracedRArray{Float64,0},Int}}, + Bijection{Int,TracedRArray{Float64,0},Dict{Int,TracedRArray{Float64,0}},Dict{TracedRArray{Float64,0},Int}} + ), + ( + Bijection{Int,ConcreteRArray{Float64,0},Dict{Int,ConcreteRArray{Float64,0}},IdDict{ConcreteRArray{Float64,0},Int}}, + Bijection{Int,TracedRArray{Float64,0},Dict{Int,TracedRArray{Float64,0}},IdDict{TracedRArray{Float64,0},Int}}, + Bijection{Int,TracedRArray{Float64,0},Dict{Int,TracedRArray{Float64,0}},IdDict{TracedRArray{Float64,0},Int}} + ), + ( + Bijection{Int,ConcreteRArray{Float64,0}}, + Bijection{Int,TracedRArray{Float64,0}}, + Bijection{Int,TracedRArray{Float64,0}} + ), + ( + Base.Pairs{Symbol,Union{}}, + Base.Pairs{Symbol,Union{}}, + Base.Pairs{Symbol,Union{}}, + ), + ( + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + NTuple{nsteps,SpectralVariable3D} where {nsteps,SpectralVariable3D}, + ), + ( + Base.RefValue{A} where {A}, + Base.RefValue{A} where {A}, + Base.RefValue{A} where {A}, + ), + (Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}, Wrapper{Symbol,Symbol}), + ( + Wrapper{Float64,Vector{Float64}}, + Wrapper{Float64,Vector{Float64}}, + Wrapper{TracedRNumber{Float64},Vector{Float64}}, + ), + ( + Wrapper{Float64,ConcreteRArray{Float64,1}}, + Wrapper{Float64,TracedRArray{Float64,1}}, + Wrapper{TracedRNumber{Float64},TracedRArray{Float64,1}}, + ), + (Wrapper{Symbol}, Wrapper{Symbol}, Wrapper{Symbol}), + (Wrapper{Float64}, Wrapper{Float64}, Wrapper{TracedRNumber{Float64}}), + ( + Wrapper{ConcreteRArray{Float64,1}}, + Wrapper{TracedRArray{Float64,1}}, + Wrapper{TracedRArray{Float64,1}}, + ), + (Wrapper, Wrapper, Wrapper), + ] + tracedty = traced_type( + origty, + Val(ConcreteToTraced), + Union{}, + Sharding.NoSharding(), + Reactant.XLA.runtime(), + ) + @test tracedty == targetty - tracedty2 = traced_type( - origty, - Val(ConcreteToTraced), - ReactantPrimitive, - Sharding.NoSharding(), - Reactant.XLA.runtime(), - ) - @test tracedty2 == targetty - end + tracedty2 = traced_type( + origty, + Val(ConcreteToTraced), + ReactantPrimitive, + Sharding.NoSharding(), + Reactant.XLA.runtime(), + ) + @test tracedty2 == targetty + end - @testset "$type" for type in [ - TracedRArray{Float64,0}, - TracedRArray{Float64,1}, - TracedRArray{Float64,2}, - TracedRArray{Float64,3}, - ] - @test_throws Union{ErrorException,String} traced_type( - type, - Val(ConcreteToTraced), - Union{}, - Sharding.NoSharding(), - Reactant.XLA.runtime(), - ) - end - end - @testset "traced_type exceptions" begin - struct Node - x::Vector{Float64} - y::Union{Nothing,Node} - end - @test_throws NoFieldMatchError traced_type( - Node, - Val(ArrayToConcrete), - Union{}, - Sharding.NoSharding(), - Reactant.XLA.runtime(), - ) - end + @testset "$type" for type in [ + TracedRArray{Float64,0}, + TracedRArray{Float64,1}, + TracedRArray{Float64,2}, + TracedRArray{Float64,3}, + ] + @test_throws Union{ErrorException,String} traced_type( + type, + Val(ConcreteToTraced), + Union{}, + Sharding.NoSharding(), + Reactant.XLA.runtime(), + ) + end end + @testset "traced_type exceptions" begin + struct Node + x::Vector{Float64} + y::Union{Nothing,Node} + end + @test_throws NoFieldMatchError traced_type( + Node, + Val(ArrayToConcrete), + Union{}, + Sharding.NoSharding(), + Reactant.XLA.runtime(), + ) + end + end - @testset "specialized dispatches" begin - @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( - 1.0; track_numbers=Number - ) isa ConcreteRNumber - @test @inferred Reactant.to_rarray(1.0) isa Float64 - @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray + @testset "specialized dispatches" begin + @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( + 1.0; track_numbers=Number + ) isa ConcreteRNumber + @test @inferred Reactant.to_rarray(1.0) isa Float64 + @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray - x_ra = Reactant.to_rarray(rand(3)) - @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray + x_ra = Reactant.to_rarray(rand(3)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray - x_ra = Reactant.to_rarray(1.0; track_numbers=Number) - @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber - end + x_ra = Reactant.to_rarray(1.0; track_numbers=Number) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber + end - @testset "no trace Val" begin - st = (; a=1, training=Val(true)) - st_traced = Reactant.to_rarray(st; track_numbers=Number) - @test st_traced.training isa Val{true} - end + @testset "no trace Val" begin + st = (; a=1, training=Val(true)) + st_traced = Reactant.to_rarray(st; track_numbers=Number) + @test st_traced.training isa Val{true} + end - @testset "to_rarray(::AbstractRule)" begin - opt = Descent(0.1) - opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) - @test opt_traced.eta isa ConcreteRNumber{Float64} + @testset "to_rarray(::AbstractRule)" begin + opt = Descent(0.1) + opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) + @test opt_traced.eta isa ConcreteRNumber{Float64} - opt = RMSProp(0.1, 0.9, 1e-8, true) - opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) - @test opt_traced.eta isa ConcreteRNumber{Float64} - @test opt_traced.rho isa ConcreteRNumber{Float64} - @test opt_traced.epsilon isa ConcreteRNumber{Float64} - @test opt_traced.centred isa Bool - end + opt = RMSProp(0.1, 0.9, 1e-8, true) + opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) + @test opt_traced.eta isa ConcreteRNumber{Float64} + @test opt_traced.rho isa ConcreteRNumber{Float64} + @test opt_traced.epsilon isa ConcreteRNumber{Float64} + @test opt_traced.centred isa Bool + end end