From a78f79afd4834f0616edeee31db53f713ccd7949 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 9 Jul 2021 19:42:31 +0900 Subject: [PATCH] apply minor corrections and improvements - fixed implementations of quotient rule - added simplification rule for symbolic representation of exponentiation - avoid using the deprecated quoted constructor for `:block` expressions - corrected seemingly wrong statements --- src/forward.jl | 12 ++++++------ src/intro.jl | 42 +++++++++++++++++++++++++++++------------- src/tracing.jl | 4 ++-- src/utils.jl | 12 +++++++----- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/forward.jl b/src/forward.jl index 165e66c..8c71913 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -50,8 +50,8 @@ Expr(dy) # dy = 1 # y1 = log(y0) # dy = dy/y0 -# y2 = cos(y1) -# dy = dy*sin(y1) +# y2 = sin(y1) +# dy = dy*cos(y1) # ... # ``` @@ -186,11 +186,11 @@ D(x -> D(sin, x), 0.5), -sin(0.5) # The issue comes about when we close over a variable that *is itself* being # differentiated. -D(x -> x*D(y -> x+y, 1), 1) # == 1 +D(x -> x*D(y -> x+y, 1), 1) # == 3 -# The derivative $\frac{d}{dy} (x + y) = 1$, so this is equivalent to -# $\frac{d}{dx}x$, which should also be $1$. So where did this go wrong? The -# problem is that when we closed over $x$, we didn't just get a numeric value +# If we simplify the inner closure, this should be equivalent to `D(x -> x*(x+1), 1)` +# and so its derivative should be `3`. So where did this go wrong? +# The problem is that when we closed over $x$, we didn't just get a numeric value # but a dual number with $\epsilon = 1$. When we then calculated $x + y$, both # epsilons were added as if $\frac{dx}{dy} = 1$ (effectively $x = y$). If we had # written this down, the answer would be correct. diff --git a/src/intro.jl b/src/intro.jl index 2d3748d..6e0685e 100644 --- a/src/intro.jl +++ b/src/intro.jl @@ -138,7 +138,7 @@ function derive(ex, x) @capture(ex, a_ + b_) ? :($(derive(a, x)) + $(derive(b, x))) : @capture(ex, a_ * b_) ? :($a * $(derive(b, x)) + $b * $(derive(a, x))) : @capture(ex, a_^n_Number) ? :($(derive(a, x)) * ($n * $a^$(n-1))) : - @capture(ex, a_ / b_) ? :($b * $(derive(a, x)) - $a * $(derive(b, x)) / $b^2) : + @capture(ex, a_ / b_) ? :(($b * $(derive(a, x)) - $a * $(derive(b, x))) / $b^2) : error("$ex is not differentiable") end @@ -155,12 +155,16 @@ dy = derive(y, :x) addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b) mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b) mulm(a, b, c...) = mulm(mulm(a, b), c...) +powm(a, b) = b == 0 ? 1 : b == 1 ? a : :($a ^ $b) + #- addm(:a, :b) #- addm(:a, 0) #- mulm(:b, 1) +#- +powm(:a, 1) # Our tweaked `derive` function: @@ -169,8 +173,8 @@ function derive(ex, x) ex isa Union{Number,Symbol} ? 0 : @capture(ex, a_ + b_) ? addm(derive(a, x), derive(b, x)) : @capture(ex, a_ * b_) ? addm(mulm(a, derive(b, x)), mulm(b, derive(a, x))) : - @capture(ex, a_^n_Number) ? mulm(derive(a, x),n,:($a^$(n-1))) : - @capture(ex, a_ / b_) ? :($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x))) / $b^2) : + @capture(ex, a_^n_Number) ? mulm(derive(a, x), n, powm(a, n-1)) : + @capture(ex, a_ / b_) ? :(($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x)))) / $(powm(b, 2))) : error("$ex is not differentiable") end @@ -233,7 +237,9 @@ printstructure(y2); # Note that this is *not* the same as running common subexpression elimination # to simplify the tree, which would have an $O(n^2)$ computational cost. If -# there is real duplication in the expression, it'll show up. +# there is real duplication in the expression, it'll show up +# (technically, that is because `IdDict` hashes by object-id and each `Expr` +# has different identity and object-id accordingly). :(1*2 + 1*2) |> printstructure; @@ -264,7 +270,7 @@ derive(:(x / (1 + x^2) * x), :x) |> printstructure; # Calculator notation – expressions without variable bindings – is a terrible # format for anything, and will tend to blow up in size whether you # differentiate it or not. Symbolic differentiation is commonly criticised for -# its susceptability to "expression swell", but in fact has nothing to do with +# its susceptibility to "expression swell", but in fact has nothing to do with # the differentiation algorithm itself, and we need not change it to get better # results. # @@ -314,8 +320,8 @@ function derive(ex, x, w) ex isa Union{Number,Symbol} ? 0 : @capture(ex, a_ + b_) ? push!(w, addm(derive(a, x, w), derive(b, x, w))) : @capture(ex, a_ * b_) ? push!(w, addm(mulm(a, derive(b, x, w)), mulm(b, derive(a, x, w)))) : - @capture(ex, a_^n_Number) ? push!(w, mulm(derive(a, x, w),n,:($a^$(n-1)))) : - @capture(ex, a_ / b_) ? push!(w, :($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w))) / $b^2)) : + @capture(ex, a_^n_Number) ? push!(w, mulm(derive(a, x, w), n, powm(a, n-1))) : + @capture(ex, a_ / b_) ? push!(w, :(($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w)))) / $(powm(b, 2)))) : error("$ex is not differentiable") end @@ -329,7 +335,7 @@ derive(Wengert(:(3x^2 + (2x + 1))), :x) |> Expr # In fact, we can compare them directly using the `printstructure` function we # wrote earlier. -derive(:(x / (1 + x^2)), :x) |> printstructure +derive(:(x / (1 + x^2)), :x) |> printstructure; #- derive(Wengert(:(x / (1 + x^2))), :x) @@ -339,6 +345,7 @@ derive(Wengert(:(x / (1 + x^2))), :x) # we convert the Wengert list back into an `Expr`. derive(Wengert(:(x / (1 + x^2))), :x) |> Expr +#- function derive(w::Wengert, x) ds = Dict() @@ -348,8 +355,8 @@ function derive(w::Wengert, x) ex = w[v] Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) : @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) : - @capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) : - @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) : + @capture(ex, a_^n_Number) ? mulm(d(a), n, powm(a,n-1)) : + @capture(ex, a_ / b_) ? :(($(mulm(b, d(a))) - $(mulm(a, d(b)))) / $(powm(b, 2))) : error("$ex is not differentiable") ds[v] = push!(w, Δ) end @@ -379,10 +386,10 @@ function derive_r(w::Wengert, x) d(a, push!(w, mulm(Δ, b))) d(b, push!(w, mulm(Δ, a))) elseif @capture(ex, a_^n_Number) - d(a, mulm(Δ, n, :($a^$(n-1)))) + d(a, mulm(Δ, n, :($(powm(a, n-1))))) elseif @capture(ex, a_ / b_) - d(a, push!(w, mulm(Δ, b))) - d(b, push!(w, :(-$(mulm(Δ, a))/$b^2))) + d(a, push!(w, :($(mulm(Δ, b)) / $(powm(b, 2))))) + d(b, push!(w, :(-$(mulm(Δ, a)) / $(powm(b, 2))))) else error("$ex is not differentiable") end @@ -401,3 +408,12 @@ derive_r(Wengert(:(x / (1 + x^2))), :x) |> Expr # For now, the output looks pretty similar to that of forward mode; we'll # explain why the [distinction makes a difference](./backandforth.ipynb) in future # notebooks. + +# Lastly, let's assert the differentiators we wrote so far are all correct. +y = :(x / (1 + x^2)) + +x = 0.5 +dy = (1-x^2) / (1+x^2)^2 # hand-written derivative +@assert @show(derive(y, :x) |> eval) == dy +@assert @show(derive(Wengert(y), :x) |> Expr |> eval) == dy +@assert @show(derive_r(Wengert(y), :x) |> Expr |> eval) ≈ dy diff --git a/src/tracing.jl b/src/tracing.jl index e475934..67ad71e 100644 --- a/src/tracing.jl +++ b/src/tracing.jl @@ -194,7 +194,7 @@ x = track(t, 5) y = pow(x, 3) y[] - +#- y.w.instructions |> Expr # Finally, we need to alter how we derive this list. The key insight is that @@ -225,7 +225,7 @@ function derive(w::Tape, xs...) elseif @capture(ex, a_^n_Number) d(a, Δ * n * val(a) ^ (n-1)) elseif @capture(ex, a_ / b_) - d(a, Δ * val(b)) + d(a, Δ*val(b)/val(b)^2) d(b, -Δ*val(a)/val(b)^2) else error("$ex is not differentiable") diff --git a/src/utils.jl b/src/utils.jl index d7c0aa6..491145e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -73,7 +73,7 @@ function Expr(w::Wengert) bs = Dict() rename(ex::Expr) = Expr(ex.head, map(x -> get(bs, x, x), ex.args)...) rename(x) = x - ex = :(;) + ex = Expr(:block) for v in keys(w) if get(cs, v, 0) > 1 push!(ex.args, :($(Symbol(v)) = $(rename(w[v])))) @@ -89,6 +89,7 @@ end addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b) mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b) mulm(a, b, c...) = mulm(mulm(a, b), c...) +powm(a, b) = b == 0 ? 1 : b == 1 ? a : :($a ^ $b) function derive(w::Wengert, x; out = w) ds = Dict() @@ -99,7 +100,8 @@ function derive(w::Wengert, x; out = w) Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) : @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) : @capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) : - @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) : + @capture(ex, a_^n_Number) ? mulm(d(a), n, powm(a,n-1)) : + @capture(ex, a_ / b_) ? :(($(mulm(b, d(a))) - $(mulm(a, d(b)))) / $(powm(b, 2))) : @capture(ex, sin(a_)) ? mulm(:(cos($a)), d(a)) : @capture(ex, cos(a_)) ? mulm(:(-sin($a)), d(a)) : @capture(ex, exp(a_)) ? mulm(v, d(a)) : @@ -125,10 +127,10 @@ function derive_r(w::Wengert, x) d(a, push!(w, mulm(Δ, b))) d(b, push!(w, mulm(Δ, a))) elseif @capture(ex, a_^n_Number) - d(a, mulm(Δ, n, :($a^$(n-1)))) + d(a, mulm(Δ, n, :($(powm(a, n-1))))) elseif @capture(ex, a_ / b_) - d(a, push!(w, mulm(Δ, b))) - d(b, push!(w, :(-$(mulm(Δ, a))/$b^2))) + d(a, push!(w, :($(mulm(Δ, b)) / $(powm(b, 2))))) + d(b, push!(w, :(-$(mulm(Δ, a)) / $(powm(b, 2))))) else error("$ex is not differentiable") end