Skip to content

Commit e98a686

Browse files
committed
docs: add docstring for more AbstractExpression methods
1 parent 5f5d767 commit e98a686

File tree

4 files changed

+101
-5
lines changed

4 files changed

+101
-5
lines changed

src/Evaluate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ and triplets of operations for lower memory usage.
100100
101101
# Arguments
102102
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
103-
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
103+
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on, with shape `[num_features, num_rows]`.
104104
- `operators::OperatorEnum`: The operators used in the tree.
105105
- `eval_options::Union{EvalOptions,Nothing}`: See [`EvalOptions`](@ref) for documentation
106106
on the different evaluation modes.

src/EvaluateDerivative.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@ import ..EvaluateModule: deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SL
99
import ..ExtensionInterfaceModule: _zygote_gradient
1010

1111
"""
12-
eval_diff_tree_array(tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Union{Bool,Val}=Val(false))
12+
eval_diff_tree_array(
13+
tree::AbstractExpressionNode{T},
14+
cX::AbstractMatrix{T},
15+
operators::OperatorEnum,
16+
direction::Integer;
17+
turbo::Union{Bool,Val}=Val(false)
18+
) where {T<:Number}
1319
1420
Compute the forward derivative of an expression, using a similar
1521
structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -19,7 +25,7 @@ respect to `x1`.
1925
# Arguments
2026
2127
- `tree::AbstractExpressionNode`: The expression tree to evaluate.
22-
- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
28+
- `cX::AbstractMatrix{T}`: The data matrix, with shape `[num_features, num_rows]`.
2329
- `operators::OperatorEnum`: The operators used to create the `tree`.
2430
- `direction::Integer`: The index of the variable to take the derivative with respect to.
2531
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. Currently this does not have

src/EvaluationHelpers.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@ and triplets of operations for lower memory usage.
1616
1717
# Arguments
1818
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
19-
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on.
19+
- `X::AbstractMatrix{T}`: The input data to evaluate the tree on, with shape `[num_features, num_rows]`.
2020
- `operators::OperatorEnum`: The operators used in the tree.
2121
- `kws...`: Passed to `eval_tree_array`.
2222
23-
2423
# Returns
2524
- `output::AbstractVector{T}`: the result, which is a 1D array.
2625
Any NaN, Inf, or other failure during the evaluation will result in the entire

src/Expression.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,30 @@ function extract_gradient(
317317
return extract_gradient(gradient.tree, get_tree(ex))
318318
end
319319

320+
"""
321+
string_tree(
322+
ex::AbstractExpression,
323+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
324+
variable_names=nothing,
325+
kws...
326+
)
327+
328+
Convert an expression to a string representation.
329+
330+
This method unpacks the operators and variable names from the expression and calls [`string_tree`](@ref StringsModule.string_tree) for `AbstractExpressionNode`.
331+
332+
# Arguments
333+
334+
- `ex`: The expression to convert to a string.
335+
- `operators`: (Optional) Operators to use. If `nothing`, operators are obtained from the expression.
336+
- `variable_names`: (Optional) Variable names to use in the string representation. If `nothing`, variable names are obtained from the expression.
337+
- `kws...`: Additional keyword arguments.
338+
339+
# Returns
340+
341+
- A string representation of the expression.
342+
343+
"""
320344
function string_tree(
321345
ex::AbstractExpression,
322346
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
@@ -367,6 +391,30 @@ function _validate_input(
367391
return nothing
368392
end
369393

394+
"""
395+
eval_tree_array(
396+
ex::AbstractExpression,
397+
cX::AbstractMatrix,
398+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
399+
kws...
400+
)
401+
402+
Evaluate an expression over a given input data matrix.
403+
404+
This method unpacks the operators from the expression and calls [`eval_tree_array`](@ref EvaluateModule.eval_tree_array) for `AbstractExpressionNode`.
405+
406+
# Arguments
407+
408+
- `ex`: The expression to evaluate.
409+
- `cX`: The input data matrix.
410+
- `operators`: (Optional) Operators to use. If `nothing`, operators are obtained from the expression.
411+
- `kws...`: Additional keyword arguments.
412+
413+
# Returns
414+
415+
- A tuple `(output, complete)` indicating the result and success of the evaluation.
416+
417+
"""
370418
function eval_tree_array(
371419
ex::AbstractExpression,
372420
cX::AbstractMatrix,
@@ -381,6 +429,30 @@ end
381429
# - eval_diff_tree_array
382430
# - differentiable_eval_tree_array
383431

432+
"""
433+
eval_grad_tree_array(
434+
ex::AbstractExpression,
435+
cX::AbstractMatrix,
436+
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
437+
kws...
438+
)
439+
440+
Compute the forward-mode derivative of an expression.
441+
442+
This method unpacks the operators from the expression and calls [`eval_grad_tree_array`](@ref EvaluateDerivativeModule.eval_grad_tree_array) for `AbstractExpressionNode`.
443+
444+
# Arguments
445+
446+
- `ex`: The expression to evaluate.
447+
- `cX`: The data matrix.
448+
- `operators`: (Optional) Operators to use. If `nothing`, operators are obtained from the expression.
449+
- `kws...`: Additional keyword arguments.
450+
451+
# Returns
452+
453+
- A tuple `(output, gradient, complete)` indicating the result, gradient, and success of the evaluation.
454+
455+
"""
384456
function eval_grad_tree_array(
385457
ex::AbstractExpression,
386458
cX::AbstractMatrix,
@@ -404,6 +476,25 @@ function _grad_evaluator(
404476
_validate_input(ex, cX, operators)
405477
return _grad_evaluator(get_tree(ex), cX, get_operators(ex, operators); variable, kws...)
406478
end
479+
480+
"""
481+
(ex::AbstractExpression)(X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...)
482+
483+
Evaluate the expression `ex` over the input data `X`.
484+
485+
This method unpacks the operators from the expression and calls the corresponding method for `AbstractExpressionNode`.
486+
487+
# Arguments
488+
489+
- `X`: The input data to evaluate the expression on.
490+
- `operators`: (Optional) Operators to use. If `nothing`, operators are obtained from the expression.
491+
- `kws...`: Additional keyword arguments.
492+
493+
# Returns
494+
495+
- The result of evaluating the expression over the input data `X`.
496+
497+
"""
407498
function (ex::AbstractExpression)(
408499
X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws...
409500
)

0 commit comments

Comments
 (0)