Skip to content

Commit 9c5e4c4

Browse files
authored
Merge pull request #104 from SymbolicML/abstract-structured-expressions
Create `AbstractStructuredExpression`
2 parents 6b9fb86 + 818e7d2 commit 9c5e4c4

File tree

4 files changed

+56
-47
lines changed

4 files changed

+56
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <[email protected]>"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ import .ExpressionModule:
9393
import .ParseModule: parse_leaf
9494
@reexport import .ParametricExpressionModule: ParametricExpression, ParametricNode
9595
@reexport import .StructuredExpressionModule: StructuredExpression
96+
import .StructuredExpressionModule: AbstractStructuredExpression
9697

9798
@stable default_mode = "disable" begin
9899
include("Interfaces.jl")

src/StructuredExpression.jl

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
55
using ..ExpressionModule: AbstractExpression, Metadata, node_type
66
using ..ChainRulesModule: NodeTangent
77

8+
import ..NodeModule: constructorof
89
import ..ExpressionModule:
910
get_contents,
1011
get_metadata,
@@ -13,17 +14,33 @@ import ..ExpressionModule:
1314
get_variable_names,
1415
Metadata,
1516
_copy,
17+
_data,
1618
default_node_type,
1719
node_type,
1820
get_scalar_constants,
1921
set_scalar_constants!
2022

23+
abstract type AbstractStructuredExpression{
24+
T,F<:Function,N<:AbstractExpressionNode{T},E<:AbstractExpression{T,N},D<:NamedTuple
25+
} <: AbstractExpression{T,N} end
26+
2127
"""
22-
StructuredExpression
28+
StructuredExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} <: AbstractExpression{T,N}
2329
2430
This expression type allows you to combine multiple expressions
2531
together in a predefined way.
2632
33+
# Parameters
34+
35+
- `T`: The numeric value type of the expressions.
36+
- `F`: The type of the structure function, which combines each expression into a single expression.
37+
- `N`: The type of the nodes inside expressions.
38+
- `E`: The type of the expressions.
39+
- `TS`: The type of the named tuple containing those inner expressions.
40+
- `D`: The type of the metadata, another named tuple.
41+
42+
# Usage
43+
2744
For example, we can create two expressions, `f`, and `g`,
2845
and then combine them together in a new expression, `f_plus_g`,
2946
using a constructor function that simply adds them together:
@@ -56,29 +73,25 @@ which will create a new method particular to this expression type defined on tha
5673
"""
5774
struct StructuredExpression{
5875
T,
59-
F,
60-
EX<:NamedTuple,
76+
F<:Function,
6177
N<:AbstractExpressionNode{T},
6278
E<:AbstractExpression{T,N},
6379
TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}},
64-
D<:@NamedTuple{structure::F, operators::O, variable_names::V, extra::EX} where {O,V},
65-
} <: AbstractExpression{T,N}
80+
D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V},
81+
} <: AbstractStructuredExpression{T,F,N,E,D}
6682
trees::TS
6783
metadata::Metadata{D}
6884

6985
function StructuredExpression(
7086
trees::TS, metadata::Metadata{D}
7187
) where {
7288
TS,
73-
F,
74-
EX,
75-
D<:@NamedTuple{
76-
structure::F, operators::O, variable_names::V, extra::EX
77-
} where {O,V},
89+
F<:Function,
90+
D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V},
7891
}
7992
E = typeof(first(values(trees)))
8093
N = node_type(E)
81-
return new{eltype(N),F,EX,N,E,TS,D}(trees, metadata)
94+
return new{eltype(N),F,N,E,TS,D}(trees, metadata)
8295
end
8396
end
8497

@@ -87,65 +100,67 @@ function StructuredExpression(
87100
structure::F,
88101
operators::Union{AbstractOperatorEnum,Nothing}=nothing,
89102
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
90-
extra...,
91103
) where {F<:Function}
92104
example_tree = first(values(trees))
93105
operators = get_operators(example_tree, operators)
94106
variable_names = get_variable_names(example_tree, variable_names)
95-
metadata = (; structure, operators, variable_names, extra=(; extra...))
107+
metadata = (; structure, operators, variable_names)
96108
return StructuredExpression(trees, Metadata(metadata))
97109
end
98-
99-
function Base.copy(e::StructuredExpression)
110+
constructorof(::Type{<:StructuredExpression}) = StructuredExpression
111+
function Base.copy(e::AbstractStructuredExpression)
100112
ts = get_contents(e)
101113
meta = get_metadata(e)
114+
meta_inner = _data(meta)
102115
copy_ts = NamedTuple{keys(ts)}(map(copy, values(ts)))
103-
return StructuredExpression(
104-
copy_ts,
105-
Metadata((;
106-
meta.structure,
107-
operators=_copy(meta.operators),
108-
variable_names=_copy(meta.variable_names),
109-
extra=_copy(meta.extra),
110-
)),
116+
keys_except_structure = filter(!=(:structure), keys(meta_inner))
117+
copy_metadata = (;
118+
meta_inner.structure,
119+
NamedTuple{keys_except_structure}(
120+
map(_copy, values(meta_inner[keys_except_structure]))
121+
)...,
111122
)
123+
return constructorof(typeof(e))(copy_ts, Metadata(copy_metadata))
112124
end
113-
#! format: off
114-
function get_contents(e::StructuredExpression)
125+
function get_contents(e::AbstractStructuredExpression)
115126
return e.trees
116127
end
117-
function get_metadata(e::StructuredExpression)
128+
function get_metadata(e::AbstractStructuredExpression)
118129
return e.metadata
119130
end
120-
function get_tree(e::StructuredExpression)
121-
return get_tree(e.metadata.structure(e.trees))
131+
function get_tree(e::AbstractStructuredExpression)
132+
return get_tree(get_metadata(e).structure(get_contents(e)))
122133
end
123-
function get_operators(e::StructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing)
124-
return operators === nothing ? e.metadata.operators : operators
134+
function get_operators(
135+
e::AbstractStructuredExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing
136+
)
137+
return operators === nothing ? get_metadata(e).operators : operators
125138
end
126-
function get_variable_names(e::StructuredExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing)
127-
return variable_names === nothing ? e.metadata.variable_names : variable_names
139+
function get_variable_names(
140+
e::AbstractStructuredExpression,
141+
variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing,
142+
)
143+
return variable_names === nothing ? get_metadata(e).variable_names : variable_names
128144
end
129-
function get_scalar_constants(e::StructuredExpression)
145+
function get_scalar_constants(e::AbstractStructuredExpression)
130146
# Get constants for each inner expression
131-
consts_and_refs = map(get_scalar_constants, values(e.trees))
147+
consts_and_refs = map(get_scalar_constants, values(get_contents(e)))
132148
flat_constants = vcat(map(first, consts_and_refs)...)
133149
# Collect info so we can put them back in the right place,
134150
# like the indexes of the constants in the flattened array
135151
refs = map(c_ref -> (; n=length(first(c_ref)), ref=last(c_ref)), consts_and_refs)
136152
return flat_constants, refs
137153
end
138-
function set_scalar_constants!(e::StructuredExpression, constants, refs)
154+
function set_scalar_constants!(e::AbstractStructuredExpression, constants, refs)
139155
cursor = Ref(1)
140-
foreach(values(e.trees), refs) do tree, r
156+
foreach(values(get_contents(e)), refs) do tree, r
141157
n = r.n
142158
i = cursor[]
143-
c = constants[i:(i+n-1)]
159+
c = constants[i:(i + n - 1)]
144160
set_scalar_constants!(tree, c, r.ref)
145161
cursor[] += n
146162
end
147163
return e
148164
end
149-
#! format: on
150165

151166
end

test/test_structured_expression.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,7 @@ end
6363
f = parse_expression(:(x * x - cos(2.5f0 * y + -0.5f0)); kws...)
6464
g = parse_expression(:(exp(-(y * y))); kws...)
6565

66-
c = [1]
67-
ex = StructuredExpression((; f, g); structure=my_factory, a=c)
68-
69-
@test ex.metadata.extra.a[] == 1
70-
@test ex.metadata.extra.a === c
71-
72-
# Should copy everything down to the metadata:
73-
@test copy(ex).metadata.extra.a !== c
66+
ex = StructuredExpression((; f, g); structure=my_factory)
7467

7568
h(_) = 1
7669
h(::StructuredExpression{<:Any,typeof(my_factory)}) = 2

0 commit comments

Comments
 (0)