Skip to content

Trace over AbstractDict #1398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.123"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand Down Expand Up @@ -65,6 +66,7 @@ ReactantYaoBlocksExt = "YaoBlocks"
AbstractFFTs = "1.5"
Adapt = "4.1"
ArrayInterface = "7.17.1"
Bijections = "0.2.1"
CEnum = "0.5"
CUDA = "5.6"
Downloads = "1.6"
Expand Down
28 changes: 15 additions & 13 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Compiler
using Reactant_jll
using Libdl: dlsym
using LinearAlgebra: BLAS
using Bijections

import ..Reactant:
Reactant,
Expand Down Expand Up @@ -32,8 +33,8 @@ const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false)

const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict()

@inline function traced_getfield(@nospecialize(obj::Dict), field)
return Base.getindex(obj, field)
@inline function traced_getfield(@nospecialize(obj::AbstractDict), idx)
return first(Iterators.drop(obj, idx - 1))
end

@inline function traced_getfield(@nospecialize(obj), field)
Expand Down Expand Up @@ -109,7 +110,7 @@ end
return setfield_carray!(obj, field, val, path)
end

@inline function traced_setfield!(@nospecialize(obj::Dict), field, val, path)
@inline function traced_setfield!(@nospecialize(obj::AbstractDict), field, val, path)
return Base.setindex!(obj, field, val)
end

Expand Down Expand Up @@ -635,12 +636,15 @@ function create_result(

result_cache[tocopy] = sym

for (k, v) in pairs(tocopy)
subexpr = create_result(v, append_path(path, k), args...)
for (i, (k, v)) in enumerate(pairs(tocopy))
path_k = append_path(append_path(path, i), 1)
k_expr = create_result(k, path_k, args...)
path_v = append_path(append_path(path, i), 2)
v_expr = create_result(v, path_v, args...)
push!(
resultgen_code,
quote
@inbounds $sym[$k] = $subexpr
@inbounds $sym[$k_expr] = $v_expr
end,
)
end
Expand Down Expand Up @@ -3284,8 +3288,7 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_fwd_body_cache = Dict{Symbol,Expr}()
const __thunk_rev_body_cache = Dict{Expr,Symbol}()
const __thunk_body_cache = Bijection{Symbol,Expr}()

function compile(f, args; sync=false, kwargs...)
_, exec, mlir_fn_res, device, client, str = compile_xla(f, args; kwargs...)
Expand Down Expand Up @@ -3397,12 +3400,11 @@ function compile(f, args; sync=false, kwargs...)
display(mlir_fn_res.donated_args_mask)
end

fname = if body in keys(__thunk_rev_body_cache)
__thunk_rev_body_cache[body]
fname = if hasvalue(__thunk_body_cache, body)
__thunk_body_cache(body)
else
fname2 = gensym(Symbol(Symbol(f), :_reactant))
__thunk_rev_body_cache[body] = fname2
__thunk_fwd_body_cache[fname2] = body
__thunk_body_cache[fname2] = body
fname2
end

Expand Down Expand Up @@ -3484,7 +3486,7 @@ end
)
end
end
body = __thunk_fwd_body_cache[tag]
body = __thunk_body_cache[tag]
if IsClosure
return quote
args = (thunk.f, args...)
Expand Down
6 changes: 5 additions & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,13 @@ function prepare_mlir_fn_args(
end
aval = args[path[2]]
for (cidx, idx) in enumerate(path[3:end])
if aval isa Array || aval isa Dict
if aval isa Array #|| aval isa Dict
aval = getindex(aval, idx)
stridx = stridx * "[" * string(idx) * "]"
elseif aval isa AbstractDict
# TODO maybe we want a way to customize this behavior like we did with `traced_getfield`? or a more powerfull `traced_getfield`?
aval = Reactant.Compiler.traced_getfield(aval, idx)
stridx = stridx * "[" * string(idx) * "]"
else
fldname = if idx isa Integer
string(fieldname(Core.Typeof(aval), idx))
Expand Down
137 changes: 110 additions & 27 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Bijections

@enum TraceMode begin
ConcreteToTraced = 1
TracedTrack = 2
Expand Down Expand Up @@ -214,28 +216,99 @@ Base.@nospecializeinfer function traced_type_inner(
@nospecialize(sharding),
@nospecialize(runtime)
)
K = dict_key(T)
V = dict_value(T)
if V === nothing

K_traced = if !isnothing(K)
traced_type_inner(K, seen, mode, track_numbers, sharding, runtime)
else
nothing
end
V_traced = if !isnothing(V)
traced_type_inner(V, seen, mode, track_numbers, sharding, runtime)
else
nothing
end

if K == K_traced && V == V_traced
return T
end

dictty = if T isa UnionAll
T.body.name.wrapper
else
K = dict_key(T)
V2 = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime)
if V == V2
return T
end
dictty = if T isa UnionAll
T.body.name.wrapper
else
T.name.wrapper
end
if K !== nothing
return dictty{K,V2}
else
return (dictty{KT,V2} where {KT})
end
T.name.wrapper
end

if isnothing(K_traced) && isnothing(V_traced)
return (dictty{Kt,Vt} where {Kt,Vt})
elseif isnothing(K_traced)
return (dictty{Kt,V_traced} where {Kt})
elseif isnothing(V_traced)
return (dictty{K_traced,Vt} where {Vt})
else
return dictty{K_traced,V_traced}
end
end

Base.@nospecializeinfer @inline bijection_fwd_type(::Type{<:Bijection}) = nothing
Base.@nospecializeinfer @inline function bijection_fwd_type(
::Type{<:(Bijection{K,V,F} where {K,V})}
) where {F}
return F
end
Base.@nospecializeinfer @inline bijection_bwd_type(::Type{<:Bijection}) = nothing
Base.@nospecializeinfer @inline function bijection_bwd_type(
::Type{<:(Bijection{K,V,F,Finv} where {K,V,F})}
) where {Finv}
return Finv
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:Bijection}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
)
B = Bijection

K = dict_key(T)
if !isnothing(K)
K = traced_type_inner(K, seen, mode, track_numbers, sharding, runtime)
B = B{K}
else
B = (B{Kt} where {Kt})
end

V = dict_value(T)
if !isnothing(V)
V = traced_type_inner(V, seen, mode, track_numbers, sharding, runtime)
B = B{V}
else
B = (B{Vt} where {Vt})
end

F = bijection_fwd_type(T)
if !isnothing(F)
F = traced_type_inner(F, seen, mode, track_numbers, sharding, runtime)
B = B{F}
else
B = (B{Ft} where {Ft})
end

Finv = bijection_bwd_type(T)
if !isnothing(Finv)
Finv = traced_type_inner(Finv, seen, mode, track_numbers, sharding, runtime)
B = B{Finv}
else
B = (B{Finvt} where {Finvt})
end

return B
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T0::Type{<:ConcretePJRTNumber}),
seen,
Expand Down Expand Up @@ -1581,14 +1654,14 @@ end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Dict{Key,Value}),
prev::D,
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type = Union{}),
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(runtime = nothing),
kwargs...,
) where {Key,Value}
) where {D<:AbstractDict}
RT = Core.Typeof(prev)
if mode != NoStopTracedTrack && haskey(seen, prev)
if mode == TracedToTypes
Expand Down Expand Up @@ -1619,31 +1692,41 @@ Base.@nospecializeinfer function make_tracer(
end
return nothing
end
Value2 = traced_type(Value, Val(mode), track_numbers, sharding, runtime)
newa = Dict{Key,Value2}()
seen[prev] = newa
Dt = traced_type(D, Val(mode), track_numbers, sharding, runtime)
dict_traced = Dt()
seen[prev] = dict_traced
same = true
for (k, v) in prev
nv = make_tracer(
for (i, (k, v)) in enumerate(prev)
kt = make_tracer(
seen,
k,
append_path(append_path(path, i), 1),
mode;
track_numbers,
sharding=Base.getproperty(sharding, k),
runtime,
kwargs...,
)
vt = make_tracer(
seen,
v,
append_path(path, k),
append_path(append_path(path, i), 2),
mode;
track_numbers,
sharding=Base.getproperty(sharding, k),
runtime,
kwargs...,
)
if v !== nv
if k !== kt || v !== vt
same = false
end
newa[k] = nv
dict_traced[kt] = vt
end
if same
seen[prev] = prev
return prev
end
return newa
return dict_traced
end

Base.@nospecializeinfer function make_tracer(
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down
Loading
Loading