Skip to content

Commit ced4263

Browse files
committed
improve saved models
1 parent 3d4ecbb commit ced4263

File tree

6 files changed

+95
-9
lines changed

6 files changed

+95
-9
lines changed

models/model_ds15_us0_bd0.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"downstream_hops": 15, "upstream_hops": 0, "bidirectional_hops": 0, "feature_columns": ["area_um2", "vol_um3", "max_dt_um", "vol_to_area", "syn_in", "syn_out", "down_area_um2", "down_vol_um3", "down_max_dt_um", "down_vol_to_area", "down_syn_in", "down_syn_out"], "spread_alpha": 0.5, "model_file": "xgb/model_ds15_us0_bd0.ubj"}

models/xgb/model_ds15_us0_bd0.ubj

775 KB
Binary file not shown.

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ description = "Using level-2 features for skeleton classification"
99
readme = "README.md"
1010
requires-python = ">=3.9"
1111
dependencies = [
12-
"joblib>=1.4.2",
1312
"meshparty>=1.18.2",
1413
"numpy>=2.0.2",
1514
"pandas>=2.2.3",
@@ -103,3 +102,12 @@ help = "Profile cpu and memory of task with scalene"
103102
[tool.poe.tasks.profile]
104103
cmd = "uv run pyinstrument -r html"
105104
help = "Profile cpu of task with pyinstrument"
105+
106+
[tool.hatch.build.targets.sdist]
107+
only-include = ["src/l2label", "models"]
108+
109+
[tool.hatch.build.targets.wheel]
110+
packages = ["src/l2label"]
111+
112+
[tool.hatch.build.targets.wheel.force-include]
113+
"models" = "l2label/models"

src/l2label/compartments.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import pathlib
33
from typing import Literal, Optional, Union
44

5-
import joblib
65
import numpy as np
76
import pandas as pd
87
import xgboost
98
from meshparty import meshwork
109
from scipy import sparse
1110

11+
from . import models
12+
1213
__all__ = [
1314
"make_skel_prop_df",
1415
"process_neuron",
@@ -312,6 +313,8 @@ def save_model(
312313
modelpath
313314
/ f"{model_name}_ds{downstream_hops}_us{upstream_hops}_bd{bidirectional_hops}.joblib"
314315
)
316+
model.save_model(model_filename)
317+
315318
outdata = {
316319
"downstream_hops": downstream_hops,
317320
"upstream_hops": upstream_hops,
@@ -320,7 +323,6 @@ def save_model(
320323
"spread_alpha": spread_alpha,
321324
"model_file": str(model_filename.absolute()),
322325
}
323-
joblib.dump(model, outdata["model_file"])
324326
description_file = (
325327
filepath
326328
/ f"{model_name}_ds{downstream_hops}_us{upstream_hops}_bd{bidirectional_hops}.json"
@@ -343,13 +345,28 @@ def __init__(
343345
):
344346
if config is not None:
345347
if not isinstance(config, dict):
346-
with open(config, "r") as f:
347-
config = json.load(f)
348+
config = models.load_model_config(config, dir="")
349+
self._downstream_hops = config.get("downstream_hops")
350+
self._upstream_hops = config.get("upstream_hops")
351+
self._bidirectional_hops = config.get("bidirectional_hops")
352+
self._spread_alpha = config.get("spread_alpha")
353+
self._load_model(config.get("model_file"))
354+
self._feature_columns = config.get("feature_columns")
355+
356+
elif (
357+
config is None
358+
and downstream_hops is None
359+
and upstream_hops is None
360+
and bidirectional_hops is None
361+
and model is None
362+
and spread_alpha is None
363+
):
364+
config = models.load_model_config()
348365
self._downstream_hops = config.get("downstream_hops")
349366
self._upstream_hops = config.get("upstream_hops")
350367
self._bidirectional_hops = config.get("bidirectional_hops")
351368
self._spread_alpha = config.get("spread_alpha")
352-
self._model = joblib.load(config.get("model_file"))
369+
self._load_model(config.get("model_file"))
353370
self._feature_columns = config.get("feature_columns")
354371
else:
355372
if (
@@ -603,6 +620,7 @@ def predict_axon_mask(
603620
root_is_soma: bool = False,
604621
is_axon_seg: Optional[np.ndarray] = None,
605622
evaluate_isolated_dendrites: bool = False,
623+
to_mesh_index: bool = False,
606624
):
607625
"""Predict a dendrite mask based on the smoothed label spreading and segment-level vote.
608626
@@ -647,4 +665,13 @@ def predict_axon_mask(
647665
in_root_comp = comp_labels == comp_labels[nrn.skeleton.root]
648666
else:
649667
in_root_comp = np.full(len(nrn.skeleton.vertices), True)
650-
return ~np.logical_and(~is_axon_base, in_root_comp)
668+
is_axon = ~np.logical_and(~is_axon_base, in_root_comp)
669+
if to_mesh_index:
670+
is_axon_sk = nrn.SkeletonIndex(np.flatnonzero(is_axon))
671+
is_axon = is_axon_sk.to_mesh_mask
672+
return is_axon
673+
674+
def _load_model(self, filepath: str):
675+
"""Load a model from a file."""
676+
self._model = xgboost.XGBClassifier()
677+
self._model.load_model(filepath)

src/l2label/models.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import glob
2+
import json
3+
import pathlib
4+
from typing import Optional
5+
6+
import xgboost
7+
8+
model_dir = pathlib.Path(__file__).parents[2] / "models"
9+
10+
current_model = "model_ds15_us0_bd0.json"
11+
12+
13+
def get_models(
14+
dir: Optional[str] = None,
15+
):
16+
"""
17+
Get a list of all models in the model directory.
18+
"""
19+
if dir is None:
20+
dir = model_dir
21+
22+
return [f.name for f in glob.glob(f"{dir}/*.json")]
23+
24+
25+
def load_model_config(
26+
model_name: Optional[str] = None,
27+
dir: Optional[str] = None,
28+
):
29+
"""
30+
Load a model configuration file.
31+
32+
Parameters
33+
----------
34+
model_name : str
35+
The filename name of the model to load.
36+
dir : str, optional
37+
The directory to load the model from. If None, the default model directory is used.
38+
"""
39+
if dir is None:
40+
dir = model_dir
41+
if model_name is None:
42+
model_name = current_model
43+
44+
if model_name is None:
45+
raise ValueError("model_name must be specified if dir is not provided")
46+
if not model_name.endswith(".json"):
47+
model_name += ".json"
48+
with open(pathlib.Path(dir) / f"{model_name}", "r") as f:
49+
model_config = json.load(f)
50+
51+
model_config["model_file"] = pathlib.Path(dir) / f"{model_config['model_file']}"
52+
return model_config

uv.lock

Lines changed: 0 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)