@@ -5,6 +5,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce
5
5
using .. ExpressionModule: AbstractExpression, Metadata, node_type
6
6
using .. ChainRulesModule: NodeTangent
7
7
8
+ import .. NodeModule: constructorof
8
9
import .. ExpressionModule:
9
10
get_contents,
10
11
get_metadata,
@@ -13,17 +14,33 @@ import ..ExpressionModule:
13
14
get_variable_names,
14
15
Metadata,
15
16
_copy,
17
+ _data,
16
18
default_node_type,
17
19
node_type,
18
20
get_scalar_constants,
19
21
set_scalar_constants!
20
22
23
+ abstract type AbstractStructuredExpression{
24
+ T,F<: Function ,N<: AbstractExpressionNode{T} ,E<: AbstractExpression{T,N} ,D<: NamedTuple
25
+ } <: AbstractExpression{T,N} end
26
+
21
27
"""
22
- StructuredExpression
28
+ StructuredExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} <: AbstractExpression{T,N}
23
29
24
30
This expression type allows you to combine multiple expressions
25
31
together in a predefined way.
26
32
33
+ # Parameters
34
+
35
+ - `T`: The numeric value type of the expressions.
36
+ - `F`: The type of the structure function, which combines each expression into a single expression.
37
+ - `N`: The type of the nodes inside expressions.
38
+ - `E`: The type of the expressions.
39
+ - `TS`: The type of the named tuple containing those inner expressions.
40
+ - `D`: The type of the metadata, another named tuple.
41
+
42
+ # Usage
43
+
27
44
For example, we can create two expressions, `f`, and `g`,
28
45
and then combine them together in a new expression, `f_plus_g`,
29
46
using a constructor function that simply adds them together:
@@ -56,29 +73,25 @@ which will create a new method particular to this expression type defined on tha
56
73
"""
57
74
struct StructuredExpression{
58
75
T,
59
- F,
60
- EX<: NamedTuple ,
76
+ F<: Function ,
61
77
N<: AbstractExpressionNode{T} ,
62
78
E<: AbstractExpression{T,N} ,
63
79
TS<: NamedTuple{<:Any,<:NTuple{<:Any,E}} ,
64
- D< :@NamedTuple {structure:: F , operators:: O , variable_names:: V , extra :: EX } where {O,V},
65
- } <: AbstractExpression {T,N }
80
+ D< :@NamedTuple {structure:: F , operators:: O , variable_names:: V } where {O,V},
81
+ } <: AbstractStructuredExpression {T,F,N,E,D }
66
82
trees:: TS
67
83
metadata:: Metadata{D}
68
84
69
85
function StructuredExpression (
70
86
trees:: TS , metadata:: Metadata{D}
71
87
) where {
72
88
TS,
73
- F,
74
- EX,
75
- D< :@NamedTuple {
76
- structure:: F , operators:: O , variable_names:: V , extra:: EX
77
- } where {O,V},
89
+ F<: Function ,
90
+ D< :@NamedTuple {structure:: F , operators:: O , variable_names:: V } where {O,V},
78
91
}
79
92
E = typeof (first (values (trees)))
80
93
N = node_type (E)
81
- return new {eltype(N),F,EX, N,E,TS,D} (trees, metadata)
94
+ return new {eltype(N),F,N,E,TS,D} (trees, metadata)
82
95
end
83
96
end
84
97
@@ -87,65 +100,67 @@ function StructuredExpression(
87
100
structure:: F ,
88
101
operators:: Union{AbstractOperatorEnum,Nothing} = nothing ,
89
102
variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
90
- extra... ,
91
103
) where {F<: Function }
92
104
example_tree = first (values (trees))
93
105
operators = get_operators (example_tree, operators)
94
106
variable_names = get_variable_names (example_tree, variable_names)
95
- metadata = (; structure, operators, variable_names, extra = (; extra ... ) )
107
+ metadata = (; structure, operators, variable_names)
96
108
return StructuredExpression (trees, Metadata (metadata))
97
109
end
98
-
99
- function Base. copy (e:: StructuredExpression )
110
+ constructorof ( :: Type{<:StructuredExpression} ) = StructuredExpression
111
+ function Base. copy (e:: AbstractStructuredExpression )
100
112
ts = get_contents (e)
101
113
meta = get_metadata (e)
114
+ meta_inner = _data (meta)
102
115
copy_ts = NamedTuple {keys(ts)} (map (copy, values (ts)))
103
- return StructuredExpression (
104
- copy_ts,
105
- Metadata ((;
106
- meta. structure,
107
- operators= _copy (meta. operators),
108
- variable_names= _copy (meta. variable_names),
109
- extra= _copy (meta. extra),
110
- )),
116
+ keys_except_structure = filter (!= (:structure ), keys (meta_inner))
117
+ copy_metadata = (;
118
+ meta_inner. structure,
119
+ NamedTuple {keys_except_structure} (
120
+ map (_copy, values (meta_inner[keys_except_structure]))
121
+ )... ,
111
122
)
123
+ return constructorof (typeof (e))(copy_ts, Metadata (copy_metadata))
112
124
end
113
- # ! format: off
114
- function get_contents (e:: StructuredExpression )
125
+ function get_contents (e:: AbstractStructuredExpression )
115
126
return e. trees
116
127
end
117
- function get_metadata (e:: StructuredExpression )
128
+ function get_metadata (e:: AbstractStructuredExpression )
118
129
return e. metadata
119
130
end
120
- function get_tree (e:: StructuredExpression )
121
- return get_tree (e . metadata . structure (e . trees ))
131
+ function get_tree (e:: AbstractStructuredExpression )
132
+ return get_tree (get_metadata (e) . structure (get_contents (e) ))
122
133
end
123
- function get_operators (e:: StructuredExpression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing )
124
- return operators === nothing ? e. metadata. operators : operators
134
+ function get_operators (
135
+ e:: AbstractStructuredExpression , operators:: Union{AbstractOperatorEnum,Nothing} = nothing
136
+ )
137
+ return operators === nothing ? get_metadata (e). operators : operators
125
138
end
126
- function get_variable_names (e:: StructuredExpression , variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing )
127
- return variable_names === nothing ? e. metadata. variable_names : variable_names
139
+ function get_variable_names (
140
+ e:: AbstractStructuredExpression ,
141
+ variable_names:: Union{AbstractVector{<:AbstractString},Nothing} = nothing ,
142
+ )
143
+ return variable_names === nothing ? get_metadata (e). variable_names : variable_names
128
144
end
129
- function get_scalar_constants (e:: StructuredExpression )
145
+ function get_scalar_constants (e:: AbstractStructuredExpression )
130
146
# Get constants for each inner expression
131
- consts_and_refs = map (get_scalar_constants, values (e . trees ))
147
+ consts_and_refs = map (get_scalar_constants, values (get_contents (e) ))
132
148
flat_constants = vcat (map (first, consts_and_refs)... )
133
149
# Collect info so we can put them back in the right place,
134
150
# like the indexes of the constants in the flattened array
135
151
refs = map (c_ref -> (; n= length (first (c_ref)), ref= last (c_ref)), consts_and_refs)
136
152
return flat_constants, refs
137
153
end
138
- function set_scalar_constants! (e:: StructuredExpression , constants, refs)
154
+ function set_scalar_constants! (e:: AbstractStructuredExpression , constants, refs)
139
155
cursor = Ref (1 )
140
- foreach (values (e . trees ), refs) do tree, r
156
+ foreach (values (get_contents (e) ), refs) do tree, r
141
157
n = r. n
142
158
i = cursor[]
143
- c = constants[i: (i+ n - 1 )]
159
+ c = constants[i: (i + n - 1 )]
144
160
set_scalar_constants! (tree, c, r. ref)
145
161
cursor[] += n
146
162
end
147
163
return e
148
164
end
149
- # ! format: on
150
165
151
166
end
0 commit comments