Skip to content

Commit 6b9fb86

Browse files
authored
Merge pull request #103 from SymbolicML/cleaner-treemapreduce
Avoid closure in `tree_mapreduce`
2 parents 7d6c11c + 1cf33ed commit 6b9fb86

14 files changed

+65
-226
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ jobs:
3434
- macOS-latest
3535
include:
3636
- os: ubuntu-latest
37-
julia-version: '1.7'
38-
- os: ubuntu-latest
39-
julia-version: '1.6'
37+
julia-version: '1.10'
4038

4139
steps:
4240
- uses: actions/checkout@v2

.github/workflows/benchmark_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- uses: actions/checkout@v2
1717
- uses: julia-actions/setup-julia@v1
1818
with:
19-
version: "1.9"
19+
version: "1"
2020
- uses: julia-actions/cache@v1
2121
- name: Extract Package Name from Project.toml
2222
id: extract-package-name

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "1.0.1"
4+
version = "1.1.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8-
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
98
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
109
Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
1110
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
12-
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1311
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1412
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1513
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -32,18 +30,16 @@ DynamicExpressionsZygoteExt = "Zygote"
3230
[compat]
3331
Bumper = "0.6"
3432
ChainRulesCore = "1"
35-
Compat = "3.37, 4"
3633
DispatchDoctor = "0.4"
3734
Interfaces = "0.3"
3835
LoopVectorization = "0.12"
3936
MacroTools = "0.4, 0.5"
4037
Optim = "0.19, 1"
41-
PackageExtensionCompat = "1"
4238
PrecompileTools = "1"
4339
Reexport = "1"
4440
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
4541
Zygote = "0.6"
46-
julia = "1.6"
42+
julia = "1.10"
4743

4844
[extras]
4945
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"

ext/DynamicExpressionsOptimExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using DynamicExpressions:
88
get_scalar_constants,
99
set_scalar_constants!,
1010
get_number_type
11-
using Compat: @inline
1211

1312
import Optim: Optim, OptimizationResults, NLSolversBase
1413

src/DynamicExpressions.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ using DispatchDoctor: @stable, @unstable
2424
include("StructuredExpression.jl")
2525
end
2626

27-
import PackageExtensionCompat: @require_extensions
2827
import Reexport: @reexport
2928
macro ignore(args...) end
3029

@@ -104,10 +103,6 @@ end
104103
import .InterfacesModule:
105104
ExpressionInterface, NodeInterface, all_ei_methods_except, all_ni_methods_except
106105

107-
function __init__()
108-
@require_extensions
109-
end
110-
111106
include("deprecated.jl")
112107

113108
import TOML: parsefile

src/Node.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module NodeModule
33
using DispatchDoctor: @unstable
44

55
import ..OperatorEnumModule: AbstractOperatorEnum
6-
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
6+
import ..UtilsModule: deprecate_varmap, Undefined
77

88
const DEFAULT_NODE_TYPE = Float32
99

src/NodeUtils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module NodeUtilsModule
22

3-
import Compat: Returns
43
import ..NodeModule:
54
AbstractNode,
65
AbstractExpressionNode,

src/Random.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module RandomModule
22

3-
using Compat: Returns, @inline
43
using Random: AbstractRNG
54
using ..NodeModule: AbstractNode, tree_mapreduce, filter_map
65
using ..ExpressionModule: AbstractExpression, get_tree

src/Utils.jl

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -13,103 +13,6 @@ macro return_on_false2(flag, retval, retval2)
1313
)
1414
end
1515

16-
"""
17-
@memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode)
18-
...
19-
end
20-
21-
This macro takes a function definition and creates a second version of the
22-
function with an additional `id_map` argument. When passed this argument (an
23-
IdDict()), it will use use the `id_map` to avoid recomputing the same value
24-
for the same node in a tree. Use this to automatically create functions that
25-
work with trees that have shared child nodes.
26-
27-
Can optionally take a `postprocess` function, which will be applied to the
28-
result of the function before returning it, taking the result as the
29-
first argument and a boolean for whether the result was memoized as the
30-
second argument. This is useful for functions that need to count the number
31-
of unique nodes in a tree, for example.
32-
"""
33-
macro memoize_on(tree, args...)
34-
if length(args) (1, 2)
35-
error("Expected 2 or 3 arguments to @memoize_on")
36-
end
37-
postprocess = length(args) == 1 ? :((r, _) -> r) : args[1]
38-
def = length(args) == 1 ? args[1] : args[2]
39-
idmap_def = _memoize_on(tree, postprocess, def)
40-
41-
return quote
42-
$(esc(def)) # The normal function
43-
$(esc(idmap_def)) # The function with an id_map argument
44-
end
45-
end
46-
function _memoize_on(tree::Symbol, postprocess, def)
47-
sdef = splitdef(def)
48-
49-
# Add an id_map argument
50-
push!(sdef[:args], :(id_map::AbstractDict))
51-
52-
f_name = sdef[:name]
53-
54-
# Forward id_map argument to all calls of the same function
55-
# within the function body:
56-
sdef[:body] = postwalk(sdef[:body]) do ex
57-
if @capture(ex, f_(args__))
58-
if f == f_name
59-
return Expr(:call, f, args..., :id_map)
60-
end
61-
end
62-
return ex
63-
end
64-
65-
# Wrap the function body in a get!(id_map, tree) do ... end block:
66-
@gensym key is_memoized result body
67-
sdef[:body] = quote
68-
$key = objectid($tree)
69-
$is_memoized = haskey(id_map, $key)
70-
function $body()
71-
return $(sdef[:body])
72-
end
73-
$result = if $is_memoized
74-
@inbounds(id_map[$key])
75-
else
76-
id_map[$key] = $body()
77-
end
78-
return $postprocess($result, $is_memoized)
79-
end
80-
81-
return combinedef(sdef)
82-
end
83-
84-
"""
85-
@with_memoize(call, id_map)
86-
87-
This simple macro simply puts the `id_map`
88-
into the call, to be consistent with the `@memoize_on` macro.
89-
90-
```
91-
@with_memoize(_copy_node(tree), IdDict{Any,Any}())
92-
````
93-
94-
is converted to
95-
96-
```
97-
_copy_node(tree, IdDict{Any,Any}())
98-
```
99-
100-
"""
101-
macro with_memoize(def, id_map)
102-
idmap_def = _add_idmap_to_call(def, id_map)
103-
return quote
104-
$(esc(idmap_def))
105-
end
106-
end
107-
108-
function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr})
109-
@assert def.head == :call
110-
return Expr(:call, def.args[1], def.args[2:end]..., id_map)
111-
end
112-
11316
@inline function fill_similar(value::T, array, args...) where {T}
11417
out_array = similar(array, args...)
11518
fill!(out_array, value)

src/base.jl

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import Base:
2424
sum
2525

2626
using DispatchDoctor: @unstable
27-
using Compat: @inline, Returns
28-
using ..UtilsModule: @memoize_on, @with_memoize, Undefined
27+
using ..UtilsModule: Undefined
2928

3029
"""
3130
tree_mapreduce(
@@ -94,41 +93,66 @@ function tree_mapreduce(
9493
f_on_shared::H=(result, is_shared) -> result,
9594
break_sharing::Val{BS}=Val(false),
9695
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT,BS}
97-
98-
# Trick taken from here:
99-
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
100-
# to speed up recursive closure
101-
@memoize_on t f_on_shared function inner(inner, t)
102-
if t.degree == 0
103-
return @inline(f_leaf(t))
104-
elseif t.degree == 1
105-
return @inline(op(@inline(f_branch(t)), inner(inner, t.l)))
106-
else
107-
return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r)))
108-
end
109-
end
110-
11196
sharing = preserve_sharing(typeof(tree)) && !BS
11297

11398
RT == Undefined &&
11499
sharing &&
115100
throw(ArgumentError("Need to specify `result_type` if nodes are shared.."))
116101

117102
if sharing && RT != Undefined
118-
d = allocate_id_map(tree, RT)
119-
return @with_memoize inner(inner, tree) d
103+
id_map = allocate_id_map(tree, RT)
104+
reducer = TreeMapreducer(Val(2), id_map, f_leaf, f_branch, op, f_on_shared)
105+
return call_mapreducer(reducer, tree)
106+
else
107+
reducer = TreeMapreducer(Val(2), nothing, f_leaf, f_branch, op, f_on_shared)
108+
return call_mapreducer(reducer, tree)
109+
end
110+
end
111+
112+
struct TreeMapreducer{
113+
D,ID<:Union{Nothing,Dict},F1<:Function,F2<:Function,G<:Function,H<:Function
114+
}
115+
max_degree::Val{D}
116+
id_map::ID
117+
f_leaf::F1
118+
f_branch::F2
119+
op::G
120+
f_on_shared::H
121+
end
122+
123+
function call_mapreducer(mapreducer::TreeMapreducer{2,ID}, tree::AbstractNode) where {ID}
124+
key = ID <: Dict ? objectid(tree) : nothing
125+
if ID <: Dict && haskey(mapreducer.id_map, key)
126+
result = @inbounds(mapreducer.id_map[key])
127+
return mapreducer.f_on_shared(result, true)
120128
else
121-
return inner(inner, tree)
129+
result = if tree.degree == 0
130+
mapreducer.f_leaf(tree)
131+
elseif tree.degree == 1
132+
mapreducer.op(mapreducer.f_branch(tree), call_mapreducer(mapreducer, tree.l))
133+
else
134+
mapreducer.op(
135+
mapreducer.f_branch(tree),
136+
call_mapreducer(mapreducer, tree.l),
137+
call_mapreducer(mapreducer, tree.r),
138+
)
139+
end
140+
if ID <: Dict
141+
mapreducer.id_map[key] = result
142+
return mapreducer.f_on_shared(result, false)
143+
else
144+
return result
145+
end
122146
end
123147
end
148+
124149
function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT}
125150
d = Dict{UInt,RT}()
126151
# Preallocate maximum storage (counting with duplicates is fast)
127152
N = length(tree; break_sharing=Val(true))
128153
sizehint!(d, N)
129154
return d
130155
end
131-
132156
# TODO: Raise Julia issue for this.
133157
# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here!
134158
# I think it's because `setindex!` is declared with `@nospecialize` in IdDict.

0 commit comments

Comments
 (0)