Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/MOABB/Commands_run_experiment.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python run_experiments.py --hparams=hparams/MotorImagery/BNCI2014001/EEGNet.yaml --data_folder=./eeg_data --output_folder=./results/test_run --nsbj=9 --nsess=2 --seed=12346 --nruns=1 --train_mode=leave-one-session-out
111 changes: 77 additions & 34 deletions benchmarks/MOABB/dataio/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,82 @@ def to_tensor(epoch):
cached_create_filter = cache(mne.filter.create_filter)


@takes("epoch", "info", "target_sfreq", "fmin", "fmax")
@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax")
def bandpass_resample(epoch, info, target_sfreq, fmin, fmax):
"""Bandpass filter and resample an epoch."""

bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)

# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
# TODO: These long filters result in massive performance degradation... Do we
# want to throw an error instead? This usually happens when fmin is used
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
def bandpass_resample():
@takes("epoch", "info", "target_sfreq", "fmin", "fmax")
@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax")
def _bandpass_resample(epoch, info, target_sfreq, fmin, fmax):
"""Bandpass filter and resample an epoch."""

bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)

# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
# TODO: These long filters result in massive performance degradation... Do we
# want to throw an error instead? This usually happens when fmin is used
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
)

yield mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield target_sfreq

return _bandpass_resample


'''def bandpass_resample(target_sfreq, fmin, fmax):
"""Create a dynamic item that bandpass filters and resamples an epoch."""

@takes("epoch", "info")
@provides("epoch")
def _bandpass_resample(epoch, info):
bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)
breakpoint()
# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
)

filtered = mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield filtered

yield mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield target_sfreq
return _bandpass_resample
'''
33 changes: 32 additions & 1 deletion benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
seed: 1234
__set_torchseed: !apply:torch.manual_seed [!ref <seed>]

#OVERRIDES


# DIRECTORIES
data_folder: !PLACEHOLDER #'/path/to/dataset'. The dataset will be automatically downloaded in this folder
cached_data_folder: !PLACEHOLDER #'path/to/pickled/dataset'
Expand Down Expand Up @@ -39,6 +42,34 @@ C: 22
test_with: 'last' # 'last' or 'best'
test_key: "acc" # Possible opts: "loss", "f1", "auc", "acc"

# DATASET

# Get target subject
#target_subject: !ref <dataset>.subject_list[!ref <target_subject_idx>]#

# Create the subjects list
#subjects: [!ref <target_subject>]

# Create dataset using EpochedEEGDataset
#dataset_class: !new:dataio.datasets.EpochedEEGDataset

#json_path: !apply:os.path.join [!ref <cached_data_folder>, "index.json"]
#save_path: !ref <data_folder>
#dynamic_items:
# - !name:dataio.preprocessing.to_tensor
#output_keys: ["label", "subject", "session", "epoch"]
#preload: True

#from_moabb_datset: !apply: !ref <dataset_class>.from_moabb
# - !ref <dataset>
# - !ref <json_path>
# - !ref <subjects>
# - !ref <save_path>
# - !ref <dynamic_items>
# - !ref <output_keys>
# - !ref <preload>
# - !ref <tmin>
# - !ref <tmax>
# METRICS
f1: !name:sklearn.metrics.f1_score
average: 'macro'
Expand All @@ -52,7 +83,7 @@ metrics:
n_train_examples: 100 # it will be replaced in the train script
# checkpoints to average
avg_models: 10 # @orion_step1: --avg_models~"uniform(1, 15,discrete=True)"
number_of_epochs: 862 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)"
number_of_epochs: 800 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)"
lr: 0.0001 # @orion_step1: --lr~"choices([0.01, 0.005, 0.001, 0.0005, 0.0001])"
# Learning rate scheduling (cyclic learning rate is used here)
max_lr: !ref <lr> # Upper bound of the cycle (max value of the lr)
Expand Down
1 change: 1 addition & 0 deletions benchmarks/MOABB/models/EEGNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def forward(self, x):
x : torch.Tensor (batch, time, EEG channel, channel)
Input to convolve. 4d tensors are expected.
"""
x = x.transpose(1, 2)
x = self.conv_module(x)
x = self.dense_module(x)
return x
Loading
Loading