Skip to content
58 changes: 46 additions & 12 deletions src/calculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,64 @@ import Base: diff
## what is the rest of the interface. This does:
## diff(ex, x, n) f^(n)
## diff(ex, x, y, ...) f_{xy...} # also diff(ex, (x,y))
## no support for diff(ex, x,n1, y,n2, ...), but can do diff(ex, (x,y), (n1, n2))
## Support for diff(ex, x,n1, y,n2, ...),
## but can also do diff(ex, (x,y), (n1, n2))

function diff(b1::SymbolicType, b2::BasicType{Val{:Symbol}})
a = Basic()

function diff!(a::Basic, b1::SymbolicType, b2::Basic)
is_symbol(b2) || throw(ArgumentError("Must differentiate with respect to a symbol"))
ret = ccall((:basic_diff, libsymengine), Int, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
return a
end

diff(b1::SymbolicType, b2::BasicType) =
throw(ArgumentError("Second argument must be of Symbol type"))
function diff(b1::SymbolicType, b2::Basic)
a = Basic()
diff!(a, b1, b2)
a
end

function diff(b1::SymbolicType, b2::SymbolicType, n::Integer=1)
function diff(b1::SymbolicType, b2::SymbolicType, n::Integer)
n < 0 && throw(DomainError("n must be non-negative integer"))
n==0 && return b1
n==1 && return diff(b1, BasicType(b2))
n > 1 && return diff(diff(b1, BasicType(b2)), BasicType(b2), n-1)
n == 0 && return b1
x = Basic(b2)
out = Basic()
diff!(out, b1, x)
for _ in (n-1):-1:1
diff!(out, out, x)
end
out
end

function diff(b1::SymbolicType, b2::SymbolicType, n::Integer, xs...)
diff(diff(b1,b2,n), xs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably use a mutating diff as well, but not necessary for this PR

end

function diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType)
isa(BasicType(b3), BasicType{Val{:Integer}}) ? diff(b1, b2, N(b3)) : diff(b1, (b2, b3))
if isinteger(b3)
n = N(b3)::Int
diff(b1, b2, n)
else
ex = diff(b1, b2)
diff(ex, b3)
end
end

function diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType, bs...)
diff(diff(b1,b2,b3), bs...)
end

diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType, b4::SymbolicType, b5...) =
diff(b1, (b2,b3,b4,b5...))
function diff(b1::SymbolicType)
xs = free_symbols(b1)
n = length(xs)
n == 0 && return zero(b1)
n > 1 && throw(ArgumentError("More than one variable; one must be specified"))
diff(b1, only(xs))
end

## deprecate
diff(b1::SymbolicType, b2::BasicType{Val{:Symbol}}) = diff(b1, Basic(b2))
diff(b1::SymbolicType, b2::BasicType) =
throw(ArgumentError("Second argument must be of Symbol type"))

## mixed partials
diff(ex::SymbolicType, bs::Tuple) = reduce((ex, x) -> diff(ex, x), bs, init=ex)
Expand Down
12 changes: 10 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,22 @@ u,v,w = x(2.1), x(1), x(0)

## calculus
x,y = symbols("x y")
@test diff(log(x)) == 1/x
@test diff(log(x),x) == 1/x
@test_throws ArgumentError diff(log(x), x^2)

n = Basic(2)
ex = sin(x*y)
@test diff(log(x),x) == 1/x
@test_throws ArgumentError diff(ex)
@test diff(ex, x) == y * cos(x*y)
@test diff(ex, x, 2) == diff(diff(ex,x), x)
@test diff(ex, x, n) == diff(diff(ex,x), x)
@test diff(ex, x, y) == diff(diff(ex,x), y)
@test diff(ex, x, y,x) == diff(diff(diff(ex,x), y), x)
@test diff(ex, x, y, x) == diff(diff(diff(ex,x), y), x)
@test diff(ex, x, 2, y, 3) == diff(ex, x,x,y,y,y)
@test diff(ex, x, n, y, 3) == diff(ex, x,x,y,y,y)
@test diff(ex, x, 2, y, x) == diff(ex, x,x,x,y)

@test series(sin(x), x, 0, 2) == x
@test series(sin(x), x, 0, 3) == x - x^3/6

Expand Down
Loading