diff --git a/src/calculus.jl b/src/calculus.jl index 60b542b..837188d 100644 --- a/src/calculus.jl +++ b/src/calculus.jl @@ -5,30 +5,64 @@ import Base: diff ## what is the rest of the interface. This does: ## diff(ex, x, n) f^(n) ## diff(ex, x, y, ...) f_{xy...} # also diff(ex, (x,y)) -## no support for diff(ex, x,n1, y,n2, ...), but can do diff(ex, (x,y), (n1, n2)) +## Support for diff(ex, x,n1, y,n2, ...), +## but can also do diff(ex, (x,y), (n1, n2)) -function diff(b1::SymbolicType, b2::BasicType{Val{:Symbol}}) - a = Basic() + +function diff!(a::Basic, b1::SymbolicType, b2::Basic) + is_symbol(b2) || throw(ArgumentError("Must differentiate with respect to a symbol")) ret = ccall((:basic_diff, libsymengine), Int, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2) return a end -diff(b1::SymbolicType, b2::BasicType) = - throw(ArgumentError("Second argument must be of Symbol type")) +function diff(b1::SymbolicType, b2::Basic) + a = Basic() + diff!(a, b1, b2) + a +end -function diff(b1::SymbolicType, b2::SymbolicType, n::Integer=1) +function diff(b1::SymbolicType, b2::SymbolicType, n::Integer) n < 0 && throw(DomainError("n must be non-negative integer")) - n==0 && return b1 - n==1 && return diff(b1, BasicType(b2)) - n > 1 && return diff(diff(b1, BasicType(b2)), BasicType(b2), n-1) + n == 0 && return b1 + x = Basic(b2) + out = Basic() + diff!(out, b1, x) + for _ in (n-1):-1:1 + diff!(out, out, x) + end + out +end + +function diff(b1::SymbolicType, b2::SymbolicType, n::Integer, xs...) + diff(diff(b1,b2,n), xs...) end function diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType) - isa(BasicType(b3), BasicType{Val{:Integer}}) ? diff(b1, b2, N(b3)) : diff(b1, (b2, b3)) + if isinteger(b3) + n = N(b3)::Int + diff(b1, b2, n) + else + ex = diff(b1, b2) + diff(ex, b3) + end +end + +function diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType, bs...) + diff(diff(b1,b2,b3), bs...) end -diff(b1::SymbolicType, b2::SymbolicType, b3::SymbolicType, b4::SymbolicType, b5...) = - diff(b1, (b2,b3,b4,b5...)) +function diff(b1::SymbolicType) + xs = free_symbols(b1) + n = length(xs) + n == 0 && return zero(b1) + n > 1 && throw(ArgumentError("More than one variable; one must be specified")) + diff(b1, only(xs)) +end + +## deprecate +diff(b1::SymbolicType, b2::BasicType{Val{:Symbol}}) = diff(b1, Basic(b2)) +diff(b1::SymbolicType, b2::BasicType) = + throw(ArgumentError("Second argument must be of Symbol type")) ## mixed partials diff(ex::SymbolicType, bs::Tuple) = reduce((ex, x) -> diff(ex, x), bs, init=ex) diff --git a/test/runtests.jl b/test/runtests.jl index a3ba7b8..01c0f52 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -101,14 +101,22 @@ u,v,w = x(2.1), x(1), x(0) ## calculus x,y = symbols("x y") +@test diff(log(x)) == 1/x +@test diff(log(x),x) == 1/x +@test_throws ArgumentError diff(log(x), x^2) + n = Basic(2) ex = sin(x*y) -@test diff(log(x),x) == 1/x +@test_throws ArgumentError diff(ex) @test diff(ex, x) == y * cos(x*y) @test diff(ex, x, 2) == diff(diff(ex,x), x) @test diff(ex, x, n) == diff(diff(ex,x), x) @test diff(ex, x, y) == diff(diff(ex,x), y) -@test diff(ex, x, y,x) == diff(diff(diff(ex,x), y), x) +@test diff(ex, x, y, x) == diff(diff(diff(ex,x), y), x) +@test diff(ex, x, 2, y, 3) == diff(ex, x,x,y,y,y) +@test diff(ex, x, n, y, 3) == diff(ex, x,x,y,y,y) +@test diff(ex, x, 2, y, x) == diff(ex, x,x,x,y) + @test series(sin(x), x, 0, 2) == x @test series(sin(x), x, 0, 3) == x - x^3/6