Skip to content

Commit 132e97a

Browse files
authored
Merge pull request #40 from SymbolicML/custom-printing
Allow user-defined printing
2 parents da22e1a + e5494ea commit 132e97a

File tree

3 files changed

+133
-50
lines changed

3 files changed

+133
-50
lines changed

benchmark/benchmarks.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ function benchmark_utilities()
9898
:is_constant,
9999
:get_set_constants!,
100100
:index_constants,
101+
:string_tree,
101102
)
102103

103104
operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
104-
105105
for func_k in all_funcs
106106
suite[func_k] = let s = BenchmarkGroup()
107107
for k in (:break_sharing, :preserve_sharing)
108-
k == :preserve_sharing && !(func_k in (:copy, :convert)) && continue
108+
has_both_modes = func_k in (:copy, :convert)
109+
k == :preserve_sharing && !has_both_modes && continue
109110

110111
f = if func_k == :copy
111112
tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing))
@@ -115,7 +116,7 @@ function benchmark_utilities()
115116
tree;
116117
preserve_sharing=(k == :preserve_sharing),
117118
)
118-
elseif func_k in (:simplify_tree, :combine_operators)
119+
elseif func_k in (:simplify_tree, :combine_operators, :string_tree)
119120
g = getfield(@__MODULE__, func_k)
120121
tree -> f_tree_op(g, tree, operators)
121122
else
@@ -133,6 +134,9 @@ function benchmark_utilities()
133134
trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees]
134135
)
135136
)
137+
if !has_both_modes
138+
s = s[k]
139+
end
136140
#! format: on
137141
end
138142
s

src/Equation.jl

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -190,110 +190,148 @@ const OP_NAMES = Dict(
190190
"safe_pow" => "^",
191191
)
192192

193-
function get_op_name(op::String)
194-
return get(OP_NAMES, op, op)
193+
@generated function get_op_name(op::F) where {F}
194+
try
195+
# Bit faster to just cache the name of the operator:
196+
op_s = string(F.instance)
197+
out = get(OP_NAMES, op_s, op_s)
198+
return :($out)
199+
catch
200+
end
201+
return quote
202+
op_s = string(op)
203+
out = get(OP_NAMES, op_s, op_s)
204+
return out
205+
end
195206
end
196207

197208
function string_op(
198-
op::F,
199-
tree::Node,
200-
operators::AbstractOperatorEnum;
201-
bracketed::Bool=false,
202-
variable_names::Union{Array{String,1},Nothing}=nothing,
203-
# Deprecated
204-
varMap=nothing,
209+
::Val{2}, op::F, tree::Node, args...; bracketed, kws...
205210
)::String where {F}
206-
variable_names = deprecate_varmap(variable_names, varMap, :string_op)
207-
op_name = get_op_name(string(op))
211+
op_name = get_op_name(op)
208212
if op_name in ["+", "-", "*", "/", "^"]
209-
l = string_tree(tree.l, operators; bracketed=false, variable_names=variable_names)
210-
r = string_tree(tree.r, operators; bracketed=false, variable_names=variable_names)
213+
l = string_tree(tree.l, args...; bracketed=false, kws...)
214+
r = string_tree(tree.r, args...; bracketed=false, kws...)
211215
if bracketed
212-
return "$l $op_name $r"
216+
return l * " " * op_name * " " * r
213217
else
214-
return "($l $op_name $r)"
218+
return "(" * l * " " * op_name * " " * r * ")"
215219
end
216220
else
217-
l = string_tree(tree.l, operators; bracketed=true, variable_names=variable_names)
218-
r = string_tree(tree.r, operators; bracketed=true, variable_names=variable_names)
219-
return "$op_name($l, $r)"
221+
l = string_tree(tree.l, args...; bracketed=true, kws...)
222+
r = string_tree(tree.r, args...; bracketed=true, kws...)
223+
# return "$op_name($l, $r)"
224+
return op_name * "(" * l * ", " * r * ")"
225+
end
226+
end
227+
function string_op(
228+
::Val{1}, op::F, tree::Node, args...; bracketed, kws...
229+
)::String where {F}
230+
op_name = get_op_name(op)
231+
l = string_tree(tree.l, args...; bracketed=true, kws...)
232+
return op_name * "(" * l * ")"
233+
end
234+
235+
function string_constant(val, bracketed::Bool)
236+
does_not_need_brackets = (typeof(val) <: Union{Real,AbstractArray})
237+
if does_not_need_brackets || bracketed
238+
string(val)
239+
else
240+
"(" * string(val) * ")"
241+
end
242+
end
243+
244+
function string_variable(feature, variable_names)
245+
if variable_names === nothing
246+
return "x" * string(feature)
247+
else
248+
return variable_names[feature]
220249
end
221250
end
222251

223252
"""
224-
string_tree(tree::Node, operators::AbstractOperatorEnum; kws...)
253+
string_tree(tree::Node, operators::AbstractOperatorEnum[; bracketed, variable_names, f_variable, f_constant])
225254
226255
Convert an equation to a string.
227256
228257
# Arguments
229-
230-
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
231-
to print for each feature.
258+
- `tree`: the tree to convert to a string
259+
- `operators`: the operators used to define the tree
260+
261+
# Keyword Arguments
262+
- `bracketed`: (optional) whether to put brackets around the outside.
263+
- `f_variable`: (optional) function to convert a variable to a string, of the form `(feature::Int, variable_names)`.
264+
- `f_constant`: (optional) function to convert a constant to a string, of the form `(val, bracketed::Bool)`
265+
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature.
232266
"""
233267
function string_tree(
234268
tree::Node{T},
235269
operators::AbstractOperatorEnum;
236270
bracketed::Bool=false,
271+
f_variable::F1=string_variable,
272+
f_constant::F2=string_constant,
237273
variable_names::Union{Array{String,1},Nothing}=nothing,
238274
# Deprecated
239275
varMap=nothing,
240-
)::String where {T}
276+
)::String where {T,F1<:Function,F2<:Function}
241277
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
242278
if tree.degree == 0
243-
if tree.constant
244-
return string_constant(tree.val::T; bracketed=bracketed)
279+
if !tree.constant
280+
return f_variable(tree.feature, variable_names)
245281
else
246-
if variable_names === nothing
247-
return "x$(tree.feature)"
248-
else
249-
return variable_names[tree.feature]
250-
end
282+
return f_constant(tree.val::T, bracketed)
251283
end
252284
elseif tree.degree == 1
253-
op_name = get_op_name(string(operators.unaops[tree.op]))
254-
return "$(op_name)($(string_tree(tree.l, operators, bracketed=true, variable_names=variable_names)))"
285+
return string_op(
286+
Val(1),
287+
operators.unaops[tree.op],
288+
tree,
289+
operators;
290+
bracketed,
291+
f_variable,
292+
f_constant,
293+
variable_names,
294+
)
255295
else
256296
return string_op(
297+
Val(2),
257298
operators.binops[tree.op],
258299
tree,
259300
operators;
260-
bracketed=bracketed,
261-
variable_names=variable_names,
301+
bracketed,
302+
f_variable,
303+
f_constant,
304+
variable_names,
262305
)
263306
end
264307
end
265308

266-
string_constant(val::T; bracketed::Bool) where {T<:Union{Real,AbstractArray}} = string(val)
267-
function string_constant(val; bracketed::Bool)
268-
if bracketed
269-
string(val)
270-
else
271-
"(" * string(val) * ")"
272-
end
273-
end
274-
275309
# Print an equation
276310
function print_tree(
277311
io::IO,
278312
tree::Node,
279313
operators::AbstractOperatorEnum;
314+
f_variable::F1=string_variable,
315+
f_constant::F2=string_constant,
280316
variable_names::Union{Array{String,1},Nothing}=nothing,
281317
# Deprecated
282318
varMap=nothing,
283-
)
319+
) where {F1<:Function,F2<:Function}
284320
variable_names = deprecate_varmap(variable_names, varMap, :print_tree)
285-
return println(io, string_tree(tree, operators; variable_names=variable_names))
321+
return println(io, string_tree(tree, operators; f_variable, f_constant, variable_names))
286322
end
287323

288324
function print_tree(
289325
tree::Node,
290326
operators::AbstractOperatorEnum;
327+
f_variable::F1=string_variable,
328+
f_constant::F2=string_constant,
291329
variable_names::Union{Array{String,1},Nothing}=nothing,
292330
# Deprecated
293331
varMap=nothing,
294-
)
332+
) where {F1<:Function,F2<:Function}
295333
variable_names = deprecate_varmap(variable_names, varMap, :print_tree)
296-
return println(string_tree(tree, operators; variable_names=variable_names))
334+
return println(string_tree(tree, operators; f_variable, f_constant, variable_names))
297335
end
298336

299337
end

test/test_print.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using DynamicExpressions
3+
import Compat: Returns
34

45
include("test_params.jl")
56

@@ -39,6 +40,24 @@ for binop in [safe_pow, ^]
3940
@test string_tree(minitree, opts) == "(x1 ^ x2)"
4041
end
4142

43+
@testset "Test print_tree function" begin
44+
if VERSION > v"1.8"
45+
operators = OperatorEnum(;
46+
binary_operators=(+, *, /, -), unary_operators=(cos, sin)
47+
)
48+
x1, x2, x3 = [Node(Float64; feature=i) for i in 1:3]
49+
tree = x1 * x1 + 0.5
50+
# Capture stdout to variable:
51+
pipe = Pipe()
52+
redirect_stdout(pipe) do
53+
print_tree(tree, operators)
54+
end
55+
close(pipe.in)
56+
s = read(pipe.out, String)
57+
@test s == "((x1 * x1) + 0.5)\n"
58+
end
59+
end
60+
4261
@testset "Test printing of complex numbers" begin
4362
@eval my_custom_op(x, y) = x + y
4463
operators = OperatorEnum(;
@@ -55,3 +74,25 @@ end
5574
tree = my_custom_op(x1, 1.0 + 2.0im)
5675
@test string_tree(tree, operators) == "my_custom_op(x1, 1.0 + 2.0im)"
5776
end
77+
78+
@testset "Test user-define printing" begin
79+
operators = OperatorEnum(;
80+
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
81+
)
82+
@extend_operators operators
83+
x1, x2, x3 = [Node(Float64; feature=i) for i in 1:3]
84+
tree = x1 * x1 + 0.5
85+
@test string_tree(tree, operators; f_constant=Returns("TEST")) == "((x1 * x1) + TEST)"
86+
@test string_tree(tree, operators; f_variable=Returns("TEST")) ==
87+
"((TEST * TEST) + 0.5)"
88+
@test string_tree(
89+
tree, operators; f_variable=Returns("TEST"), f_constant=Returns("TEST2")
90+
) == "((TEST * TEST) + TEST2)"
91+
92+
# Try printing with a precision:
93+
tree = x1 * x1 + π
94+
f_constant(val::Float64, args...) = string(round(val; digits=2))
95+
@test string_tree(tree, operators; f_constant=f_constant) == "((x1 * x1) + 3.14)"
96+
f_constant(val::Float64, args...) = string(round(val; digits=4))
97+
@test string_tree(tree, operators; f_constant=f_constant) == "((x1 * x1) + 3.1416)"
98+
end

0 commit comments

Comments
 (0)