diff --git a/src/intro.jl b/src/intro.jl index 2d3748d..443d0cc 100644 --- a/src/intro.jl +++ b/src/intro.jl @@ -138,7 +138,7 @@ function derive(ex, x) @capture(ex, a_ + b_) ? :($(derive(a, x)) + $(derive(b, x))) : @capture(ex, a_ * b_) ? :($a * $(derive(b, x)) + $b * $(derive(a, x))) : @capture(ex, a_^n_Number) ? :($(derive(a, x)) * ($n * $a^$(n-1))) : - @capture(ex, a_ / b_) ? :($b * $(derive(a, x)) - $a * $(derive(b, x)) / $b^2) : + @capture(ex, a_ / b_) ? :(($b * $(derive(a, x)) - $a * $(derive(b, x))) / $b^2) : error("$ex is not differentiable") end @@ -170,7 +170,7 @@ function derive(ex, x) @capture(ex, a_ + b_) ? addm(derive(a, x), derive(b, x)) : @capture(ex, a_ * b_) ? addm(mulm(a, derive(b, x)), mulm(b, derive(a, x))) : @capture(ex, a_^n_Number) ? mulm(derive(a, x),n,:($a^$(n-1))) : - @capture(ex, a_ / b_) ? :($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x))) / $b^2) : + @capture(ex, a_ / b_) ? :(($(mulm(b, derive(a, x))) - $(mulm(a, derive(b, x)))) / $b^2) : error("$ex is not differentiable") end @@ -315,7 +315,7 @@ function derive(ex, x, w) @capture(ex, a_ + b_) ? push!(w, addm(derive(a, x, w), derive(b, x, w))) : @capture(ex, a_ * b_) ? push!(w, addm(mulm(a, derive(b, x, w)), mulm(b, derive(a, x, w)))) : @capture(ex, a_^n_Number) ? push!(w, mulm(derive(a, x, w),n,:($a^$(n-1)))) : - @capture(ex, a_ / b_) ? push!(w, :($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w))) / $b^2)) : + @capture(ex, a_ / b_) ? push!(w, :(($(mulm(b, derive(a, x, w))) - $(mulm(a, derive(b, x, w)))) / $b^2)) : error("$ex is not differentiable") end @@ -349,7 +349,7 @@ function derive(w::Wengert, x) Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) : @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) : @capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) : - @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) : + @capture(ex, a_ / b_) ? :(($(mulm(b, d(a))) - $(mulm(a, d(b)))) / $b^2) : error("$ex is not differentiable") ds[v] = push!(w, Δ) end @@ -381,7 +381,7 @@ function derive_r(w::Wengert, x) elseif @capture(ex, a_^n_Number) d(a, mulm(Δ, n, :($a^$(n-1)))) elseif @capture(ex, a_ / b_) - d(a, push!(w, mulm(Δ, b))) + d(a, push!(w, :($(mulm(Δ, b))/$b^2))) d(b, push!(w, :(-$(mulm(Δ, a))/$b^2))) else error("$ex is not differentiable") diff --git a/src/tracing.jl b/src/tracing.jl index e475934..c7ed824 100644 --- a/src/tracing.jl +++ b/src/tracing.jl @@ -225,8 +225,8 @@ function derive(w::Tape, xs...) elseif @capture(ex, a_^n_Number) d(a, Δ * n * val(a) ^ (n-1)) elseif @capture(ex, a_ / b_) - d(a, Δ * val(b)) - d(b, -Δ*val(a)/val(b)^2) + d(a, Δ * val(b)/val(b)^2) + d(b, -Δ * val(a)/val(b)^2) else error("$ex is not differentiable") end diff --git a/src/utils.jl b/src/utils.jl index d7c0aa6..9b3257a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -99,7 +99,7 @@ function derive(w::Wengert, x; out = w) Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) : @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) : @capture(ex, a_^n_Number) ? mulm(d(a),n,:($a^$(n-1))) : - @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) : + @capture(ex, a_ / b_) ? :(($(mulm(b, d(a))) - $(mulm(a, d(b)))) / $b^2) : @capture(ex, sin(a_)) ? mulm(:(cos($a)), d(a)) : @capture(ex, cos(a_)) ? mulm(:(-sin($a)), d(a)) : @capture(ex, exp(a_)) ? mulm(v, d(a)) : @@ -127,7 +127,7 @@ function derive_r(w::Wengert, x) elseif @capture(ex, a_^n_Number) d(a, mulm(Δ, n, :($a^$(n-1)))) elseif @capture(ex, a_ / b_) - d(a, push!(w, mulm(Δ, b))) + d(a, push!(w, :($(mulm(Δ, b))/$b^2))) d(b, push!(w, :(-$(mulm(Δ, a))/$b^2))) else error("$ex is not differentiable")