Skip to content

Freezing layers at model construction time #1931

Open
@jondeuce

Description

@jondeuce

There have been several issues/PRs related to freezing model parameters:

  1. freeze parameters #1022
  2. How to keep weights of parts of a model fixed under Flux.train! #1001
  3. Implement APIs of freeze parameters and freeze layers #1101
  4. delete! for Params Zygote.jl#505
  5. Per-leaf freezing Optimisers.jl#49

Right now, the recommendation made in the documentation is to manually specify which parameters should not be trained using some combination of Flux.params and Zygote.delete!.

While this works, it is somewhat inflexible in several respects:

  1. Training routines must be aware of the model architecture in order to select which parameters to freeze
  2. Specifying that a layer is frozen is often much more convenient to do at model construction time, particularly if the frozen layer is nested deeply inside the model
  3. It is not clear how the current approach would fit into the functional-style approach which is coming in v0.13, since Params would no longer be used at all (would one need to e.g. fmap over a model and somehow mark specific layers as frozen before passing to gradient?)

For these reasons, I often find myself defining a Frozen layer (similar to #1001) which looks something like this:

using Flux
using Flux: @adjoint

struct Frozen{F}
    f::F
end
Flux.@functor Frozen # need functor for e.g. `fmap`
Flux.trainable(::Frozen) = NamedTuple() # no trainable parameters

# Something like `whitebox_apply` is required to explicitly treat `f` as a "white box":
# propagate gradients through `f`, but treat `f` itself as a constant functor
(l::Frozen)(xs...) = whitebox_apply(l.f, xs...)

whitebox_apply(f, xs...) = f(xs...)

@adjoint function whitebox_apply(f, xs...)
    y, J = Flux.pullback(f, xs...)
    y, Δ -> (nothing, J(Δ)...)
end

A frozen layer l::Frozen wraps a functor f and has two properties:

  1. l(x) = f(x) is differentiable with respect to x (as opposed to e.g. l(x) = dropgrad(f(x)) which would treat f(x) as constant)
  2. f is treated as a constant functor: gradients of l(x) with respect to parameters internal to f return zero

Below is some test code to illustrate how this layer should behave:

Examples/tests
x = rand(Float32, 2)
l1 = Dense(2, 3, tanh)
l2 = Dense(3, 4, tanh)
l3 = Dense(4, 2, identity)

m0 = Chain(l1, l2, l3)
m1 = Chain(l1, Frozen(l2), l3) # identical to `m0` but with the middle layer frozen

p0 = Flux.params(m0)
p1 = Flux.params(m1)
pfree = Flux.params(l1, l3)
pfrozen = Flux.params(l2)

# Basics
@assert all(p  p1 for p in pfree) # free params are present
@assert all(p  p1 for p in pfrozen) # frozen params are not

∇p1 = gradient(() -> sum(m1(x)), pfrozen)
@assert all(∇p1[p] === nothing for p in pfrozen) # frozen params have zero gradients, even if passed to `gradient` explicitly

∇p1 = gradient(() -> sum(m1(x)), p1)
@assert all(haskey(∇p1, p) for p in pfree) # free params have gradients
@assert !any(haskey(∇p1, p) for p in pfrozen) # frozen params do not have gradients

∇p0 = gradient(() -> sum(m0(x)), p0)
@assert all(∇p0[p]  ∇p1[p] for p in pfree) # gradients are equal for free params

# This loss is constant as a function of `pfree`: `m0` and `m1` co-vary exactly as `pfree` changes,
# and therefore the difference `m0(x) - m1(x)` is zero with zero gradient w.r.t. `pfree`.
# However, since `m1` is treated as a constant function of `pfrozen` but `m0` is not,
# the gradient of `m0(x) - m1(x)` is nonzero w.r.t. `pfrozen`.
loss = () -> sum(m0(x) - m1(x))

∇p0 = gradient(loss, p0)
@assert all(iszero(∇p0[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(iszero(∇p0[p]) for p in pfrozen) # gradient != 0 for frozen parameters

∇p1 = gradient(loss, p1)
@assert all(iszero(∇p1[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(haskey(∇p1, p) for p in pfrozen) # gradients not present for frozen parameters
@assert all(∇p0[p]  ∇p1[p] for p in pfree) # gradients are equal for free params

If there is interest in including a layer like Frozen into Flux I would be happy to make a PR. Of course, if there is an easy way to do what I'm describing which I have overlooked, please do let me know and I'll close this issue.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions