diff --git a/src/interfaces/MLJ.jl b/src/interfaces/MLJ.jl index 5566edf..1351b63 100644 --- a/src/interfaces/MLJ.jl +++ b/src/interfaces/MLJ.jl @@ -100,6 +100,10 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class # syntaxstring_kwargs = (; hidemodality = (length(var_grouping) == 1), variable_names_map = var_grouping) )) + translate_model = (m, preds)->ModalDecisionTrees.translate(m, + (; supporting_predictions=preds) + ) + rawmodel_full = model rawmodel = MDT.prune(model; simplify = true) @@ -124,7 +128,7 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class if simplify sprinkledmodel = MDT.prune(sprinkledmodel; simplify = true) end - preds, translate_function(sprinkledmodel) + preds, translate_model(sprinkledmodel, preds) end, # TODO remove redundancy? model = solemodel, diff --git a/src/interfaces/Sole/main.jl b/src/interfaces/Sole/main.jl index 6904b91..3ded541 100644 --- a/src/interfaces/Sole/main.jl +++ b/src/interfaces/Sole/main.jl @@ -60,7 +60,8 @@ function translate( ) pure_root = translate(ModalDecisionTrees.root(tree), ModalDecisionTrees.initconditions(tree); kwargs...) - info = merge(info, SoleModels.info(pure_root)) + # info = merge(info, SoleModels.info(pure_root)) + info = merge(SoleModels.info(pure_root), info) info = merge(info, (;)) return SoleModels.DecisionTree(pure_root, info)