- Sponsor
-
Notifications
You must be signed in to change notification settings - Fork 611
Open
Labels
Description
Package Version
v0.13.7
Julia Version
1.8.2
OS / Environment
Windows 11
Describe the bug
I followed the example of fitting a straight line from the Flux's docs, and used JET to analyze train!
. It found 19 possible runtime dispatch
errors.
Steps to Reproduce
using Flux, JET
actual(x) = 4x + 2
x_train, x_test = hcat(0:5...), hcat(6:10...)
y_train, y_test = actual.(x_train), actual.(x_test)
predict = Dense(1 => 1)
loss_(x, y) = Flux.Losses.mse(predict(x), y);
opt = Descent()
data = [(x_train, y_train)]
parameters = Flux.params(predict)
Flux.train!(loss_, parameters, data, opt) # [edit: qualify train!]
@report_opt Flux.train!(loss_, parameters, data, opt) # ═════ 19 possible errors found ═════
# runtime dispatch detected: isequal(%1::Any, v::Task)::Bool
Expected Results
I expected to find no errors with JET
Observed Results
I found runtime dispatch errors.
Relevant log output
julia> @report_opt train!(loss_, parameters, data, opt)
═════ 19 possible errors found ═════
┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:125 Flux.Optimise.:(var"#train!#36")(#38, #self#, loss, ps, data, opt)
│┌ @ logging.jl:376 logger = Base.CoreLogging.current_logger_for_env(std_level, group, _module)
││┌ @ logging.jl:499 Base.CoreLogging.env_override_minlevel(group, _module)
│││┌ @ logging.jl:565 Base.moduleroot(_module)
││││┌ @ reflection.jl:45 Base.is_root_module(m)
│││││┌ @ lock.jl:221 lock(temp)
││││││┌ @ lock.jl:103 slowlock(rl)
│││││││┌ @ lock.jl:112 wait(c)
││││││││┌ @ condition.jl:126 Base.list_deletefirst!(ct.queue, ct)
│││││││││┌ @ linked_list.jl:145 isequal(h.value, val)
││││││││││┌ @ gcutils.jl:4 isequal(%1, v)
│││││││││││ runtime dispatch detected: isequal(%1::Any, v::Task)::Bool
││││││││││└────────────────
││││││││┌ @ condition.jl:126 Base.list_deletefirst!(%45, %39)
│││││││││ runtime dispatch detected: Base.list_deletefirst!(%45::Any, %39::Task)::Any
││││││││└────────────────────
│││││┌ @ lock.jl:225 unlock(temp)
││││││┌ @ lock.jl:133 _unlock(rl)
│││││││┌ @ lock.jl:139 notifywaiters(rl)
││││││││┌ @ lock.jl:143 = notify(cond_wait)
│││││││││┌ @ condition.jl:142 #self#(c, Base.nothing)
││││││││││┌ @ condition.jl:142 Base.:(var"#notify#586")(true, false, #self#, c, arg)
│││││││││││┌ @ condition.jl:142 notify(c, arg, all, error)
││││││││││││┌ @ condition.jl:148 Core.kwfunc(schedule)(NamedTuple{(:error,)}(tuple(error)), schedule, t, arg)
│││││││││││││┌ @ task.jl:789 Base.:(var"#schedule#613")(error, _3, t, arg)
││││││││││││││┌ @ task.jl:793 %10(%11, t)
│││││││││││││││ runtime dispatch detected: %10::typeof(Base.list_deletefirst!)(%11::Any, t::Task)::Any
││││││││││││││└───────────────
│┌ @ logging.jl:364 Base.CoreLogging.logging_error(logger, level, _module, group, id, file, line, err, true)
││┌ @ logging.jl:463 (%51)
│││ runtime dispatch detected: ::NamedTuple{(:exception,)}(%51::Tuple{Tuple{Any, Vector{Union{Ptr{Nothing}, Base.InterpreterIP}}}})::NamedTuple{(:exception,), _A} where _A<:Tuple{Tuple{Any, Vector{Union{Ptr{Nothing}, Base.InterpreterIP}}}}
││└──────────────────
││┌ @ logging.jl:463 handle_message##kw(%52, Base.CoreLogging.handle_message, %37, Base.CoreLogging.Error, %44, %38, :logevent_error, %39, %40, %41)
│││ runtime dispatch detected: handle_message##kw(%52::NamedTuple{(:exception,), _A} where _A<:Tuple{Tuple{Any, Vector{Union{Ptr{Nothing}, Base.InterpreterIP}}}}, Base.CoreLogging.handle_message, %37::Base.CoreLogging.AbstractLogger, Base.CoreLogging.Error, %44::Union{LazyString, String}, %38::Module, :logevent_error, %39::Base.UUID, %40::String, %41::Int64)::Any
││└──────────────────
│┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:131 Flux.Optimise.withgradient(#37, ps)
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:132 pullback(tuple(f), args...)
│││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:384 Zygote._pullback(cx, f)
││││┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:132 Zygote._pullback(ctx, Core._apply_iterate, iterate, Zygote._pullback(ctx, Zygote.literal_getfield, f, Val{:loss}())[1], Zygote._pullback(ctx, Flux.Optimise.batchmemaybe, Zygote._pullback(ctx, Zygote.literal_getfield, f, Val{:d}())[1])[1])
│││││┌ @ C:\Users\Math User\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 ZygoteRules.adjoint(tuple(__context__, 550, 551, f), args...)
││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\lib.jl:203 Core._apply(tuple(Zygote._pullback, tuple(__context__, f)), args...)
│││││││┌ @ boot.jl:816 Core._apply_iterate(tuple(Main.Base.iterate), x...)
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 Zygote._pullback(ctx, Zygote.literal_getproperty, Flux.Losses, Val{:mse}())
│││││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\literal_getproperty.jl:83 Zygote._pullback(cx, Zygote.getproperty, x, :mse)
││││││││││┌ @ Base.jl:31 Zygote._pullback(ctx, Base.getfield, Base.getfield(args, 1), Base.getfield(args, 2))
│││││││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\lib.jl:244 Zygote.Val(field_name)
││││││││││││┌ @ essentials.jl:714 %1()
│││││││││││││ runtime dispatch detected: %1::Type{Val{_A}} where _A()::Val
││││││││││││└─────────────────────
│││││││││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\lib\lib.jl:244 Zygote._pullback(cx, Zygote.literal_getfield, x, %1)
││││││││││││ runtime dispatch detected: Zygote._pullback(cx::Zygote.Context{true}, Zygote.literal_getfield, x::Module, %1::Val)::Tuple{Any, Zygote.var"#2077#back#218"}
│││││││││││└───────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 Zygote._pullback(ctx, %12, %1)
│││││││││ runtime dispatch detected: Zygote._pullback(ctx::Zygote.Context{true}, %12::Any, %1::Matrix{Int64})::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %21[1]
│││││││││ runtime dispatch detected: (%21::Any)[1]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %21[2]
│││││││││ runtime dispatch detected: (%21::Any)[2]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 Zygote._pullback(ctx, %5, %22, %2)
│││││││││ runtime dispatch detected: Zygote._pullback(ctx::Zygote.Context{true}, %5::Any, %22::Any, %2::Matrix{Int64})::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %24[1]
│││││││││ runtime dispatch detected: (%24::Any)[1]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││││││││┌ @ c:\Users\Math User\Github\Math\src\fokker_planck_probability_flow.jl:170 %24[2]
│││││││││ runtime dispatch detected: (%24::Any)[2]::Any
││││││││└────────────────────────────────────────────────────────────────────────────
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:133 grad = back(Zygote.sensitivity(y))
│││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:390 Zygote.Grads(getfield(#self#, :cx).cache, getfield(#self#, :ps))
││││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:281
convert(, grads)
│││││ runtime dispatch detected: convert(::IdDict{Any, Any}, grads::Nothing)
││││└──────────────────────────────────────────────────────────────────────
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:133 Zygote.sensitivity(%5)
│││ runtime dispatch detected: Zygote.sensitivity(%5::Any)::Any
││└──────────────────────────────────────────────────────────────────────
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:133 %10(%11)
│││ runtime dispatch detected: %10::Zygote.var"#99#100"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, _A, Zygote.Context{true}} where _A(%11::Any)::Zygote.Grads
││└──────────────────────────────────────────────────────────────────────
││┌ @ C:\Users\Math User\.julia\dev\Zygote\src\compiler\interface.jl:135 (%13)
│││ runtime dispatch detected: ::NamedTuple{(:val, :grad)}(%13::Tuple{Any, Zygote.Grads})::NamedTuple{(:val, :grad), _A} where _A<:Tuple{Any, Zygote.Grads}
││└──────────────────────────────────────────────────────────────────────
│┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:135 string("Loss is ", l, " on data item ", i, ", stopping training")
││┌ @ strings/io.jl:185 Base.print_to_string(xs...)
│││┌ @ strings/io.jl:144 print(%83, %86)
││││ runtime dispatch detected: print(%83::IOBuffer, %86::Any)::Any
│││└─────────────────────
│┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:137 update!(opt, ps, gs)
││┌ @ C:\Users\Math User\.julia\packages\Flux\nJ0IB\src\optimise\train.jl:24 update!(opt, %32, %65)
│││ runtime dispatch detected: update!(opt::Descent, %32::Any, %65::Any)::Any
││└──────────────────────────────────────────────────────────────────────────
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
ToucheSir commentedon Nov 18, 2022
Flux.train!
was written at a time before JET and type stability were big discussion points in the community. Thus, it really depends on the dynamic nature of the language and doesn't exactly play well with type inference. That said, any performance or latency hit you get from this should be so minimal that it doesn't even register in practice. You may be interested in the ongoing PRs (#2082, #2083) to revamptrain!
for the modern day. More pragmatically though, I'd recommend just using a custom training loop for now: it's not many more lines of code, more flexible and can be as type stable as you want.mcabbott commentedon Nov 29, 2022
For what it's worth, here's the result on master:
Or without
train!
: