Description
There have been several issues/PRs related to freezing model parameters:
- freeze parameters #1022
- How to keep weights of parts of a model fixed under Flux.train! #1001
- Implement APIs of freeze parameters and freeze layers #1101
- delete! for Params Zygote.jl#505
- Per-leaf freezing Optimisers.jl#49
Right now, the recommendation made in the documentation is to manually specify which parameters should not be trained using some combination of Flux.params
and Zygote.delete!
.
While this works, it is somewhat inflexible in several respects:
- Training routines must be aware of the model architecture in order to select which parameters to freeze
- Specifying that a layer is frozen is often much more convenient to do at model construction time, particularly if the frozen layer is nested deeply inside the model
- It is not clear how the current approach would fit into the functional-style approach which is coming in v0.13, since
Params
would no longer be used at all (would one need to e.g.fmap
over a model and somehow mark specific layers as frozen before passing togradient
?)
For these reasons, I often find myself defining a Frozen
layer (similar to #1001) which looks something like this:
using Flux
using Flux: @adjoint
struct Frozen{F}
f::F
end
Flux.@functor Frozen # need functor for e.g. `fmap`
Flux.trainable(::Frozen) = NamedTuple() # no trainable parameters
# Something like `whitebox_apply` is required to explicitly treat `f` as a "white box":
# propagate gradients through `f`, but treat `f` itself as a constant functor
(l::Frozen)(xs...) = whitebox_apply(l.f, xs...)
whitebox_apply(f, xs...) = f(xs...)
@adjoint function whitebox_apply(f, xs...)
y, J = Flux.pullback(f, xs...)
y, Δ -> (nothing, J(Δ)...)
end
A frozen layer l::Frozen
wraps a functor f
and has two properties:
l(x) = f(x)
is differentiable with respect tox
(as opposed to e.g.l(x) = dropgrad(f(x))
which would treatf(x)
as constant)f
is treated as a constant functor: gradients ofl(x)
with respect to parameters internal tof
return zero
Below is some test code to illustrate how this layer should behave:
Examples/tests
x = rand(Float32, 2)
l1 = Dense(2, 3, tanh)
l2 = Dense(3, 4, tanh)
l3 = Dense(4, 2, identity)
m0 = Chain(l1, l2, l3)
m1 = Chain(l1, Frozen(l2), l3) # identical to `m0` but with the middle layer frozen
p0 = Flux.params(m0)
p1 = Flux.params(m1)
pfree = Flux.params(l1, l3)
pfrozen = Flux.params(l2)
# Basics
@assert all(p ∈ p1 for p in pfree) # free params are present
@assert all(p ∉ p1 for p in pfrozen) # frozen params are not
∇p1 = gradient(() -> sum(m1(x)), pfrozen)
@assert all(∇p1[p] === nothing for p in pfrozen) # frozen params have zero gradients, even if passed to `gradient` explicitly
∇p1 = gradient(() -> sum(m1(x)), p1)
@assert all(haskey(∇p1, p) for p in pfree) # free params have gradients
@assert !any(haskey(∇p1, p) for p in pfrozen) # frozen params do not have gradients
∇p0 = gradient(() -> sum(m0(x)), p0)
@assert all(∇p0[p] ≈ ∇p1[p] for p in pfree) # gradients are equal for free params
# This loss is constant as a function of `pfree`: `m0` and `m1` co-vary exactly as `pfree` changes,
# and therefore the difference `m0(x) - m1(x)` is zero with zero gradient w.r.t. `pfree`.
# However, since `m1` is treated as a constant function of `pfrozen` but `m0` is not,
# the gradient of `m0(x) - m1(x)` is nonzero w.r.t. `pfrozen`.
loss = () -> sum(m0(x) - m1(x))
∇p0 = gradient(loss, p0)
@assert all(iszero(∇p0[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(iszero(∇p0[p]) for p in pfrozen) # gradient != 0 for frozen parameters
∇p1 = gradient(loss, p1)
@assert all(iszero(∇p1[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(haskey(∇p1, p) for p in pfrozen) # gradients not present for frozen parameters
@assert all(∇p0[p] ≈ ∇p1[p] for p in pfree) # gradients are equal for free params
If there is interest in including a layer like Frozen
into Flux I would be happy to make a PR. Of course, if there is an easy way to do what I'm describing which I have overlooked, please do let me know and I'll close this issue.