22import pathlib
33from typing import Literal , Optional , Union
44
5- import joblib
65import numpy as np
76import pandas as pd
87import xgboost
98from meshparty import meshwork
109from scipy import sparse
1110
11+ from . import models
12+
1213__all__ = [
1314 "make_skel_prop_df" ,
1415 "process_neuron" ,
@@ -312,6 +313,8 @@ def save_model(
312313 modelpath
313314 / f"{ model_name } _ds{ downstream_hops } _us{ upstream_hops } _bd{ bidirectional_hops } .joblib"
314315 )
316+ model .save_model (model_filename )
317+
315318 outdata = {
316319 "downstream_hops" : downstream_hops ,
317320 "upstream_hops" : upstream_hops ,
@@ -320,7 +323,6 @@ def save_model(
320323 "spread_alpha" : spread_alpha ,
321324 "model_file" : str (model_filename .absolute ()),
322325 }
323- joblib .dump (model , outdata ["model_file" ])
324326 description_file = (
325327 filepath
326328 / f"{ model_name } _ds{ downstream_hops } _us{ upstream_hops } _bd{ bidirectional_hops } .json"
@@ -343,13 +345,28 @@ def __init__(
343345 ):
344346 if config is not None :
345347 if not isinstance (config , dict ):
346- with open (config , "r" ) as f :
347- config = json .load (f )
348+ config = models .load_model_config (config , dir = "" )
349+ self ._downstream_hops = config .get ("downstream_hops" )
350+ self ._upstream_hops = config .get ("upstream_hops" )
351+ self ._bidirectional_hops = config .get ("bidirectional_hops" )
352+ self ._spread_alpha = config .get ("spread_alpha" )
353+ self ._load_model (config .get ("model_file" ))
354+ self ._feature_columns = config .get ("feature_columns" )
355+
356+ elif (
357+ config is None
358+ and downstream_hops is None
359+ and upstream_hops is None
360+ and bidirectional_hops is None
361+ and model is None
362+ and spread_alpha is None
363+ ):
364+ config = models .load_model_config ()
348365 self ._downstream_hops = config .get ("downstream_hops" )
349366 self ._upstream_hops = config .get ("upstream_hops" )
350367 self ._bidirectional_hops = config .get ("bidirectional_hops" )
351368 self ._spread_alpha = config .get ("spread_alpha" )
352- self ._model = joblib . load (config .get ("model_file" ))
369+ self ._load_model (config .get ("model_file" ))
353370 self ._feature_columns = config .get ("feature_columns" )
354371 else :
355372 if (
@@ -603,6 +620,7 @@ def predict_axon_mask(
603620 root_is_soma : bool = False ,
604621 is_axon_seg : Optional [np .ndarray ] = None ,
605622 evaluate_isolated_dendrites : bool = False ,
623+ to_mesh_index : bool = False ,
606624 ):
607625 """Predict a dendrite mask based on the smoothed label spreading and segment-level vote.
608626
@@ -647,4 +665,13 @@ def predict_axon_mask(
647665 in_root_comp = comp_labels == comp_labels [nrn .skeleton .root ]
648666 else :
649667 in_root_comp = np .full (len (nrn .skeleton .vertices ), True )
650- return ~ np .logical_and (~ is_axon_base , in_root_comp )
668+ is_axon = ~ np .logical_and (~ is_axon_base , in_root_comp )
669+ if to_mesh_index :
670+ is_axon_sk = nrn .SkeletonIndex (np .flatnonzero (is_axon ))
671+ is_axon = is_axon_sk .to_mesh_mask
672+ return is_axon
673+
674+ def _load_model (self , filepath : str ):
675+ """Load a model from a file."""
676+ self ._model = xgboost .XGBClassifier ()
677+ self ._model .load_model (filepath )
0 commit comments