Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## [UNRELEASED] - YYYY-MM-DD
## Added
- Add support for activity counts feature ([#262](https://github.com/cbrnr/sleepecg/pull/262) by [Simon Pusterhofer](https://github.com/simon-p-2000))
- Added support for Scikit-Learn and PyTorch models ([#263](https://github.com/cbrnr/sleepecg/pull/263) by [Simon Pusterhofer](https://github.com/simon-p-2000))

## [0.5.9] - 2025-02-01
### Added
Expand Down
1 change: 1 addition & 0 deletions docs/classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ A weaker weighting approach is likely required to find the optimal middle ground

![wrn-gru-mesa confusion matrix](./img/wrn-gru-mesa.svg)     ![wrn-gru-mesa-weighted confusion matrix](./img/wrn-gru-mesa-weighted.svg)

Additional classifiers using [Scikit-Learn](https://scikit-learn.org/stable/) can be trained with the example scripts available in the `examples/classifiers` folder.

## Usage examples
The example [`try_ws_gru_mesa.py`](https://github.com/cbrnr/sleepecg/blob/main/examples/try_ws_gru_mesa.py) demonstrates how to use the WAKE–SLEEP classifier `ws-gru-mesa`, a [GRU](https://en.wikipedia.org/wiki/Gated_recurrent_unit)-based classifier bundled with SleepECG which was trained on 1971 nights of the [MESA](https://sleepdata.org/datasets/mesa/) dataset.
Expand Down
236 changes: 236 additions & 0 deletions examples/classifiers/wrn_gru_mesa_actigraphy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import warnings

from tensorflow.keras import layers, models
from tqdm import tqdm

from sleepecg import (
evaluate,
extract_features,
load_classifier,
prepare_data_keras,
print_class_balance,
read_mesa,
save_classifier,
set_nsrr_token,
)

set_nsrr_token("your-token-here")

TRAIN = False # set to False to skip training and load classifier from disk

# silence warnings (which might pop up during feature extraction)
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="HR analysis window too short"
)

if TRAIN:
print("‣ Starting training...")
print("‣‣ Extracting features...")
records_train = (
list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="0*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="1*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="2*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="3*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="4*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="50*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="51*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="52*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="53*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="54*",
)
)
)

feature_extraction_params = {
"lookback": 120,
"lookforward": 150,
"feature_selection": [
"hrv-time",
"hrv-frequency",
"recording_start_time",
"age",
"gender",
"activity_counts",
],
"min_rri": 0.3,
"max_rri": 2,
"max_nans": 0.5,
}

features_train, stages_train, feature_ids = extract_features(
tqdm(records_train),
**feature_extraction_params,
n_jobs=-1,
)

print("‣‣ Preparing data for Keras...")
stages_mode = "wake-rem-nrem"

features_train_pad, stages_train_pad, _ = prepare_data_keras(
features_train,
stages_train,
stages_mode,
)
print_class_balance(stages_train_pad, stages_mode)

print("‣‣ Defining model...")
model = models.Sequential(
[
layers.Input((None, features_train_pad.shape[2])),
layers.Masking(-1),
layers.BatchNormalization(),
layers.Dense(64),
layers.ReLU(),
layers.Bidirectional(layers.GRU(8, return_sequences=True)),
layers.Bidirectional(layers.GRU(8, return_sequences=True)),
layers.Dense(stages_train_pad.shape[-1], activation="softmax"),
]
)

model.compile(
optimizer="rmsprop",
loss="categorical_crossentropy",
metrics=["accuracy"],
)
model.build()
model.summary()

print("‣‣ Training model...")
model.fit(
features_train_pad,
stages_train_pad,
epochs=25,
)

print("‣‣ Saving classifier...")
save_classifier(
name="wrn-gru-mesa",
model=model,
stages_mode=stages_mode,
feature_extraction_params=feature_extraction_params,
mask_value=-1,
classifiers_dir="./classifiers",
)

print("‣ Starting testing...")
print("‣‣ Loading classifier...")
clf = load_classifier("wrn-gru-mesa", "./classifiers")

print("‣‣ Extracting features...")
records_test = (
list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="55*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="56*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="57*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="58*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="59*",
)
)
+ list(
read_mesa(
offline=False,
activity_source="actigraphy",
records_pattern="6*",
)
)
)

features_test, stages_test, feature_ids = extract_features(
tqdm(records_test),
**clf.feature_extraction_params,
n_jobs=-2,
)

print("‣‣ Evaluating classifier...")
features_test_pad, stages_test_pad, _ = prepare_data_keras(
features_test,
stages_test,
clf.stages_mode,
)
y_pred = clf.model.predict(features_test_pad)
evaluate(stages_test_pad, y_pred, clf.stages_mode)
Loading