Skip to content

Commit 651f053

Browse files
authored
Merge pull request #22 from SymbolicML/deprecate-implicit
Deprecate implicit call API
2 parents 76f2467 + 8ff8158 commit 651f053

12 files changed

+276
-56
lines changed

README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ x2 = Node(; feature=2)
3636
expression = x1 * cos(x2 - 3.2)
3737

3838
X = randn(Float64, 2, 100);
39-
expression(X) # 100-element Vector{Float64}
39+
expression(X, operators) # 100-element Vector{Float64}
4040
```
4141

4242
(We can construct this expression with normal operators, since calling `OperatorEnum()` will `@eval` new functions on `Node` that use the specified enum.)
@@ -53,18 +53,18 @@ First, what happens if we naively use Julia symbols to define and then evaluate
5353
This is quite slow, meaning it will be hard to quickly search over the space of expressions. Let's see how DynamicExpressions.jl compares:
5454

5555
```julia
56-
@btime expression(X)
56+
@btime expression(X, operators)
5757
# 693 ns
5858
```
5959

60-
Much faster! And we didn't even need to compile it. (Internally, this is calling `eval_tree_array(expression, X, operators)` - where `operators` has been pre-defined when we called `OperatorEnum()`).
60+
Much faster! And we didn't even need to compile it. (Internally, this is calling `eval_tree_array(expression, X, operators)`).
6161

6262
If we change `expression` dynamically with a random number generator, it will have the same performance:
6363

6464
```julia
6565
@btime begin
6666
expression.op = rand(1:3) # random operator in [+, -, *]
67-
expression(X)
67+
expression(X, operators)
6868
end
6969
# 842 ns
7070
```
@@ -113,13 +113,13 @@ expression = x1 * cos(x2 - 3.2)
113113
We can take the gradient with respect to inputs with simply the `'` character:
114114

115115
```julia
116-
grad = expression'(X)
116+
grad = expression'(X, operators)
117117
```
118118

119119
This is quite fast:
120120

121121
```julia
122-
@btime expression'(X)
122+
@btime expression'(X, operators)
123123
# 2894 ns
124124
```
125125

@@ -128,7 +128,7 @@ and again, we can change this expression at runtime, without loss in performance
128128
```julia
129129
@btime begin
130130
expression.op = rand(1:3)
131-
expression'(X)
131+
expression'(X, operators)
132132
end
133133
# 3198 ns
134134
```
@@ -180,14 +180,14 @@ Now, let's create an expression:
180180
tree = "H" * my_string_func(x1)
181181
# ^ `(H * my_string_func(x1))`
182182

183-
tree(["World!", "Me?"])
183+
tree(["World!", "Me?"], operators)
184184
# Hello World!
185185
```
186186

187187
So indeed it works for arbitrary types. It is a bit slower due to the potential for type instability, but it's not too bad:
188188

189189
```julia
190-
@btime tree(["Hello", "Me?"])
190+
@btime tree(["Hello", "Me?"], operators)
191191
# 1738 ns
192192
```
193193

@@ -220,15 +220,15 @@ tree = vec_add(vec_add(vec_square(x1), c2), c1)
220220
X = [[-1.0, 5.2, 0.1], [0.0, 0.0, 0.0]]
221221

222222
# Evaluate!
223-
tree(X) # [2.0, 29.04, 3.01]
223+
tree(X, operators) # [2.0, 29.04, 3.01]
224224
```
225225

226226
Note that if an operator is not defined for the particular input, `nothing` will be returned instead.
227227

228228
This is all still pretty fast, too:
229229

230230
```julia
231-
@btime tree(X)
231+
@btime tree(X, operators)
232232
# 2,949 ns
233233
@btime eval(:(vec_add(vec_add(vec_square(X[1]), [1.0, 2.0, 3.0]), 0.0)))
234234
# 115,000 ns

docs/src/eval.md

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,34 @@ eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) w
88
```
99

1010
Assuming you are only using a single `OperatorEnum`, you can also use
11-
the following short-hand by using the expression as a function:
11+
the following shorthand by using the expression as a function:
12+
13+
```
14+
(tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
15+
16+
Evaluate a binary tree (equation) over a given data matrix. The
17+
operators contain all of the operators used in the tree.
18+
19+
# Arguments
20+
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
21+
- `operators::OperatorEnum`: The operators used in the tree.
22+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
23+
24+
# Returns
25+
- `output::AbstractVector{T}`: the result, which is a 1D array.
26+
Any NaN, Inf, or other failure during the evaluation will result in the entire
27+
output array being set to NaN.
28+
```
29+
30+
For example,
31+
32+
```@example
33+
using DynamicExpressions
1234
13-
```julia
1435
operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])
1536
tree = Node(; feature=1) * cos(Node(; feature=2) - 3.2)
1637
17-
tree(X)
38+
tree([1 2 3; 4 5 6.], operators)
1839
```
1940

2041
This is possible because when you call `OperatorEnum`, it automatically re-defines
@@ -32,7 +53,31 @@ The notation is the same for `eval_tree_array`, though it will return `nothing`
3253
when it can't find a method, and not do any NaN checks:
3354

3455
```@docs
35-
eval_tree_array(tree, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true)
56+
eval_tree_array(tree::Node, cX::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
57+
```
58+
59+
Likewise for the shorthand notation:
60+
61+
```
62+
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
63+
64+
# Arguments
65+
- `X::AbstractArray`: The input data to evaluate the tree on.
66+
- `operators::GenericOperatorEnum`: The operators used in the tree.
67+
- `throw_errors::Bool=true`: Whether to throw errors
68+
if they occur during evaluation. Otherwise,
69+
MethodErrors will be caught before they happen and
70+
evaluation will return `nothing`,
71+
rather than throwing an error. This is useful in cases
72+
where you are unsure if a particular tree is valid or not,
73+
and would prefer to work with `nothing` as an output.
74+
75+
# Returns
76+
- `output`: the result of the evaluation.
77+
If evaluation failed, `nothing` will be returned for the first argument.
78+
A `false` complete means an operator was called on input types
79+
that it was not defined for. You can change this behavior by
80+
setting `throw_errors=false`.
3681
```
3782

3883
## Derivatives
@@ -46,7 +91,32 @@ all variables (or, all constants). Both use forward-mode automatic, but use
4691

4792
```@docs
4893
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Number}
49-
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false) where {T<:Number}
94+
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
95+
```
96+
97+
You can compute gradients this with shorthand notation as well (which by default computes
98+
gradients with respect to input matrix, rather than constants).
99+
100+
```
101+
(tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)
102+
103+
Compute the forward-mode derivative of an expression, using a similar
104+
structure and optimization to eval_tree_array. `variable` specifies whether
105+
we should take derivatives with respect to features (i.e., X), or with respect
106+
to every constant in the expression.
107+
108+
# Arguments
109+
- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
110+
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
111+
must be `true`. This is needed to create the derivative operations.
112+
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
113+
or with respect to every constant in the expression (`variable=false`).
114+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
115+
116+
# Returns
117+
118+
- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
119+
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
50120
```
51121

52122
Alternatively, you can compute higher-order derivatives by using `ForwardDiff` on

src/DynamicExpressions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include("Equation.jl")
66
include("EquationUtils.jl")
77
include("EvaluateEquation.jl")
88
include("EvaluateEquationDerivative.jl")
9+
include("EvaluationHelpers.jl")
910
include("InterfaceSymbolicUtils.jl")
1011
include("SimplifyEquation.jl")
1112
include("OperatorEnumConstruction.jl")
@@ -31,6 +32,7 @@ using Reexport
3132
eval_diff_tree_array, eval_grad_tree_array
3233
@reexport import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node
3334
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
35+
@reexport import .EvaluationHelpersModule
3436

3537
import TOML: parsefile
3638

src/EvaluationHelpers.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
module EvaluationHelpersModule
2+
3+
import Base: adjoint
4+
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
5+
import ..EquationModule: Node
6+
import ..EvaluateEquationModule: eval_tree_array
7+
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
8+
9+
# Evaluation:
10+
"""
11+
(tree::Node)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false)
12+
13+
Evaluate a binary tree (equation) over a given data matrix. The
14+
operators contain all of the operators used in the tree.
15+
16+
# Arguments
17+
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
18+
- `operators::OperatorEnum`: The operators used in the tree.
19+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
20+
21+
# Returns
22+
- `output::AbstractVector{T}`: the result, which is a 1D array.
23+
Any NaN, Inf, or other failure during the evaluation will result in the entire
24+
output array being set to NaN.
25+
"""
26+
function (tree::Node)(X, operators::OperatorEnum; kws...)
27+
out, did_finish = eval_tree_array(tree, X, operators; kws...)
28+
!did_finish && (out .= convert(eltype(out), NaN))
29+
return out
30+
end
31+
"""
32+
(tree::Node)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true)
33+
34+
# Arguments
35+
- `X::AbstractArray`: The input data to evaluate the tree on.
36+
- `operators::GenericOperatorEnum`: The operators used in the tree.
37+
- `throw_errors::Bool=true`: Whether to throw errors
38+
if they occur during evaluation. Otherwise,
39+
MethodErrors will be caught before they happen and
40+
evaluation will return `nothing`,
41+
rather than throwing an error. This is useful in cases
42+
where you are unsure if a particular tree is valid or not,
43+
and would prefer to work with `nothing` as an output.
44+
45+
# Returns
46+
- `output`: the result of the evaluation.
47+
If evaluation failed, `nothing` will be returned for the first argument.
48+
A `false` complete means an operator was called on input types
49+
that it was not defined for. You can change this behavior by
50+
setting `throw_errors=false`.
51+
"""
52+
function (tree::Node)(X, operators::GenericOperatorEnum; kws...)
53+
out, did_finish = eval_tree_array(tree, X, operators; kws...)
54+
!did_finish && return nothing
55+
return out
56+
end
57+
function (tree::Node)(X; kws...)
58+
## This will be overwritten by OperatorEnumConstructionModule, and turned
59+
## into a depwarn.
60+
return error(
61+
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
62+
)
63+
end
64+
65+
# Gradients:
66+
function _grad_evaluator(tree::Node, X, operators::OperatorEnum; variable=true, kws...)
67+
_, grad, did_complete = eval_grad_tree_array(
68+
tree, X, operators; variable=variable, kws...
69+
)
70+
!did_complete && (grad .= convert(eltype(grad), NaN))
71+
return grad
72+
end
73+
function _grad_evaluator(tree::Node, X, operators::GenericOperatorEnum; kws...)
74+
return error("Gradients are not implemented for `GenericOperatorEnum`.")
75+
end
76+
function _grad_evaluator(tree::Node, X; kws...)
77+
## This will be overwritten by OperatorEnumConstructionModule, and turned
78+
## into a depwarn
79+
return error(
80+
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
81+
)
82+
end
83+
84+
"""
85+
(tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)
86+
87+
Compute the forward-mode derivative of an expression, using a similar
88+
structure and optimization to eval_tree_array. `variable` specifies whether
89+
we should take derivatives with respect to features (i.e., X), or with respect
90+
to every constant in the expression.
91+
92+
# Arguments
93+
- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
94+
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
95+
must be `true`. This is needed to create the derivative operations.
96+
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
97+
or with respect to every constant in the expression (`variable=false`).
98+
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
99+
100+
# Returns
101+
102+
- `(evaluation, gradient, complete)::Tuple{AbstractVector{T}, AbstractMatrix{T}, Bool}`: the normal evaluation,
103+
the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
104+
"""
105+
Base.adjoint(tree::Node) = ((args...; kws...) -> _grad_evaluator(tree, args...; kws...))
106+
107+
end

src/OperatorEnumConstruction.jl

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,26 @@ import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperator
55
import ..EquationModule: string_tree, Node
66
import ..EvaluateEquationModule: eval_tree_array
77
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
8+
import ..EvaluationHelpersModule: _grad_evaluator
89

910
function create_evaluation_helpers!(operators::OperatorEnum)
1011
@eval begin
1112
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
1213
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
1314
function (tree::Node)(X; kws...)
14-
out, did_finish = eval_tree_array(tree, X, $operators; kws...)
15-
if !did_finish
16-
out .= convert(eltype(out), NaN)
17-
end
18-
return out
15+
Base.depwarn(
16+
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
17+
:Node,
18+
)
19+
return tree(X, $operators; kws...)
1920
end
2021
# Gradients:
21-
function Base.adjoint(tree::Node{T}) where {T}
22-
return (X; kws...) -> begin
23-
_, grad, did_complete = eval_grad_tree_array(
24-
tree, X, $operators; variable=true, kws...
25-
)
26-
!did_complete && (grad .= T(NaN))
27-
grad
28-
end
22+
function _grad_evaluator(tree::Node, X; kws...)
23+
Base.depwarn(
24+
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
25+
:Node,
26+
)
27+
return _grad_evaluator(tree, X, $operators; kws...)
2928
end
3029
end
3130
end
@@ -36,16 +35,14 @@ function create_evaluation_helpers!(operators::GenericOperatorEnum)
3635
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
3736

3837
function (tree::Node)(X; kws...)
39-
out, did_finish = eval_tree_array(tree, X, $operators; kws...)
40-
if !did_finish
41-
return nothing
42-
end
43-
return out
38+
Base.depwarn(
39+
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
40+
:Node,
41+
)
42+
return tree(X, $operators; kws...)
4443
end
45-
function Base.adjoint(::Node{T}) where {T}
46-
return _ -> begin
47-
error("Gradients are not implemented for `GenericOperatorEnum`.")
48-
end
44+
function _grad_evaluator(tree::Node, X; kws...)
45+
return error("Gradients are not implemented for `GenericOperatorEnum`.")
4946
end
5047
end
5148
end

0 commit comments

Comments
 (0)