diff --git a/CHANGELOG.md b/CHANGELOG.md index b92b107c..5ad32cd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## [UNRELEASED] - YYYY-MM-DD ## Added - Add support for activity counts feature ([#262](https://github.com/cbrnr/sleepecg/pull/262) by [Simon Pusterhofer](https://github.com/simon-p-2000)) +- Added support for Scikit-Learn and PyTorch models ([#263](https://github.com/cbrnr/sleepecg/pull/263) by [Simon Pusterhofer](https://github.com/simon-p-2000)) ## [0.5.9] - 2025-02-01 ### Added diff --git a/docs/classification.md b/docs/classification.md index 5ffab9f1..dce30cae 100644 --- a/docs/classification.md +++ b/docs/classification.md @@ -18,6 +18,7 @@ A weaker weighting approach is likely required to find the optimal middle ground ![wrn-gru-mesa confusion matrix](./img/wrn-gru-mesa.svg)     ![wrn-gru-mesa-weighted confusion matrix](./img/wrn-gru-mesa-weighted.svg) +Additional classifiers using [Scikit-Learn](https://scikit-learn.org/stable/) can be trained with the example scripts available in the `examples/classifiers` folder. ## Usage examples The example [`try_ws_gru_mesa.py`](https://github.com/cbrnr/sleepecg/blob/main/examples/try_ws_gru_mesa.py) demonstrates how to use the WAKE–SLEEP classifier `ws-gru-mesa`, a [GRU](https://en.wikipedia.org/wiki/Gated_recurrent_unit)-based classifier bundled with SleepECG which was trained on 1971 nights of the [MESA](https://sleepdata.org/datasets/mesa/) dataset. diff --git a/examples/classifiers/wrn_gru_mesa_actigraphy.py b/examples/classifiers/wrn_gru_mesa_actigraphy.py new file mode 100644 index 00000000..65a2c032 --- /dev/null +++ b/examples/classifiers/wrn_gru_mesa_actigraphy.py @@ -0,0 +1,236 @@ +import warnings + +from tensorflow.keras import layers, models +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_keras, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = False # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + feature_extraction_params = { + "lookback": 120, + "lookforward": 150, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Keras...") + stages_mode = "wake-rem-nrem" + + features_train_pad, stages_train_pad, _ = prepare_data_keras( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = models.Sequential( + [ + layers.Input((None, features_train_pad.shape[2])), + layers.Masking(-1), + layers.BatchNormalization(), + layers.Dense(64), + layers.ReLU(), + layers.Bidirectional(layers.GRU(8, return_sequences=True)), + layers.Bidirectional(layers.GRU(8, return_sequences=True)), + layers.Dense(stages_train_pad.shape[-1], activation="softmax"), + ] + ) + + model.compile( + optimizer="rmsprop", + loss="categorical_crossentropy", + metrics=["accuracy"], + ) + model.build() + model.summary() + + print("‣‣ Training model...") + model.fit( + features_train_pad, + stages_train_pad, + epochs=25, + ) + + print("‣‣ Saving classifier...") + save_classifier( + name="wrn-gru-mesa", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-gru-mesa", "./classifiers") + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-2, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, _ = prepare_data_keras( + features_test, + stages_test, + clf.stages_mode, +) +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/wrn_gru_mesa_weighted_actigraphy.py b/examples/classifiers/wrn_gru_mesa_weighted_actigraphy.py new file mode 100644 index 00000000..48eae4c2 --- /dev/null +++ b/examples/classifiers/wrn_gru_mesa_weighted_actigraphy.py @@ -0,0 +1,237 @@ +import warnings + +from tensorflow.keras import layers, models +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_keras, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = False # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + feature_extraction_params = { + "lookback": 120, + "lookforward": 150, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Keras...") + stages_mode = "wake-rem-nrem" + + features_train_pad, stages_train_pad, sample_weight = prepare_data_keras( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = models.Sequential( + [ + layers.Input((None, features_train_pad.shape[2])), + layers.Masking(-1), + layers.BatchNormalization(), + layers.Dense(64), + layers.ReLU(), + layers.Bidirectional(layers.GRU(8, return_sequences=True)), + layers.Bidirectional(layers.GRU(8, return_sequences=True)), + layers.Dense(stages_train_pad.shape[-1], activation="softmax"), + ] + ) + + model.compile( + optimizer="rmsprop", + loss="categorical_crossentropy", + metrics=["accuracy"], + ) + model.build() + model.summary() + + print("‣‣ Training model...") + model.fit( + features_train_pad, + stages_train_pad, + epochs=25, + sample_weight=sample_weight, + ) + + print("‣‣ Saving classifier...") + save_classifier( + name="wrn-gru-mesa-weighted", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-gru-mesa-weighted", "./classifiers") + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-1, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, _ = prepare_data_keras( + features_test, + stages_test, + clf.stages_mode, +) +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/wrn_sklearn_mesa.py b/examples/classifiers/wrn_sklearn_mesa.py new file mode 100644 index 00000000..401c4e98 --- /dev/null +++ b/examples/classifiers/wrn_sklearn_mesa.py @@ -0,0 +1,112 @@ +import warnings + +import sklearn +from sklearn.impute import SimpleImputer +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_sklearn, + print_class_balance, + read_mesa, + read_shhs, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = True # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records = list(read_mesa(offline=False)) + + feature_extraction_params = { + "lookback": 120, + "lookforward": 150, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Sklearn...") + stages_mode = "wake-rem-nrem" + + features_train_pad, stages_train_pad, record_ids = prepare_data_sklearn( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + pipe = make_pipeline( + SimpleImputer(), + StandardScaler(), + sklearn.discriminant_analysis.LinearDiscriminantAnalysis(), + verbose=False, + ) + + print("‣‣ Training model...") + + pipe.fit( + X=features_train_pad, + y=stages_train_pad, + ) + + print("‣‣ Saving model...") + save_classifier( + name="wrn-sklearn-mesa", + model=pipe, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-sklearn-mesa", "./classifiers") + +print("‣‣ Extracting features...") +shhs = list(read_shhs(offline=False)) + +features_test, stages_test, feature_ids = extract_features( + tqdm(shhs), + **clf.feature_extraction_params, + n_jobs=-2, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, record_ids_test = prepare_data_sklearn( + features_test, + stages_test, + clf.stages_mode, +) + +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/wrn_sklearn_mesa_actigraphy.py b/examples/classifiers/wrn_sklearn_mesa_actigraphy.py new file mode 100644 index 00000000..f3b93665 --- /dev/null +++ b/examples/classifiers/wrn_sklearn_mesa_actigraphy.py @@ -0,0 +1,227 @@ +import warnings + +import sklearn +from sklearn.impute import SimpleImputer +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_sklearn, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = False # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + + feature_extraction_params = { + "lookback": 240, + "lookforward": 270, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + features_train, stages_train, feature_ids_train = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Sklearn...") + stages_mode = "wake-rem-nrem" + + features_train_pad, stages_train_pad, record_ids = prepare_data_sklearn( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + pipe = make_pipeline( + SimpleImputer(), + StandardScaler(), + sklearn.ensemble.RandomForestClassifier(), + verbose=False, + ) + + print("‣‣ Training model...") + + pipe.fit( + X=features_train_pad, + y=stages_train_pad, + ) + + print("‣‣ Saving model...") + save_classifier( + name="wrn-sklearn-mesa-actigraphy", + model=pipe, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-sklearn-mesa-actigraphy", "./classifiers") +stages_mode = clf.stages_mode + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids_test = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-1, +) + +features_test_pad, stages_test_pad, record_ids_test = prepare_data_sklearn( + features_test, + stages_test, + stages_mode, +) + +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, stages_mode) diff --git a/examples/classifiers/wrn_torch_mesa.py b/examples/classifiers/wrn_torch_mesa.py new file mode 100644 index 00000000..d2331070 --- /dev/null +++ b/examples/classifiers/wrn_torch_mesa.py @@ -0,0 +1,193 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_pytorch, + print_class_balance, + read_mesa, + read_shhs, + save_classifier, + set_nsrr_token, +) + + +class Torch_mesa(nn.Module): + """ + + Neural network intended to mimic the existing keras models. + + The model consists of: + A normalization layer + A fully connected linear layer (fc) + Two gru layers (gru1, gru2) + Another fully connected linear layer as the output layer (output). + + In addition, features with the specified mask_value are omitted. During training, the + cross-entropy loss is measured and the model is trained using a RMS propagation + optimizer. + """ + + def __init__(self, input_dim, hidden_dim, output_dim, mask_value=-1.0): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.fc = nn.Linear(input_dim, 64) + self.gru1 = nn.GRU(64, hidden_dim, batch_first=True, bidirectional=True) + self.gru2 = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True, bidirectional=True) + self.output = nn.Linear(hidden_dim * 2, output_dim) + self.mask_value = mask_value + + def forward(self, x): + """ + Forward pass of the model. + + Parameters + ---------- + x: torch.Tensor + Input tensor with shape (batch_size, seq_len, input_dim) + + Returns + ------- + torch.Tensor + Output tensor after forward pass with shape (batch_size, seq_len, output_dim) + """ + x = torch.where(x == self.mask_value, torch.zeros_like(x), x) + + x = self.layer_norm(x) + + x = F.relu(self.fc(x)) + + x, _ = self.gru1(x) + x, _ = self.gru2(x) + + x = self.output(x) + return x + + +set_nsrr_token("your-token-here") + +TRAIN = True # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records = list(read_mesa(offline=False, data_dir=r"D:\SleepData")) + + feature_extraction_params = { + "lookback": 240, + "lookforward": 270, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Pytorch...") + stages_mode = "wake-rem-nrem" + features_train_pad, stages_train_pad = prepare_data_pytorch( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = Torch_mesa( + input_dim=features_train_pad.shape[2], + hidden_dim=8, + output_dim=stages_train_pad.shape[-1], + mask_value=-1.0, + ) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) + + print("‣‣ Training model...") + train_dataset = TensorDataset(features_train_pad, stages_train_pad) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + + model.train() + for epoch in range(25): + epoch_loss = 0.0 + for batch_x, batch_y in train_loader: + optimizer.zero_grad() + outputs = model(batch_x) + + N, T, C = outputs.size() + batch_y = batch_y.argmax(dim=-1) + loss = criterion(outputs.view(N * T, C), batch_y.view(N * T)) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + print(f"Epoch {epoch + 1}/25, Loss: {epoch_loss / len(train_loader):.4f}") + + print("‣‣ Saving classifier...") + save_classifier( + name="wrn-pytorch-mesa", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-pytorch-mesa", "./classifiers") +model = clf.model + +print("‣‣ Extracting features...") +shhs = list(read_shhs(offline=False)) + +features_test, stages_test, feature_ids = extract_features( + tqdm(shhs), + **clf.feature_extraction_params, + n_jobs=-1, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, sample_weight = prepare_data_pytorch( + features_test, + stages_test, + clf.stages_mode, +) + +model.eval() + +test_dataset = TensorDataset(features_test_pad, stages_test_pad) +test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True) +y_pred = [] + +with torch.no_grad(): + for batch_x, batch_y in test_loader: + test_output = model(batch_x) + predictions = torch.argmax(test_output, dim=-1) + y_pred.append(predictions) + +y_pred = torch.cat(y_pred, dim=0) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/wrn_torch_mesa_actigraphy.py b/examples/classifiers/wrn_torch_mesa_actigraphy.py new file mode 100644 index 00000000..28cbacd2 --- /dev/null +++ b/examples/classifiers/wrn_torch_mesa_actigraphy.py @@ -0,0 +1,308 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_pytorch, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + + +class Torch_mesa(nn.Module): + """ + + Neural network intended to mimic the existing keras models. + + The model consists of: + A normalization layer + A fully connected linear layer (fc) + Two gru layers (gru1, gru2) + Another fully connected linear layer as the output layer (output). + + In addition, features with the specified mask_value are omitted. During training, the + cross-entropy loss is measured and the model is trained using a RMS propagation + optimizer. + """ + + def __init__(self, input_dim, hidden_dim, output_dim, mask_value=-1.0): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.fc = nn.Linear(input_dim, 64) + self.gru1 = nn.GRU(64, hidden_dim, batch_first=True, bidirectional=True) + self.gru2 = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True, bidirectional=True) + self.output = nn.Linear(hidden_dim * 2, output_dim) + self.mask_value = mask_value + self.dropout = nn.Dropout(0.3) + + def forward(self, x): + """ + Forward pass of the model. + + Parameters + ---------- + x: torch.Tensor + Input tensor with shape (batch_size, seq_len, input_dim) + + Returns + ------- + torch.Tensor + Output tensor after forward pass with shape (batch_size, seq_len, output_dim) + """ + x = torch.where(x == self.mask_value, torch.zeros_like(x), x) + + x = self.layer_norm(x) + + x = F.relu(self.fc(x)) + + x, _ = self.gru1(x) + x, _ = self.gru2(x) + + x = self.output(x) + return x + + +set_nsrr_token("your-token-here") + +TRAIN = True # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + feature_extraction_params = { + "lookback": 240, + "lookforward": 270, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Pytorch...") + stages_mode = "wake-rem-nrem" + features_train_pad, stages_train_pad, sample_weights = prepare_data_pytorch( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = Torch_mesa( + input_dim=features_train_pad.shape[2], + hidden_dim=8, + output_dim=stages_train_pad.shape[-1], + mask_value=-1.0, + ) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) + + print("‣‣ Training model...") + train_dataset = TensorDataset(features_train_pad, stages_train_pad) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + + model.train() + for epoch in range(50): + epoch_loss = 0.0 + for batch_x, batch_y in train_loader: + optimizer.zero_grad() + outputs = model(batch_x) + + N, T, C = outputs.size() + batch_y = batch_y.argmax(dim=-1) + loss = criterion(outputs.view(N * T, C), batch_y.view(N * T)) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + print(f"Epoch {epoch + 1}/50, Loss: {epoch_loss / len(train_loader):.4f}") + + print("‣‣ Saving classifier...") + save_classifier( + name="wrn-pytorch-mesa-actigraphy", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-pytorch-mesa-actigraphy", "./classifiers") +model = clf.model + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-1, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad = prepare_data_pytorch( + features_test, + stages_test, + clf.stages_mode, +) + +model.eval() + +test_dataset = TensorDataset(features_test_pad, stages_test_pad) +test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True) +y_pred = [] + +with torch.no_grad(): + for batch_x, batch_y in test_loader: + test_output = model(batch_x) + predictions = torch.argmax(test_output, dim=-1) + y_pred.append(predictions) + +y_pred = torch.cat(y_pred, dim=0) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/ws_gru_mesa_actigraphy.py b/examples/classifiers/ws_gru_mesa_actigraphy.py new file mode 100644 index 00000000..8ba5d825 --- /dev/null +++ b/examples/classifiers/ws_gru_mesa_actigraphy.py @@ -0,0 +1,236 @@ +import warnings + +from tensorflow.keras import layers, models +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_keras, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = False # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + feature_extraction_params = { + "lookback": 120, + "lookforward": 150, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Keras...") + stages_mode = "wake-sleep" + + features_train_pad, stages_train_pad, _ = prepare_data_keras( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = models.Sequential( + [ + layers.Input((None, features_train_pad.shape[2])), + layers.Masking(-1), + layers.BatchNormalization(), + layers.Dense(64), + layers.ReLU(), + layers.Bidirectional(layers.GRU(8, return_sequences=True)), + layers.Bidirectional(layers.GRU(8, return_sequences=True)), + layers.Dense(stages_train_pad.shape[-1], activation="softmax"), + ] + ) + + model.compile( + optimizer="rmsprop", + loss="categorical_crossentropy", + metrics=["accuracy"], + ) + model.build() + model.summary() + + print("‣‣ Training model...") + model.fit( + features_train_pad, + stages_train_pad, + epochs=5, + ) + + print("‣‣ Saving classifier...") + save_classifier( + name="wrn-gru-mesa", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("wrn-gru-mesa", "./classifiers") + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-2, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, _ = prepare_data_keras( + features_test, + stages_test, + clf.stages_mode, +) +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/ws_sklearn_mesa.py b/examples/classifiers/ws_sklearn_mesa.py new file mode 100644 index 00000000..ad58f4c3 --- /dev/null +++ b/examples/classifiers/ws_sklearn_mesa.py @@ -0,0 +1,112 @@ +import warnings + +import sklearn +from sklearn.impute import SimpleImputer +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_sklearn, + print_class_balance, + read_mesa, + read_shhs, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = False # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records = list(read_mesa(offline=False)) + + feature_extraction_params = { + "lookback": 120, + "lookforward": 150, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Sklearn...") + stages_mode = "wake-sleep" + + features_train_pad, stages_train_pad, record_ids = prepare_data_sklearn( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + pipe = make_pipeline( + SimpleImputer(), + StandardScaler(), + sklearn.ensemble.RandomForestClassifier(), + verbose=False, + ) + + print("‣‣ Training model...") + + pipe.fit( + X=features_train_pad, + y=stages_train_pad, + ) + + print("‣‣ Saving model...") + save_classifier( + name="ws-sklearn-mesa", + model=pipe, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("ws-sklearn-mesa", "./classifiers") + +print("‣‣ Extracting features...") +shhs = list(read_shhs(offline=False)) + +features_test, stages_test, feature_ids = extract_features( + tqdm(shhs), + **clf.feature_extraction_params, + n_jobs=-2, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, record_ids_test = prepare_data_sklearn( + features_test, + stages_test, + clf.stages_mode, +) + +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/ws_sklearn_mesa_actigraphy.py b/examples/classifiers/ws_sklearn_mesa_actigraphy.py new file mode 100644 index 00000000..f3da9fe7 --- /dev/null +++ b/examples/classifiers/ws_sklearn_mesa_actigraphy.py @@ -0,0 +1,226 @@ +import warnings + +import sklearn +from sklearn.impute import SimpleImputer +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_sklearn, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + +set_nsrr_token("your-token-here") + +TRAIN = False # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + + feature_extraction_params = { + "lookback": 240, + "lookforward": 270, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + features_train, stages_train, feature_ids_train = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Sklearn...") + stages_mode = "wake-sleep" + + features_train_pad, stages_train_pad, record_ids = prepare_data_sklearn( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + pipe = make_pipeline( + SimpleImputer(), + StandardScaler(), + sklearn.ensemble.RandomForestClassifier(), + verbose=False, + ) + + print("‣‣ Training model...") + + pipe.fit( + X=features_train_pad, + y=stages_train_pad, + ) + + print("‣‣ Saving model...") + save_classifier( + name="ws-sklearn-mesa-actigraphy", + model=pipe, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("ws-sklearn-mesa-actigraphy", "./classifiers") +stages_mode = clf.stages_mode + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids_test = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-1, +) + +features_test_pad, stages_test_pad, record_ids_test = prepare_data_sklearn( + features_test, + stages_test, + stages_mode, +) + +y_pred = clf.model.predict(features_test_pad) +evaluate(stages_test_pad, y_pred, stages_mode) diff --git a/examples/classifiers/ws_torch_mesa.py b/examples/classifiers/ws_torch_mesa.py new file mode 100644 index 00000000..3e503ff0 --- /dev/null +++ b/examples/classifiers/ws_torch_mesa.py @@ -0,0 +1,193 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_pytorch, + print_class_balance, + read_mesa, + read_shhs, + save_classifier, + set_nsrr_token, +) + + +class Torch_mesa(nn.Module): + """ + + Neural network intended to mimic the existing keras models. + + The model consists of: + A normalization layer + A fully connected linear layer (fc) + Two gru layers (gru1, gru2) + Another fully connected linear layer as the output layer (output). + In addition, features with the specified mask_value are omitted. During training, the + cross-entropy loss is measured and the model is trained using a RMS propagation + optimizer. + """ + + def __init__(self, input_dim, hidden_dim, output_dim, mask_value=-1.0): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.fc = nn.Linear(input_dim, 64) + self.gru1 = nn.GRU(64, hidden_dim, batch_first=True, bidirectional=True) + self.gru2 = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True, bidirectional=True) + self.output = nn.Linear(hidden_dim * 2, output_dim) + self.mask_value = mask_value + + def forward(self, x): + """ + Forward pass of the model. + + Parameters + ---------- + x: torch.Tensor + Input tensor with shape (batch_size, seq_len, input_dim + + Returns + ------- + torch.Tensor + Output tensor after forward pass with shape (batch_size, seq_len, output_dim) + """ + x = torch.where(x == self.mask_value, torch.zeros_like(x), x) + + x = self.layer_norm(x) + + x = F.relu(self.fc(x)) + + x, _ = self.gru1(x) + x, _ = self.gru2(x) + + x = self.output(x) + + return x + + +set_nsrr_token("your-token-here") + +TRAIN = True # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records = list(read_mesa(offline=False, data_dir=r"D:\SleepData")) + + feature_extraction_params = { + "lookback": 240, + "lookforward": 270, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Pytorch...") + stages_mode = "wake-sleep" + features_train_pad, stages_train_pad = prepare_data_pytorch( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = Torch_mesa( + input_dim=features_train_pad.shape[2], + hidden_dim=8, + output_dim=stages_train_pad.shape[-1], + mask_value=-1.0, + ) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) + + print("‣‣ Training model...") + train_dataset = TensorDataset(features_train_pad, stages_train_pad) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + + model.train() + for epoch in range(25): + epoch_loss = 0.0 + for batch_x, batch_y in train_loader: + optimizer.zero_grad() + outputs = model(batch_x) + + N, T, C = outputs.size() + batch_y = batch_y.argmax(dim=-1) + loss = criterion(outputs.view(N * T, C), batch_y.view(N * T)) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + print(f"Epoch {epoch + 1}/25, Loss: {epoch_loss / len(train_loader):.4f}") + + print("‣‣ Saving classifier...") + save_classifier( + name="ws-pytorch-mesa", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("ws-pytorch-mesa", "./classifiers") +model = clf.model + +print("‣‣ Extracting features...") +shhs = list(read_shhs(offline=False)) + +features_test, stages_test, feature_ids = extract_features( + tqdm(shhs), + **clf.feature_extraction_params, + n_jobs=-1, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, sample_weight = prepare_data_pytorch( + features_test, + stages_test, + clf.stages_mode, +) + +model.eval() + +test_dataset = TensorDataset(features_test_pad, stages_test_pad) +test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True) +y_pred = [] + +with torch.no_grad(): + for batch_x, batch_y in test_loader: + test_output = model(batch_x) + predictions = torch.argmax(test_output, dim=-1) + y_pred.append(predictions) + +y_pred = torch.cat(y_pred, dim=0) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/examples/classifiers/ws_torch_mesa_actigraphy.py b/examples/classifiers/ws_torch_mesa_actigraphy.py new file mode 100644 index 00000000..8e51fea5 --- /dev/null +++ b/examples/classifiers/ws_torch_mesa_actigraphy.py @@ -0,0 +1,306 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm + +from sleepecg import ( + evaluate, + extract_features, + load_classifier, + prepare_data_pytorch, + print_class_balance, + read_mesa, + save_classifier, + set_nsrr_token, +) + + +class Torch_mesa(nn.Module): + """ + + Neural network intended to mimic the existing keras models. + + The model consists of: + A normalization layer + A fully connected linear layer (fc) + Two gru layers (gru1, gru2) + Another fully connected linear layer as the output layer (output). + In addition, features with the specified mask_value are omitted. During training, the + cross-entropy loss is measured and the model is trained using a RMS propagation + optimizer. + """ + + def __init__(self, input_dim, hidden_dim, output_dim, mask_value=-1.0): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.fc = nn.Linear(input_dim, 64) + self.gru1 = nn.GRU(64, hidden_dim, batch_first=True, bidirectional=True) + self.gru2 = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True, bidirectional=True) + self.output = nn.Linear(hidden_dim * 2, output_dim) + self.mask_value = mask_value + + def forward(self, x): + """ + Forward pass of the model. + + Parameters + ---------- + x: torch.Tensor + Input tensor with shape (batch_size, seq_len, input_dim + + Returns + ------- + torch.Tensor + Output tensor after forward pass with shape (batch_size, seq_len, output_dim) + """ + x = torch.where(x == self.mask_value, torch.zeros_like(x), x) + + x = self.layer_norm(x) + + x = F.relu(self.fc(x)) + + x, _ = self.gru1(x) + x, _ = self.gru2(x) + + x = self.output(x) + return x + + +set_nsrr_token("your-token-here") + +TRAIN = True # set to False to skip training and load classifier from disk + +# silence warnings (which might pop up during feature extraction) +warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="HR analysis window too short" +) + +if TRAIN: + print("‣ Starting training...") + print("‣‣ Extracting features...") + records_train = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="0*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="1*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="2*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="3*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="4*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="50*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="51*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="52*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="53*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="54*", + ) + ) + ) + + feature_extraction_params = { + "lookback": 240, + "lookforward": 270, + "feature_selection": [ + "hrv-time", + "hrv-frequency", + "recording_start_time", + "age", + "gender", + "activity_counts", + ], + "min_rri": 0.3, + "max_rri": 2, + "max_nans": 0.5, + } + + features_train, stages_train, feature_ids = extract_features( + tqdm(records_train), + **feature_extraction_params, + n_jobs=-1, + ) + + print("‣‣ Preparing data for Pytorch...") + stages_mode = "wake-sleep" + features_train_pad, stages_train_pad = prepare_data_pytorch( + features_train, + stages_train, + stages_mode, + ) + print_class_balance(stages_train_pad, stages_mode) + + print("‣‣ Defining model...") + model = Torch_mesa( + input_dim=features_train_pad.shape[2], + hidden_dim=8, + output_dim=stages_train_pad.shape[-1], + mask_value=-1.0, + ) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3) + + print("‣‣ Training model...") + train_dataset = TensorDataset(features_train_pad, stages_train_pad) + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) + + model.train() + for epoch in range(25): + epoch_loss = 0.0 + for batch_x, batch_y in train_loader: + optimizer.zero_grad() + outputs = model(batch_x) + + N, T, C = outputs.size() + batch_y = batch_y.argmax(dim=-1) + loss = criterion(outputs.view(N * T, C), batch_y.view(N * T)) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + print(f"Epoch {epoch + 1}/25, Loss: {epoch_loss / len(train_loader):.4f}") + + print("‣‣ Saving classifier...") + save_classifier( + name="ws-pytorch-mesa-actigraphy", + model=model, + stages_mode=stages_mode, + feature_extraction_params=feature_extraction_params, + mask_value=-1, + classifiers_dir="./classifiers", + ) + +print("‣ Starting testing...") +print("‣‣ Loading classifier...") +clf = load_classifier("ws-pytorch-mesa-actigraphy", "./classifiers") +model = clf.model + +print("‣‣ Extracting features...") +records_test = ( + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="55*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="56*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="57*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="58*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="59*", + ) + ) + + list( + read_mesa( + offline=False, + activity_source="actigraphy", + records_pattern="6*", + ) + ) +) + +features_test, stages_test, feature_ids = extract_features( + tqdm(records_test), + **clf.feature_extraction_params, + n_jobs=-1, +) + +print("‣‣ Evaluating classifier...") +features_test_pad, stages_test_pad, sample_weight = prepare_data_pytorch( + features_test, + stages_test, + clf.stages_mode, +) + +model.eval() + +test_dataset = TensorDataset(features_test_pad, stages_test_pad) +test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True) +y_pred = [] + +with torch.no_grad(): + for batch_x, batch_y in test_loader: + test_output = model(batch_x) + predictions = torch.argmax(test_output, dim=-1) + y_pred.append(predictions) + +y_pred = torch.cat(y_pred, dim=0) +evaluate(stages_test_pad, y_pred, clf.stages_mode) diff --git a/pyproject.toml b/pyproject.toml index cce90a67..5f7171a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ full = [ # complete package functionality "matplotlib >= 3.9.2", "numba >= 0.61.0", "tensorflow >= 2.17.0; python_version < '3.13'", + "torch >= 2.6.0", "wfdb >= 4.2.0", ] diff --git a/src/sleepecg/__init__.py b/src/sleepecg/__init__.py index ce45fb42..9c642df8 100644 --- a/src/sleepecg/__init__.py +++ b/src/sleepecg/__init__.py @@ -6,6 +6,8 @@ list_classifiers, load_classifier, prepare_data_keras, + prepare_data_pytorch, + prepare_data_sklearn, print_class_balance, save_classifier, stage, diff --git a/src/sleepecg/classification.py b/src/sleepecg/classification.py index 8048e297..35fc3a94 100644 --- a/src/sleepecg/classification.py +++ b/src/sleepecg/classification.py @@ -6,13 +6,17 @@ from __future__ import annotations +import pickle import shutil from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from zipfile import ZipFile +if TYPE_CHECKING: + import torch + import numpy as np import yaml @@ -22,6 +26,51 @@ from sleepecg.utils import _STAGE_NAMES, _merge_sleep_stages +def prepare_data_sklearn( + features: list[np.ndarray], + stages: list[np.ndarray], + stages_mode: str, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Prepare sleep records for a sklearn model. + + The following steps are performed: + + - Merge sleep stages in `stages` according to `stage_mode`. + - Set feature values of infinity to `-1`. + - Mark features corresponding to `SleepStage.UNDEFINED` as invalid. + + Parameters + ---------- + features : list[np.ndarray] + Each 2D array in this list is a feature matrix of shape `(n_samples, n_features)` + corresponding to a single record as returned by `extract_features()`. + stages : list[np.ndarray] + Each 1D array in this list contains the sleep stages of a single record as returned + by `extract_features()`. + stages_mode : str + Identifier of the grouping mode. Can be any of `'wake-sleep'`, `'wake-rem-nrem'`, + `'wake-rem-light-n3'`, `'wake-rem-n1-n2-n3'`. + + Returns + ------- + features_stacked : np.ndarray + A 2D array of shape `(n_samples, n_features)`. + stages_stacked : np.ndarray + A 1D array containing the annotated sleep stage for each sample. The sleep stages + are merged based on the `stages_mode` parameter. + record_ids : np.ndarray + A 1D index array for each valid sample that is returned. + """ + record_ids = np.hstack([i * np.ones(len(X)) for i, X in enumerate(features)]) + features_stacked = np.vstack(features) + stages_stacked = np.hstack(_merge_sleep_stages(stages, stages_mode)) + features_stacked[np.isinf(features_stacked)] = -1 + valid = stages_stacked != SleepStage.UNDEFINED + + return features_stacked[valid], stages_stacked[valid], record_ids[valid] + + def prepare_data_keras( features: list[np.ndarray], stages: list[np.ndarray], @@ -90,6 +139,73 @@ def prepare_data_keras( return features_padded, stages_padded_onehot, sample_weight +def prepare_data_pytorch( + features: list[np.ndarray], + stages: list[np.ndarray], + stages_mode: str, + mask_value: int = -1, +) -> tuple[torch.float32, torch.int64]: + """ + Mask and pad data and calculate sample weights for a PyTorch model. + + The following steps are performed: + + - Merge sleep stages in `stages` according to `stage_mode`. + - Set features corresponding to `SleepStage.UNDEFINED` to `mask_value`. + - Replace `np.nan` and `np.inf` in `features` with `mask_value`. + - Pad to a common length, where `mask_value` is used for `features` and + `SleepStage.UNDEFINED` (i.e `0`) is used for stages. + - One-hot encode stages. + + Parameters + ---------- + features : list[np.ndarray] + Each 2D array in this list is a feature matrix of shape `(n_samples, n_features)` + corresponding to a single record as returned by `extract_features()`. + stages : list[np.ndarray] + Each 1D array in this list contains the sleep stages of a single record as returned + by `extract_features()`. + stages_mode : str + Identifier of the grouping mode. Can be any of `'wake-sleep'`, `'wake-rem-nrem'`, + `'wake-rem-light-n3'`, `'wake-rem-n1-n2-n3'`. + mask_value : int, optional + Value used to pad features and replace `np.nan` and `np.inf`, by default `-1`. + Remember to pass the same value to `layers.Masking` in your model. + + Returns + ------- + features_padded : torch.float32 + A PyTorch tensor of shape `(n_records, max_n_samples, n_features)`, + where `n_records` is the length of `features`/`stages` and `max_n_samples` is the + maximum number of rows of all feature matrices in `features`. + stages_padded_onehot : torch.int64 + A PyTorch tensor of shape `(n_records, max_n_samples, n_classes+1)`, where + `n_classes` is the number of classes remaining after merging sleep stages (excluding + `SleepStage.UNDEFINED`). + """ + import torch + import torch.nn.functional as F + from torch.nn.utils.rnn import pad_sequence + + stages_merged = _merge_sleep_stages(stages, stages_mode) + stages_merged_tensor = [ + torch.tensor(stage, dtype=torch.long) for stage in stages_merged + ] + stages_padded = pad_sequence( + stages_merged_tensor, padding_value=SleepStage.UNDEFINED, batch_first=True + ) + stages_padded_onehot = F.one_hot(stages_padded) + + features_tensor = [torch.tensor(feature, dtype=torch.float32) for feature in features] + features_padded = pad_sequence( + features_tensor, padding_value=mask_value, batch_first=True + ) + features_padded[stages_padded == SleepStage.UNDEFINED, :] = mask_value + features_padded[~torch.isfinite(features_padded)] = mask_value + + return features_padded, stages_padded_onehot + + def print_class_balance(stages: np.ndarray, stages_mode: str | None = None) -> None: """ Print the number of samples and percentages of each class in `stages`. @@ -184,6 +300,13 @@ def save_classifier( if model_type == "keras": model.save(f"{tmpdir}/classifier.keras") + elif model_type == "sklearn": + with open(f"{tmpdir}/classifier.pkl", "wb") as classifier_file: + pickle.dump(model, classifier_file) + elif "torch" in str(type(model)).lower(): + import torch + + torch.save(model, f"{tmpdir}/classifier.pth") else: raise ValueError(f"Saving model of type {type(model)} is not supported") @@ -307,7 +430,13 @@ def load_classifier( finally: os.environ.clear() os.environ.update(environ_orig) + elif classifier_info["model_type"] == "sklearn": + with open(f"{tmpdir}/classifier.pkl", "rb") as classifier_file: + classifier = pickle.load(classifier_file) + elif classifier_info["model_type"] == "__main__": + import torch + classifier = torch.load(f"{tmpdir}/classifier.pth", weights_only=False) else: raise ValueError( f"Loading model of type {classifier_info['model_type']} is not supported"