Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ Expr(dy)
# dy = 1
# y1 = log(y0)
# dy = dy/y0
# y2 = cos(y1)
# dy = dy*sin(y1)
# y2 = sin(y1)
# dy = dy*cos(y1)
# ...
# ```

Expand Down Expand Up @@ -186,11 +186,11 @@ D(x -> D(sin, x), 0.5), -sin(0.5)
# The issue comes about when we close over a variable that *is itself* being
# differentiated.

D(x -> x*D(y -> x+y, 1), 1) # == 1
D(x -> x*D(y -> x+y, 1), 1) # == 3

# The derivative $\frac{d}{dy} (x + y) = 1$, so this is equivalent to
# $\frac{d}{dx}x$, which should also be $1$. So where did this go wrong? The
# problem is that when we closed over $x$, we didn't just get a numeric value
# If we simplify the inner closure, this should be equivalent to `D(x -> x*(x+1), 1)`
# and so its derivative should be `3`. So where did this go wrong?
# The problem is that when we closed over $x$, we didn't just get a numeric value
# but a dual number with $\epsilon = 1$. When we then calculated $x + y$, both
# epsilons were added as if $\frac{dx}{dy} = 1$ (effectively $x = y$). If we had
# written this down, the answer would be correct.
Expand Down
42 changes: 29 additions & 13 deletions src/intro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function derive(ex, x)
@capture(ex, a_ + b_) ? :($(derive(a, x)) + $(derive(b, x))) :
@capture(ex, a_ * b_) ? :($a * $(derive(b, x)) + $b * $(derive(a, x))) :
@capture(ex, a_^n_Number) ? :($(derive(a, x)) * ($n * $a^$(n-1))) :
@capture(ex, a_ / b_) ? :($b * $(derive(a, x)) - $a * $(derive(b, x)) / $b^2) :
@capture(ex, a_ / b_) ? :(($b * $(derive(a, x)) - $a * $(derive(b, x))) / $b^2) :
error("$ex is not differentiable")
end

Expand All @@ -155,12 +155,16 @@ dy = derive(y, :x)
addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b)
mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b)
mulm(a, b, c...) = mulm(mulm(a, b), c...)
powm(a, b) = b == 0 ? 1 : b == 1 ? a : :($a ^ $b)

#-
addm(:a, :b)
#-
addm(:a, 0)
#-
mulm(:b, 1)
#-
powm(:a, 1)

# Our tweaked `derive` function:

Expand All @@ -169,8 +173,8 @@ function derive(ex, x)
ex isa Union{Number,Symbol} ? 0 :
@capture(ex, a_ + b_) ? addm(derive(a, x), derive(b, x)) :
@capture(ex, a_ * b_) ? addm(mulm(a, derive(b, x)), mulm(b, derive(a, x))) :
@capture(ex, a_^n_Number) ? mulm(derive(a, x),n,:($a^$(n-1))) :
@capture(ex, a_ / b_) ? :($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x))) / $b^2) :
@capture(ex, a_^n_Number) ? mulm(derive(a, x), n, powm(a, n-1)) :
@capture(ex, a_ / b_) ? :(($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x)))) / $(powm(b, 2))) :
error("$ex is not differentiable")
end

Expand Down Expand Up @@ -233,7 +237,9 @@ printstructure(y2);

# Note that this is *not* the same as running common subexpression elimination
# to simplify the tree, which would have an $O(n^2)$ computational cost. If
# there is real duplication in the expression, it'll show up.
# there is real duplication in the expression, it'll show up
# (technically, that is because `IdDict` hashes by object-id and each `Expr`
# has different identity and object-id accordingly).

:(1*2 + 1*2) |> printstructure;

Expand Down Expand Up @@ -264,7 +270,7 @@ derive(:(x / (1 + x^2) * x), :x) |> printstructure;
# Calculator notation – expressions without variable bindings – is a terrible
# format for anything, and will tend to blow up in size whether you
# differentiate it or not. Symbolic differentiation is commonly criticised for
# its susceptability to "expression swell", but in fact has nothing to do with
# its susceptibility to "expression swell", but in fact has nothing to do with
# the differentiation algorithm itself, and we need not change it to get better
# results.
#
Expand Down Expand Up @@ -314,8 +320,8 @@ function derive(ex, x, w)
ex isa Union{Number,Symbol} ? 0 :
@capture(ex, a_ + b_) ? push!(w, addm(derive(a, x, w), derive(b, x, w))) :
@capture(ex, a_ * b_) ? push!(w, addm(mulm(a, derive(b, x, w)), mulm(b, derive(a, x, w)))) :
@capture(ex, a_^n_Number) ? push!(w, mulm(derive(a, x, w),n,:($a^$(n-1)))) :
@capture(ex, a_ / b_) ? push!(w, :($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w))) / $b^2)) :
@capture(ex, a_^n_Number) ? push!(w, mulm(derive(a, x, w), n, powm(a, n-1))) :
@capture(ex, a_ / b_) ? push!(w, :(($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w)))) / $(powm(b, 2)))) :
error("$ex is not differentiable")
end

Expand All @@ -329,7 +335,7 @@ derive(Wengert(:(3x^2 + (2x + 1))), :x) |> Expr
# In fact, we can compare them directly using the `printstructure` function we
# wrote earlier.

derive(:(x / (1 + x^2)), :x) |> printstructure
derive(:(x / (1 + x^2)), :x) |> printstructure;
#-
derive(Wengert(:(x / (1 + x^2))), :x)

Expand All @@ -339,6 +345,7 @@ derive(Wengert(:(x / (1 + x^2))), :x)
# we convert the Wengert list back into an `Expr`.

derive(Wengert(:(x / (1 + x^2))), :x) |> Expr
#-

function derive(w::Wengert, x)
ds = Dict()
Expand All @@ -348,8 +355,8 @@ function derive(w::Wengert, x)
ex = w[v]
Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) :
@capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) :
@capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) :
@capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) :
@capture(ex, a_^n_Number) ? mulm(d(a), n, powm(a,n-1)) :
@capture(ex, a_ / b_) ? :(($(mulm(b, d(a))) - $(mulm(a, d(b)))) / $(powm(b, 2))) :
error("$ex is not differentiable")
ds[v] = push!(w, Δ)
end
Expand Down Expand Up @@ -379,10 +386,10 @@ function derive_r(w::Wengert, x)
d(a, push!(w, mulm(Δ, b)))
d(b, push!(w, mulm(Δ, a)))
elseif @capture(ex, a_^n_Number)
d(a, mulm(Δ, n, :($a^$(n-1))))
d(a, mulm(Δ, n, :($(powm(a, n-1)))))
elseif @capture(ex, a_ / b_)
d(a, push!(w, mulm(Δ, b)))
d(b, push!(w, :(-$(mulm(Δ, a))/$b^2)))
d(a, push!(w, :($(mulm(Δ, b)) / $(powm(b, 2)))))
d(b, push!(w, :(-$(mulm(Δ, a)) / $(powm(b, 2)))))
else
error("$ex is not differentiable")
end
Expand All @@ -401,3 +408,12 @@ derive_r(Wengert(:(x / (1 + x^2))), :x) |> Expr
# For now, the output looks pretty similar to that of forward mode; we'll
# explain why the [distinction makes a difference](./backandforth.ipynb) in future
# notebooks.

# Lastly, let's assert the differentiators we wrote so far are all correct.
y = :(x / (1 + x^2))

x = 0.5
dy = (1-x^2) / (1+x^2)^2 # hand-written derivative
@assert @show(derive(y, :x) |> eval) == dy
@assert @show(derive(Wengert(y), :x) |> Expr |> eval) == dy
@assert @show(derive_r(Wengert(y), :x) |> Expr |> eval) ≈ dy
4 changes: 2 additions & 2 deletions src/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ x = track(t, 5)

y = pow(x, 3)
y[]

#-
y.w.instructions |> Expr

# Finally, we need to alter how we derive this list. The key insight is that
Expand Down Expand Up @@ -225,7 +225,7 @@ function derive(w::Tape, xs...)
elseif @capture(ex, a_^n_Number)
d(a, Δ * n * val(a) ^ (n-1))
elseif @capture(ex, a_ / b_)
d(a, Δ * val(b))
d(a, Δ*val(b)/val(b)^2)
d(b, -Δ*val(a)/val(b)^2)
else
error("$ex is not differentiable")
Expand Down
12 changes: 7 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function Expr(w::Wengert)
bs = Dict()
rename(ex::Expr) = Expr(ex.head, map(x -> get(bs, x, x), ex.args)...)
rename(x) = x
ex = :(;)
ex = Expr(:block)
for v in keys(w)
if get(cs, v, 0) > 1
push!(ex.args, :($(Symbol(v)) = $(rename(w[v]))))
Expand All @@ -89,6 +89,7 @@ end
addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b)
mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b)
mulm(a, b, c...) = mulm(mulm(a, b), c...)
powm(a, b) = b == 0 ? 1 : b == 1 ? a : :($a ^ $b)

function derive(w::Wengert, x; out = w)
ds = Dict()
Expand All @@ -99,7 +100,8 @@ function derive(w::Wengert, x; out = w)
Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) :
@capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) :
@capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) :
@capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) :
@capture(ex, a_^n_Number) ? mulm(d(a), n, powm(a,n-1)) :
@capture(ex, a_ / b_) ? :(($(mulm(b, d(a))) - $(mulm(a, d(b)))) / $(powm(b, 2))) :
@capture(ex, sin(a_)) ? mulm(:(cos($a)), d(a)) :
@capture(ex, cos(a_)) ? mulm(:(-sin($a)), d(a)) :
@capture(ex, exp(a_)) ? mulm(v, d(a)) :
Expand All @@ -125,10 +127,10 @@ function derive_r(w::Wengert, x)
d(a, push!(w, mulm(Δ, b)))
d(b, push!(w, mulm(Δ, a)))
elseif @capture(ex, a_^n_Number)
d(a, mulm(Δ, n, :($a^$(n-1))))
d(a, mulm(Δ, n, :($(powm(a, n-1)))))
elseif @capture(ex, a_ / b_)
d(a, push!(w, mulm(Δ, b)))
d(b, push!(w, :(-$(mulm(Δ, a))/$b^2)))
d(a, push!(w, :($(mulm(Δ, b)) / $(powm(b, 2)))))
d(b, push!(w, :(-$(mulm(Δ, a)) / $(powm(b, 2)))))
else
error("$ex is not differentiable")
end
Expand Down