Skip to content

Commit fbdceaf

Browse files
authored
Merge pull request #44 from SymbolicML/dont-overwrite
Avoid method invalidation in helper functions
2 parents c84014f + 3f01acd commit fbdceaf

12 files changed

+275
-141
lines changed

src/Equation.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ const OP_NAMES = Dict(
190190
"safe_pow" => "^",
191191
)
192192

193+
get_op_name(op::String) = op
193194
@generated function get_op_name(op::F) where {F}
194195
try
195196
# Bit faster to just cache the name of the operator:
@@ -266,7 +267,7 @@ Convert an equation to a string.
266267
"""
267268
function string_tree(
268269
tree::Node{T},
269-
operators::AbstractOperatorEnum;
270+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
270271
bracketed::Bool=false,
271272
f_variable::F1=string_variable,
272273
f_constant::F2=string_constant,
@@ -284,7 +285,11 @@ function string_tree(
284285
elseif tree.degree == 1
285286
return string_op(
286287
Val(1),
287-
operators.unaops[tree.op],
288+
if operators === nothing
289+
"unary_operator[" * string(tree.op) * "]"
290+
else
291+
operators.unaops[tree.op]
292+
end,
288293
tree,
289294
operators;
290295
bracketed,
@@ -295,7 +300,11 @@ function string_tree(
295300
else
296301
return string_op(
297302
Val(2),
298-
operators.binops[tree.op],
303+
if operators === nothing
304+
"binary_operator[" * string(tree.op) * "]"
305+
else
306+
operators.binops[tree.op]
307+
end,
299308
tree,
300309
operators;
301310
bracketed,

src/EvaluationHelpers.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,6 @@ function (tree::Node)(X, operators::GenericOperatorEnum; kws...)
5454
!did_finish && return nothing
5555
return out
5656
end
57-
function (tree::Node)(X; kws...)
58-
## This will be overwritten by OperatorEnumConstructionModule, and turned
59-
## into a depwarn.
60-
return error(
61-
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
62-
)
63-
end
6457

6558
# Gradients:
6659
function _grad_evaluator(tree::Node, X, operators::OperatorEnum; variable=true, kws...)
@@ -73,13 +66,6 @@ end
7366
function _grad_evaluator(tree::Node, X, operators::GenericOperatorEnum; kws...)
7467
return error("Gradients are not implemented for `GenericOperatorEnum`.")
7568
end
76-
function _grad_evaluator(tree::Node, X; kws...)
77-
## This will be overwritten by OperatorEnumConstructionModule, and turned
78-
## into a depwarn
79-
return error(
80-
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
81-
)
82-
end
8369

8470
"""
8571
(tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)

src/OperatorEnumConstruction.jl

Lines changed: 143 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,82 +6,135 @@ import ..EvaluateEquationModule: eval_tree_array
66
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
77
import ..EvaluationHelpersModule: _grad_evaluator
88

9-
function create_evaluation_helpers!(operators::OperatorEnum)
10-
@eval begin
11-
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
12-
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
13-
function (tree::Node)(X; kws...)
14-
Base.depwarn(
15-
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
16-
:Node,
17-
)
18-
return tree(X, $operators; kws...)
19-
end
20-
# Gradients:
21-
function _grad_evaluator(tree::Node, X; kws...)
22-
Base.depwarn(
23-
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
24-
:Node,
25-
)
26-
return _grad_evaluator(tree, X, $operators; kws...)
27-
end
9+
"""Used to set a default value for `operators` for ease of use."""
10+
@enum AvailableOperatorTypes begin
11+
IsNothing
12+
IsOperatorEnum
13+
IsGenericOperatorEnum
14+
end
15+
16+
# These constants are purely for convenience. Internal code
17+
# should make use of `Node`, `string_tree`, `eval_tree_array`,
18+
# and `eval_grad_tree_array` directly.
19+
20+
const LATEST_OPERATORS = Ref{Union{Nothing,AbstractOperatorEnum}}(nothing)
21+
const LATEST_OPERATORS_TYPE = Ref{AvailableOperatorTypes}(IsNothing)
22+
const LATEST_UNARY_OPERATOR_MAPPING = Dict{Function,Int}()
23+
const LATEST_BINARY_OPERATOR_MAPPING = Dict{Function,Int}()
24+
const ALREADY_DEFINED_UNARY_OPERATORS = (;
25+
operator_enum=Dict{Function,Bool}(), generic_operator_enum=Dict{Function,Bool}()
26+
)
27+
const ALREADY_DEFINED_BINARY_OPERATORS = (;
28+
operator_enum=Dict{Function,Bool}(), generic_operator_enum=Dict{Function,Bool}()
29+
)
30+
31+
function Base.show(io::IO, tree::Node)
32+
latest_operators_type = LATEST_OPERATORS_TYPE.x
33+
if latest_operators_type == IsNothing
34+
return print(io, string_tree(tree))
35+
elseif latest_operators_type == IsOperatorEnum
36+
latest_operators = LATEST_OPERATORS.x::OperatorEnum
37+
return print(io, string_tree(tree, latest_operators))
38+
else
39+
latest_operators = LATEST_OPERATORS.x::GenericOperatorEnum
40+
return print(io, string_tree(tree, latest_operators))
2841
end
2942
end
43+
function (tree::Node)(X; kws...)
44+
Base.depwarn(
45+
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
46+
:Node,
47+
)
48+
latest_operators_type = LATEST_OPERATORS_TYPE.x
49+
if latest_operators_type == IsNothing
50+
error("Please use the `tree(X, operators; kws...)` syntax instead.")
51+
elseif latest_operators_type == IsOperatorEnum
52+
latest_operators = LATEST_OPERATORS.x::OperatorEnum
53+
return tree(X, latest_operators; kws...)
54+
else
55+
latest_operators = LATEST_OPERATORS.x::GenericOperatorEnum
56+
return tree(X, latest_operators; kws...)
57+
end
58+
end
59+
60+
function _grad_evaluator(tree::Node, X; kws...)
61+
Base.depwarn(
62+
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
63+
:Node,
64+
)
65+
latest_operators_type = LATEST_OPERATORS_TYPE.x
66+
# return _grad_evaluator(tree, X, $operators; kws...)
67+
if latest_operators_type == IsNothing
68+
error("Please use the `tree'(X, operators; kws...)` syntax instead.")
69+
elseif latest_operators_type == IsOperatorEnum
70+
latest_operators = LATEST_OPERATORS.x::OperatorEnum
71+
return _grad_evaluator(tree, X, latest_operators; kws...)
72+
else
73+
error("Gradients are not implemented for `GenericOperatorEnum`.")
74+
end
75+
end
76+
77+
function create_evaluation_helpers!(operators::OperatorEnum)
78+
LATEST_OPERATORS.x = operators
79+
return LATEST_OPERATORS_TYPE.x = IsOperatorEnum
80+
end
3081

3182
function create_evaluation_helpers!(operators::GenericOperatorEnum)
32-
@eval begin
33-
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
34-
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
35-
36-
function (tree::Node)(X; kws...)
37-
Base.depwarn(
38-
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
39-
:Node,
40-
)
41-
return tree(X, $operators; kws...)
42-
end
43-
function _grad_evaluator(tree::Node, X; kws...)
44-
return error("Gradients are not implemented for `GenericOperatorEnum`.")
45-
end
83+
LATEST_OPERATORS.x = operators
84+
return LATEST_OPERATORS_TYPE.x = IsGenericOperatorEnum
85+
end
86+
function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
87+
mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
88+
if !haskey(mapping, f)
89+
error(
90+
"Convenience constructor for `Node` using operator `$(f)` is out-of-date. " *
91+
"Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
92+
"`define_helper_functions=true` and pass `$(f)`.",
93+
)
4694
end
95+
return mapping[f]
4796
end
4897

49-
function _extend_unary_operator(f::Symbol, op, type_requirements)
98+
function _extend_unary_operator(f::Symbol, type_requirements)
5099
quote
51100
quote
52101
function $($f)(l::Node{T})::Node{T} where {T<:$($type_requirements)}
53102
return if (l.degree == 0 && l.constant)
54103
Node(T; val=$($f)(l.val::T))
55104
else
56-
Node($($op), l)
105+
latest_op_idx = $($lookup_op)($($f), Val(1))
106+
Node(latest_op_idx, l)
57107
end
58108
end
59109
end
60110
end
61111
end
62112

63-
function _extend_binary_operator(f::Symbol, op, type_requirements, build_converters)
113+
function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
64114
quote
65115
quote
66116
function $($f)(l::Node{T}, r::Node{T}) where {T<:$($type_requirements)}
67117
if (l.degree == 0 && l.constant && r.degree == 0 && r.constant)
68118
Node(T; val=$($f)(l.val::T, r.val::T))
69119
else
70-
Node($($op), l, r)
120+
latest_op_idx = $($lookup_op)($($f), Val(2))
121+
Node(latest_op_idx, l, r)
71122
end
72123
end
73124
function $($f)(l::Node{T}, r::T) where {T<:$($type_requirements)}
74125
if l.degree == 0 && l.constant
75126
Node(T; val=$($f)(l.val::T, r))
76127
else
77-
Node($($op), l, Node(T; val=r))
128+
latest_op_idx = $($lookup_op)($($f), Val(2))
129+
Node(latest_op_idx, l, Node(T; val=r))
78130
end
79131
end
80132
function $($f)(l::T, r::Node{T}) where {T<:$($type_requirements)}
81133
if r.degree == 0 && r.constant
82134
Node(T; val=$($f)(l, r.val::T))
83135
else
84-
Node($($op), Node(T; val=l), r)
136+
latest_op_idx = $($lookup_op)($($f), Val(2))
137+
Node(latest_op_idx, Node(T; val=l), r)
85138
end
86139
end
87140
if $($build_converters)
@@ -116,37 +169,62 @@ function _extend_binary_operator(f::Symbol, op, type_requirements, build_convert
116169
end
117170

118171
function _extend_operators(operators, skip_user_operators, __module__::Module)
119-
binary_ex = _extend_binary_operator(:f, :op, :type_requirements, :build_converters)
120-
unary_ex = _extend_unary_operator(:f, :op, :type_requirements)
172+
binary_ex = _extend_binary_operator(:f, :type_requirements, :build_converters)
173+
unary_ex = _extend_unary_operator(:f, :type_requirements)
121174
return quote
122175
local type_requirements
123176
local build_converters
177+
local binary_exists
178+
local unary_exists
124179
if isa($operators, OperatorEnum)
125180
type_requirements = Number
126181
build_converters = true
182+
binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum
183+
unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum
127184
else
128185
type_requirements = Any
129186
build_converters = false
187+
binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum
188+
unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum
130189
end
131-
for (op, f) in enumerate(map(Symbol, $(operators).binops))
190+
# Trigger errors if operators are not yet defined:
191+
empty!($(LATEST_BINARY_OPERATOR_MAPPING))
192+
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
193+
for (op, func) in enumerate($(operators).binops)
194+
local f = Symbol(func)
195+
local skip = false
132196
if isdefined(Base, f)
133197
f = :(Base.$(f))
134198
elseif $(skip_user_operators)
135-
continue
199+
skip = true
136200
else
137201
f = :($($__module__).$(f))
138202
end
139-
eval($binary_ex)
203+
$(LATEST_BINARY_OPERATOR_MAPPING)[func] = op
204+
skip && continue
205+
# Avoid redefining methods:
206+
if !haskey(unary_exists, func)
207+
eval($binary_ex)
208+
unary_exists[func] = true
209+
end
140210
end
141-
for (op, f) in enumerate(map(Symbol, $(operators).unaops))
211+
for (op, func) in enumerate($(operators).unaops)
212+
local f = Symbol(func)
213+
local skip = false
142214
if isdefined(Base, f)
143215
f = :(Base.$(f))
144216
elseif $(skip_user_operators)
145-
continue
217+
skip = true
146218
else
147219
f = :($($__module__).$(f))
148220
end
149-
eval($unary_ex)
221+
$(LATEST_UNARY_OPERATOR_MAPPING)[func] = op
222+
skip && continue
223+
# Avoid redefining methods:
224+
if !haskey(binary_exists, func)
225+
eval($unary_ex)
226+
binary_exists[func] = true
227+
end
150228
end
151229
end
152230
end
@@ -162,14 +240,16 @@ apply this macro to the operator enum in the same module you have the operators
162240
defined.
163241
"""
164242
macro extend_operators(operators)
165-
ex = _extend_operators(esc(operators), false, __module__)
243+
ex = _extend_operators(operators, false, __module__)
166244
expected_type = AbstractOperatorEnum
167-
quote
168-
if !isa($(esc(operators)), $expected_type)
169-
error("You must pass an operator enum to `@extend_operators`.")
170-
end
171-
$ex
172-
end
245+
return esc(
246+
quote
247+
if !isa($(operators), $expected_type)
248+
error("You must pass an operator enum to `@extend_operators`.")
249+
end
250+
$ex
251+
end,
252+
)
173253
end
174254

175255
"""
@@ -179,14 +259,16 @@ Similar to `@extend_operators`, but only extends operators already
179259
defined in `Base`.
180260
"""
181261
macro extend_operators_base(operators)
182-
ex = _extend_operators(esc(operators), true, __module__)
262+
ex = _extend_operators(operators, true, __module__)
183263
expected_type = AbstractOperatorEnum
184-
quote
185-
if !isa($(esc(operators)), $expected_type)
186-
error("You must pass an operator enum to `@extend_operators_base`.")
187-
end
188-
$ex
189-
end
264+
return esc(
265+
quote
266+
if !isa($(operators), $expected_type)
267+
error("You must pass an operator enum to `@extend_operators_base`.")
268+
end
269+
$ex
270+
end,
271+
)
190272
end
191273

192274
"""

test/test_derivatives.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using DynamicExpressions: eval_diff_tree_array, eval_grad_tree_array
44
using Random
55
using Zygote
66
using LinearAlgebra
7+
include("test_params.jl")
78

89
seed = 0
910
# SIMD doesn't like abs(x) ^ y for some reason.
1011
pow_abs2(x, y) = exp(y * log(abs(x)))
11-
custom_cos(x) = cos(x)^2
1212

1313
equation1(x1, x2, x3) = x1 + x2 + x3 + 3.2
1414
equation2(x1, x2, x3) = pow_abs2(x1, x2) + x3 + custom_cos(1.0 + x3) + 3.0 / x1

0 commit comments

Comments
 (0)