Skip to content

Commit 3046dcb

Browse files
authored
Merge pull request #102 from SymbolicML/release-v1
Release v1.0.0 / StructuredExpression change
2 parents d931c69 + c77cf7a commit 3046dcb

File tree

6 files changed

+84
-76
lines changed

6 files changed

+84
-76
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "0.19.3"
4+
version = "1.0.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/Expression.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,23 +274,23 @@ copy_node(ex::AbstractExpression; kws...) = copy(ex)
274274
count_nodes(ex::AbstractExpression; kws...) = count_nodes(get_tree(ex); kws...)
275275

276276
function tree_mapreduce(
277-
f::Function,
278-
op::Function,
277+
f::F,
278+
op::G,
279279
ex::AbstractExpression,
280-
result_type::Type=Undefined;
280+
result_type::Type{RT}=Undefined;
281281
kws...,
282-
)
283-
return tree_mapreduce(f, op, get_tree(ex), result_type; kws...)
282+
) where {F<:Function,G<:Function,RT}
283+
return tree_mapreduce(f, op, get_tree(ex), RT; kws...)
284284
end
285285
function tree_mapreduce(
286-
f_leaf::Function,
287-
f_branch::Function,
288-
op::Function,
286+
f_leaf::F,
287+
f_branch::G,
288+
op::H,
289289
ex::AbstractExpression,
290-
result_type::Type=Undefined;
290+
result_type::Type{RT}=Undefined;
291291
kws...,
292-
)
293-
return tree_mapreduce(f_leaf, f_branch, op, get_tree(ex), result_type; kws...)
292+
) where {F<:Function,G<:Function,H<:Function,RT}
293+
return tree_mapreduce(f_leaf, f_branch, op, get_tree(ex), RT; kws...)
294294
end
295295

296296
count_constant_nodes(ex::AbstractExpression) = count_constant_nodes(get_tree(ex))

src/Node.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,13 @@ include("base.jl")
218218
end
219219
return node_factory(N, T1, val, feature, op, l, r, allocator)
220220
end
221-
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode}
222-
return nothing
223-
end
224-
function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}}
225-
if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing
226-
error(
227-
"Encountered the call for $N() inside the generic constructor. "
228-
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
229-
)
230-
end
221+
validate_not_all_defaults(::Type{<:AbstractExpressionNode}, val, feature, op, l, r, children) = nothing
222+
validate_not_all_defaults(::Type{<:AbstractExpressionNode{T}}, val, feature, op, l, r, children) where {T} = nothing
223+
function validate_not_all_defaults(::Type{N}, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing, ::Nothing) where {T,N<:AbstractExpressionNode{T}}
224+
error(
225+
"Encountered the call for $N() inside the generic constructor. "
226+
* "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?"
227+
)
231228
return nothing
232229
end
233230
"""Create a constant leaf."""

src/StructuredExpression.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ kws = (;
3737
f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...)
3838
g = parse_expression(:(exp(-(y * y))); kws...)
3939
40-
f_plus_g = StructuredExpression((; f, g), nt -> nt.f + nt.g)
40+
f_plus_g = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g)
4141
```
4242
4343
Now, when evaluating `f_plus_g`, this expression type will
@@ -83,8 +83,8 @@ struct StructuredExpression{
8383
end
8484

8585
function StructuredExpression(
86-
trees::NamedTuple,
87-
structure::F;
86+
trees::NamedTuple;
87+
structure::F,
8888
operators::Union{AbstractOperatorEnum,Nothing}=nothing,
8989
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
9090
extra...,

src/base.jl

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ function tree_mapreduce(
8181
tree::AbstractNode,
8282
result_type::Type{RT}=Undefined;
8383
f_on_shared::H=(result, is_shared) -> result,
84-
break_sharing=Val(false),
85-
) where {RT,F<:Function,G<:Function,H<:Function}
86-
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing)
84+
break_sharing::Val{BS}=Val(false),
85+
) where {RT,F<:Function,G<:Function,H<:Function,BS}
86+
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing=Val(BS))
8787
end
8888
function tree_mapreduce(
8989
f_leaf::F1,
@@ -92,8 +92,8 @@ function tree_mapreduce(
9292
tree::AbstractNode,
9393
result_type::Type{RT}=Undefined;
9494
f_on_shared::H=(result, is_shared) -> result,
95-
break_sharing::Val=Val(false),
96-
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT}
95+
break_sharing::Val{BS}=Val(false),
96+
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT,BS}
9797

9898
# Trick taken from here:
9999
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
@@ -108,7 +108,7 @@ function tree_mapreduce(
108108
end
109109
end
110110

111-
sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false)
111+
sharing = preserve_sharing(typeof(tree)) && !BS
112112

113113
RT == Undefined &&
114114
sharing &&
@@ -222,14 +222,14 @@ end
222222
223223
Count the number of nodes in the tree.
224224
"""
225-
function count_nodes(tree::AbstractNode; break_sharing=Val(false))
225+
function count_nodes(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
226226
return tree_mapreduce(
227227
_ -> 1,
228228
+,
229229
tree,
230230
Int64;
231231
f_on_shared=(c, is_shared) -> is_shared ? 0 : c,
232-
break_sharing,
232+
break_sharing=Val(BS),
233233
)
234234
end
235235

@@ -239,10 +239,14 @@ end
239239
Apply a function to each node in a tree without returning the results.
240240
"""
241241
function foreach(
242-
f::F, tree::AbstractNode; break_sharing::Val=Val(false)
243-
) where {F<:Function}
242+
f::F, tree::AbstractNode; break_sharing::Val{BS}=Val(false)
243+
) where {F<:Function,BS}
244244
tree_mapreduce(
245-
t -> (@inline(f(t)); nothing), Returns(nothing), tree, Nothing; break_sharing
245+
t -> (@inline(f(t)); nothing),
246+
Returns(nothing),
247+
tree,
248+
Nothing;
249+
break_sharing=Val(BS),
246250
)
247251
return nothing
248252
end
@@ -260,10 +264,10 @@ function filter_map(
260264
map_fnc::G,
261265
tree::AbstractNode,
262266
result_type::Type{GT};
263-
break_sharing::Val=Val(false),
264-
) where {F<:Function,G<:Function,GT}
265-
stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing))
266-
filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing)
267+
break_sharing::Val{BS}=Val(false),
268+
) where {F<:Function,G<:Function,GT,BS}
269+
stack = Array{GT}(undef, count(filter_fnc, tree; init=0, break_sharing=Val(BS)))
270+
filter_map!(filter_fnc, map_fnc, stack, tree; break_sharing=Val(BS))
267271
return stack::Vector{GT}
268272
end
269273

@@ -277,10 +281,10 @@ function filter_map!(
277281
map_fnc::G,
278282
destination::Vector{GT},
279283
tree::AbstractNode;
280-
break_sharing::Val=Val(false),
281-
) where {GT,F<:Function,G<:Function}
284+
break_sharing::Val{BS}=Val(false),
285+
) where {GT,F<:Function,G<:Function,BS}
282286
pointer = Ref(0)
283-
foreach(tree; break_sharing) do t
287+
foreach(tree; break_sharing=Val(BS)) do t
284288
if @inline(filter_fnc(t))
285289
map_result = @inline(map_fnc(t))::GT
286290
@inbounds destination[pointer.x += 1] = map_result
@@ -294,55 +298,60 @@ end
294298
295299
Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
296300
"""
297-
function filter(f::F, tree::AbstractNode; break_sharing::Val=Val(false)) where {F<:Function}
298-
return filter_map(f, identity, tree, typeof(tree); break_sharing)
301+
function filter(
302+
f::F, tree::AbstractNode; break_sharing::Val{BS}=Val(false)
303+
) where {F<:Function,BS}
304+
return filter_map(f, identity, tree, typeof(tree); break_sharing=Val(BS))
299305
end
300306

301307
"""
302308
collect(tree::AbstractNode; break_sharing::Val=Val(false))
303309
304310
Collect all nodes in a tree into a flat array in depth-first order.
305311
"""
306-
function collect(tree::AbstractNode; break_sharing::Val=Val(false))
307-
return filter(Returns(true), tree; break_sharing)
312+
function collect(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
313+
return filter(Returns(true), tree; break_sharing=Val(BS))
308314
end
309315

310316
"""
311-
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT}
317+
map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val{BS}=Val(false)) where {F<:Function,RT,BS}
312318
313319
Map a function over a tree and return a flat array of the results in depth-first order.
314320
Pre-specifying the `result_type` of the function can be used to avoid extra allocations.
315321
"""
316322
function map(
317-
f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)
318-
) where {F<:Function,RT}
323+
f::F,
324+
tree::AbstractNode,
325+
result_type::Type{RT}=Nothing;
326+
break_sharing::Val{BS}=Val(false),
327+
) where {F<:Function,RT,BS}
319328
if RT == Nothing
320-
return map(f, collect(tree; break_sharing))
329+
return map(f, collect(tree; break_sharing=Val(BS)))
321330
else
322-
return filter_map(Returns(true), f, tree, result_type; break_sharing)
331+
return filter_map(Returns(true), f, tree, result_type; break_sharing=Val(BS))
323332
end
324333
end
325334

326335
"""
327-
count(f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)) where {F<:Function}
336+
count(f::F, tree::AbstractNode; init=0, break_sharing::Val{BS}=Val(false)) where {F<:Function,BS}
328337
329338
Count the number of nodes in a tree for which the function returns `true`.
330339
"""
331340
function count(
332-
f::F, tree::AbstractNode; init=0, break_sharing::Val=Val(false)
333-
) where {F<:Function}
341+
f::F, tree::AbstractNode; init=0, break_sharing::Val{BS}=Val(false)
342+
) where {F<:Function,BS}
334343
return tree_mapreduce(
335344
t -> @inline(f(t)) ? 1 : 0,
336345
+,
337346
tree,
338347
Int64;
339348
f_on_shared=(c, is_shared) -> is_shared ? 0 : c,
340-
break_sharing,
349+
break_sharing=Val(BS),
341350
) + init
342351
end
343352

344353
"""
345-
sum(f::Function, tree::AbstractNode; result_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
354+
sum(f::Function, tree::AbstractNode; result_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val{BS}=Val(false)) where {F<:Function,BS}
346355
347356
Sum the results of a function over a tree. For graphs with shared nodes
348357
such as [`GraphNode`](@ref), the function `f_on_shared` is called on the result
@@ -386,7 +395,7 @@ function mapreduce(
386395
"Must specify `result_type` as a keyword argument to `mapreduce` if `preserve_sharing` is true."
387396
)
388397
end
389-
return tree_mapreduce(f, op, tree, RT; f_on_shared, break_sharing)
398+
return tree_mapreduce(f, op, tree, RT; f_on_shared, break_sharing=Val(BS))
390399
end
391400

392401
isempty(::AbstractNode) = false
@@ -396,8 +405,8 @@ end
396405
@unstable iterate(::AbstractNode, stack) =
397406
isempty(stack) ? nothing : (popfirst!(stack), stack)
398407
in(item, tree::AbstractNode) = any(t -> t == item, tree)
399-
function length(tree::AbstractNode; break_sharing::Val=Val(false))
400-
return count_nodes(tree; break_sharing)
408+
function length(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS}
409+
return count_nodes(tree; break_sharing=Val(BS))
401410
end
402411

403412
"""
@@ -407,8 +416,8 @@ Compute a hash of a tree. This will compute a hash differently
407416
if nodes are shared in a tree. This is ignored if `break_sharing` is set to `Val(true)`.
408417
"""
409418
function hash(
410-
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val=Val(false)
411-
) where {T}
419+
tree::AbstractExpressionNode{T}, h::UInt=zero(UInt); break_sharing::Val{BS}=Val(false)
420+
) where {T,BS}
412421
return tree_mapreduce(
413422
t -> leaf_hash(h, t),
414423
identity,
@@ -417,7 +426,7 @@ function hash(
417426
UInt;
418427
f_on_shared=(cur_hash, is_shared) ->
419428
is_shared ? hash((:shared, cur_hash), h) : cur_hash,
420-
break_sharing,
429+
break_sharing=Val(BS),
421430
)
422431
end
423432
function leaf_hash(h::UInt, t::AbstractExpressionNode)
@@ -428,17 +437,17 @@ function branch_hash(h::UInt, t::AbstractExpressionNode, children::Vararg{Any,M}
428437
end
429438

430439
"""
431-
copy_node(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
440+
copy_node(tree::AbstractExpressionNode; break_sharing::Val{BS}=Val(false)) where {BS}
432441
433442
Copy a node, recursively copying all children nodes.
434443
This is more efficient than the built-in copy.
435444
436445
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
437446
"""
438447
function copy_node(
439-
tree::N; break_sharing::Val=Val(false)
440-
) where {T,N<:AbstractExpressionNode{T}}
441-
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing)
448+
tree::N; break_sharing::Val{BS}=Val(false)
449+
) where {T,N<:AbstractExpressionNode{T},BS}
450+
return tree_mapreduce(leaf_copy, identity, branch_copy, tree, N; break_sharing=Val(BS))
442451
end
443452
function leaf_copy(t::N) where {T,N<:AbstractExpressionNode{T}}
444453
if t.constant
@@ -459,8 +468,8 @@ This is more efficient than the built-in copy.
459468
460469
If `break_sharing` is set to `Val(true)`, sharing in a tree will be ignored.
461470
"""
462-
function copy(tree::AbstractExpressionNode; break_sharing::Val=Val(false))
463-
return copy_node(tree; break_sharing)
471+
function copy(tree::AbstractExpressionNode; break_sharing::Val{BS}=Val(false)) where {BS}
472+
return copy_node(tree; break_sharing=Val(BS))
464473
end
465474

466475
"""

test/test_structured_expression.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
shower(ex) = sprint((io, e) -> show(io, MIME"text/plain"(), e), ex)
1313

14-
f_plus_g = StructuredExpression((; f, g), nt -> nt.f + nt.g)
15-
f_div_g = StructuredExpression((; f, g), nt -> nt.f / nt.g)
16-
cos_f = StructuredExpression((; f), nt -> cos(nt.f))
17-
exp_g = StructuredExpression((; g), nt -> exp(nt.g))
14+
f_plus_g = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g)
15+
f_div_g = StructuredExpression((; f, g); structure=nt -> nt.f / nt.g)
16+
cos_f = StructuredExpression((; f); structure=nt -> cos(nt.f))
17+
exp_g = StructuredExpression((; g); structure=nt -> exp(nt.g))
1818

1919
@test shower(f_plus_g) == "((x * x) - cos((2.5 * y) + -0.5)) + exp(-(y * y))"
2020
@test shower(f_div_g) == "((x * x) - cos((2.5 * y) + -0.5)) / exp(-(y * y))"
@@ -43,7 +43,7 @@ end
4343
f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...)
4444
g = parse_expression(:(exp(-(y * y))); kws...)
4545

46-
ex = StructuredExpression((; f, g), nt -> nt.f + nt.g)
46+
ex = StructuredExpression((; f, g); structure=nt -> nt.f + nt.g)
4747

4848
@test test(ExpressionInterface, StructuredExpression, [ex])
4949
end
@@ -64,7 +64,7 @@ end
6464
g = parse_expression(:(exp(-(y * y))); kws...)
6565

6666
c = [1]
67-
ex = StructuredExpression((; f, g), my_factory; a=c)
67+
ex = StructuredExpression((; f, g); structure=my_factory, a=c)
6868

6969
@test ex.metadata.extra.a[] == 1
7070
@test ex.metadata.extra.a === c
@@ -114,7 +114,9 @@ end
114114
This is a composite `AbstractExpression` object that composes multiple
115115
expressions during evaluation.
116116
=#
117-
ex = StructuredExpression((; f, g), nt -> nt.f + nt.g; operators, variable_names)
117+
ex = StructuredExpression(
118+
(; f, g); structure=nt -> nt.f + nt.g, operators, variable_names
119+
)
118120
ex
119121
@test typeof(ex) <: AbstractExpression{Float64,<:Node{Float64}} #src
120122
#=

0 commit comments

Comments
 (0)