@@ -6,82 +6,135 @@ import ..EvaluateEquationModule: eval_tree_array
6
6
import .. EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
7
7
import .. EvaluationHelpersModule: _grad_evaluator
8
8
9
- function create_evaluation_helpers! (operators:: OperatorEnum )
10
- @eval begin
11
- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
12
- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
13
- function (tree:: Node )(X; kws... )
14
- Base. depwarn (
15
- " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
16
- :Node ,
17
- )
18
- return tree (X, $ operators; kws... )
19
- end
20
- # Gradients:
21
- function _grad_evaluator (tree:: Node , X; kws... )
22
- Base. depwarn (
23
- " The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." ,
24
- :Node ,
25
- )
26
- return _grad_evaluator (tree, X, $ operators; kws... )
27
- end
9
+ """ Used to set a default value for `operators` for ease of use."""
10
+ @enum AvailableOperatorTypes begin
11
+ IsNothing
12
+ IsOperatorEnum
13
+ IsGenericOperatorEnum
14
+ end
15
+
16
+ # These constants are purely for convenience. Internal code
17
+ # should make use of `Node`, `string_tree`, `eval_tree_array`,
18
+ # and `eval_grad_tree_array` directly.
19
+
20
+ const LATEST_OPERATORS = Ref {Union{Nothing,AbstractOperatorEnum}} (nothing )
21
+ const LATEST_OPERATORS_TYPE = Ref {AvailableOperatorTypes} (IsNothing)
22
+ const LATEST_UNARY_OPERATOR_MAPPING = Dict {Function,Int} ()
23
+ const LATEST_BINARY_OPERATOR_MAPPING = Dict {Function,Int} ()
24
+ const ALREADY_DEFINED_UNARY_OPERATORS = (;
25
+ operator_enum= Dict {Function,Bool} (), generic_operator_enum= Dict {Function,Bool} ()
26
+ )
27
+ const ALREADY_DEFINED_BINARY_OPERATORS = (;
28
+ operator_enum= Dict {Function,Bool} (), generic_operator_enum= Dict {Function,Bool} ()
29
+ )
30
+
31
+ function Base. show (io:: IO , tree:: Node )
32
+ latest_operators_type = LATEST_OPERATORS_TYPE. x
33
+ if latest_operators_type == IsNothing
34
+ return print (io, string_tree (tree))
35
+ elseif latest_operators_type == IsOperatorEnum
36
+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
37
+ return print (io, string_tree (tree, latest_operators))
38
+ else
39
+ latest_operators = LATEST_OPERATORS. x:: GenericOperatorEnum
40
+ return print (io, string_tree (tree, latest_operators))
28
41
end
29
42
end
43
+ function (tree:: Node )(X; kws... )
44
+ Base. depwarn (
45
+ " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
46
+ :Node ,
47
+ )
48
+ latest_operators_type = LATEST_OPERATORS_TYPE. x
49
+ if latest_operators_type == IsNothing
50
+ error (" Please use the `tree(X, operators; kws...)` syntax instead." )
51
+ elseif latest_operators_type == IsOperatorEnum
52
+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
53
+ return tree (X, latest_operators; kws... )
54
+ else
55
+ latest_operators = LATEST_OPERATORS. x:: GenericOperatorEnum
56
+ return tree (X, latest_operators; kws... )
57
+ end
58
+ end
59
+
60
+ function _grad_evaluator (tree:: Node , X; kws... )
61
+ Base. depwarn (
62
+ " The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." ,
63
+ :Node ,
64
+ )
65
+ latest_operators_type = LATEST_OPERATORS_TYPE. x
66
+ # return _grad_evaluator(tree, X, $operators; kws...)
67
+ if latest_operators_type == IsNothing
68
+ error (" Please use the `tree'(X, operators; kws...)` syntax instead." )
69
+ elseif latest_operators_type == IsOperatorEnum
70
+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
71
+ return _grad_evaluator (tree, X, latest_operators; kws... )
72
+ else
73
+ error (" Gradients are not implemented for `GenericOperatorEnum`." )
74
+ end
75
+ end
76
+
77
+ function create_evaluation_helpers! (operators:: OperatorEnum )
78
+ LATEST_OPERATORS. x = operators
79
+ return LATEST_OPERATORS_TYPE. x = IsOperatorEnum
80
+ end
30
81
31
82
function create_evaluation_helpers! (operators:: GenericOperatorEnum )
32
- @eval begin
33
- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
34
- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
35
-
36
- function (tree:: Node )(X; kws... )
37
- Base. depwarn (
38
- " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
39
- :Node ,
40
- )
41
- return tree (X, $ operators; kws... )
42
- end
43
- function _grad_evaluator (tree:: Node , X; kws... )
44
- return error (" Gradients are not implemented for `GenericOperatorEnum`." )
45
- end
83
+ LATEST_OPERATORS. x = operators
84
+ return LATEST_OPERATORS_TYPE. x = IsGenericOperatorEnum
85
+ end
86
+ function lookup_op (@nospecialize (f), :: Val{degree} ) where {degree}
87
+ mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
88
+ if ! haskey (mapping, f)
89
+ error (
90
+ " Convenience constructor for `Node` using operator `$(f) ` is out-of-date. " *
91
+ " Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
92
+ " `define_helper_functions=true` and pass `$(f) `." ,
93
+ )
46
94
end
95
+ return mapping[f]
47
96
end
48
97
49
- function _extend_unary_operator (f:: Symbol , op, type_requirements)
98
+ function _extend_unary_operator (f:: Symbol , type_requirements)
50
99
quote
51
100
quote
52
101
function $ ($ f)(l:: Node{T} ):: Node{T} where {T<: $ ($ type_requirements)}
53
102
return if (l. degree == 0 && l. constant)
54
103
Node (T; val= $ ($ f)(l. val:: T ))
55
104
else
56
- Node ($ ($ op), l)
105
+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (1 ))
106
+ Node (latest_op_idx, l)
57
107
end
58
108
end
59
109
end
60
110
end
61
111
end
62
112
63
- function _extend_binary_operator (f:: Symbol , op, type_requirements, build_converters)
113
+ function _extend_binary_operator (f:: Symbol , type_requirements, build_converters)
64
114
quote
65
115
quote
66
116
function $ ($ f)(l:: Node{T} , r:: Node{T} ) where {T<: $ ($ type_requirements)}
67
117
if (l. degree == 0 && l. constant && r. degree == 0 && r. constant)
68
118
Node (T; val= $ ($ f)(l. val:: T , r. val:: T ))
69
119
else
70
- Node ($ ($ op), l, r)
120
+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
121
+ Node (latest_op_idx, l, r)
71
122
end
72
123
end
73
124
function $ ($ f)(l:: Node{T} , r:: T ) where {T<: $ ($ type_requirements)}
74
125
if l. degree == 0 && l. constant
75
126
Node (T; val= $ ($ f)(l. val:: T , r))
76
127
else
77
- Node ($ ($ op), l, Node (T; val= r))
128
+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
129
+ Node (latest_op_idx, l, Node (T; val= r))
78
130
end
79
131
end
80
132
function $ ($ f)(l:: T , r:: Node{T} ) where {T<: $ ($ type_requirements)}
81
133
if r. degree == 0 && r. constant
82
134
Node (T; val= $ ($ f)(l, r. val:: T ))
83
135
else
84
- Node ($ ($ op), Node (T; val= l), r)
136
+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
137
+ Node (latest_op_idx, Node (T; val= l), r)
85
138
end
86
139
end
87
140
if $ ($ build_converters)
@@ -116,37 +169,62 @@ function _extend_binary_operator(f::Symbol, op, type_requirements, build_convert
116
169
end
117
170
118
171
function _extend_operators (operators, skip_user_operators, __module__:: Module )
119
- binary_ex = _extend_binary_operator (:f , :op , : type_requirements , :build_converters )
120
- unary_ex = _extend_unary_operator (:f , :op , : type_requirements )
172
+ binary_ex = _extend_binary_operator (:f , :type_requirements , :build_converters )
173
+ unary_ex = _extend_unary_operator (:f , :type_requirements )
121
174
return quote
122
175
local type_requirements
123
176
local build_converters
177
+ local binary_exists
178
+ local unary_exists
124
179
if isa ($ operators, OperatorEnum)
125
180
type_requirements = Number
126
181
build_converters = true
182
+ binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). operator_enum
183
+ unary_exists = $ (ALREADY_DEFINED_UNARY_OPERATORS). operator_enum
127
184
else
128
185
type_requirements = Any
129
186
build_converters = false
187
+ binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). generic_operator_enum
188
+ unary_exists = $ (ALREADY_DEFINED_UNARY_OPERATORS). generic_operator_enum
130
189
end
131
- for (op, f) in enumerate (map (Symbol, $ (operators). binops))
190
+ # Trigger errors if operators are not yet defined:
191
+ empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
192
+ empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
193
+ for (op, func) in enumerate ($ (operators). binops)
194
+ local f = Symbol (func)
195
+ local skip = false
132
196
if isdefined (Base, f)
133
197
f = :(Base.$ (f))
134
198
elseif $ (skip_user_operators)
135
- continue
199
+ skip = true
136
200
else
137
201
f = :($ ($ __module__). $ (f))
138
202
end
139
- eval ($ binary_ex)
203
+ $ (LATEST_BINARY_OPERATOR_MAPPING)[func] = op
204
+ skip && continue
205
+ # Avoid redefining methods:
206
+ if ! haskey (unary_exists, func)
207
+ eval ($ binary_ex)
208
+ unary_exists[func] = true
209
+ end
140
210
end
141
- for (op, f) in enumerate (map (Symbol, $ (operators). unaops))
211
+ for (op, func) in enumerate ($ (operators). unaops)
212
+ local f = Symbol (func)
213
+ local skip = false
142
214
if isdefined (Base, f)
143
215
f = :(Base.$ (f))
144
216
elseif $ (skip_user_operators)
145
- continue
217
+ skip = true
146
218
else
147
219
f = :($ ($ __module__). $ (f))
148
220
end
149
- eval ($ unary_ex)
221
+ $ (LATEST_UNARY_OPERATOR_MAPPING)[func] = op
222
+ skip && continue
223
+ # Avoid redefining methods:
224
+ if ! haskey (binary_exists, func)
225
+ eval ($ unary_ex)
226
+ binary_exists[func] = true
227
+ end
150
228
end
151
229
end
152
230
end
@@ -162,14 +240,16 @@ apply this macro to the operator enum in the same module you have the operators
162
240
defined.
163
241
"""
164
242
macro extend_operators (operators)
165
- ex = _extend_operators (esc ( operators) , false , __module__)
243
+ ex = _extend_operators (operators, false , __module__)
166
244
expected_type = AbstractOperatorEnum
167
- quote
168
- if ! isa ($ (esc (operators)), $ expected_type)
169
- error (" You must pass an operator enum to `@extend_operators`." )
170
- end
171
- $ ex
172
- end
245
+ return esc (
246
+ quote
247
+ if ! isa ($ (operators), $ expected_type)
248
+ error (" You must pass an operator enum to `@extend_operators`." )
249
+ end
250
+ $ ex
251
+ end ,
252
+ )
173
253
end
174
254
175
255
"""
@@ -179,14 +259,16 @@ Similar to `@extend_operators`, but only extends operators already
179
259
defined in `Base`.
180
260
"""
181
261
macro extend_operators_base (operators)
182
- ex = _extend_operators (esc ( operators) , true , __module__)
262
+ ex = _extend_operators (operators, true , __module__)
183
263
expected_type = AbstractOperatorEnum
184
- quote
185
- if ! isa ($ (esc (operators)), $ expected_type)
186
- error (" You must pass an operator enum to `@extend_operators_base`." )
187
- end
188
- $ ex
189
- end
264
+ return esc (
265
+ quote
266
+ if ! isa ($ (operators), $ expected_type)
267
+ error (" You must pass an operator enum to `@extend_operators_base`." )
268
+ end
269
+ $ ex
270
+ end ,
271
+ )
190
272
end
191
273
192
274
"""
0 commit comments